upload
This commit is contained in:
67
sota/cnn/visualize.py
Normal file
67
sota/cnn/visualize.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import sys
|
||||
import genotypes
|
||||
from graphviz import Digraph
|
||||
|
||||
|
||||
def plot(genotype, filename, mode=''):
|
||||
g = Digraph(
|
||||
format='pdf',
|
||||
edge_attr=dict(fontsize='40', fontname="times"),
|
||||
node_attr=dict(style='filled', shape='rect', align='center', fontsize='40', height='0.5', width='0.5',
|
||||
penwidth='2', fontname="times"),
|
||||
engine='dot')
|
||||
|
||||
g.body.extend(['rankdir=LR'])
|
||||
|
||||
# g.body.extend(['ratio=0.15'])
|
||||
# g.view()
|
||||
|
||||
g.node("c_{k-2}", fillcolor='darkseagreen2')
|
||||
g.node("c_{k-1}", fillcolor='darkseagreen2')
|
||||
assert len(genotype) % 2 == 0
|
||||
steps = len(genotype) // 2
|
||||
|
||||
for i in range(steps):
|
||||
g.node(str(i), fillcolor='lightblue')
|
||||
|
||||
for i in range(steps):
|
||||
for k in [2 * i, 2 * i + 1]:
|
||||
op, j = genotype[k]
|
||||
if j == 0:
|
||||
u = "c_{k-2}"
|
||||
elif j == 1:
|
||||
u = "c_{k-1}"
|
||||
else:
|
||||
u = str(j - 2)
|
||||
v = str(i)
|
||||
|
||||
if mode == 'cue' and op != 'skip_connect' and op != 'noise':
|
||||
g.edge(u, v, label=op, fillcolor='gray', color='red', fontcolor='red')
|
||||
else:
|
||||
g.edge(u, v, label=op, fillcolor="gray")
|
||||
|
||||
g.node("c_{k}", fillcolor='palegoldenrod')
|
||||
for i in range(steps):
|
||||
g.edge(str(i), "c_{k}", fillcolor="gray")
|
||||
|
||||
g.render(filename, view=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) != 2:
|
||||
print("usage:\n python {} ARCH_NAME".format(sys.argv[0]))
|
||||
sys.exit(1)
|
||||
|
||||
genotype_name = sys.argv[1]
|
||||
try:
|
||||
genotype = eval('genotypes.{}'.format(genotype_name))
|
||||
# print(genotype)
|
||||
except AttributeError:
|
||||
print("{} is not specified in genotypes.py".format(genotype_name))
|
||||
sys.exit(1)
|
||||
|
||||
mode = 'cue'
|
||||
path = '../../figs/genotypes/cnn_{}/'.format(mode)
|
||||
# print(genotype.normal)
|
||||
plot(genotype.normal, path + genotype_name + "_normal", mode=mode)
|
||||
plot(genotype.reduce, path + genotype_name + "_reduce", mode=mode)
|
Reference in New Issue
Block a user