import sys import genotypes import numpy as np from graphviz import Digraph supernet_dict = { 0: ('c_{k-2}', '0'), 1: ('c_{k-1}', '0'), 2: ('c_{k-2}', '1'), 3: ('c_{k-1}', '1'), 4: ('0', '1'), 5: ('c_{k-2}', '2'), 6: ('c_{k-1}', '2'), 7: ('0', '2'), 8: ('1', '2'), 9: ('c_{k-2}', '3'), 10: ('c_{k-1}', '3'), 11: ('0', '3'), 12: ('1', '3'), 13: ('2', '3'), } steps = 4 def plot_space(primitives, filename): g = Digraph( format='pdf', edge_attr=dict(fontsize='20', fontname="times"), node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"), engine='dot') g.body.extend(['rankdir=LR']) g.body.extend(['ratio=50.0']) g.node("c_{k-2}", fillcolor='darkseagreen2') g.node("c_{k-1}", fillcolor='darkseagreen2') steps = 4 for i in range(steps): g.node(str(i), fillcolor='lightblue') n = 2 start = 0 nodes_indx = ["c_{k-2}", "c_{k-1}"] for i in range(steps): end = start + n p = primitives[start:end] v = str(i) for node, prim in zip(nodes_indx, p): u = node for op in prim: g.edge(u, v, label=op, fillcolor="gray") start = end n += 1 nodes_indx.append(v) g.node("c_{k}", fillcolor='palegoldenrod') for i in range(steps): g.edge(str(i), "c_{k}", fillcolor="gray") g.render(filename, view=False) def plot(genotype, filename): g = Digraph( format='pdf', edge_attr=dict(fontsize='100', fontname="times"), node_attr=dict(style='filled', shape='rect', align='center', fontsize='100', height='0.5', width='0.5', penwidth='2', fontname="times"), engine='dot') g.body.extend(['rankdir=LR']) g.body.extend(['ratio=0.3']) g.node("c_{k-2}", fillcolor='darkseagreen2') g.node("c_{k-1}", fillcolor='darkseagreen2') num_edges = len(genotype) for i in range(steps): g.node(str(i), fillcolor='lightblue') for eid in range(num_edges): op = genotype[eid] u, v = supernet_dict[eid] if op != 'skip_connect': g.edge(u, v, label=op, fillcolor="gray", color='red', fontcolor='red') else: g.edge(u, v, label=op, fillcolor="gray") g.node("c_{k}", fillcolor='palegoldenrod') for i in range(steps): g.edge(str(i), "c_{k}", fillcolor="gray") g.render(filename, view=False) # def plot(genotype, filename): # g = Digraph( # format='pdf', # edge_attr=dict(fontsize='100', fontname="times", penwidth='3'), # node_attr=dict(style='filled', shape='rect', align='center', fontsize='100', height='0.5', width='0.5', # penwidth='2', fontname="times"), # engine='dot') # g.body.extend(['rankdir=LR']) # g.node("c_{k-2}", fillcolor='darkseagreen2') # g.node("c_{k-1}", fillcolor='darkseagreen2') # num_edges = len(genotype) # for i in range(steps): # g.node(str(i), fillcolor='lightblue') # for eid in range(num_edges): # op = genotype[eid] # u, v = supernet_dict[eid] # if op != 'skip_connect': # g.edge(u, v, label=op, fillcolor="gray", color='red', fontcolor='red') # else: # g.edge(u, v, label=op, fillcolor="gray") # g.node("c_{k}", fillcolor='palegoldenrod') # for i in range(steps): # g.edge(str(i), "c_{k}", fillcolor="gray") # g.render(filename, view=False) if __name__ == '__main__': #### visualize the supernet #### if len(sys.argv) != 2: print("usage:\n python {} ARCH_NAME".format(sys.argv[0])) sys.exit(1) genotype_name = sys.argv[1] assert 'supernet' in genotype_name, 'this script only supports supernet visualization' try: genotype = eval('genotypes.{}'.format(genotype_name)) except AttributeError: print("{} is not specified in genotypes.py".format(genotype_name)) sys.exit(1) path = '../../figs/genotypes/cnn_supernet_cue/' plot(genotype.normal, path + genotype_name + "_normal") plot(genotype.reduce, path + genotype_name + "_reduce")