Reformulate via black
This commit is contained in:
		@@ -5,110 +5,148 @@ from copy import deepcopy
 | 
			
		||||
import torch
 | 
			
		||||
import numpy as np
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
 | 
			
		||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
 | 
			
		||||
 | 
			
		||||
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
 | 
			
		||||
if str(lib_dir) not in sys.path:
 | 
			
		||||
    sys.path.insert(0, str(lib_dir))
 | 
			
		||||
 | 
			
		||||
from nas_201_api import NASBench201API as API
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_nas_api():
 | 
			
		||||
  from nas_201_api import ArchResults
 | 
			
		||||
  xdata   = torch.load('/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-201-4/simplifies/architectures/000157-FULL.pth')
 | 
			
		||||
  for key in ['full', 'less']:
 | 
			
		||||
    print ('\n------------------------- {:} -------------------------'.format(key))
 | 
			
		||||
    archRes = ArchResults.create_from_state_dict(xdata[key])
 | 
			
		||||
    print(archRes)
 | 
			
		||||
    print(archRes.arch_idx_str())
 | 
			
		||||
    print(archRes.get_dataset_names())
 | 
			
		||||
    print(archRes.get_comput_costs('cifar10-valid'))
 | 
			
		||||
    # get the metrics
 | 
			
		||||
    print(archRes.get_metrics('cifar10-valid', 'x-valid', None, False))
 | 
			
		||||
    print(archRes.get_metrics('cifar10-valid', 'x-valid', None,  True))
 | 
			
		||||
    print(archRes.query('cifar10-valid', 777))
 | 
			
		||||
    from nas_201_api import ArchResults
 | 
			
		||||
 | 
			
		||||
    xdata = torch.load(
 | 
			
		||||
        "/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-201-4/simplifies/architectures/000157-FULL.pth"
 | 
			
		||||
    )
 | 
			
		||||
    for key in ["full", "less"]:
 | 
			
		||||
        print("\n------------------------- {:} -------------------------".format(key))
 | 
			
		||||
        archRes = ArchResults.create_from_state_dict(xdata[key])
 | 
			
		||||
        print(archRes)
 | 
			
		||||
        print(archRes.arch_idx_str())
 | 
			
		||||
        print(archRes.get_dataset_names())
 | 
			
		||||
        print(archRes.get_comput_costs("cifar10-valid"))
 | 
			
		||||
        # get the metrics
 | 
			
		||||
        print(archRes.get_metrics("cifar10-valid", "x-valid", None, False))
 | 
			
		||||
        print(archRes.get_metrics("cifar10-valid", "x-valid", None, True))
 | 
			
		||||
        print(archRes.query("cifar10-valid", 777))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
OPS    = ['skip-connect', 'conv-1x1', 'conv-3x3', 'pool-3x3']
 | 
			
		||||
COLORS = ['chartreuse'  , 'cyan'    , 'navyblue', 'chocolate1']
 | 
			
		||||
OPS = ["skip-connect", "conv-1x1", "conv-3x3", "pool-3x3"]
 | 
			
		||||
