dotviz visulation for cells
This commit is contained in:
		
							
								
								
									
										81
									
								
								visualiser.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								visualiser.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,81 @@ | ||||
| import re | ||||
| from graphviz import Digraph | ||||
| import pandas as pd | ||||
| import time | ||||
| import argparse | ||||
|  | ||||
| parser = argparse.ArgumentParser(description='Fast cell visualisation') | ||||
| parser.add_argument('--arch', default=1, type=int) | ||||
| parser.add_argument('--save', action='store_true') | ||||
| args = parser.parse_args() | ||||
|  | ||||
| def set_none(bit): | ||||
|     print(bit) | ||||
|     tmp = bit.split('~') | ||||
|     tmp[0] = 'none' | ||||
|     print('~'.join(tmp)) | ||||
|     return '~'.join(tmp) | ||||
|  | ||||
| def remove_pointless_ops(archstr): | ||||
|     old = None | ||||
|     new = archstr | ||||
|     while old != new: | ||||
|         old = new | ||||
|         bits = old.strip('|').split('|') | ||||
|         if 'none~' in bits[0]: # node 1 has no connections to it | ||||
|             bits[3] =  set_none(bits[3]) # node 1 -> 2 now none | ||||
|             bits[6] =  set_none(bits[6]) # node 1 -> 3 now none | ||||
|         if 'none~' in bits[2] and 'none~' in bits[3]: # node 2 has no connections to it | ||||
|             bits[7] =  set_none(bits[7]) # node 2 -> 3 now none | ||||
|         if 'none~' in bits[7]: # doesn't matter what comes through node 2 | ||||
|             bits[2] =  set_none(bits[2]) # node 0 -> 2 now none | ||||
|             bits[3] =  set_none(bits[3]) # node 1 -> 2 now none | ||||
|         if 'none~' in bits[6] and 'none~' in bits[7]: # doesn't matter what comes through node 1 | ||||
|             bits[0] =  set_none(bits[0]) # node 0 -> 1 now none | ||||
|         new = '|'.join(bits) | ||||
|     print(new) | ||||
|     return new | ||||
|  | ||||
|  | ||||
| df = pd.read_pickle('results/arch_score_acc.pd') | ||||
|  | ||||
| nodestr = df.iloc[args.arch]['cellstr'] | ||||
| nodestr = nodestr[1:-1] # remove leading and trailing bars | | ||||
|  | ||||
| nodestr = remove_pointless_ops(nodestr) | ||||
| nodes = nodestr.split("|+|") | ||||
|  | ||||
| dot = Digraph( | ||||
|   format='pdf', | ||||
|   edge_attr=dict(fontsize='12'), | ||||
|   node_attr=dict(fixedsize='true',shape="circle", height='0.5', width='0.5'), | ||||
|   engine='dot') | ||||
|  | ||||
| dot.body.extend(['rankdir=LR']) | ||||
|  | ||||
| OPS = ['conv_3x3','avg_pool_3x3','skip_connect','conv_1x1','none'] | ||||
|  | ||||
| dot.node('0', 'in') | ||||
|  | ||||
| ## ops are separated by bars (|) so | ||||
| for i, node in enumerate(nodes): | ||||
|  | ||||
|     # if node 3 then label as output | ||||
|     if (i+1) == 3: | ||||
|         dot.node(str(i+1), 'out') | ||||
|     else: | ||||
|         dot.node(str(i+1)) | ||||
|  | ||||
|     for op_str in node.split('|'): | ||||
|         op_name = [o for o in OPS if o in op_str][0] | ||||
|         if op_name == 'none': | ||||
|             break | ||||
|         connect = re.findall('~[0-9]', op_str)[0] | ||||
|         connect = connect[1:] | ||||
|         dot.edge(connect,str(i+1), label=op_name) | ||||
|  | ||||
| dot.render( view=True) | ||||
|  | ||||
|  | ||||
| if args.save: | ||||
|     dot.render(f'outputs/{args.arch}.gv') | ||||
		Reference in New Issue
	
	Block a user