| 
									
										
										
										
											2019-12-20 20:41:49 +11:00
										 |  |  | # python ./exps/vis/test.py | 
					
						
							| 
									
										
										
										
											2020-01-02 14:35:58 +11:00
										 |  |  | import os, sys, random | 
					
						
							| 
									
										
										
										
											2019-12-20 20:41:49 +11:00
										 |  |  | from pathlib import Path | 
					
						
							| 
									
										
										
										
											2020-01-09 22:26:23 +11:00
										 |  |  | from copy import deepcopy | 
					
						
							| 
									
										
										
										
											2019-12-20 20:41:49 +11:00
										 |  |  | import torch | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | from collections import OrderedDict | 
					
						
							|  |  |  | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | 
					
						
							|  |  |  | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-01-15 00:52:06 +11:00
										 |  |  | from nas_201_api import NASBench201API as API | 
					
						
							| 
									
										
										
										
											2019-12-20 20:41:49 +11:00
										 |  |  | 
 | 
					
						
							|  |  |  | def test_nas_api(): | 
					
						
							| 
									
										
										
										
											2020-01-15 00:52:06 +11:00
										 |  |  |   from nas_201_api import ArchResults | 
					
						
							|  |  |  |   xdata   = torch.load('/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-201-4/simplifies/architectures/000157-FULL.pth') | 
					
						
							| 
									
										
										
										
											2019-12-20 20:41:49 +11:00
										 |  |  |   for key in ['full', 'less']: | 
					
						
							|  |  |  |     print ('\n------------------------- {:} -------------------------'.format(key)) | 
					
						
							|  |  |  |     archRes = ArchResults.create_from_state_dict(xdata[key]) | 
					
						
							|  |  |  |     print(archRes) | 
					
						
							|  |  |  |     print(archRes.arch_idx_str()) | 
					
						
							|  |  |  |     print(archRes.get_dataset_names()) | 
					
						
							|  |  |  |     print(archRes.get_comput_costs('cifar10-valid')) | 
					
						
							|  |  |  |     # get the metrics | 
					
						
							|  |  |  |     print(archRes.get_metrics('cifar10-valid', 'x-valid', None, False)) | 
					
						
							|  |  |  |     print(archRes.get_metrics('cifar10-valid', 'x-valid', None,  True)) | 
					
						
							|  |  |  |     print(archRes.query('cifar10-valid', 777)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-01-02 14:35:58 +11:00
										 |  |  | 
 | 
					
						
							|  |  |  | OPS    = ['skip-connect', 'conv-1x1', 'conv-3x3', 'pool-3x3'] | 
					
						
							|  |  |  | COLORS = ['chartreuse'  , 'cyan'    , 'navyblue', 'chocolate1'] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def plot(filename): | 
					
						
							| 
									
										
										
										
											2020-01-05 22:19:38 +11:00
										 |  |  |   from graphviz import Digraph | 
					
						
							| 
									
										
										
										
											2020-01-02 14:35:58 +11:00
										 |  |  |   g = Digraph( | 
					
						
							|  |  |  |       format='png', | 
					
						
							|  |  |  |       edge_attr=dict(fontsize='20', fontname="times"), | 
					
						
							|  |  |  |       node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"), | 
					
						
							|  |  |  |       engine='dot') | 
					
						
							|  |  |  |   g.body.extend(['rankdir=LR']) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   steps = 5 | 
					
						
							|  |  |  |   for i in range(0, steps): | 
					
						
							|  |  |  |     if i == 0: | 
					
						
							|  |  |  |       g.node(str(i), fillcolor='darkseagreen2') | 
					
						
							|  |  |  |     elif i+1 == steps: | 
					
						
							|  |  |  |       g.node(str(i), fillcolor='palegoldenrod') | 
					
						
							|  |  |  |     else: g.node(str(i), fillcolor='lightblue') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   for i in range(1, steps): | 
					
						
							|  |  |  |     for xin in range(i): | 
					
						
							|  |  |  |       op_i = random.randint(0, len(OPS)-1) | 
					
						
							|  |  |  |       #g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i]) | 
					
						
							|  |  |  |       g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i]) | 
					
						
							|  |  |  |       #import pdb; pdb.set_trace() | 
					
						
							|  |  |  |   g.render(filename, cleanup=True, view=False) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-01-05 22:19:38 +11:00
										 |  |  | def test_auto_grad(): | 
					
						
							|  |  |  |   class Net(torch.nn.Module): | 
					
						
							|  |  |  |     def __init__(self, iS): | 
					
						
							|  |  |  |       super(Net, self).__init__() | 
					
						
							|  |  |  |       self.layer = torch.nn.Linear(iS, 1) | 
					
						
							|  |  |  |     def forward(self, inputs): | 
					
						
							|  |  |  |       outputs = self.layer(inputs) | 
					
						
							|  |  |  |       outputs = torch.exp(outputs) | 
					
						
							|  |  |  |       return outputs.mean() | 
					
						
							|  |  |  |   net = Net(10) | 
					
						
							|  |  |  |   inputs = torch.rand(256, 10) | 
					
						
							|  |  |  |   loss = net(inputs) | 
					
						
							|  |  |  |   first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True) | 
					
						
							|  |  |  |   first_order_grads = torch.cat([x.view(-1) for x in first_order_grads]) | 
					
						
							|  |  |  |   second_order_grads = [] | 
					
						
							|  |  |  |   for grads in  first_order_grads: | 
					
						
							|  |  |  |     s_grads = torch.autograd.grad(grads, net.parameters()) | 
					
						
							|  |  |  |     second_order_grads.append( s_grads ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-01-09 22:26:23 +11:00
										 |  |  | 
 | 
					
						
							|  |  |  | def test_one_shot_model(ckpath, use_train): | 
					
						
							|  |  |  |   from models import get_cell_based_tiny_net, get_search_spaces | 
					
						
							|  |  |  |   from datasets import get_datasets, SearchDataset | 
					
						
							|  |  |  |   from config_utils import load_config, dict2config | 
					
						
							|  |  |  |   from utils.nas_utils import evaluate_one_shot | 
					
						
							|  |  |  |   use_train = int(use_train) > 0 | 
					
						
							| 
									
										
										
										
											2020-01-15 00:52:06 +11:00
										 |  |  |   #ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth' | 
					
						
							|  |  |  |   #ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth' | 
					
						
							| 
									
										
										
										
											2020-01-09 22:26:23 +11:00
										 |  |  |   print ('ckpath : {:}'.format(ckpath)) | 
					
						
							|  |  |  |   ckp = torch.load(ckpath) | 
					
						
							|  |  |  |   xargs = ckp['args'] | 
					
						
							|  |  |  |   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | 
					
						
							| 
									
										
										
										
											2020-01-10 17:26:37 +11:00
										 |  |  |   #config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None) | 
					
						
							|  |  |  |   config = load_config('./configs/nas-benchmark/algos/DARTS.config', {'class_num': class_num, 'xshape': xshape}, None) | 
					
						
							| 
									
										
										
										
											2020-01-09 22:26:23 +11:00
										 |  |  |   if xargs.dataset == 'cifar10': | 
					
						
							|  |  |  |     cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None) | 
					
						
							|  |  |  |     xvalid_data = deepcopy(train_data) | 
					
						
							|  |  |  |     xvalid_data.transform = valid_data.transform | 
					
						
							|  |  |  |     valid_loader= torch.utils.data.DataLoader(xvalid_data, batch_size=2048, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid), num_workers=12, pin_memory=True) | 
					
						
							|  |  |  |   else: raise ValueError('invalid dataset : {:}'.format(xargs.dataseet)) | 
					
						
							|  |  |  |   search_space = get_search_spaces('cell', xargs.search_space_name) | 
					
						
							|  |  |  |   model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells, | 
					
						
							|  |  |  |                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | 
					
						
							|  |  |  |                               'space'    : search_space, | 
					
						
							|  |  |  |                               'affine'   : False, 'track_running_stats': True}, None) | 
					
						
							|  |  |  |   search_model = get_cell_based_tiny_net(model_config) | 
					
						
							|  |  |  |   search_model.load_state_dict( ckp['search_model'] ) | 
					
						
							|  |  |  |   search_model = search_model.cuda() | 
					
						
							| 
									
										
										
										
											2020-01-15 00:52:06 +11:00
										 |  |  |   api = API('/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth') | 
					
						
							| 
									
										
										
										
											2020-01-09 22:26:23 +11:00
										 |  |  |   archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-20 20:41:49 +11:00
										 |  |  | if __name__ == '__main__': | 
					
						
							| 
									
										
										
										
											2020-01-05 22:19:38 +11:00
										 |  |  |   #test_nas_api() | 
					
						
							|  |  |  |   #for i in range(200): plot('{:04d}'.format(i)) | 
					
						
							| 
									
										
										
										
											2020-01-09 22:26:23 +11:00
										 |  |  |   #test_auto_grad() | 
					
						
							|  |  |  |   test_one_shot_model(sys.argv[1], sys.argv[2]) |