115 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			115 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # python ./exps/vis/test.py
 | |
| import os, sys, random
 | |
| from pathlib import Path
 | |
| from copy import deepcopy
 | |
| import torch
 | |
| import numpy as np
 | |
| from collections import OrderedDict
 | |
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
 | |
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
 | |
| 
 | |
| from nas_201_api import NASBench201API as API
 | |
| 
 | |
| def test_nas_api():
 | |
|   from nas_201_api import ArchResults
 | |
|   xdata   = torch.load('/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-201-4/simplifies/architectures/000157-FULL.pth')
 | |
|   for key in ['full', 'less']:
 | |
|     print ('\n------------------------- {:} -------------------------'.format(key))
 | |
|     archRes = ArchResults.create_from_state_dict(xdata[key])
 | |
|     print(archRes)
 | |
|     print(archRes.arch_idx_str())
 | |
|     print(archRes.get_dataset_names())
 | |
|     print(archRes.get_comput_costs('cifar10-valid'))
 | |
|     # get the metrics
 | |
|     print(archRes.get_metrics('cifar10-valid', 'x-valid', None, False))
 | |
|     print(archRes.get_metrics('cifar10-valid', 'x-valid', None,  True))
 | |
|     print(archRes.query('cifar10-valid', 777))
 | |
| 
 | |
| 
 | |
| OPS    = ['skip-connect', 'conv-1x1', 'conv-3x3', 'pool-3x3']
 | |
| COLORS = ['chartreuse'  , 'cyan'    , 'navyblue', 'chocolate1']
 | |
| 
 | |
| def plot(filename):
 | |
|   from graphviz import Digraph
 | |
|   g = Digraph(
 | |
|       format='png',
 | |
|       edge_attr=dict(fontsize='20', fontname="times"),
 | |
|       node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"),
 | |
|       engine='dot')
 | |
|   g.body.extend(['rankdir=LR'])
 | |
| 
 | |
|   steps = 5
 | |
|   for i in range(0, steps):
 | |
|     if i == 0:
 | |
|       g.node(str(i), fillcolor='darkseagreen2')
 | |
|     elif i+1 == steps:
 | |
|       g.node(str(i), fillcolor='palegoldenrod')
 | |
|     else: g.node(str(i), fillcolor='lightblue')
 | |
| 
 | |
|   for i in range(1, steps):
 | |
|     for xin in range(i):
 | |
|       op_i = random.randint(0, len(OPS)-1)
 | |
|       #g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i])
 | |
|       g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i])
 | |
|       #import pdb; pdb.set_trace()
 | |
|   g.render(filename, cleanup=True, view=False)
 | |
| 
 | |
| 
 | |
| def test_auto_grad():
 | |
|   class Net(torch.nn.Module):
 | |
|     def __init__(self, iS):
 | |
|       super(Net, self).__init__()
 | |
|       self.layer = torch.nn.Linear(iS, 1)
 | |
|     def forward(self, inputs):
 | |
|       outputs = self.layer(inputs)
 | |
|       outputs = torch.exp(outputs)
 | |
|       return outputs.mean()
 | |
|   net = Net(10)
 | |
|   inputs = torch.rand(256, 10)
 | |
|   loss = net(inputs)
 | |
|   first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True)
 | |
|   first_order_grads = torch.cat([x.view(-1) for x in first_order_grads])
 | |
|   second_order_grads = []
 | |
|   for grads in  first_order_grads:
 | |
|     s_grads = torch.autograd.grad(grads, net.parameters())
 | |
|     second_order_grads.append( s_grads )
 | |
| 
 | |
| 
 | |
| def test_one_shot_model(ckpath, use_train):
 | |
|   from models import get_cell_based_tiny_net, get_search_spaces
 | |
|   from datasets import get_datasets, SearchDataset
 | |
|   from config_utils import load_config, dict2config
 | |
|   from utils.nas_utils import evaluate_one_shot
 | |
|   use_train = int(use_train) > 0
 | |
|   #ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth'
 | |
|   #ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth'
 | |
|   print ('ckpath : {:}'.format(ckpath))
 | |
|   ckp = torch.load(ckpath)
 | |
|   xargs = ckp['args']
 | |
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
 | |
|   #config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
 | |
|   config = load_config('./configs/nas-benchmark/algos/DARTS.config', {'class_num': class_num, 'xshape': xshape}, None)
 | |
|   if xargs.dataset == 'cifar10':
 | |
|     cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
 | |
|     xvalid_data = deepcopy(train_data)
 | |
|     xvalid_data.transform = valid_data.transform
 | |
|     valid_loader= torch.utils.data.DataLoader(xvalid_data, batch_size=2048, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid), num_workers=12, pin_memory=True)
 | |
|   else: raise ValueError('invalid dataset : {:}'.format(xargs.dataseet))
 | |
|   search_space = get_search_spaces('cell', xargs.search_space_name)
 | |
|   model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells,
 | |
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num,
 | |
|                               'space'    : search_space,
 | |
|                               'affine'   : False, 'track_running_stats': True}, None)
 | |
|   search_model = get_cell_based_tiny_net(model_config)
 | |
|   search_model.load_state_dict( ckp['search_model'] )
 | |
|   search_model = search_model.cuda()
 | |
|   api = API('/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth')
 | |
|   archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|   #test_nas_api()
 | |
|   #for i in range(200): plot('{:04d}'.format(i))
 | |
|   #test_auto_grad()
 | |
|   test_one_shot_model(sys.argv[1], sys.argv[2])
 |