Add int search space
This commit is contained in:
		@@ -69,7 +69,13 @@ def plot(filename):
 | 
			
		||||
        for xin in range(i):
 | 
			
		||||
            op_i = random.randint(0, len(OPS) - 1)
 | 
			
		||||
            # g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i])
 | 
			
		||||
            g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i])
 | 
			
		||||
            g.edge(
 | 
			
		||||
                str(xin),
 | 
			
		||||
                str(i),
 | 
			
		||||
                label=OPS[op_i],
 | 
			
		||||
                color=COLORS[op_i],
 | 
			
		||||
                fillcolor=COLORS[op_i],
 | 
			
		||||
            )
 | 
			
		||||
            # import pdb; pdb.set_trace()
 | 
			
		||||
    g.render(filename, cleanup=True, view=False)
 | 
			
		||||
 | 
			
		||||
@@ -88,7 +94,9 @@ def test_auto_grad():
 | 
			
		||||
    net = Net(10)
 | 
			
		||||
    inputs = torch.rand(256, 10)
 | 
			
		||||
    loss = net(inputs)
 | 
			
		||||
    first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True)
 | 
			
		||||
    first_order_grads = torch.autograd.grad(
 | 
			
		||||
        loss, net.parameters(), retain_graph=True, create_graph=True
 | 
			
		||||
    )
 | 
			
		||||
    first_order_grads = torch.cat([x.view(-1) for x in first_order_grads])
 | 
			
		||||
    second_order_grads = []
 | 
			
		||||
    for grads in first_order_grads:
 | 
			
		||||
@@ -108,9 +116,15 @@ def test_one_shot_model(ckpath, use_train):
 | 
			
		||||
    print("ckpath : {:}".format(ckpath))
 | 
			
		||||
    ckp = torch.load(ckpath)
 | 
			
		||||
    xargs = ckp["args"]
 | 
			
		||||
    train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
 | 
			
		||||
    train_data, valid_data, xshape, class_num = get_datasets(
 | 
			
		||||
        xargs.dataset, xargs.data_path, -1
 | 
			
		||||
    )
 | 
			
		||||
    # config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
 | 
			
		||||
    config = load_config("./configs/nas-benchmark/algos/DARTS.config", {"class_num": class_num, "xshape": xshape}, None)
 | 
			
		||||
    config = load_config(
 | 
			
		||||
        "./configs/nas-benchmark/algos/DARTS.config",
 | 
			
		||||
        {"class_num": class_num, "xshape": xshape},
 | 
			
		||||
        None,
 | 
			
		||||
    )
 | 
			
		||||
    if xargs.dataset == "cifar10":
 | 
			
		||||
        cifar_split = load_config("configs/nas-benchmark/cifar-split.txt", None, None)
 | 
			
		||||
        xvalid_data = deepcopy(train_data)
 | 
			
		||||
@@ -142,7 +156,9 @@ def test_one_shot_model(ckpath, use_train):
 | 
			
		||||
    search_model.load_state_dict(ckp["search_model"])
 | 
			
		||||
    search_model = search_model.cuda()
 | 
			
		||||
    api = API("/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth")
 | 
			
		||||
    archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
 | 
			
		||||
    archs, probs, accuracies = evaluate_one_shot(
 | 
			
		||||
        search_model, valid_loader, api, use_train
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user