COLORS = ["chartreuse", "cyan", "navyblue", "chocolate1"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def plot(filename):
 | 
			
		||||
  from graphviz import Digraph
 | 
			
		||||
  g = Digraph(
 | 
			
		||||
      format='png',
 | 
			
		||||
      edge_attr=dict(fontsize='20', fontname="times"),
 | 
			
		||||
      node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"),
 | 
			
		||||
      engine='dot')
 | 
			
		||||
  g.body.extend(['rankdir=LR'])
 | 
			
		||||
    from graphviz import Digraph
 | 
			
		||||
 | 
			
		||||
  steps = 5
 | 
			
		||||
  for i in range(0, steps):
 | 
			
		||||
    if i == 0:
 | 
			
		||||
      g.node(str(i), fillcolor='darkseagreen2')
 | 
			
		||||
    elif i+1 == steps:
 | 
			
		||||
      g.node(str(i), fillcolor='palegoldenrod')
 | 
			
		||||
    else: g.node(str(i), fillcolor='lightblue')
 | 
			
		||||
    g = Digraph(
 | 
			
		||||
        format="png",
 | 
			
		||||
        edge_attr=dict(fontsize="20", fontname="times"),
 | 
			
		||||
        node_attr=dict(
 | 
			
		||||
            style="filled",
 | 
			
		||||
            shape="rect",
 | 
			
		||||
            align="center",
 | 
			
		||||
            fontsize="20",
 | 
			
		||||
            height="0.5",
 | 
			
		||||
            width="0.5",
 | 
			
		||||
            penwidth="2",
 | 
			
		||||
            fontname="times",
 | 
			
		||||
        ),
 | 
			
		||||
        engine="dot",
 | 
			
		||||
    )
 | 
			
		||||
    g.body.extend(["rankdir=LR"])
 | 
			
		||||
 | 
			
		||||
  for i in range(1, steps):
 | 
			
		||||
    for xin in range(i):
 | 
			
		||||
      op_i = random.randint(0, len(OPS)-1)
 | 
			
		||||
      #g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i])
 | 
			
		||||
      g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i])
 | 
			
		||||
      #import pdb; pdb.set_trace()
 | 
			
		||||
  g.render(filename, cleanup=True, view=False)
 | 
			
		||||
    steps = 5
 | 
			
		||||
    for i in range(0, steps):
 | 
			
		||||
        if i == 0:
 | 
			
		||||
            g.node(str(i), fillcolor="darkseagreen2")
 | 
			
		||||
        elif i + 1 == steps:
 | 
			
		||||
            g.node(str(i), fillcolor="palegoldenrod")
 | 
			
		||||
        else:
 | 
			
		||||
            g.node(str(i), fillcolor="lightblue")
 | 
			
		||||
 | 
			
		||||
    for i in range(1, steps):
 | 
			
		||||
        for xin in range(i):
 | 
			
		||||
            op_i = random.randint(0, len(OPS) - 1)
 | 
			
		||||
            # g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i])
 | 
			
		||||
            g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i])
 | 
			
		||||
            # import pdb; pdb.set_trace()
 | 
			
		||||
    g.render(filename, cleanup=True, view=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_auto_grad():
 | 
			
		||||
  class Net(torch.nn.Module):
 | 
			
		||||
    def __init__(self, iS):
 | 
			
		||||
      super(Net, self).__init__()
 | 
			
		||||
      self.layer = torch.nn.Linear(iS, 1)
 | 
			
		||||
    def forward(self, inputs):
 | 
			
		||||
      outputs = self.layer(inputs)
 | 
			
		||||
      outputs = torch.exp(outputs)
 | 
			
		||||
      return outputs.mean()
 | 
			
		||||
  net = Net(10)
 | 
			
		||||
  inputs = torch.rand(256, 10)
 | 
			
		||||
  loss = net(inputs)
 | 
			
		||||
  first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True)
 | 
			
		||||
  first_order_grads = torch.cat([x.view(-1) for x in first_order_grads])
 | 
			
		||||
  second_order_grads = []
 | 
			
		||||
  for grads in  first_order_grads:
 | 
			
		||||
    s_grads = torch.autograd.grad(grads, net.parameters())
 | 
			
		||||
    second_order_grads.append( s_grads )
 | 
			
		||||
    class Net(torch.nn.Module):
 | 
			
		||||
        def __init__(self, iS):
 | 
			
		||||
            super(Net, self).__init__()
 | 
			
		||||
            self.layer = torch.nn.Linear(iS, 1)
 | 
			
		||||
 | 
			
		||||
        def forward(self, inputs):
 | 
			
		||||
            outputs = self.layer(inputs)
 | 
			
		||||
            outputs = torch.exp(outputs)
 | 
			
		||||
            return outputs.mean()
 | 
			
		||||
 | 
			
		||||
    net = Net(10)
 | 
			
		||||
    inputs = torch.rand(256, 10)
 | 
			
		||||
    loss = net(inputs)
 | 
			
		||||
    first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True)
 | 
			
		||||
    first_order_grads = torch.cat([x.view(-1) for x in first_order_grads])
 | 
			
		||||
    second_order_grads = []
 | 
			
		||||
    for grads in first_order_grads:
 | 
			
		||||
        s_grads = torch.autograd.grad(grads, net.parameters())
 | 
			
		||||
        second_order_grads.append(s_grads)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_one_shot_model(ckpath, use_train):
 | 
			
		||||
  from models import get_cell_based_tiny_net, get_search_spaces
 | 
			
		||||
  from datasets import get_datasets, SearchDataset
 | 
			
		||||
  from config_utils import load_config, dict2config
 | 
			
		||||
  from utils.nas_utils import evaluate_one_shot
 | 
			
		||||
  use_train = int(use_train) > 0
 | 
			
		||||
  #ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth'
 | 
			
		||||
  #ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth'
 | 
			
		||||
  print ('ckpath : {:}'.format(ckpath))
 | 
			
		||||
  ckp = torch.load(ckpath)
 | 
			
		||||
  xargs = ckp['args']
 | 
			
		||||
  train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
 | 
			
		||||
  #config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
 | 
			
		||||
  config = load_config('./configs/nas-benchmark/algos/DARTS.config', {'class_num': class_num, 'xshape': xshape}, None)
 | 
			
		||||
  if xargs.dataset == 'cifar10':
 | 
			
		||||
    cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
 | 
			
		||||
    xvalid_data = deepcopy(train_data)
 | 
			
		||||
    xvalid_data.transform = valid_data.transform
 | 
			
		||||
    valid_loader= torch.utils.data.DataLoader(xvalid_data, batch_size=2048, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid), num_workers=12, pin_memory=True)
 | 
			
		||||
  else: raise ValueError('invalid dataset : {:}'.format(xargs.dataseet))
 | 
			
		||||
  search_space = get_search_spaces('cell', xargs.search_space_name)
 | 
			
		||||
  model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells,
 | 
			
		||||
                              'max_nodes': xargs.max_nodes, 'num_classes': class_num,
 | 
			
		||||
                              'space'    : search_space,
 | 
			
		||||
                              'affine'   : False, 'track_running_stats': True}, None)
 | 
			
		||||
  search_model = get_cell_based_tiny_net(model_config)
 | 
			
		||||
  search_model.load_state_dict( ckp['search_model'] )
 | 
			
		||||
  search_model = search_model.cuda()
 | 
			
		||||
  api = API('/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth')
 | 
			
		||||
  archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
 | 
			
		||||
    from models import get_cell_based_tiny_net, get_search_spaces
 | 
			
		||||
    from datasets import get_datasets, SearchDataset
 | 
			
		||||
    from config_utils import load_config, dict2config
 | 
			
		||||
    from utils.nas_utils import evaluate_one_shot
 | 
			
		||||
 | 
			
		||||
    use_train = int(use_train) > 0
 | 
			
		||||
    # ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth'
 | 
			
		||||
    # ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth'
 | 
			
		||||
    print("ckpath : {:}".format(ckpath))
 | 
			
		||||
    ckp = torch.load(ckpath)
 | 
			
		||||
    xargs = ckp["args"]
 | 
			
		||||
    train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
 | 
			
		||||
    # config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
 | 
			
		||||
    config = load_config("./configs/nas-benchmark/algos/DARTS.config", {"class_num": class_num, "xshape": xshape}, None)
 | 
			
		||||
    if xargs.dataset == "cifar10":
 | 
			
		||||
        cifar_split = load_config("configs/nas-benchmark/cifar-split.txt", None, None)
 | 
			
		||||
        xvalid_data = deepcopy(train_data)
 | 
			
		||||
        xvalid_data.transform = valid_data.transform
 | 
			
		||||
        valid_loader = torch.utils.data.DataLoader(
 | 
			
		||||
            xvalid_data,
 | 
			
		||||
            batch_size=2048,
 | 
			
		||||
            sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid),
 | 
			
		||||
            num_workers=12,
 | 
			
		||||
            pin_memory=True,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("invalid dataset : {:}".format(xargs.dataseet))
 | 
			
		||||
    search_space = get_search_spaces("cell", xargs.search_space_name)
 | 
			
		||||
    model_config = dict2config(
 | 
			
		||||
        {
 | 
			
		||||
            "name": "SETN",
 | 
			
		||||
            "C": xargs.channel,
 | 
			
		||||
            "N": xargs.num_cells,
 | 
			
		||||
            "max_nodes": xargs.max_nodes,
 | 
			
		||||
            "num_classes": class_num,
 | 
			
		||||
            "space": search_space,
 | 
			
		||||
            "affine": False,
 | 
			
		||||
            "track_running_stats": True,
 | 
			
		||||
        },
 | 
			
		||||
        None,
 | 
			
		||||
    )
 | 
			
		||||
    search_model = get_cell_based_tiny_net(model_config)
 | 
			
		||||
    search_model.load_state_dict(ckp["search_model"])
 | 
			
		||||
    search_model = search_model.cuda()
 | 
			
		||||
    api = API("/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth")
 | 
			
		||||
    archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
  #test_nas_api()
 | 
			
		||||
  #for i in range(200): plot('{:04d}'.format(i))
 | 
			
		||||
  #test_auto_grad()
 | 
			
		||||
  test_one_shot_model(sys.argv[1], sys.argv[2])
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # test_nas_api()
 | 
			
		||||
    # for i in range(200): plot('{:04d}'.format(i))
 | 
			
		||||
    # test_auto_grad()
 | 
			
		||||
    test_one_shot_model(sys.argv[1], sys.argv[2])
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user