70 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			70 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | import os, sys, time, glob, random, argparse | ||
|  | import numpy as np | ||
|  | from copy import deepcopy | ||
|  | import torch | ||
|  | from pathlib import Path | ||
|  | lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||
|  | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||
|  | from graphviz import Digraph | ||
|  | 
 | ||
|  | parser = argparse.ArgumentParser("Visualize the Networks") | ||
|  | parser.add_argument('--checkpoint', type=str,   help='The path to the checkpoint.') | ||
|  | parser.add_argument('--save_dir',   type=str,   help='The directory to save the network plot.') | ||
|  | args = parser.parse_args() | ||
|  | 
 | ||
|  | 
 | ||
|  | def plot(genotype, filename): | ||
|  |   g = Digraph( | ||
|  |       format='pdf', | ||
|  |       edge_attr=dict(fontsize='20', fontname="times"), | ||
|  |       node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"), | ||
|  |       engine='dot') | ||
|  |   g.body.extend(['rankdir=LR']) | ||
|  | 
 | ||
|  |   g.node("c_{k-2}", fillcolor='darkseagreen2') | ||
|  |   g.node("c_{k-1}", fillcolor='darkseagreen2') | ||
|  |   assert len(genotype) % 2 == 0 | ||
|  |   steps = len(genotype) // 2 | ||
|  | 
 | ||
|  |   for i in range(steps): | ||
|  |     g.node(str(i), fillcolor='lightblue') | ||
|  | 
 | ||
|  |   for i in range(steps): | ||
|  |     for k in [2*i, 2*i + 1]: | ||
|  |       op, j, weight = genotype[k] | ||
|  |       if j == 0: | ||
|  |         u = "c_{k-2}" | ||
|  |       elif j == 1: | ||
|  |         u = "c_{k-1}" | ||
|  |       else: | ||
|  |         u = str(j-2) | ||
|  |       v = str(i) | ||
|  |       g.edge(u, v, label=op, fillcolor="gray") | ||
|  | 
 | ||
|  |   g.node("c_{k}", fillcolor='palegoldenrod') | ||
|  |   for i in range(steps): | ||
|  |     g.edge(str(i), "c_{k}", fillcolor="gray") | ||
|  | 
 | ||
|  |   g.render(filename, view=False) | ||
|  | 
 | ||
|  | 
 | ||
|  | if __name__ == '__main__': | ||
|  |   checkpoint = args.checkpoint | ||
|  |   assert os.path.isfile(checkpoint), 'Invalid path for checkpoint : {:}'.format(checkpoint) | ||
|  |   checkpoint = torch.load( checkpoint, map_location='cpu' ) | ||
|  |   genotypes  = checkpoint['genotypes'] | ||
|  |   save_dir   = Path(args.save_dir) | ||
|  |   subs       = ['normal', 'reduce'] | ||
|  |   for sub in subs: | ||
|  |     if not (save_dir / sub).exists(): | ||
|  |       (save_dir / sub).mkdir(parents=True, exist_ok=True) | ||
|  | 
 | ||
|  |   for key, network in genotypes.items(): | ||
|  |     save_path = str(save_dir / 'normal' / 'epoch-{:03d}'.format( int(key) )) | ||
|  |     print('save into {:}'.format(save_path)) | ||
|  |     plot(network.normal, save_path) | ||
|  | 
 | ||
|  |     save_path = str(save_dir / 'reduce' / 'epoch-{:03d}'.format( int(key) )) | ||
|  |     print('save into {:}'.format(save_path)) | ||
|  |     plot(network.reduce, save_path) |