first commit
This commit is contained in:
8
NAS-Bench-201/all_path.py
Normal file
8
NAS-Bench-201/all_path.py
Normal file
@@ -0,0 +1,8 @@
|
||||
SCORENET_CKPT_PATH="./checkpoints/scorenet/checkpoint.pth.tar"
|
||||
META_SURROGATE_CKPT_PATH="./checkpoints/meta_surrogate/checkpoint.pth.tar"
|
||||
META_SURROGATE_UNNOISED_CKPT_PATH = "./checkpoints/meta_surrogate/unnoised_checkpoint.pth.tar"
|
||||
NASBENCH201="./data/transfer_nag/nasbench201.pt"
|
||||
NASBENCH201_INFO="./data/transfer_nag/nasbench201_info.pt"
|
||||
META_TEST_PATH="./data/transfer_nag/test"
|
||||
RAW_DATA_PATH="./data/raw_data"
|
||||
DATA_PATH = "./data/transfer_nag"
|
||||
347
NAS-Bench-201/analysis/arch_functions.py
Normal file
347
NAS-Bench-201/analysis/arch_functions.py
Normal file
@@ -0,0 +1,347 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from all_path import *
|
||||
|
||||
|
||||
class BasicArchMetrics(object):
|
||||
def __init__(self, train_ds=None, train_arch_str_list=None):
|
||||
if train_ds is None:
|
||||
self.ops_decoder = ['input', 'output', 'none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
else:
|
||||
self.ops_decoder = train_ds.ops_decoder
|
||||
self.nasbench201 = torch.load(NASBENCH201_INFO)
|
||||
self.train_arch_str_list = train_arch_str_list
|
||||
|
||||
|
||||
def compute_validity(self, generated):
|
||||
START_TYPE = self.ops_decoder.index('input')
|
||||
END_TYPE = self.ops_decoder.index('output')
|
||||
|
||||
valid = []
|
||||
valid_arch_str = []
|
||||
all_arch_str = []
|
||||
for x in generated:
|
||||
is_valid, error_types = is_valid_NAS201_x(x, START_TYPE, END_TYPE)
|
||||
if is_valid:
|
||||
valid.append(x)
|
||||
arch_str = decode_x_to_NAS_BENCH_201_string(x, self.ops_decoder)
|
||||
valid_arch_str.append(arch_str)
|
||||
else:
|
||||
arch_str = None
|
||||
all_arch_str.append(arch_str)
|
||||
validity = 0 if len(generated) == 0 else (len(valid)/len(generated))
|
||||
return valid, validity, valid_arch_str, all_arch_str
|
||||
|
||||
|
||||
def compute_uniqueness(self, valid_arch_str):
|
||||
return list(set(valid_arch_str)), len(set(valid_arch_str)) / len(valid_arch_str)
|
||||
|
||||
|
||||
def compute_novelty(self, unique):
|
||||
num_novel = 0
|
||||
novel = []
|
||||
if self.train_arch_str_list is None:
|
||||
print("Dataset arch_str is None, novelty computation skipped")
|
||||
return 1, 1
|
||||
for arch_str in unique:
|
||||
if arch_str not in self.train_arch_str_list:
|
||||
novel.append(arch_str)
|
||||
num_novel += 1
|
||||
return novel, num_novel / len(unique)
|
||||
|
||||
|
||||
def evaluate(self, generated, check_dataname='cifar10'):
|
||||
valid, validity, valid_arch_str, all_arch_str = self.compute_validity(generated)
|
||||
|
||||
if validity > 0:
|
||||
unique, uniqueness = self.compute_uniqueness(valid_arch_str)
|
||||
if self.train_arch_str_list is not None:
|
||||
_, novelty = self.compute_novelty(unique)
|
||||
else:
|
||||
novelty = -1.0
|
||||
else:
|
||||
novelty = -1.0
|
||||
uniqueness = 0.0
|
||||
unique = []
|
||||
|
||||
if uniqueness > 0.:
|
||||
arch_idx_list, flops_list, params_list, latency_list = list(), list(), list(), list()
|
||||
for arch in unique:
|
||||
arch_index, flops, params, latency = \
|
||||
get_arch_acc_info(self.nasbench201, arch=arch, dataname=check_dataname)
|
||||
arch_idx_list.append(arch_index)
|
||||
flops_list.append(flops)
|
||||
params_list.append(params)
|
||||
latency_list.append(latency)
|
||||
else:
|
||||
arch_idx_list, flops_list, params_list, latency_list = [-1], [0], [0], [0]
|
||||
|
||||
return ([validity, uniqueness, novelty],
|
||||
unique,
|
||||
dict(arch_idx_list=arch_idx_list, flops_list=flops_list, params_list=params_list, latency_list=latency_list),
|
||||
all_arch_str)
|
||||
|
||||
|
||||
class BasicArchMetricsMeta(object):
|
||||
def __init__(self, train_ds=None, train_arch_str_list=None):
|
||||
if train_ds is None:
|
||||
self.ops_decoder = ['input', 'output', 'none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
else:
|
||||
self.ops_decoder = train_ds.ops_decoder
|
||||
self.nasbench201 = torch.load(NASBENCH201_INFO)
|
||||
self.train_arch_str_list = train_arch_str_list
|
||||
|
||||
|
||||
def compute_validity(self, generated):
|
||||
START_TYPE = self.ops_decoder.index('input')
|
||||
END_TYPE = self.ops_decoder.index('output')
|
||||
|
||||
valid = []
|
||||
valid_arch_str = []
|
||||
all_arch_str = []
|
||||
error_types = []
|
||||
|
||||
for x in generated:
|
||||
is_valid, error_type = is_valid_NAS201_x(x, START_TYPE, END_TYPE)
|
||||
if is_valid:
|
||||
valid.append(x)
|
||||
arch_str = decode_x_to_NAS_BENCH_201_string(x, self.ops_decoder)
|
||||
valid_arch_str.append(arch_str)
|
||||
else:
|
||||
arch_str = None
|
||||
error_types.append(error_type)
|
||||
all_arch_str.append(arch_str)
|
||||
|
||||
# exceptional case
|
||||
validity = 0 if len(generated) == 0 else (len(valid)/len(generated))
|
||||
if len(valid) == 0:
|
||||
validity = 0
|
||||
valid_arch_str = []
|
||||
|
||||
return valid, validity, valid_arch_str, all_arch_str
|
||||
|
||||
|
||||
def compute_uniqueness(self, valid_arch_str):
|
||||
return list(set(valid_arch_str)), len(set(valid_arch_str)) / len(valid_arch_str)
|
||||
|
||||
|
||||
def compute_novelty(self, unique):
|
||||
num_novel = 0
|
||||
novel = []
|
||||
if self.train_arch_str_list is None:
|
||||
print("Dataset arch_str is None, novelty computation skipped")
|
||||
return 1, 1
|
||||
for arch_str in unique:
|
||||
if arch_str not in self.train_arch_str_list:
|
||||
novel.append(arch_str)
|
||||
num_novel += 1
|
||||
return novel, num_novel / len(unique)
|
||||
|
||||
|
||||
def evaluate(self, generated, check_dataname='cifar10'):
|
||||
valid, validity, valid_arch_str, all_arch_str = self.compute_validity(generated)
|
||||
|
||||
if validity > 0:
|
||||
unique, uniqueness = self.compute_uniqueness(valid_arch_str)
|
||||
if self.train_arch_str_list is not None:
|
||||
_, novelty = self.compute_novelty(unique)
|
||||
else:
|
||||
novelty = -1.0
|
||||
else:
|
||||
novelty = -1.0
|
||||
uniqueness = 0.0
|
||||
unique = []
|
||||
|
||||
if uniqueness > 0.:
|
||||
arch_idx_list, flops_list, params_list, latency_list = list(), list(), list(), list()
|
||||
for arch in unique:
|
||||
arch_index, flops, params, latency = \
|
||||
get_arch_acc_info_meta(self.nasbench201, arch=arch, dataname=check_dataname)
|
||||
arch_idx_list.append(arch_index)
|
||||
flops_list.append(flops)
|
||||
params_list.append(params)
|
||||
latency_list.append(latency)
|
||||
else:
|
||||
arch_idx_list, flops_list, params_list, latency_list = [-1], [0], [0], [0]
|
||||
|
||||
return ([validity, uniqueness, novelty],
|
||||
unique,
|
||||
dict(arch_idx_list=arch_idx_list, flops_list=flops_list, params_list=params_list, latency_list=latency_list),
|
||||
all_arch_str)
|
||||
|
||||
|
||||
def get_arch_acc_info(nasbench201, arch, dataname='cifar10'):
|
||||
arch_index = nasbench201['str'].index(arch)
|
||||
flops = nasbench201['flops'][dataname][arch_index]
|
||||
params = nasbench201['params'][dataname][arch_index]
|
||||
latency = nasbench201['latency'][dataname][arch_index]
|
||||
return arch_index, flops, params, latency
|
||||
|
||||
|
||||
def get_arch_acc_info_meta(nasbench201, arch, dataname='cifar10'):
|
||||
arch_index = nasbench201['str'].index(arch)
|
||||
flops = nasbench201['flops'][dataname][arch_index]
|
||||
params = nasbench201['params'][dataname][arch_index]
|
||||
latency = nasbench201['latency'][dataname][arch_index]
|
||||
return arch_index, flops, params, latency
|
||||
|
||||
|
||||
def decode_igraph_to_NAS_BENCH_201_string(g):
|
||||
if not is_valid_NAS201(g):
|
||||
return None
|
||||
m = decode_igraph_to_NAS201_matrix(g)
|
||||
types = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
return '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.\
|
||||
format(types[int(m[1][0])],
|
||||
types[int(m[2][0])], types[int(m[2][1])],
|
||||
types[int(m[3][0])], types[int(m[3][1])], types[int(m[3][2])])
|
||||
|
||||
|
||||
def decode_igraph_to_NAS201_matrix(g):
|
||||
m = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
|
||||
xys = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]
|
||||
for i, xy in enumerate(xys):
|
||||
m[xy[0]][xy[1]] = float(g.vs[i + 1]['type']) - 2
|
||||
import numpy
|
||||
return numpy.array(m)
|
||||
|
||||
|
||||
def decode_x_to_NAS_BENCH_201_matrix(x):
|
||||
m = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
|
||||
xys = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]
|
||||
for i, xy in enumerate(xys):
|
||||
# m[xy[0]][xy[1]] = int(torch.argmax(torch.tensor(x[i+1])).item()) - 2
|
||||
m[xy[0]][xy[1]] = int(torch.argmax(torch.tensor(x[i+1])).item())
|
||||
import numpy
|
||||
return numpy.array(m)
|
||||
|
||||
|
||||
def decode_x_to_NAS_BENCH_201_string(x, ops_decoder):
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): x_elem [8, 7]
|
||||
|
||||
Returns:
|
||||
arch_str
|
||||
"""
|
||||
is_valid, error_type = is_valid_NAS201_x(x)
|
||||
if not is_valid:
|
||||
return None
|
||||
m = decode_x_to_NAS_BENCH_201_matrix(x)
|
||||
types = ops_decoder
|
||||
arch_str = '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.\
|
||||
format(types[int(m[1][0])],
|
||||
types[int(m[2][0])], types[int(m[2][1])],
|
||||
types[int(m[3][0])], types[int(m[3][1])], types[int(m[3][2])])
|
||||
return arch_str
|
||||
|
||||
|
||||
def decode_x_to_NAS_BENCH_201_string(x, ops_decoder):
|
||||
"""_summary_
|
||||
Args:
|
||||
x (torch.Tensor): x_elem [8, 7]
|
||||
Returns:
|
||||
arch_str
|
||||
"""
|
||||
|
||||
if not is_valid_NAS201_x(x)[0]:
|
||||
return None
|
||||
m = decode_x_to_NAS_BENCH_201_matrix(x)
|
||||
types = ops_decoder
|
||||
arch_str = '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.\
|
||||
format(types[int(m[1][0])],
|
||||
types[int(m[2][0])], types[int(m[2][1])],
|
||||
types[int(m[3][0])], types[int(m[3][1])], types[int(m[3][2])])
|
||||
return arch_str
|
||||
|
||||
|
||||
def is_valid_DAG(g, START_TYPE=0, END_TYPE=1):
|
||||
res = g.is_dag()
|
||||
n_start, n_end = 0, 0
|
||||
for v in g.vs:
|
||||
if v['type'] == START_TYPE:
|
||||
n_start += 1
|
||||
elif v['type'] == END_TYPE:
|
||||
n_end += 1
|
||||
if v.indegree() == 0 and v['type'] != START_TYPE:
|
||||
return False
|
||||
if v.outdegree() == 0 and v['type'] != END_TYPE:
|
||||
return False
|
||||
return res and n_start == 1 and n_end == 1
|
||||
|
||||
|
||||
def is_valid_NAS201(g, START_TYPE=0, END_TYPE=1):
|
||||
# first need to be a valid DAG computation graph
|
||||
res = is_valid_DAG(g, START_TYPE, END_TYPE)
|
||||
# in addition, node i must connect to node i+1
|
||||
res = res and len(g.vs['type']) == 8
|
||||
res = res and not (START_TYPE in g.vs['type'][1:-1])
|
||||
res = res and not (END_TYPE in g.vs['type'][1:-1])
|
||||
return res
|
||||
|
||||
|
||||
def check_single_node_type(x):
|
||||
for x_elem in x:
|
||||
if int(np.sum(x_elem)) != 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_start_end_nodes(x, START_TYPE, END_TYPE):
|
||||
if x[0][START_TYPE] != 1:
|
||||
return False
|
||||
if x[-1][END_TYPE] != 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_interm_node_types(x, START_TYPE, END_TYPE):
|
||||
for x_elem in x[1:-1]:
|
||||
if x_elem[START_TYPE] == 1:
|
||||
return False
|
||||
if x_elem[END_TYPE] == 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
ERORR_NB201 = {
|
||||
'MULTIPLE_NODE_TYPES': 1,
|
||||
'No_START_END': 2,
|
||||
'INTERM_START_END': 3,
|
||||
'NO_ERROR': -1
|
||||
}
|
||||
|
||||
|
||||
def is_valid_NAS201_x(x, START_TYPE=0, END_TYPE=1):
|
||||
# first need to be a valid DAG computation graph
|
||||
assert len(x.shape) == 2
|
||||
|
||||
if not check_single_node_type(x):
|
||||
return False, ERORR_NB201['MULTIPLE_NODE_TYPES']
|
||||
|
||||
if not check_start_end_nodes(x, START_TYPE, END_TYPE):
|
||||
return False, ERORR_NB201['No_START_END']
|
||||
|
||||
if not check_interm_node_types(x, START_TYPE, END_TYPE):
|
||||
return False, ERORR_NB201['INTERM_START_END']
|
||||
|
||||
return True, ERORR_NB201['NO_ERROR']
|
||||
|
||||
|
||||
def compute_arch_metrics(arch_list,
|
||||
train_arch_str_list,
|
||||
train_ds,
|
||||
check_dataname='cifar10'):
|
||||
metrics = BasicArchMetrics(train_ds, train_arch_str_list)
|
||||
arch_metrics = metrics.evaluate(arch_list, check_dataname=check_dataname)
|
||||
all_arch_str = arch_metrics[-1]
|
||||
return arch_metrics, all_arch_str
|
||||
|
||||
def compute_arch_metrics_meta(arch_list,
|
||||
train_arch_str_list,
|
||||
train_ds,
|
||||
check_dataname='cifar10'):
|
||||
metrics = BasicArchMetricsMeta(train_ds, train_arch_str_list)
|
||||
arch_metrics = metrics.evaluate(arch_list, check_dataname=check_dataname)
|
||||
return arch_metrics
|
||||
77
NAS-Bench-201/analysis/arch_metrics.py
Normal file
77
NAS-Bench-201/analysis/arch_metrics.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from analysis.arch_functions import compute_arch_metrics, compute_arch_metrics_meta
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SamplingArchMetrics(nn.Module):
|
||||
def __init__(self,
|
||||
config,
|
||||
train_ds,
|
||||
exp_name,):
|
||||
|
||||
super().__init__()
|
||||
self.exp_name = exp_name
|
||||
self.train_ds = train_ds
|
||||
self.train_arch_str_list = train_ds.arch_str_list_
|
||||
|
||||
|
||||
def forward(self,
|
||||
arch_list: list,
|
||||
this_sample_dir,
|
||||
check_dataname='cifar10'):
|
||||
|
||||
arch_metrics, all_arch_str = compute_arch_metrics(arch_list=arch_list,
|
||||
train_arch_str_list=self.train_arch_str_list,
|
||||
train_ds=self.train_ds,
|
||||
check_dataname=check_dataname)
|
||||
|
||||
valid_unique_arch = arch_metrics[1] # arch_str
|
||||
valid_unique_arch_prop_dict = arch_metrics[2] # flops, params, latency
|
||||
textfile = open(f'{this_sample_dir}/valid_unique_archs.txt', "w")
|
||||
for i in range(len(valid_unique_arch)):
|
||||
textfile.write(f"Arch: {valid_unique_arch[i]} \n")
|
||||
textfile.write(f"Arch Index: {valid_unique_arch_prop_dict['arch_idx_list'][i]} \n")
|
||||
textfile.write(f"FLOPs: {valid_unique_arch_prop_dict['flops_list'][i]} \n")
|
||||
textfile.write(f"#Params: {valid_unique_arch_prop_dict['params_list'][i]} \n")
|
||||
textfile.write(f"Latency: {valid_unique_arch_prop_dict['latency_list'][i]} \n\n")
|
||||
textfile.writelines(valid_unique_arch)
|
||||
textfile.close()
|
||||
|
||||
return arch_metrics
|
||||
|
||||
|
||||
class SamplingArchMetricsMeta(nn.Module):
|
||||
def __init__(self,
|
||||
config,
|
||||
train_ds,
|
||||
exp_name):
|
||||
|
||||
super().__init__()
|
||||
self.exp_name = exp_name
|
||||
self.train_ds = train_ds
|
||||
self.search_space = config.data.name
|
||||
self.train_arch_str_list = [train_ds.arch_str_list[i] for i in train_ds.idx_lst['train']]
|
||||
|
||||
|
||||
def forward(self,
|
||||
arch_list: list,
|
||||
this_sample_dir,
|
||||
check_dataname='cifar10'):
|
||||
|
||||
arch_metrics = compute_arch_metrics_meta(arch_list=arch_list,
|
||||
train_arch_str_list=self.train_arch_str_list,
|
||||
train_ds=self.train_ds,
|
||||
check_dataname=check_dataname)
|
||||
|
||||
valid_unique_arch = arch_metrics[1] # arch_str
|
||||
valid_unique_arch_prop_dict = arch_metrics[2] # flops, params, latency
|
||||
textfile = open(f'{this_sample_dir}/valid_unique_archs.txt', "w")
|
||||
for i in range(len(valid_unique_arch)):
|
||||
textfile.write(f"Arch: {valid_unique_arch[i]} \n")
|
||||
textfile.write(f"Arch Index: {valid_unique_arch_prop_dict['arch_idx_list'][i]} \n")
|
||||
textfile.write(f"FLOPs: {valid_unique_arch_prop_dict['flops_list'][i]} \n")
|
||||
textfile.write(f"#Params: {valid_unique_arch_prop_dict['params_list'][i]} \n")
|
||||
textfile.write(f"Latency: {valid_unique_arch_prop_dict['latency_list'][i]} \n\n")
|
||||
textfile.writelines(valid_unique_arch)
|
||||
textfile.close()
|
||||
|
||||
return arch_metrics
|
||||
72
NAS-Bench-201/configs/eval_scorenet.py
Normal file
72
NAS-Bench-201/configs/eval_scorenet.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Evaluate trained score network"""
|
||||
|
||||
import ml_collections
|
||||
import torch
|
||||
|
||||
from all_path import SCORENET_CKPT_PATH
|
||||
|
||||
def get_config():
|
||||
config = ml_collections.ConfigDict()
|
||||
|
||||
# general
|
||||
config.folder_name = 'test'
|
||||
config.model_type = 'scorenet'
|
||||
config.task = 'eval_scorenet'
|
||||
config.exp_name = None
|
||||
config.seed = 42
|
||||
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
config.resume = False
|
||||
config.scorenet_ckpt_path = SCORENET_CKPT_PATH
|
||||
|
||||
# training
|
||||
config.training = training = ml_collections.ConfigDict()
|
||||
training.sde = 'vesde'
|
||||
training.continuous = True
|
||||
training.reduce_mean = True
|
||||
training.noised = True
|
||||
|
||||
# sampling
|
||||
config.sampling = sampling = ml_collections.ConfigDict()
|
||||
sampling.method = 'pc'
|
||||
sampling.predictor = 'euler_maruyama'
|
||||
sampling.corrector = 'langevin'
|
||||
sampling.n_steps_each = 1
|
||||
sampling.noise_removal = True
|
||||
sampling.probability_flow = False
|
||||
sampling.snr = 0.16
|
||||
|
||||
# evaluation
|
||||
config.eval = evaluate = ml_collections.ConfigDict()
|
||||
evaluate.batch_size = 256
|
||||
evaluate.enable_sampling = True
|
||||
evaluate.num_samples = 256
|
||||
|
||||
# data
|
||||
config.data = data = ml_collections.ConfigDict()
|
||||
data.centered = True
|
||||
data.dequantization = False
|
||||
|
||||
data.root = '../data/transfer_nag/nasbench201_info.pt'
|
||||
data.name = 'NASBench201'
|
||||
data.split_ratio = 1.0
|
||||
data.dataset_idx = 'random' # 'sorted' | 'random'
|
||||
data.max_node = 8
|
||||
data.n_vocab = 7 # number of operations
|
||||
data.START_TYPE = 0
|
||||
data.END_TYPE = 1
|
||||
data.num_graphs = 15625
|
||||
data.num_channels = 1
|
||||
data.label_list = ['test-acc']
|
||||
data.tg_dataset = 'cifar10'
|
||||
# aug_mask
|
||||
data.aug_mask_algo = 'floyd' # 'long_range' | 'floyd'
|
||||
|
||||
# model
|
||||
config.model = model = ml_collections.ConfigDict()
|
||||
model.num_scales = 1000
|
||||
model.beta_min = 0.1
|
||||
model.beta_max = 5.0
|
||||
model.sigma_min = 0.1
|
||||
model.sigma_max = 5.0
|
||||
|
||||
return config
|
||||
125
NAS-Bench-201/configs/tr_meta_surrogate.py
Normal file
125
NAS-Bench-201/configs/tr_meta_surrogate.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Training PGSN on Community Small Dataset with GraphGDP"""
|
||||
|
||||
import ml_collections
|
||||
import torch
|
||||
from all_path import SCORENET_CKPT_PATH
|
||||
from all_path import NASBENCH201_INFO
|
||||
|
||||
|
||||
def get_config():
|
||||
config = ml_collections.ConfigDict()
|
||||
|
||||
# config.search_space = None
|
||||
|
||||
# general
|
||||
config.folder_name = 'test'
|
||||
config.model_type = 'meta_surrogate'
|
||||
config.task = 'tr_meta_surrogate'
|
||||
config.exp_name = None
|
||||
config.seed = 42
|
||||
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
config.resume = False
|
||||
config.scorenet_ckpt_path = SCORENET_CKPT_PATH
|
||||
|
||||
# training
|
||||
config.training = training = ml_collections.ConfigDict()
|
||||
training.sde = 'vesde'
|
||||
training.continuous = True
|
||||
training.reduce_mean = True
|
||||
training.noised = True
|
||||
training.batch_size = 256
|
||||
training.eval_batch_size = 100
|
||||
training.n_iters = 10000
|
||||
training.snapshot_freq = 500
|
||||
training.log_freq = 100
|
||||
training.eval_freq = 100
|
||||
training.snapshot_sampling = True
|
||||
training.likelihood_weighting = False
|
||||
|
||||
# sampling
|
||||
config.sampling = sampling = ml_collections.ConfigDict()
|
||||
sampling.method = 'pc'
|
||||
sampling.predictor = 'euler_maruyama'
|
||||
sampling.corrector = 'langevin'
|
||||
sampling.n_steps_each = 1
|
||||
sampling.noise_removal = True
|
||||
sampling.probability_flow = False
|
||||
sampling.snr = 0.16
|
||||
|
||||
# for conditional sampling
|
||||
sampling.classifier_scale = 10000.0
|
||||
sampling.regress = True
|
||||
sampling.labels = 'max'
|
||||
sampling.weight_ratio = False
|
||||
sampling.weight_scheduling = False
|
||||
sampling.check_dataname = 'cifar10'
|
||||
|
||||
# evaluation
|
||||
config.eval = evaluate = ml_collections.ConfigDict()
|
||||
evaluate.batch_size = 512
|
||||
evaluate.enable_sampling = True
|
||||
evaluate.num_samples = 1024
|
||||
|
||||
# data
|
||||
config.data = data = ml_collections.ConfigDict()
|
||||
data.centered = True
|
||||
data.dequantization = False
|
||||
|
||||
data.root = NASBENCH201_INFO
|
||||
data.name = 'NASBench201'
|
||||
data.max_node = 8
|
||||
data.n_vocab = 7
|
||||
data.START_TYPE = 0
|
||||
data.END_TYPE = 1
|
||||
data.num_channels = 1
|
||||
data.label_list = ['meta-acc']
|
||||
# aug_mask
|
||||
data.aug_mask_algo = 'floyd' # 'long_range' | 'floyd'
|
||||
|
||||
# model
|
||||
config.model = model = ml_collections.ConfigDict()
|
||||
model.name = 'MetaNeuralPredictor'
|
||||
model.ema_rate = 0.9999
|
||||
model.normalization = 'GroupNorm'
|
||||
model.nonlinearity = 'swish'
|
||||
model.nf = 128
|
||||
model.num_gnn_layers = 4
|
||||
model.size_cond = False
|
||||
model.embedding_type = 'positional'
|
||||
model.rw_depth = 16
|
||||
model.graph_layer = 'PosTransLayer'
|
||||
model.edge_th = -1.
|
||||
model.heads = 8
|
||||
model.attn_clamp = False
|
||||
|
||||
# meta-predictor
|
||||
model.input_type = 'DA'
|
||||
model.hs = 32
|
||||
model.nz = 56
|
||||
model.num_sample = 20
|
||||
|
||||
model.num_scales = 1000
|
||||
model.beta_min = 0.1
|
||||
model.beta_max = 5.0
|
||||
model.sigma_min = 0.1
|
||||
model.sigma_max = 5.0
|
||||
model.dropout = 0.1
|
||||
|
||||
# graph encoder
|
||||
config.model.graph_encoder = graph_encoder = ml_collections.ConfigDict()
|
||||
graph_encoder.initial_hidden = 7
|
||||
graph_encoder.gcn_hidden = 144
|
||||
graph_encoder.gcn_layers = 4
|
||||
graph_encoder.linear_hidden = 128
|
||||
|
||||
# optimization
|
||||
config.optim = optim = ml_collections.ConfigDict()
|
||||
optim.weight_decay = 0
|
||||
optim.optimizer = 'Adam'
|
||||
optim.lr = 0.001
|
||||
optim.beta1 = 0.9
|
||||
optim.eps = 1e-8
|
||||
optim.warmup = 1000
|
||||
optim.grad_clip = 1.
|
||||
|
||||
return config
|
||||
113
NAS-Bench-201/configs/tr_scorenet.py
Normal file
113
NAS-Bench-201/configs/tr_scorenet.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Training Score Network"""
|
||||
|
||||
import ml_collections
|
||||
import torch
|
||||
|
||||
|
||||
def get_config():
|
||||
config = ml_collections.ConfigDict()
|
||||
|
||||
# general
|
||||
config.folder_name = 'test'
|
||||
config.model_type = 'scorenet'
|
||||
config.task = 'tr_scorenet'
|
||||
config.exp_name = None
|
||||
config.seed = 42
|
||||
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
config.resume = False
|
||||
config.resume_ckpt_path = ''
|
||||
|
||||
# training
|
||||
config.training = training = ml_collections.ConfigDict()
|
||||
training.sde = 'vesde'
|
||||
training.continuous = True
|
||||
training.reduce_mean = True
|
||||
|
||||
training.batch_size = 256
|
||||
training.eval_batch_size = 1000
|
||||
training.n_iters = 250000
|
||||
training.snapshot_freq = 10000
|
||||
training.log_freq = 200
|
||||
training.eval_freq = 10000
|
||||
training.snapshot_sampling = True
|
||||
training.likelihood_weighting = False
|
||||
|
||||
# sampling
|
||||
config.sampling = sampling = ml_collections.ConfigDict()
|
||||
sampling.method = 'pc'
|
||||
sampling.predictor = 'euler_maruyama'
|
||||
sampling.corrector = 'langevin'
|
||||
sampling.n_steps_each = 1
|
||||
sampling.noise_removal = True
|
||||
sampling.probability_flow = False
|
||||
sampling.snr = 0.16
|
||||
|
||||
# evaluation
|
||||
config.eval = evaluate = ml_collections.ConfigDict()
|
||||
evaluate.batch_size = 1024
|
||||
evaluate.enable_sampling = True
|
||||
evaluate.num_samples = 1024
|
||||
|
||||
# data
|
||||
config.data = data = ml_collections.ConfigDict()
|
||||
data.centered = True
|
||||
data.dequantization = False
|
||||
|
||||
data.root = '../data/transfer_nag/nasbench201_info.pt'
|
||||
data.name = 'NASBench201'
|
||||
data.split_ratio = 1.0
|
||||
data.dataset_idx = 'random' # 'sorted' | 'random'
|
||||
data.max_node = 8
|
||||
data.n_vocab = 7 # number of operations
|
||||
data.START_TYPE = 0
|
||||
data.END_TYPE = 1
|
||||
data.num_graphs = 15625
|
||||
data.num_channels = 1
|
||||
data.label_list = None
|
||||
data.tg_dataset = None
|
||||
# aug_mask
|
||||
data.aug_mask_algo = 'floyd' # 'long_range' | 'floyd'
|
||||
|
||||
# model
|
||||
config.model = model = ml_collections.ConfigDict()
|
||||
model.name = 'CATE'
|
||||
model.ema_rate = 0.9999
|
||||
model.normalization = 'GroupNorm'
|
||||
model.nonlinearity = 'swish'
|
||||
model.nf = 128
|
||||
model.num_gnn_layers = 4
|
||||
model.size_cond = False
|
||||
model.embedding_type = 'positional'
|
||||
model.rw_depth = 16
|
||||
model.graph_layer = 'PosTransLayer'
|
||||
model.edge_th = -1.
|
||||
model.heads = 8
|
||||
model.attn_clamp = False
|
||||
# for pos emb
|
||||
model.pos_enc_type = 2
|
||||
|
||||
model.num_scales = 1000
|
||||
model.sigma_min = 0.1
|
||||
model.sigma_max = 5.0
|
||||
model.dropout = 0.1
|
||||
|
||||
# graph encoder
|
||||
config.model.graph_encoder = graph_encoder = ml_collections.ConfigDict()
|
||||
graph_encoder.n_layers = 12
|
||||
graph_encoder.d_model = 64
|
||||
graph_encoder.n_head = 8
|
||||
graph_encoder.d_ff = 128
|
||||
graph_encoder.dropout = 0.1
|
||||
graph_encoder.n_vocab = 7
|
||||
|
||||
# optimization
|
||||
config.optim = optim = ml_collections.ConfigDict()
|
||||
optim.weight_decay = 0
|
||||
optim.optimizer = 'Adam'
|
||||
optim.lr = 2e-5
|
||||
optim.beta1 = 0.9
|
||||
optim.eps = 1e-8
|
||||
optim.warmup = 1000
|
||||
optim.grad_clip = 1.
|
||||
|
||||
return config
|
||||
469
NAS-Bench-201/datasets_nas.py
Normal file
469
NAS-Bench-201/datasets_nas.py
Normal file
@@ -0,0 +1,469 @@
|
||||
from __future__ import print_function
|
||||
import torch
|
||||
import os
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from analysis.arch_functions import decode_x_to_NAS_BENCH_201_matrix, decode_x_to_NAS_BENCH_201_string
|
||||
from all_path import *
|
||||
|
||||
|
||||
def get_data_scaler(config):
|
||||
"""Data normalizer. Assume data are always in [0, 1]."""
|
||||
|
||||
if config.data.centered:
|
||||
# Rescale to [-1, 1]
|
||||
return lambda x: x * 2. - 1.
|
||||
else:
|
||||
return lambda x: x
|
||||
|
||||
|
||||
def get_data_inverse_scaler(config):
|
||||
"""Inverse data normalizer."""
|
||||
|
||||
if config.data.centered:
|
||||
# Rescale [-1, 1] to [0, 1]
|
||||
return lambda x: (x + 1.) / 2.
|
||||
else:
|
||||
return lambda x: x
|
||||
|
||||
|
||||
def is_triu(mat):
|
||||
is_triu_ = np.allclose(mat, np.triu(mat))
|
||||
return is_triu_
|
||||
|
||||
|
||||
def get_dataset(config):
|
||||
train_dataset = NASBench201Dataset(
|
||||
data_path=NASBENCH201_INFO,
|
||||
mode='train')
|
||||
|
||||
eval_dataset = NASBench201Dataset(
|
||||
data_path=NASBENCH201_INFO,
|
||||
mode='eval')
|
||||
|
||||
test_dataset = NASBench201Dataset(
|
||||
data_path=NASBENCH201_INFO,
|
||||
mode='test')
|
||||
|
||||
return train_dataset, eval_dataset, test_dataset
|
||||
|
||||
|
||||
def get_dataloader(config, train_dataset, eval_dataset, test_dataset):
|
||||
train_loader = DataLoader(dataset=train_dataset,
|
||||
batch_size=config.training.batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=None)
|
||||
|
||||
eval_loader = DataLoader(dataset=eval_dataset,
|
||||
batch_size=config.training.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=None)
|
||||
|
||||
test_loader = DataLoader(dataset=test_dataset,
|
||||
batch_size=config.training.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=None)
|
||||
|
||||
return train_loader, eval_loader, test_loader
|
||||
|
||||
|
||||
class NASBench201Dataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_path,
|
||||
split_ratio=1.0,
|
||||
mode='train',
|
||||
label_list=None,
|
||||
tg_dataset=None):
|
||||
|
||||
self.ops_decoder = ['input', 'output', 'none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
|
||||
# ---------- entire dataset ---------- #
|
||||
self.data = torch.load(data_path)
|
||||
# ---------- igraph ---------- #
|
||||
self.igraph_list = self.data['g']
|
||||
# ---------- x ---------- #
|
||||
self.x_list = self.data['x']
|
||||
# ---------- adj ---------- #
|
||||
adj = self.get_adj()
|
||||
self.adj_list = [adj] * len(self.igraph_list)
|
||||
# ---------- matrix ---------- #
|
||||
self.matrix_list = self.data['matrix']
|
||||
# ---------- arch_str ---------- #
|
||||
self.arch_str_list = self.data['str']
|
||||
# ---------- labels ---------- #
|
||||
self.label_list = label_list
|
||||
if self.label_list is not None:
|
||||
self.val_acc_list = self.data['val-acc'][tg_dataset]
|
||||
self.test_acc_list = self.data['test-acc'][tg_dataset]
|
||||
self.flops_list = self.data['flops'][tg_dataset]
|
||||
self.params_list = self.data['params'][tg_dataset]
|
||||
self.latency_list = self.data['latency'][tg_dataset]
|
||||
|
||||
# ----------- split dataset ---------- #
|
||||
self.ds_idx = list(torch.load(DATA_PATH + '/ridx.pt'))
|
||||
self.split_ratio = split_ratio
|
||||
num_train = int(len(self.x_list) * self.split_ratio)
|
||||
num_test = len(self.x_list) - num_train
|
||||
|
||||
# ----------- compute mean and std w/ training dataset ---------- #
|
||||
if self.label_list is not None:
|
||||
self.train_idx_list = self.ds_idx[:num_train]
|
||||
print('>>> Computing mean and std of the training set...')
|
||||
LABEL_TO_MEAN_STD = defaultdict(dict)
|
||||
assert type(self.label_list) == list, f"self.label_list is {type(self.label_list)}"
|
||||
for label in self.label_list:
|
||||
if label == 'val-acc':
|
||||
self.val_acc_list_tr = [self.val_acc_list[i] for i in self.train_idx_list]
|
||||
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.val_acc_list_tr))
|
||||
elif label == 'test-acc':
|
||||
self.test_acc_list_tr = [self.test_acc_list[i] for i in self.train_idx_list]
|
||||
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.test_acc_list_tr))
|
||||
elif label == 'flops':
|
||||
self.flops_list_tr = [self.flops_list[i] for i in self.train_idx_list]
|
||||
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.flops_list_tr))
|
||||
elif label == 'params':
|
||||
self.params_list_tr = [self.params_list[i] for i in self.train_idx_list]
|
||||
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.params_list_tr))
|
||||
elif label == 'latency':
|
||||
self.latency_list_tr = [self.latency_list[i] for i in self.train_idx_list]
|
||||
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.latency_list_tr))
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self.mode = mode
|
||||
if self.mode in ['train']:
|
||||
self.idx_list = self.ds_idx[:num_train]
|
||||
elif self.mode in ['eval']:
|
||||
if num_test == 0:
|
||||
self.idx_list = self.ds_idx[:100]
|
||||
else:
|
||||
self.idx_list = self.ds_idx[:num_test]
|
||||
elif self.mode in ['test']:
|
||||
if num_test == 0:
|
||||
self.idx_list = self.ds_idx[15000:]
|
||||
else:
|
||||
self.idx_list = self.ds_idx[num_train:]
|
||||
|
||||
self.igraph_list_ = [self.igraph_list[i] for i in self.idx_list]
|
||||
self.x_list_ = [self.x_list[i] for i in self.idx_list]
|
||||
self.adj_list_ = [self.adj_list[i] for i in self.idx_list]
|
||||
self.matrix_list_ = [self.matrix_list[i] for i in self.idx_list]
|
||||
self.arch_str_list_ = [self.arch_str_list[i] for i in self.idx_list]
|
||||
|
||||
if self.label_list is not None:
|
||||
assert type(self.label_list) == list
|
||||
for label in self.label_list:
|
||||
if label == 'val-acc':
|
||||
self.val_acc_list_ = [self.val_acc_list[i] for i in self.idx_list]
|
||||
self.val_acc_list_ = self.normalize(self.val_acc_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
|
||||
elif label == 'test-acc':
|
||||
self.test_acc_list_ = [self.test_acc_list[i] for i in self.idx_list]
|
||||
self.test_acc_list_ = self.normalize(self.test_acc_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
|
||||
elif label == 'flops':
|
||||
self.flops_list_ = [self.flops_list[i] for i in self.idx_list]
|
||||
self.flops_list_ = self.normalize(self.flops_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
|
||||
elif label == 'params':
|
||||
self.params_list_ = [self.params_list[i] for i in self.idx_list]
|
||||
self.params_list_ = self.normalize(self.params_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
|
||||
elif label == 'latency':
|
||||
self.latency_list_ = [self.latency_list[i] for i in self.idx_list]
|
||||
self.latency_list_ = self.normalize(self.latency_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def normalize(self, original, mean, std):
|
||||
return [(i-mean)/std for i in original]
|
||||
|
||||
|
||||
# def get_not_connect_prev_adj(self):
|
||||
def get_adj(self):
|
||||
adj = np.asarray(
|
||||
[[0, 1, 1, 1, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 1, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0]]
|
||||
)
|
||||
adj = torch.tensor(adj, dtype=torch.float32, device=torch.device('cpu'))
|
||||
return adj
|
||||
|
||||
|
||||
@property
|
||||
def adj(self):
|
||||
return self.adj_list_[0]
|
||||
|
||||
|
||||
def mask(self, algo='floyd'):
|
||||
from utils import aug_mask
|
||||
return aug_mask(self.adj, algo=algo)[0]
|
||||
|
||||
|
||||
def get_unnoramlized_entire_data(self, label, tg_dataset):
|
||||
entire_val_acc_list = self.data['val-acc'][tg_dataset]
|
||||
entire_test_acc_list = self.data['test-acc'][tg_dataset]
|
||||
entire_flops_list = self.data['flops'][tg_dataset]
|
||||
entire_params_list = self.data['params'][tg_dataset]
|
||||
entire_latency_list = self.data['latency'][tg_dataset]
|
||||
|
||||
if label == 'val-acc':
|
||||
return entire_val_acc_list
|
||||
elif label == 'test-acc':
|
||||
return entire_test_acc_list
|
||||
elif label == 'flops':
|
||||
return entire_flops_list
|
||||
elif label == 'params':
|
||||
return entire_params_list
|
||||
elif label == 'latency':
|
||||
return entire_latency_list
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def get_unnoramlized_data(self, label, tg_dataset):
|
||||
entire_val_acc_list = self.data['val-acc'][tg_dataset]
|
||||
entire_test_acc_list = self.data['test-acc'][tg_dataset]
|
||||
entire_flops_list = self.data['flops'][tg_dataset]
|
||||
entire_params_list = self.data['params'][tg_dataset]
|
||||
entire_latency_list = self.data['latency'][tg_dataset]
|
||||
|
||||
if label == 'val-acc':
|
||||
return [entire_val_acc_list[i] for i in self.idx_list]
|
||||
elif label == 'test-acc':
|
||||
return [entire_test_acc_list[i] for i in self.idx_list]
|
||||
elif label == 'flops':
|
||||
return [entire_flops_list[i] for i in self.idx_list]
|
||||
elif label == 'params':
|
||||
return [entire_params_list[i] for i in self.idx_list]
|
||||
elif label == 'latency':
|
||||
return [entire_latency_list[i] for i in self.idx_list]
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.x_list_)
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
label_dict = {}
|
||||
if self.label_list is not None:
|
||||
assert type(self.label_list) == list
|
||||
for label in self.label_list:
|
||||
if label == 'val-acc':
|
||||
label_dict[f"{label}"] = self.val_acc_list_[index]
|
||||
elif label == 'test-acc':
|
||||
label_dict[f"{label}"] = self.test_acc_list_[index]
|
||||
elif label == 'flops':
|
||||
label_dict[f"{label}"] = self.flops_list_[index]
|
||||
elif label == 'params':
|
||||
label_dict[f"{label}"] = self.params_list_[index]
|
||||
elif label == 'latency':
|
||||
label_dict[f"{label}"] = self.latency_list_[index]
|
||||
else:
|
||||
raise ValueError
|
||||
return self.x_list_[index], self.adj_list_[index], label_dict
|
||||
|
||||
|
||||
# ---------- Meta-Dataset ---------- #
|
||||
def get_meta_dataset(config):
|
||||
train_dataset = MetaTrainDatabase(
|
||||
data_path=DATA_PATH,
|
||||
num_sample=config.model.num_sample,
|
||||
label_list=config.data.label_list,
|
||||
mode='train')
|
||||
|
||||
eval_dataset = MetaTrainDatabase(
|
||||
data_path=DATA_PATH,
|
||||
num_sample=config.model.num_sample,
|
||||
label_list=config.data.label_list,
|
||||
mode='eval')
|
||||
|
||||
test_dataset = None
|
||||
|
||||
return train_dataset, eval_dataset, test_dataset
|
||||
|
||||
|
||||
def get_meta_dataloader(config ,train_dataset, eval_dataset, test_dataset):
|
||||
train_loader = DataLoader(dataset=train_dataset,
|
||||
batch_size=config.training.batch_size,
|
||||
shuffle=True)
|
||||
|
||||
eval_loader = DataLoader(dataset=eval_dataset,
|
||||
batch_size=config.training.batch_size,
|
||||
shuffle=False)
|
||||
|
||||
test_loader = None
|
||||
|
||||
return train_loader, eval_loader, test_loader
|
||||
|
||||
|
||||
class MetaTrainDatabase(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_path,
|
||||
num_sample,
|
||||
label_list,
|
||||
mode='train'):
|
||||
|
||||
self.ops_decoder = ['input', 'output', 'none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
|
||||
self.mode = mode
|
||||
self.acc_norm = True
|
||||
self.num_sample = num_sample
|
||||
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
|
||||
|
||||
mtr_data_path = os.path.join(data_path, 'meta_train_tasks_predictor.pt')
|
||||
idx_path = os.path.join(data_path, 'meta_train_tasks_predictor_idx.pt')
|
||||
data = torch.load(mtr_data_path)
|
||||
|
||||
self.acc_list = data['acc']
|
||||
self.task = data['task']
|
||||
|
||||
# ---------- igraph ---------- #
|
||||
self.igraph_list = data['g']
|
||||
# ---------- x ---------- #
|
||||
self.x_list = data['x']
|
||||
# ---------- adj ---------- #
|
||||
adj = self.get_adj()
|
||||
self.adj_list = [adj] * len(self.igraph_list)
|
||||
# ---------- matrix ----------- #
|
||||
if 'matrix' in data:
|
||||
self.matrix_list = data['matrix']
|
||||
else:
|
||||
self.matrix_list = [decode_x_to_NAS_BENCH_201_matrix(i) for i in self.x_list]
|
||||
# ---------- arch_str ---------- #
|
||||
if 'str' in data:
|
||||
self.arch_str_list = data['str']
|
||||
else:
|
||||
self.arch_str_list = [decode_x_to_NAS_BENCH_201_string(i, self.ops_decoder) for i in self.x_list]
|
||||
# ---------- label ---------- #
|
||||
self.label_list = label_list
|
||||
if self.label_list is not None:
|
||||
self.flops_list = torch.tensor(data['flops'])
|
||||
self.params_list = torch.tensor(data['params'])
|
||||
self.latency_list = torch.tensor(data['latency'])
|
||||
|
||||
random_idx_lst = torch.load(idx_path)
|
||||
self.idx_lst = {}
|
||||
self.idx_lst['eval'] = random_idx_lst[:400]
|
||||
self.idx_lst['train'] = random_idx_lst[400:]
|
||||
self.acc_list = torch.tensor(self.acc_list)
|
||||
self.mean = torch.mean(self.acc_list[self.idx_lst['train']]).item()
|
||||
self.std = torch.std(self.acc_list[self.idx_lst['train']]).item()
|
||||
self.task_lst = torch.load(os.path.join(data_path, 'meta_train_task_lst.pt'))
|
||||
|
||||
|
||||
def get_adj(self):
|
||||
adj = np.asarray(
|
||||
[[0, 1, 1, 1, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 1, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0]]
|
||||
)
|
||||
adj = torch.tensor(adj, dtype=torch.float32, device=torch.device('cpu'))
|
||||
return adj
|
||||
|
||||
|
||||
@property
|
||||
def adj(self):
|
||||
return self.adj_list[0]
|
||||
|
||||
|
||||
def mask(self, algo='floyd'):
|
||||
from utils import aug_mask
|
||||
return aug_mask(self.adj, algo=algo)[0]
|
||||
|
||||
|
||||
def set_mode(self, mode):
|
||||
self.mode = mode
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.idx_lst[self.mode])
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = []
|
||||
ridx = self.idx_lst[self.mode]
|
||||
tidx = self.task[ridx[index]]
|
||||
classes = self.task_lst[tidx]
|
||||
|
||||
# ---------- igraph -----------
|
||||
graph = self.igraph_list[ridx[index]]
|
||||
# ---------- x -----------
|
||||
x = self.x_list[ridx[index]]
|
||||
# ---------- adj ----------
|
||||
adj = self.adj_list[ridx[index]]
|
||||
|
||||
acc = self.acc_list[ridx[index]]
|
||||
for cls in classes:
|
||||
cx = self.x[cls-1][0]
|
||||
ridx = torch.randperm(len(cx))
|
||||
data.append(cx[ridx[:self.num_sample]])
|
||||
task = torch.cat(data)
|
||||
if self.acc_norm:
|
||||
acc = ((acc- self.mean) / self.std) / 100.0
|
||||
else:
|
||||
acc = acc / 100.0
|
||||
|
||||
label_dict = {}
|
||||
if self.label_list is not None:
|
||||
assert type(self.label_list) == list
|
||||
for label in self.label_list:
|
||||
if label == 'meta-acc':
|
||||
label_dict[f"{label}"] = acc
|
||||
elif label == 'flops':
|
||||
label_dict[f"{label}"] = self.flops_list[ridx[index]]
|
||||
elif label == 'params':
|
||||
label_dict[f"{label}"] = self.params_list[ridx[index]]
|
||||
elif label == 'latency':
|
||||
label_dict[f"{label}"] = self.latency_list[ridx[index]]
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
return x, adj, label_dict, task
|
||||
|
||||
|
||||
class MetaTestDataset(Dataset):
|
||||
def __init__(self, data_path, data_name, num_sample, num_class=None):
|
||||
self.num_sample = num_sample
|
||||
self.data_name = data_name
|
||||
|
||||
num_class_dict = {
|
||||
'cifar100': 100,
|
||||
'cifar10': 10,
|
||||
'aircraft': 30,
|
||||
'pets': 37
|
||||
}
|
||||
|
||||
if num_class is not None:
|
||||
self.num_class = num_class
|
||||
else:
|
||||
self.num_class = num_class_dict[data_name]
|
||||
|
||||
self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt'))
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return 1000000
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = []
|
||||
classes = list(range(self.num_class))
|
||||
for cls in classes:
|
||||
cx = self.x[cls][0]
|
||||
ridx = torch.randperm(len(cx))
|
||||
data.append(cx[ridx[:self.num_sample]])
|
||||
x = torch.cat(data)
|
||||
return x
|
||||
137
NAS-Bench-201/logger.py
Normal file
137
NAS-Bench-201/logger.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import os
|
||||
import wandb
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(
|
||||
self,
|
||||
log_dir=None,
|
||||
write_textfile=True
|
||||
):
|
||||
|
||||
self.log_dir = log_dir
|
||||
self.write_textfile = write_textfile
|
||||
|
||||
self.logs_for_save = {}
|
||||
self.logs = {}
|
||||
|
||||
if self.write_textfile:
|
||||
self.f = open(os.path.join(log_dir, 'logs.txt'), 'w')
|
||||
|
||||
|
||||
def write_str(self, log_str):
|
||||
self.f.write(log_str+'\n')
|
||||
self.f.flush()
|
||||
|
||||
|
||||
def update_config(self, v, is_args=False):
|
||||
if is_args:
|
||||
self.logs_for_save.update({'args': v})
|
||||
else:
|
||||
self.logs_for_save.update(v)
|
||||
|
||||
|
||||
def write_log(self, element, step, return_log_dict=False):
|
||||
log_str = f"{step} | "
|
||||
log_dict = {}
|
||||
for head, keys in element.items():
|
||||
for k in keys:
|
||||
if k in self.logs:
|
||||
v = self.logs[k].avg
|
||||
if not k in self.logs_for_save:
|
||||
self.logs_for_save[k] = []
|
||||
self.logs_for_save[k].append(v)
|
||||
log_str += f'{k} {v}| '
|
||||
log_dict[f'{head}/{k}'] = v
|
||||
|
||||
if self.write_textfile:
|
||||
self.f.write(log_str+'\n')
|
||||
self.f.flush()
|
||||
|
||||
if return_log_dict:
|
||||
return log_dict
|
||||
|
||||
|
||||
def save_log(self, name=None):
|
||||
name = 'logs.pt' if name is None else name
|
||||
torch.save(self.logs_for_save, os.path.join(self.log_dir, name))
|
||||
|
||||
|
||||
def update(self, key, v, n=1):
|
||||
if not key in self.logs:
|
||||
self.logs[key] = AverageMeter()
|
||||
self.logs[key].update(v, n)
|
||||
|
||||
|
||||
def reset(self, keys=None, except_keys=[]):
|
||||
if keys is not None:
|
||||
if isinstance(keys, list):
|
||||
for key in keys:
|
||||
self.logs[key] = AverageMeter()
|
||||
else:
|
||||
self.logs[keys] = AverageMeter()
|
||||
else:
|
||||
for key in self.logs.keys():
|
||||
if not key in except_keys:
|
||||
self.logs[key] = AverageMeter()
|
||||
|
||||
|
||||
def avg(self, keys=None, except_keys=[]):
|
||||
if keys is not None:
|
||||
if isinstance(keys, list):
|
||||
return {key: self.logs[key].avg for key in keys if key in self.logs.keys()}
|
||||
else:
|
||||
return self.logs[keys].avg
|
||||
else:
|
||||
avg_dict = {}
|
||||
for key in self.logs.keys():
|
||||
if not key in except_keys:
|
||||
avg_dict[key] = self.logs[key].avg
|
||||
return avg_dict
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""
|
||||
Computes and stores the average and current value
|
||||
Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def get_metrics(g_embeds, x_embeds, logit_scale, prefix='train'):
|
||||
metrics = {}
|
||||
logits_per_g = (logit_scale * g_embeds @ x_embeds.t()).detach().cpu()
|
||||
logits_per_x = logits_per_g.t().detach().cpu()
|
||||
|
||||
logits = {"g_to_x": logits_per_g, "x_to_g": logits_per_x}
|
||||
ground_truth = torch.arange(len(x_embeds)).view(-1, 1)
|
||||
|
||||
for name, logit in logits.items():
|
||||
ranking = torch.argsort(logit, descending=True)
|
||||
preds = torch.where(ranking == ground_truth)[1]
|
||||
preds = preds.detach().cpu().numpy()
|
||||
metrics[f"{prefix}_{name}_mean_rank"] = preds.mean() + 1
|
||||
metrics[f"{prefix}_{name}_median_rank"] = np.floor(np.median(preds)) + 1
|
||||
for k in [1, 5, 10]:
|
||||
metrics[f"{prefix}_{name}_R@{k}"] = np.mean(preds < k)
|
||||
|
||||
return metrics
|
||||
369
NAS-Bench-201/losses.py
Normal file
369
NAS-Bench-201/losses.py
Normal file
@@ -0,0 +1,369 @@
|
||||
"""All functions related to loss computation and optimization."""
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
from models import utils as mutils
|
||||
from sde_lib import VPSDE, VESDE
|
||||
|
||||
|
||||
def get_optimizer(config, params):
|
||||
"""Return a flax optimizer object based on `config`."""
|
||||
if config.optim.optimizer == 'Adam':
|
||||
optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,
|
||||
weight_decay=config.optim.weight_decay)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Optimizer {config.optim.optimizer} not supported yet!'
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
||||
def optimization_manager(config):
|
||||
"""Return an optimize_fn based on `config`."""
|
||||
|
||||
def optimize_fn(optimizer, params, step, lr=config.optim.lr,
|
||||
warmup=config.optim.warmup,
|
||||
grad_clip=config.optim.grad_clip):
|
||||
"""Optimize with warmup and gradient clipping (disabled if negative)."""
|
||||
if warmup > 0:
|
||||
for g in optimizer.param_groups:
|
||||
g['lr'] = lr * np.minimum(step / warmup, 1.0)
|
||||
if grad_clip >= 0:
|
||||
torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
return optimize_fn
|
||||
|
||||
|
||||
def get_sde_loss_fn(sde, train, reduce_mean=True, continuous=True, likelihood_weighting=True, eps=1e-5):
|
||||
"""Create a loss function for training with arbitrary SDEs.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
train: `True` for training loss and `False` for evaluation loss.
|
||||
reduce_mean: If `True`, average the loss across data dimensions. Otherwise, sum the loss across data dimensions.
|
||||
continuous: `True` indicates that the model is defined to take continuous time steps.
|
||||
Otherwise, it requires ad-hoc interpolation to take continuous time steps.
|
||||
likelihood_weighting: If `True`, weight the mixture of score matching losses according
|
||||
to https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended in Score SDE paper.
|
||||
eps: A `float` number. The smallest time step to sample from.
|
||||
|
||||
Returns:
|
||||
A loss function.
|
||||
"""
|
||||
|
||||
# reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)
|
||||
|
||||
def loss_fn(model, batch):
|
||||
"""Compute the loss function.
|
||||
|
||||
Args:
|
||||
model: A score model.
|
||||
batch: A mini-batch of training data, including adjacency matrices and mask.
|
||||
|
||||
Returns:
|
||||
loss: A scalar that represents the average loss value across the mini-batch.
|
||||
"""
|
||||
adj, mask = batch
|
||||
score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)
|
||||
t = torch.rand(adj.shape[0], device=adj.device) * (sde.T - eps) + eps
|
||||
|
||||
z = torch.randn_like(adj) # [B, C, N, N]
|
||||
z = torch.tril(z, -1)
|
||||
z = z + z.transpose(2, 3)
|
||||
|
||||
mean, std = sde.marginal_prob(adj, t)
|
||||
mean = torch.tril(mean, -1)
|
||||
mean = mean + mean.transpose(2, 3)
|
||||
|
||||
perturbed_data = mean + std[:, None, None, None] * z
|
||||
score = score_fn(perturbed_data, t, mask=mask)
|
||||
|
||||
mask = torch.tril(mask, -1)
|
||||
mask = mask + mask.transpose(2, 3)
|
||||
mask = mask.reshape(mask.shape[0], -1) # low triangular part of adj matrices
|
||||
|
||||
if not likelihood_weighting:
|
||||
losses = torch.square(score * std[:, None, None, None] + z)
|
||||
losses = losses.reshape(losses.shape[0], -1)
|
||||
if reduce_mean:
|
||||
losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
|
||||
else:
|
||||
losses = 0.5 * torch.sum(losses * mask, dim=-1)
|
||||
loss = losses.mean()
|
||||
else:
|
||||
g2 = sde.sde(torch.zeros_like(adj), t)[1] ** 2
|
||||
losses = torch.square(score + z / std[:, None, None, None])
|
||||
losses = losses.reshape(losses.shape[0], -1)
|
||||
if reduce_mean:
|
||||
losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
|
||||
else:
|
||||
losses = 0.5 * torch.sum(losses * mask, dim=-1)
|
||||
loss = (losses * g2).mean()
|
||||
|
||||
return loss
|
||||
|
||||
return loss_fn
|
||||
|
||||
|
||||
def get_sde_loss_fn_nas(sde, train, reduce_mean=True, continuous=True, likelihood_weighting=True, eps=1e-5):
|
||||
"""Create a loss function for training with arbitrary SDEs.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
train: `True` for training loss and `False` for evaluation loss.
|
||||
reduce_mean: If `True`, average the loss across data dimensions. Otherwise, sum the loss across data dimensions.
|
||||
continuous: `True` indicates that the model is defined to take continuous time steps.
|
||||
Otherwise, it requires ad-hoc interpolation to take continuous time steps.
|
||||
likelihood_weighting: If `True`, weight the mixture of score matching losses according
|
||||
to https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended in Score SDE paper.
|
||||
eps: A `float` number. The smallest time step to sample from.
|
||||
|
||||
Returns:
|
||||
A loss function.
|
||||
"""
|
||||
|
||||
def loss_fn(model, batch):
|
||||
"""Compute the loss function.
|
||||
|
||||
Args:
|
||||
model: A score model.
|
||||
batch: A mini-batch of training data, including adjacency matrices and mask.
|
||||
|
||||
Returns:
|
||||
loss: A scalar that represents the average loss value across the mini-batch.
|
||||
"""
|
||||
x, adj, mask = batch
|
||||
score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)
|
||||
t = torch.rand(x.shape[0], device=adj.device) * (sde.T - eps) + eps
|
||||
|
||||
z = torch.randn_like(x) # [B, C, N, N]
|
||||
|
||||
mean, std = sde.marginal_prob(x, t)
|
||||
|
||||
perturbed_data = mean + std[:, None, None] * z
|
||||
score = score_fn(perturbed_data, t, mask)
|
||||
|
||||
if not likelihood_weighting:
|
||||
losses = torch.square(score * std[:, None, None] + z)
|
||||
losses = losses.reshape(losses.shape[0], -1)
|
||||
if reduce_mean:
|
||||
losses = torch.mean(losses, dim=-1)
|
||||
else:
|
||||
losses = 0.5 * torch.sum(losses, dim=-1)
|
||||
loss = losses.mean()
|
||||
else:
|
||||
g2 = sde.sde(torch.zeros_like(x), t)[1] ** 2
|
||||
losses = torch.square(score + z / std[:, None, None])
|
||||
losses = losses.reshape(losses.shape[0], -1)
|
||||
if reduce_mean:
|
||||
losses = torch.mean(losses, dim=-1)
|
||||
else:
|
||||
losses = 0.5 * torch.sum(losses, dim=-1)
|
||||
loss = (losses * g2).mean()
|
||||
|
||||
return loss
|
||||
|
||||
return loss_fn
|
||||
|
||||
|
||||
def get_step_fn(sde,
|
||||
train,
|
||||
optimize_fn=None,
|
||||
reduce_mean=False,
|
||||
continuous=True,
|
||||
likelihood_weighting=False,
|
||||
data='NASBench201'):
|
||||
"""Create a one-step training/evaluation function.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
Tuple (`sde_lib.SDE`, `sde_lib.SDE`) that represents the forward node SDE and edge SDE.
|
||||
optimize_fn: An optimization function.
|
||||
reduce_mean: If `True`, average the loss across data dimensions.
|
||||
Otherwise, sum the loss across data dimensions.
|
||||
continuous: `True` indicates that the model is defined to take continuous time steps.
|
||||
likelihood_weighting: If `True`, weight the mixture of score matching losses according to
|
||||
https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended by score-sde.
|
||||
|
||||
Returns:
|
||||
A one-step function for training or evaluation.
|
||||
"""
|
||||
|
||||
if continuous:
|
||||
if data in ['NASBench201', 'ofa']:
|
||||
loss_fn = get_sde_loss_fn_nas(sde, train, reduce_mean=reduce_mean,
|
||||
continuous=True, likelihood_weighting=likelihood_weighting)
|
||||
else:
|
||||
raise NotImplementedError(f"Data {data} (search space) is not supported yet.")
|
||||
else:
|
||||
raise NotImplementedError(f"Discrete training for {sde.__class__.__name__} is not implemented.")
|
||||
|
||||
|
||||
def step_fn(state, batch):
|
||||
"""Running one step of training or evaluation.
|
||||
|
||||
For jax version: This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and
|
||||
jit-compiled together for faster execution.
|
||||
|
||||
Args:
|
||||
state: A dictionary of training information, containing the score model, optimizer,
|
||||
EMA status, and number of optimization steps.
|
||||
batch: A mini-batch of training/evaluation data, including min-batch adjacency matrices and mask.
|
||||
|
||||
Returns:
|
||||
loss: The average loss value of this state.
|
||||
"""
|
||||
model = state['model']
|
||||
if train:
|
||||
optimizer = state['optimizer']
|
||||
optimizer.zero_grad()
|
||||
loss = loss_fn(model, batch)
|
||||
loss.backward()
|
||||
optimize_fn(optimizer, model.parameters(), step=state['step'])
|
||||
state['step'] += 1
|
||||
state['ema'].update(model.parameters())
|
||||
else:
|
||||
with torch.no_grad():
|
||||
ema = state['ema']
|
||||
ema.store(model.parameters())
|
||||
ema.copy_to(model.parameters())
|
||||
loss = loss_fn(model, batch)
|
||||
ema.restore(model.parameters())
|
||||
|
||||
return loss
|
||||
|
||||
return step_fn
|
||||
|
||||
|
||||
# ------------------- predictor -------------------
|
||||
def get_meta_predictor_loss_fn_nas(sde,
|
||||
train,
|
||||
reduce_mean=True,
|
||||
continuous=True,
|
||||
likelihood_weighting=True,
|
||||
eps=1e-5,
|
||||
label_list=None,
|
||||
noised=True):
|
||||
"""Create a loss function for training with arbitrary SDEs.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
train: `True` for training loss and `False` for evaluation loss.
|
||||
reduce_mean: If `True`, average the loss across data dimensions. Otherwise, sum the loss across data dimensions.
|
||||
continuous: `True` indicates that the model is defined to take continuous time steps.
|
||||
Otherwise, it requires ad-hoc interpolation to take continuous time steps.
|
||||
likelihood_weighting: If `True`, weight the mixture of score matching losses according
|
||||
to https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended in Score SDE paper.
|
||||
eps: A `float` number. The smallest time step to sample from.
|
||||
|
||||
Returns:
|
||||
A loss function.
|
||||
"""
|
||||
|
||||
def loss_fn(model, batch):
|
||||
"""Compute the loss function.
|
||||
|
||||
Args:
|
||||
model: A score model.
|
||||
batch: A mini-batch of training data, including adjacency matrices and mask.
|
||||
|
||||
Returns:
|
||||
loss: A scalar that represents the average loss value across the mini-batch.
|
||||
"""
|
||||
x, adj, mask, extra, task = batch
|
||||
predictor_fn = mutils.get_predictor_fn(sde, model, train=train, continuous=continuous)
|
||||
if noised:
|
||||
t = torch.rand(x.shape[0], device=adj.device) * (sde.T - eps) + eps
|
||||
z = torch.randn_like(x) # [B, C, N, N]
|
||||
|
||||
mean, std = sde.marginal_prob(x, t)
|
||||
|
||||
perturbed_data = mean + std[:, None, None] * z
|
||||
pred = predictor_fn(perturbed_data, t, mask, task)
|
||||
else:
|
||||
t = eps * torch.ones(x.shape[0], device=adj.device)
|
||||
pred = predictor_fn(x, t, mask, task)
|
||||
|
||||
labels = extra[f"{label_list[-1]}"]
|
||||
labels = labels.to(pred.device).unsqueeze(1).type(pred.dtype)
|
||||
|
||||
loss = torch.nn.MSELoss()(pred, labels)
|
||||
|
||||
return loss, pred, labels
|
||||
|
||||
return loss_fn
|
||||
|
||||
|
||||
def get_step_fn_predictor(sde,
|
||||
train,
|
||||
optimize_fn=None,
|
||||
reduce_mean=False,
|
||||
continuous=True,
|
||||
likelihood_weighting=False,
|
||||
data='NASBench201',
|
||||
label_list=None,
|
||||
noised=True):
|
||||
"""Create a one-step training/evaluation function.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
Tuple (`sde_lib.SDE`, `sde_lib.SDE`) that represents the forward node SDE and edge SDE.
|
||||
optimize_fn: An optimization function.
|
||||
reduce_mean: If `True`, average the loss across data dimensions.
|
||||
Otherwise, sum the loss across data dimensions.
|
||||
continuous: `True` indicates that the model is defined to take continuous time steps.
|
||||
likelihood_weighting: If `True`, weight the mixture of score matching losses according to
|
||||
https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended by score-sde.
|
||||
|
||||
Returns:
|
||||
A one-step function for training or evaluation.
|
||||
"""
|
||||
|
||||
if continuous:
|
||||
if data in ['NASBench201', 'ofa']:
|
||||
loss_fn = get_meta_predictor_loss_fn_nas(sde,
|
||||
train,
|
||||
reduce_mean=reduce_mean,
|
||||
continuous=True,
|
||||
likelihood_weighting=likelihood_weighting,
|
||||
label_list=label_list,
|
||||
noised=noised)
|
||||
else:
|
||||
raise NotImplementedError(f"Data {data} (search space) is not supported yet.")
|
||||
else:
|
||||
raise NotImplementedError(f"Discrete training for {sde.__class__.__name__} is not implemented.")
|
||||
|
||||
|
||||
def step_fn(state, batch):
|
||||
"""Running one step of training or evaluation.
|
||||
|
||||
For jax version: This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and
|
||||
jit-compiled together for faster execution.
|
||||
|
||||
Args:
|
||||
state: A dictionary of training information, containing the score model, optimizer,
|
||||
EMA status, and number of optimization steps.
|
||||
batch: A mini-batch of training/evaluation data, including min-batch adjacency matrices and mask.
|
||||
|
||||
Returns:
|
||||
loss: The average loss value of this state.
|
||||
"""
|
||||
model = state['model']
|
||||
if train:
|
||||
model.train()
|
||||
optimizer = state['optimizer']
|
||||
optimizer.zero_grad()
|
||||
loss, pred, labels = loss_fn(model, batch)
|
||||
loss.backward()
|
||||
optimize_fn(optimizer, model.parameters(), step=state['step'])
|
||||
state['step'] += 1
|
||||
else:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
loss, pred, labels = loss_fn(model, batch)
|
||||
|
||||
return loss, pred, labels
|
||||
|
||||
return step_fn
|
||||
37
NAS-Bench-201/main.py
Normal file
37
NAS-Bench-201/main.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Training and evaluation"""
|
||||
|
||||
import run_lib
|
||||
from absl import app, flags
|
||||
from ml_collections.config_flags import config_flags
|
||||
import logging
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
config_flags.DEFINE_config_file(
|
||||
'config', None, 'Training configuration.', lock_config=True
|
||||
)
|
||||
config_flags.DEFINE_config_file(
|
||||
'classifier_config_nf', None, 'Training configuration.', lock_config=True
|
||||
)
|
||||
flags.DEFINE_enum('mode', None, ['train', 'eval'],
|
||||
'Running mode: train or eval')
|
||||
|
||||
|
||||
def main(argv):
|
||||
## Set random seed
|
||||
run_lib.set_random_seed(FLAGS.config)
|
||||
|
||||
if FLAGS.mode == 'train':
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel('INFO')
|
||||
run_lib.train(FLAGS.config)
|
||||
elif FLAGS.mode == 'eval':
|
||||
run_lib.evaluate(FLAGS.config)
|
||||
else:
|
||||
raise ValueError(f"Mode {FLAGS.mode} not recognized.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
286
NAS-Bench-201/main_exp/diffusion/run_lib.py
Normal file
286
NAS-Bench-201/main_exp/diffusion/run_lib.py
Normal file
@@ -0,0 +1,286 @@
|
||||
import torch
|
||||
import sys
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
sys.path.append('.')
|
||||
import sampling
|
||||
import datasets_nas
|
||||
from models import cate
|
||||
from models import digcn
|
||||
from models import digcn_meta
|
||||
from models import utils as mutils
|
||||
from models.ema import ExponentialMovingAverage
|
||||
import sde_lib
|
||||
from utils import *
|
||||
from analysis.arch_functions import BasicArchMetricsMeta
|
||||
from all_path import *
|
||||
|
||||
|
||||
def get_sampling_fn_meta(config):
|
||||
## Set SDE
|
||||
if config.training.sde.lower() == 'vpsde':
|
||||
sde = sde_lib.VPSDE(
|
||||
beta_min=config.model.beta_min,
|
||||
beta_max=config.model.beta_max,
|
||||
N=config.model.num_scales)
|
||||
sampling_eps = 1e-3
|
||||
elif config.training.sde.lower() == 'subvpsde':
|
||||
sde = sde_lib.subVPSDE(
|
||||
beta_min=config.model.beta_min,
|
||||
beta_max=config.model.beta_max,
|
||||
N=config.model.num_scales)
|
||||
sampling_eps = 1e-3
|
||||
elif config.training.sde.lower() == 'vesde':
|
||||
sde = sde_lib.VESDE(
|
||||
sigma_min=config.model.sigma_min,
|
||||
sigma_max=config.model.sigma_max,
|
||||
N=config.model.num_scales)
|
||||
sampling_eps = 1e-5
|
||||
else:
|
||||
raise NotImplementedError(f"SDE {config.training.sde} unknown.")
|
||||
|
||||
## Get data normalizer inverse
|
||||
inverse_scaler = datasets_nas.get_data_inverse_scaler(config)
|
||||
|
||||
## Get sampling function
|
||||
sampling_shape = (config.eval.batch_size, config.data.max_node, config.data.n_vocab)
|
||||
sampling_fn = sampling.get_sampling_fn(
|
||||
config=config,
|
||||
sde=sde,
|
||||
shape=sampling_shape,
|
||||
inverse_scaler=inverse_scaler,
|
||||
eps=sampling_eps,
|
||||
conditional=True,
|
||||
data_name=config.sampling.check_dataname,
|
||||
num_sample=config.model.num_sample)
|
||||
|
||||
return sampling_fn, sde
|
||||
|
||||
|
||||
def get_score_model(config):
|
||||
try:
|
||||
score_config = torch.load(config.scorenet_ckpt_path)['config']
|
||||
ckpt_path = config.scorenet_ckpt_path
|
||||
except:
|
||||
config.scorenet_ckpt_path = SCORENET_CKPT_PATH
|
||||
score_config = torch.load(config.scorenet_ckpt_path)['config']
|
||||
ckpt_path = config.scorenet_ckpt_path
|
||||
|
||||
score_model = mutils.create_model(score_config)
|
||||
score_ema = ExponentialMovingAverage(
|
||||
score_model.parameters(), decay=score_config.model.ema_rate)
|
||||
score_state = dict(
|
||||
model=score_model, ema=score_ema, step=0, config=score_config)
|
||||
score_state = restore_checkpoint(
|
||||
ckpt_path, score_state,
|
||||
device=config.device, resume=True)
|
||||
score_ema.copy_to(score_model.parameters())
|
||||
return score_model, score_ema, score_config
|
||||
|
||||
|
||||
def get_surrogate(config):
|
||||
surrogate_model = mutils.create_model(config)
|
||||
return surrogate_model
|
||||
|
||||
|
||||
def get_adj(except_inout=False):
|
||||
_adj = np.asarray(
|
||||
[[0, 1, 1, 1, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 1, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0]]
|
||||
)
|
||||
_adj = torch.tensor(_adj, dtype=torch.float32, device=torch.device('cpu'))
|
||||
if except_inout: _adj = _adj[1:-1, 1:-1]
|
||||
return _adj
|
||||
|
||||
|
||||
def generate_archs_meta(
|
||||
config,
|
||||
sampling_fn,
|
||||
score_model,
|
||||
score_ema,
|
||||
meta_surrogate_model,
|
||||
num_samples,
|
||||
args=None,
|
||||
task=None,
|
||||
patient_factor=20,
|
||||
batch_size=256,):
|
||||
|
||||
metrics = BasicArchMetricsMeta()
|
||||
|
||||
## Get the adj and mask
|
||||
adj_s = get_adj()
|
||||
mask_s = aug_mask(adj_s)[0]
|
||||
adj_c = get_adj()
|
||||
mask_c = aug_mask(adj_c)[0]
|
||||
assert (adj_s == adj_c).all() and (mask_s == mask_c).all()
|
||||
adj_s, mask_s, adj_c, mask_c = \
|
||||
adj_s.to(config.device), mask_s.to(config.device), adj_c.to(config.device), mask_c.to(config.device)
|
||||
|
||||
score_ema.copy_to(score_model.parameters())
|
||||
score_model.eval()
|
||||
meta_surrogate_model.eval()
|
||||
c_scale = args.classifier_scale
|
||||
|
||||
num_sampling_rounds = int(np.ceil(num_samples / batch_size) * patient_factor) if num_samples > batch_size else int(patient_factor)
|
||||
round = 0
|
||||
all_samples = []
|
||||
while True and round < num_sampling_rounds:
|
||||
round += 1
|
||||
sample = sampling_fn(score_model,
|
||||
mask_s,
|
||||
meta_surrogate_model,
|
||||
classifier_scale=c_scale,
|
||||
task=task)
|
||||
quantized_sample = quantize(sample)
|
||||
_, _, valid_arch_str, _ = metrics.compute_validity(quantized_sample)
|
||||
if len(valid_arch_str) > 0: all_samples += valid_arch_str
|
||||
# to sample various architectures
|
||||
c_scale -= args.scale_step
|
||||
seed = int(random.random() * 10000)
|
||||
reset_seed(seed)
|
||||
# stop sampling if we have enough samples
|
||||
if (len(set(all_samples)) >= num_samples):
|
||||
break
|
||||
|
||||
return list(set(all_samples))
|
||||
|
||||
|
||||
def save_checkpoint(ckpt_dir, state, epoch, is_best):
|
||||
saved_state = {}
|
||||
for k in state:
|
||||
if k in ['optimizer', 'model', 'ema']:
|
||||
saved_state.update({k: state[k].state_dict()})
|
||||
else:
|
||||
saved_state.update({k: state[k]})
|
||||
os.makedirs(ckpt_dir, exist_ok=True)
|
||||
torch.save(saved_state, os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth.tar'))
|
||||
if is_best:
|
||||
shutil.copy(os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth.tar'), os.path.join(ckpt_dir, 'model_best.pth.tar'))
|
||||
|
||||
# remove the ckpt except is_best state
|
||||
for ckpt_file in sorted(os.listdir(ckpt_dir)):
|
||||
if not ckpt_file.startswith('checkpoint'):
|
||||
continue
|
||||
if os.path.join(ckpt_dir, ckpt_file) != os.path.join(ckpt_dir, 'model_best.pth.tar'):
|
||||
os.remove(os.path.join(ckpt_dir, ckpt_file))
|
||||
|
||||
|
||||
def restore_checkpoint(ckpt_dir, state, device, resume=False):
|
||||
if not resume:
|
||||
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
|
||||
return state
|
||||
elif not os.path.exists(ckpt_dir):
|
||||
if not os.path.exists(os.path.dirname(ckpt_dir)):
|
||||
os.makedirs(os.path.dirname(ckpt_dir))
|
||||
logging.warning(f"No checkpoint found at {ckpt_dir}. "
|
||||
f"Returned the same state as input")
|
||||
return state
|
||||
else:
|
||||
loaded_state = torch.load(ckpt_dir, map_location=device)
|
||||
for k in state:
|
||||
if k in ['optimizer', 'model', 'ema']:
|
||||
state[k].load_state_dict(loaded_state[k])
|
||||
else:
|
||||
state[k] = loaded_state[k]
|
||||
return state
|
||||
|
||||
|
||||
def floyed(r):
|
||||
"""
|
||||
:param r: a numpy NxN matrix with float 0,1
|
||||
:return: a numpy NxN matrix with float 0,1
|
||||
"""
|
||||
if type(r) == torch.Tensor:
|
||||
r = r.cpu().numpy()
|
||||
N = r.shape[0]
|
||||
for k in range(N):
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
if r[i, k] > 0 and r[k, j] > 0:
|
||||
r[i, j] = 1
|
||||
return r
|
||||
|
||||
|
||||
def aug_mask(adj, algo='floyed', data='NASBench201'):
|
||||
if len(adj.shape) == 2:
|
||||
adj = adj.unsqueeze(0)
|
||||
|
||||
if data.lower() in ['nasbench201', 'ofa']:
|
||||
assert len(adj.shape) == 3
|
||||
r = adj[0].clone().detach()
|
||||
if algo == 'long_range':
|
||||
mask_i = torch.from_numpy(long_range(r)).float().to(adj.device)
|
||||
elif algo == 'floyed':
|
||||
mask_i = torch.from_numpy(floyed(r)).float().to(adj.device)
|
||||
else:
|
||||
mask_i = r
|
||||
masks = [mask_i] * adj.size(0)
|
||||
return torch.stack(masks)
|
||||
else:
|
||||
masks = []
|
||||
for r in adj:
|
||||
if algo == 'long_range':
|
||||
mask_i = torch.from_numpy(long_range(r)).float().to(adj.device)
|
||||
elif algo == 'floyed':
|
||||
mask_i = torch.from_numpy(floyed(r)).float().to(adj.device)
|
||||
else:
|
||||
mask_i = r
|
||||
masks.append(mask_i)
|
||||
return torch.stack(masks)
|
||||
|
||||
|
||||
def long_range(r):
|
||||
"""
|
||||
:param r: a numpy NxN matrix with float 0,1
|
||||
:return: a numpy NxN matrix with float 0,1
|
||||
"""
|
||||
# r = np.array(r)
|
||||
if type(r) == torch.Tensor:
|
||||
r = r.cpu().numpy()
|
||||
N = r.shape[0]
|
||||
for j in range(1, N):
|
||||
col_j = r[:, j][:j]
|
||||
in_to_j = [i for i, val in enumerate(col_j) if val > 0]
|
||||
if len(in_to_j) > 0:
|
||||
for i in in_to_j:
|
||||
col_i = r[:, i][:i]
|
||||
in_to_i = [i for i, val in enumerate(col_i) if val > 0]
|
||||
if len(in_to_i) > 0:
|
||||
for k in in_to_i:
|
||||
r[k, j] = 1
|
||||
return r
|
||||
|
||||
|
||||
def quantize(x):
|
||||
"""Covert the PyTorch tensor x, adj matrices to numpy array.
|
||||
|
||||
Args:
|
||||
x: [Batch_size, Max_node, N_vocab]
|
||||
"""
|
||||
x_list = []
|
||||
|
||||
# discretization
|
||||
x[x >= 0.5] = 1.
|
||||
x[x < 0.5] = 0.
|
||||
|
||||
for i in range(x.shape[0]):
|
||||
x_tmp = x[i]
|
||||
x_tmp = x_tmp.cpu().numpy()
|
||||
x_list.append(x_tmp)
|
||||
|
||||
return x_list
|
||||
|
||||
|
||||
def reset_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
137
NAS-Bench-201/main_exp/logger.py
Normal file
137
NAS-Bench-201/main_exp/logger.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import os
|
||||
import wandb
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(
|
||||
self,
|
||||
log_dir=None,
|
||||
write_textfile=True
|
||||
):
|
||||
|
||||
self.log_dir = log_dir
|
||||
self.write_textfile = write_textfile
|
||||
|
||||
self.logs_for_save = {}
|
||||
self.logs = {}
|
||||
|
||||
if self.write_textfile:
|
||||
self.f = open(os.path.join(log_dir, 'logs.txt'), 'w')
|
||||
|
||||
|
||||
def write_str(self, log_str):
|
||||
self.f.write(log_str+'\n')
|
||||
self.f.flush()
|
||||
|
||||
|
||||
def update_config(self, v, is_args=False):
|
||||
if is_args:
|
||||
self.logs_for_save.update({'args': v})
|
||||
else:
|
||||
self.logs_for_save.update(v)
|
||||
|
||||
|
||||
def write_log(self, element, step, return_log_dict=False):
|
||||
log_str = f"{step} | "
|
||||
log_dict = {}
|
||||
for head, keys in element.items():
|
||||
for k in keys:
|
||||
if k in self.logs:
|
||||
v = self.logs[k].avg
|
||||
if not k in self.logs_for_save:
|
||||
self.logs_for_save[k] = []
|
||||
self.logs_for_save[k].append(v)
|
||||
log_str += f'{k} {v}| '
|
||||
log_dict[f'{head}/{k}'] = v
|
||||
|
||||
if self.write_textfile:
|
||||
self.f.write(log_str+'\n')
|
||||
self.f.flush()
|
||||
|
||||
if return_log_dict:
|
||||
return log_dict
|
||||
|
||||
|
||||
def save_log(self, name=None):
|
||||
name = 'logs.pt' if name is None else name
|
||||
torch.save(self.logs_for_save, os.path.join(self.log_dir, name))
|
||||
|
||||
|
||||
def update(self, key, v, n=1):
|
||||
if not key in self.logs:
|
||||
self.logs[key] = AverageMeter()
|
||||
self.logs[key].update(v, n)
|
||||
|
||||
|
||||
def reset(self, keys=None, except_keys=[]):
|
||||
if keys is not None:
|
||||
if isinstance(keys, list):
|
||||
for key in keys:
|
||||
self.logs[key] = AverageMeter()
|
||||
else:
|
||||
self.logs[keys] = AverageMeter()
|
||||
else:
|
||||
for key in self.logs.keys():
|
||||
if not key in except_keys:
|
||||
self.logs[key] = AverageMeter()
|
||||
|
||||
|
||||
def avg(self, keys=None, except_keys=[]):
|
||||
if keys is not None:
|
||||
if isinstance(keys, list):
|
||||
return {key: self.logs[key].avg for key in keys if key in self.logs.keys()}
|
||||
else:
|
||||
return self.logs[keys].avg
|
||||
else:
|
||||
avg_dict = {}
|
||||
for key in self.logs.keys():
|
||||
if not key in except_keys:
|
||||
avg_dict[key] = self.logs[key].avg
|
||||
return avg_dict
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""
|
||||
Computes and stores the average and current value
|
||||
Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def get_metrics(g_embeds, x_embeds, logit_scale, prefix='train'):
|
||||
metrics = {}
|
||||
logits_per_g = (logit_scale * g_embeds @ x_embeds.t()).detach().cpu()
|
||||
logits_per_x = logits_per_g.t().detach().cpu()
|
||||
|
||||
logits = {"g_to_x": logits_per_g, "x_to_g": logits_per_x}
|
||||
ground_truth = torch.arange(len(x_embeds)).view(-1, 1)
|
||||
|
||||
for name, logit in logits.items():
|
||||
ranking = torch.argsort(logit, descending=True)
|
||||
preds = torch.where(ranking == ground_truth)[1]
|
||||
preds = preds.detach().cpu().numpy()
|
||||
metrics[f"{prefix}_{name}_mean_rank"] = preds.mean() + 1
|
||||
metrics[f"{prefix}_{name}_median_rank"] = np.floor(np.median(preds)) + 1
|
||||
for k in [1, 5, 10]:
|
||||
metrics[f"{prefix}_{name}_R@{k}"] = np.mean(preds < k)
|
||||
|
||||
return metrics
|
||||
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
@author: Hayeon Lee
|
||||
2020/02/19
|
||||
Script for downloading, and reorganizing aircraft
|
||||
for few shot classification
|
||||
Run this file as follows:
|
||||
python get_data.py
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import os
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import tarfile
|
||||
from PIL import Image
|
||||
import glob
|
||||
import shutil
|
||||
import pickle
|
||||
import collections
|
||||
import sys
|
||||
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
|
||||
from all_path import RAW_DATA_PATH
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
dir_path = RAW_DATA_PATH
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
file_name = os.path.join(dir_path, 'fgvc-aircraft-2013b.tar.gz')
|
||||
|
||||
if not os.path.exists(file_name):
|
||||
print(f"Downloading {file_name}\n")
|
||||
download_file(
|
||||
'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz',
|
||||
file_name)
|
||||
print("\nDownloading done.\n")
|
||||
else:
|
||||
print("fgvc-aircraft-2013b.tar.gz has already been downloaded. Did not download twice.\n")
|
||||
|
||||
untar_file_name = os.path.join(dir_path, 'aircraft')
|
||||
if not os.path.exists(untar_file_name):
|
||||
tarname = file_name
|
||||
print("Untarring: {}".format(tarname))
|
||||
tar = tarfile.open(tarname)
|
||||
tar.extractall(untar_file_name)
|
||||
tar.close()
|
||||
else:
|
||||
print(f"{untar_file_name} folder already exists. Did not untarring twice\n")
|
||||
os.remove(file_name)
|
||||
50
NAS-Bench-201/main_exp/transfer_nag/get_files/get_pets.py
Normal file
50
NAS-Bench-201/main_exp/transfer_nag/get_files/get_pets.py
Normal file
@@ -0,0 +1,50 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import zipfile
|
||||
import sys
|
||||
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
|
||||
from all_path import RAW_DATA_PATH
|
||||
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm(unit="B", total=int(r.headers['Content-Length']))
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update(len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
|
||||
dir_path = os.path.join(RAW_DATA_PATH, 'pets')
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
full_name = os.path.join(dir_path, 'test15.pth')
|
||||
if not os.path.exists(full_name):
|
||||
print(f"Downloading {full_name}\n")
|
||||
download_file(
|
||||
'https://www.dropbox.com/s/kzmrwyyk5iaugv0/test15.pth?dl=1', full_name)
|
||||
print("Downloading done.\n")
|
||||
else:
|
||||
print(f"{full_name} has already been downloaded. Did not download twice.\n")
|
||||
|
||||
full_name = os.path.join(dir_path, 'train85.pth')
|
||||
if not os.path.exists(full_name):
|
||||
print(f"Downloading {full_name}\n")
|
||||
download_file(
|
||||
'https://www.dropbox.com/s/w7mikpztkamnw9s/train85.pth?dl=1', full_name)
|
||||
print("Downloading done.\n")
|
||||
else:
|
||||
print(f"{full_name} has already been downloaded. Did not download twice.\n")
|
||||
@@ -0,0 +1,47 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
|
||||
|
||||
DATA_PATH = "./data/transfer_nag"
|
||||
dir_path = DATA_PATH
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
|
||||
def get_preprocessed_data(file_name, url):
|
||||
print(f"Downloading {file_name} datasets\n")
|
||||
full_name = os.path.join(dir_path, file_name)
|
||||
download_file(url, full_name)
|
||||
print("Downloading done.\n")
|
||||
|
||||
|
||||
for file_name, url in [
|
||||
('aircraftbylabel.pt', 'https://www.dropbox.com/s/mb66kitv20ykctp/aircraftbylabel.pt?dl=1'),
|
||||
('cifar100bylabel.pt', 'https://www.dropbox.com/s/y0xahxgzj29kffk/cifar100bylabel.pt?dl=1'),
|
||||
('cifar10bylabel.pt', 'https://www.dropbox.com/s/wt1pcwi991xyhwr/cifar10bylabel.pt?dl=1'),
|
||||
('imgnet32bylabel.pt', 'https://www.dropbox.com/s/7r3hpugql8qgi9d/imgnet32bylabel.pt?dl=1'),
|
||||
('petsbylabel.pt', 'https://www.dropbox.com/s/mxh6qz3grhy7wcn/petsbylabel.pt?dl=1'),
|
||||
]:
|
||||
|
||||
get_preprocessed_data(file_name, url)
|
||||
130
NAS-Bench-201/main_exp/transfer_nag/loader.py
Normal file
130
NAS-Bench-201/main_exp/transfer_nag/loader.py
Normal file
@@ -0,0 +1,130 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def get_meta_train_loader(batch_size, data_path, num_sample, is_pred=True):
|
||||
dataset = MetaTrainDatabase(data_path, num_sample, is_pred)
|
||||
print(f'==> The number of tasks for meta-training: {len(dataset)}')
|
||||
|
||||
loader = DataLoader(dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
collate_fn=collate_fn)
|
||||
return loader
|
||||
|
||||
|
||||
def get_meta_test_loader(data_path, data_name, num_class=None, is_pred=False):
|
||||
dataset = MetaTestDataset(data_path, data_name, num_class)
|
||||
print(f'==> Meta-Test dataset {data_name}')
|
||||
|
||||
loader = DataLoader(dataset=dataset,
|
||||
batch_size=100,
|
||||
shuffle=False,
|
||||
num_workers=0)
|
||||
return loader
|
||||
|
||||
|
||||
class MetaTrainDatabase(Dataset):
|
||||
def __init__(self, data_path, num_sample, is_pred=True):
|
||||
self.mode = 'train'
|
||||
self.acc_norm = True
|
||||
self.num_sample = num_sample
|
||||
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
|
||||
|
||||
mtr_data_path = os.path.join(
|
||||
data_path, 'meta_train_tasks_predictor.pt')
|
||||
idx_path = os.path.join(
|
||||
data_path, 'meta_train_tasks_predictor_idx.pt')
|
||||
data = torch.load(mtr_data_path)
|
||||
self.acc = data['acc']
|
||||
self.task = data['task']
|
||||
self.graph = data['g']
|
||||
|
||||
random_idx_lst = torch.load(idx_path)
|
||||
self.idx_lst = {}
|
||||
self.idx_lst['valid'] = random_idx_lst[:400]
|
||||
self.idx_lst['train'] = random_idx_lst[400:]
|
||||
self.acc = torch.tensor(self.acc)
|
||||
self.mean = torch.mean(self.acc[self.idx_lst['train']]).item()
|
||||
self.std = torch.std(self.acc[self.idx_lst['train']]).item()
|
||||
self.task_lst = torch.load(os.path.join(
|
||||
data_path, 'meta_train_task_lst.pt'))
|
||||
|
||||
|
||||
def set_mode(self, mode):
|
||||
self.mode = mode
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.idx_lst[self.mode])
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = []
|
||||
ridx = self.idx_lst[self.mode]
|
||||
tidx = self.task[ridx[index]]
|
||||
classes = self.task_lst[tidx]
|
||||
graph = self.graph[ridx[index]]
|
||||
acc = self.acc[ridx[index]]
|
||||
for cls in classes:
|
||||
cx = self.x[cls-1][0]
|
||||
ridx = torch.randperm(len(cx))
|
||||
data.append(cx[ridx[:self.num_sample]])
|
||||
x = torch.cat(data)
|
||||
if self.acc_norm:
|
||||
acc = ((acc - self.mean) / self.std) / 100.0
|
||||
else:
|
||||
acc = acc / 100.0
|
||||
return x, graph, acc
|
||||
|
||||
|
||||
class MetaTestDataset(Dataset):
|
||||
def __init__(self, data_path, data_name, num_sample, num_class=None):
|
||||
self.num_sample = num_sample
|
||||
self.data_name = data_name
|
||||
|
||||
num_class_dict = {
|
||||
'cifar100': 100,
|
||||
'cifar10': 10,
|
||||
'mnist': 10,
|
||||
'svhn': 10,
|
||||
'aircraft': 30,
|
||||
'pets': 37
|
||||
}
|
||||
|
||||
if num_class is not None:
|
||||
self.num_class = num_class
|
||||
else:
|
||||
self.num_class = num_class_dict[data_name]
|
||||
|
||||
self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt'))
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return 1000000
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = []
|
||||
classes = list(range(self.num_class))
|
||||
for cls in classes:
|
||||
cx = self.x[cls][0]
|
||||
ridx = torch.randperm(len(cx))
|
||||
data.append(cx[ridx[:self.num_sample]])
|
||||
x = torch.cat(data)
|
||||
return x
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
x = torch.stack([item[0] for item in batch])
|
||||
graph = [item[1] for item in batch]
|
||||
acc = torch.stack([item[2] for item in batch])
|
||||
return [x, graph, acc]
|
||||
91
NAS-Bench-201/main_exp/transfer_nag/main.py
Normal file
91
NAS-Bench-201/main_exp/transfer_nag/main.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
import numpy as np
|
||||
import argparse
|
||||
import torch
|
||||
import os
|
||||
from nag import NAG
|
||||
sys.path.append(os.getcwd())
|
||||
save_path = "results"
|
||||
data_path = os.path.join('MetaD2A_nas_bench_201', 'data')
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
return v.lower() in ['t', 'true', True]
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
# general settings
|
||||
parser.add_argument('--seed', type=int, default=444)
|
||||
parser.add_argument('--gpu', type=str, default='0', help='set visible gpus')
|
||||
parser.add_argument('--save-path', type=str, default=save_path, help='the path of save directory')
|
||||
parser.add_argument('--data-path', type=str, default=data_path, help='the path of save directory')
|
||||
parser.add_argument('--model-load-path', type=str, default='', help='')
|
||||
parser.add_argument('--save-epoch', type=int, default=20, help='how many epochs to wait each time to save model states')
|
||||
parser.add_argument('--max-epoch', type=int, default=1000, help='number of epochs to train')
|
||||
parser.add_argument('--batch_size', type=int, default=1024, help='batch size for generator')
|
||||
parser.add_argument('--graph-data-name', default='nasbench201', help='graph dataset name')
|
||||
parser.add_argument('--nvt', type=int, default=7, help='number of different node types, 7: NAS-Bench-201 including in/out node')
|
||||
# set encoder
|
||||
parser.add_argument('--num-sample', type=int, default=20, help='the number of images as input for set encoder')
|
||||
# graph encoder
|
||||
parser.add_argument('--hs', type=int, default=512, help='hidden size of GRUs')
|
||||
parser.add_argument('--nz', type=int, default=56, help='the number of dimensions of latent vectors z')
|
||||
# test
|
||||
parser.add_argument('--test', action='store_true', default=True, help='turn on test mode')
|
||||
parser.add_argument('--load-epoch', type=int, default=100, help='checkpoint epoch loaded for meta-test')
|
||||
parser.add_argument('--data-name', type=str, default='pets', help='meta-test dataset name')
|
||||
parser.add_argument('--trials', type=int, default=20)
|
||||
parser.add_argument('--num-class', type=int, default=None, help='the number of class of dataset')
|
||||
parser.add_argument('--num-gen-arch', type=int, default=500, help='the number of candidate architectures generated by the generator')
|
||||
parser.add_argument('--train-arch', type=str2bool, default=True, help='whether to train the searched architecture')
|
||||
parser.add_argument('--n_init', type=int, default=10)
|
||||
parser.add_argument('--N', type=int, default=1)
|
||||
# DiffusionNAG
|
||||
parser.add_argument('--folder_name', type=str, default='debug')
|
||||
parser.add_argument('--exp_name', type=str, default='')
|
||||
parser.add_argument('--classifier_scale', type=float, default=10000., help='classifier scale')
|
||||
parser.add_argument('--scale_step', type=float, default=300.)
|
||||
parser.add_argument('--eval_batch_size', type=int, default=256)
|
||||
parser.add_argument('--predictor', type=str, default='euler_maruyama', choices=['euler_maruyama', 'reverse_diffusion', 'none'])
|
||||
parser.add_argument('--corrector', type=str, default='langevin', choices=['none', 'langevin'])
|
||||
parser.add_argument('--patient_factor', type=int, default=20)
|
||||
parser.add_argument('--n_gen_samples', type=int, default=10)
|
||||
parser.add_argument('--multi_proc', type=str2bool, default=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def set_exp_name(args):
|
||||
exp_name = f'./results/transfer_nag/{args.folder_name}/{args.data_name}'
|
||||
os.makedirs(exp_name, exist_ok=True)
|
||||
args.exp_name = exp_name
|
||||
|
||||
|
||||
def main():
|
||||
## Get arguments
|
||||
args = get_parser()
|
||||
|
||||
## Set gpus and seeds
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
|
||||
## Set experiment name
|
||||
set_exp_name(args)
|
||||
|
||||
## Run
|
||||
nag = NAG(args)
|
||||
nag.meta_test()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
305
NAS-Bench-201/main_exp/transfer_nag/nag.py
Normal file
305
NAS-Bench-201/main_exp/transfer_nag/nag.py
Normal file
@@ -0,0 +1,305 @@
|
||||
from __future__ import print_function
|
||||
import torch
|
||||
import os
|
||||
import gc
|
||||
import sys
|
||||
import numpy as np
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from nag_utils import mean_confidence_interval
|
||||
from nag_utils import restore_checkpoint
|
||||
from nag_utils import load_graph_config
|
||||
from nag_utils import load_model
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
|
||||
from nas_bench_201 import train_single_model
|
||||
from unnoised_model import MetaSurrogateUnnoisedModel
|
||||
from diffusion.run_lib import generate_archs_meta
|
||||
from diffusion.run_lib import get_sampling_fn_meta
|
||||
from diffusion.run_lib import get_score_model
|
||||
from diffusion.run_lib import get_surrogate
|
||||
from loader import MetaTestDataset
|
||||
from logger import Logger
|
||||
from all_path import *
|
||||
|
||||
|
||||
class NAG:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
## Target dataset information
|
||||
self.raw_data_path = RAW_DATA_PATH
|
||||
self.data_path = DATA_PATH
|
||||
self.data_name = args.data_name
|
||||
self.num_class = args.num_class
|
||||
self.num_sample = args.num_sample
|
||||
|
||||
graph_config = load_graph_config(args.graph_data_name, args.nvt, NASBENCH201)
|
||||
self.meta_surrogate_unnoised_model = MetaSurrogateUnnoisedModel(args, graph_config)
|
||||
load_model(model=self.meta_surrogate_unnoised_model,
|
||||
ckpt_path=META_SURROGATE_UNNOISED_CKPT_PATH)
|
||||
self.meta_surrogate_unnoised_model.to(self.device)
|
||||
|
||||
## Load pre-trained meta-surrogate model
|
||||
self.meta_surrogate_ckpt_path = META_SURROGATE_CKPT_PATH
|
||||
|
||||
## Load score network model (base diffusion model)
|
||||
self.load_diffusion_model(args=args)
|
||||
|
||||
## Check config
|
||||
self.check_config()
|
||||
|
||||
## Set logger
|
||||
self.logger = Logger(
|
||||
log_dir=args.exp_name,
|
||||
write_textfile=True
|
||||
)
|
||||
self.logger.update_config(args, is_args=True)
|
||||
self.logger.write_str(str(vars(args)))
|
||||
self.logger.write_str('-' * 100)
|
||||
|
||||
|
||||
def check_config(self):
|
||||
"""
|
||||
Check if the configuration of the pre-trained score network model matches that of the meta surrogate model.
|
||||
"""
|
||||
scorenet_config = torch.load(self.config.scorenet_ckpt_path)['config']
|
||||
meta_surrogate_config = torch.load(self.meta_surrogate_ckpt_path)['config']
|
||||
assert scorenet_config.model.sigma_min == meta_surrogate_config.model.sigma_min
|
||||
assert scorenet_config.model.sigma_max == meta_surrogate_config.model.sigma_max
|
||||
assert scorenet_config.training.sde == meta_surrogate_config.training.sde
|
||||
assert scorenet_config.training.continuous == meta_surrogate_config.training.continuous
|
||||
assert scorenet_config.data.centered == meta_surrogate_config.data.centered
|
||||
assert scorenet_config.data.max_node == meta_surrogate_config.data.max_node
|
||||
assert scorenet_config.data.n_vocab == meta_surrogate_config.data.n_vocab
|
||||
|
||||
|
||||
def forward(self, x, arch):
|
||||
D_mu = self.meta_surrogate_unnoised_model.set_encode(x.to(self.device))
|
||||
G_mu = self.meta_surrogate_unnoised_model.graph_encode(arch)
|
||||
y_pred = self.meta_surrogate_unnoised_model.predict(D_mu, G_mu)
|
||||
return y_pred
|
||||
|
||||
|
||||
def meta_test(self):
|
||||
if self.data_name == 'all':
|
||||
for data_name in ['cifar10', 'cifar100', 'aircraft', 'pets']:
|
||||
self.meta_test_per_dataset(data_name)
|
||||
else:
|
||||
self.meta_test_per_dataset(self.data_name)
|
||||
|
||||
|
||||
def meta_test_per_dataset(self, data_name):
|
||||
## Load NASBench201
|
||||
self.nasbench201 = torch.load(NASBENCH201)
|
||||
all_arch_str = np.array(self.nasbench201['arch']['str'])
|
||||
|
||||
## Load meta-test dataset
|
||||
self.test_dataset = MetaTestDataset(self.data_path, data_name, self.num_sample, self.num_class)
|
||||
|
||||
## Set save path
|
||||
meta_test_path = os.path.join(META_TEST_PATH, data_name)
|
||||
os.makedirs(meta_test_path, exist_ok=True)
|
||||
f_arch_str = open(os.path.join(self.args.exp_name, 'architecture.txt'), 'w')
|
||||
f_arch_acc = open(os.path.join(self.args.exp_name, 'accuracy.txt'), 'w')
|
||||
|
||||
## Generate architectures
|
||||
gen_arch_str = self.get_gen_arch_str()
|
||||
gen_arch_igraph = self.get_items(
|
||||
full_target=self.nasbench201['arch']['igraph'],
|
||||
full_source=self.nasbench201['arch']['str'],
|
||||
source=gen_arch_str)
|
||||
|
||||
## Sort with unnoised meta-surrogate model
|
||||
y_pred_all = []
|
||||
self.meta_surrogate_unnoised_model.eval()
|
||||
self.meta_surrogate_unnoised_model.to(self.device)
|
||||
with torch.no_grad():
|
||||
for arch_igraph in gen_arch_igraph:
|
||||
x, g = self.collect_data(arch_igraph)
|
||||
y_pred = self.forward(x, g)
|
||||
y_pred = torch.mean(y_pred)
|
||||
y_pred_all.append(y_pred.cpu().detach().item())
|
||||
sorted_arch_lst = self.sort_arch(data_name, torch.tensor(y_pred_all), gen_arch_str)
|
||||
|
||||
## Record the information of the architecture generated in sorted order
|
||||
for _, arch_str in enumerate(sorted_arch_lst):
|
||||
f_arch_str.write(f'{arch_str}\n')
|
||||
arch_idx_lst = [self.nasbench201['arch']['str'].index(i) for i in sorted_arch_lst]
|
||||
arch_str_lst = []
|
||||
arch_acc_lst = []
|
||||
|
||||
## Get the accuracy of the architecture
|
||||
if 'cifar' in data_name:
|
||||
sorted_acc_lst = self.get_items(
|
||||
full_target=self.nasbench201['test-acc'][data_name],
|
||||
full_source=self.nasbench201['arch']['str'],
|
||||
source=sorted_arch_lst)
|
||||
arch_str_lst += sorted_arch_lst
|
||||
arch_acc_lst += sorted_acc_lst
|
||||
for arch_idx, acc in zip(arch_idx_lst, sorted_acc_lst):
|
||||
msg = f'Avg {acc:4f} (%)'
|
||||
f_arch_acc.write(msg + '\n')
|
||||
else:
|
||||
if self.args.multi_proc:
|
||||
## Run multiple processes in parallel
|
||||
run_file = os.path.join(os.getcwd(), 'main_exp', 'transfer_nag', 'run_multi_proc.py')
|
||||
MAX_CAP = 5 # hard-coded for available GPUs
|
||||
if not len(arch_idx_lst) > MAX_CAP:
|
||||
arch_idx_lst_ = [arch_idx for arch_idx in arch_idx_lst if not os.path.exists(os.path.join(meta_test_path, str(arch_idx)))]
|
||||
support_ = ','.join([str(i) for i in arch_idx_lst_])
|
||||
num_split = int(3 * len(arch_idx_lst_)) # why 3? => running for 3 seeds
|
||||
cmd = f"python {run_file} --num_split {num_split} --arch_idx_lst {support_} --meta_test_path {meta_test_path} --data_name {data_name} --raw_data_path {self.raw_data_path}"
|
||||
subprocess.run([cmd], shell=True)
|
||||
else:
|
||||
arch_idx_lst_ = []
|
||||
for j, arch_idx in enumerate(arch_idx_lst):
|
||||
if not os.path.exists(os.path.join(meta_test_path, str(arch_idx))):
|
||||
arch_idx_lst_.append(arch_idx)
|
||||
if (len(arch_idx_lst_) == MAX_CAP) or (j == len(arch_idx_lst) - 1):
|
||||
support_ = ','.join([str(i) for i in arch_idx_lst_])
|
||||
num_split = int(3 * len(arch_idx_lst_))
|
||||
cmd = f"python {run_file} --num_split {num_split} --arch_idx_lst {support_} --meta_test_path {meta_test_path} --data_name {data_name} --raw_data_path {self.raw_data_path}"
|
||||
subprocess.run([cmd], shell=True)
|
||||
arch_idx_lst_ = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
acc_runs_lst = []
|
||||
epoch = 199
|
||||
seeds = (777, 888, 999)
|
||||
for arch_idx in arch_idx_lst:
|
||||
acc_runs = []
|
||||
save_path_ = os.path.join(meta_test_path, str(arch_idx))
|
||||
for seed in seeds:
|
||||
result = torch.load(os.path.join(save_path_, f'seed-0{seed}.pth'))
|
||||
acc_runs.append(result[data_name]['valid_acc1es'][f'x-test@{epoch}'])
|
||||
acc_runs_lst.append(acc_runs)
|
||||
break
|
||||
except:
|
||||
pass
|
||||
for i in acc_runs_lst:print(np.mean(i))
|
||||
for arch_idx, acc_runs in zip(arch_idx_lst, acc_runs_lst):
|
||||
for r, acc in enumerate(acc_runs):
|
||||
msg = f'run {r+1} {acc:.2f} (%)'
|
||||
f_arch_acc.write(msg + '\n')
|
||||
m, h = mean_confidence_interval(acc_runs)
|
||||
msg = f'Avg {m:.2f}+-{h.item():.2f} (%)'
|
||||
f_arch_acc.write(msg + '\n')
|
||||
arch_acc_lst.append(np.mean(acc_runs))
|
||||
arch_str_lst.append(all_arch_str[arch_idx])
|
||||
|
||||
else:
|
||||
for arch_idx in arch_idx_lst:
|
||||
acc_runs = self.train_single_arch(
|
||||
data_name, self.nasbench201['str'][arch_idx], meta_test_path)
|
||||
for r, acc in enumerate(acc_runs):
|
||||
msg = f'run {r+1} {acc:.2f} (%)'
|
||||
f_arch_acc.write(msg + '\n')
|
||||
m, h = mean_confidence_interval(acc_runs)
|
||||
msg = f'Avg {m:.2f}+-{h.item():.2f} (%)'
|
||||
f_arch_acc.write(msg + '\n')
|
||||
arch_acc_lst.append(np.mean(acc_runs))
|
||||
arch_str_lst.append(all_arch_str[arch_idx])
|
||||
|
||||
# Save results
|
||||
results_path = os.path.join(self.args.exp_name, 'results.pt')
|
||||
torch.save({
|
||||
'arch_idx_lst': arch_idx_lst,
|
||||
'arch_str_lst': arch_str_lst,
|
||||
'arch_acc_lst': arch_acc_lst
|
||||
}, results_path)
|
||||
print(f">>> Save the results at {results_path}...")
|
||||
|
||||
|
||||
def train_single_arch(self, data_name, arch_str, meta_test_path):
|
||||
save_path = os.path.join(meta_test_path, arch_str)
|
||||
seeds = (777, 888, 999)
|
||||
train_single_model(save_dir=save_path,
|
||||
workers=24,
|
||||
datasets=[data_name],
|
||||
xpaths=[f'{self.raw_data_path}/{data_name}'],
|
||||
splits=[0],
|
||||
use_less=False,
|
||||
seeds=seeds,
|
||||
model_str=arch_str,
|
||||
arch_config={'channel': 16, 'num_cells': 5})
|
||||
epoch = 199
|
||||
test_acc_lst = []
|
||||
for seed in seeds:
|
||||
result = torch.load(os.path.join(save_path, f'seed-0{seed}.pth'))
|
||||
test_acc_lst.append(result[data_name]['valid_acc1es'][f'x-test@{epoch}'])
|
||||
return test_acc_lst
|
||||
|
||||
|
||||
def sort_arch(self, data_name, y_pred_all, gen_arch_str):
|
||||
_, sorted_idx = torch.sort(y_pred_all, descending=True)
|
||||
sotred_gen_arch_str = [gen_arch_str[_] for _ in sorted_idx]
|
||||
return sotred_gen_arch_str
|
||||
|
||||
|
||||
def collect_data_only(self):
|
||||
x_batch = []
|
||||
x_batch.append(self.test_dataset[0])
|
||||
return torch.stack(x_batch).to(self.device)
|
||||
|
||||
|
||||
def collect_data(self, arch_igraph):
|
||||
x_batch, g_batch = [], []
|
||||
for _ in range(10):
|
||||
x_batch.append(self.test_dataset[0])
|
||||
g_batch.append(arch_igraph)
|
||||
return torch.stack(x_batch).to(self.device), g_batch
|
||||
|
||||
|
||||
def get_items(self, full_target, full_source, source):
|
||||
return [full_target[full_source.index(_)] for _ in source]
|
||||
|
||||
|
||||
def load_diffusion_model(self, args):
|
||||
self.config = torch.load('./configs/transfer_nag_config.pt')
|
||||
self.config.device = torch.device('cuda')
|
||||
self.config.data.label_list = ['meta-acc']
|
||||
self.config.scorenet_ckpt_path = SCORENET_CKPT_PATH
|
||||
self.config.sampling.classifier_scale = args.classifier_scale
|
||||
self.config.eval.batch_size = args.eval_batch_size
|
||||
self.config.sampling.predictor = args.predictor
|
||||
self.config.sampling.corrector = args.corrector
|
||||
self.config.sampling.check_dataname = self.data_name
|
||||
self.sampling_fn, self.sde = get_sampling_fn_meta(self.config)
|
||||
self.score_model, self.score_ema, self.score_config = get_score_model(self.config)
|
||||
|
||||
|
||||
def get_gen_arch_str(self):
|
||||
## Load meta-surrogate model
|
||||
meta_surrogate_config = torch.load(self.meta_surrogate_ckpt_path)['config']
|
||||
meta_surrogate_model = get_surrogate(meta_surrogate_config)
|
||||
meta_surrogate_state = dict(model=meta_surrogate_model, step=0, config=meta_surrogate_config)
|
||||
meta_surrogate_state = restore_checkpoint(
|
||||
self.meta_surrogate_ckpt_path,
|
||||
meta_surrogate_state,
|
||||
device=self.config.device,
|
||||
resume=True)
|
||||
|
||||
## Get dataset embedding, x
|
||||
with torch.no_grad():
|
||||
x = self.collect_data_only()
|
||||
|
||||
## Generate architectures
|
||||
generated_arch_str = generate_archs_meta(
|
||||
config=self.config,
|
||||
sampling_fn=self.sampling_fn,
|
||||
score_model=self.score_model,
|
||||
score_ema=self.score_ema,
|
||||
meta_surrogate_model=meta_surrogate_model,
|
||||
num_samples=self.args.n_gen_samples,
|
||||
args=self.args,
|
||||
task=x)
|
||||
|
||||
## Clean up
|
||||
meta_surrogate_model = None
|
||||
gc.collect()
|
||||
|
||||
return generated_arch_str
|
||||
301
NAS-Bench-201/main_exp/transfer_nag/nag_utils.py
Normal file
301
NAS-Bench-201/main_exp/transfer_nag/nag_utils.py
Normal file
@@ -0,0 +1,301 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import time
|
||||
import igraph
|
||||
import random
|
||||
import numpy as np
|
||||
import scipy.stats
|
||||
import torch
|
||||
import logging
|
||||
|
||||
|
||||
def reset_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def restore_checkpoint(ckpt_dir, state, device, resume=False):
|
||||
if not resume:
|
||||
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
|
||||
return state
|
||||
elif not os.path.exists(ckpt_dir):
|
||||
if not os.path.exists(os.path.dirname(ckpt_dir)):
|
||||
os.makedirs(os.path.dirname(ckpt_dir))
|
||||
logging.warning(f"No checkpoint found at {ckpt_dir}. "
|
||||
f"Returned the same state as input")
|
||||
return state
|
||||
else:
|
||||
loaded_state = torch.load(ckpt_dir, map_location=device)
|
||||
for k in state:
|
||||
if k in ['optimizer', 'model', 'ema']:
|
||||
state[k].load_state_dict(loaded_state[k])
|
||||
else:
|
||||
state[k] = loaded_state[k]
|
||||
return state
|
||||
|
||||
|
||||
def load_graph_config(graph_data_name, nvt, data_path):
|
||||
if graph_data_name is not 'nasbench201':
|
||||
raise NotImplementedError(graph_data_name)
|
||||
g_list = []
|
||||
max_n = 0 # maximum number of nodes
|
||||
ms = torch.load(data_path)['arch']['matrix']
|
||||
for i in range(len(ms)):
|
||||
g, n = decode_NAS_BENCH_201_8_to_igraph(ms[i])
|
||||
max_n = max(max_n, n)
|
||||
g_list.append((g, 0))
|
||||
# number of different node types including in/out node
|
||||
graph_config = {}
|
||||
graph_config['num_vertex_type'] = nvt # original types + start/end types
|
||||
graph_config['max_n'] = max_n # maximum number of nodes
|
||||
graph_config['START_TYPE'] = 0 # predefined start vertex type
|
||||
graph_config['END_TYPE'] = 1 # predefined end vertex type
|
||||
|
||||
return graph_config
|
||||
|
||||
|
||||
def decode_NAS_BENCH_201_8_to_igraph(row):
|
||||
if type(row) == str:
|
||||
row = eval(row) # convert string to list of lists
|
||||
n = len(row)
|
||||
g = igraph.Graph(directed=True)
|
||||
g.add_vertices(n)
|
||||
for i, node in enumerate(row):
|
||||
g.vs[i]['type'] = node[0]
|
||||
if i < (n - 2) and i > 0:
|
||||
g.add_edge(i, i + 1) # always connect from last node
|
||||
for j, edge in enumerate(node[1:]):
|
||||
if edge == 1:
|
||||
g.add_edge(j, i)
|
||||
return g, n
|
||||
|
||||
|
||||
def is_valid_NAS201(g, START_TYPE=0, END_TYPE=1):
|
||||
# first need to be a valid DAG computation graph
|
||||
res = is_valid_DAG(g, START_TYPE, END_TYPE)
|
||||
# in addition, node i must connect to node i+1
|
||||
res = res and len(g.vs['type']) == 8
|
||||
res = res and not (0 in g.vs['type'][1:-1])
|
||||
res = res and not (1 in g.vs['type'][1:-1])
|
||||
return res
|
||||
|
||||
|
||||
def decode_igraph_to_NAS201_matrix(g):
|
||||
m = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
|
||||
xys = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]
|
||||
for i, xy in enumerate(xys):
|
||||
m[xy[0]][xy[1]] = float(g.vs[i + 1]['type']) - 2
|
||||
import numpy
|
||||
return numpy.array(m)
|
||||
|
||||
|
||||
def decode_igraph_to_NAS_BENCH_201_string(g):
|
||||
if not is_valid_NAS201(g):
|
||||
return None
|
||||
m = decode_igraph_to_NAS201_matrix(g)
|
||||
types = ['none', 'skip_connect', 'nor_conv_1x1',
|
||||
'nor_conv_3x3', 'avg_pool_3x3']
|
||||
return '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.\
|
||||
format(types[int(m[1][0])],
|
||||
types[int(m[2][0])], types[int(m[2][1])],
|
||||
types[int(m[3][0])], types[int(m[3][1])], types[int(m[3][2])])
|
||||
|
||||
|
||||
def is_valid_DAG(g, START_TYPE=0, END_TYPE=1):
|
||||
res = g.is_dag()
|
||||
n_start, n_end = 0, 0
|
||||
for v in g.vs:
|
||||
if v['type'] == START_TYPE:
|
||||
n_start += 1
|
||||
elif v['type'] == END_TYPE:
|
||||
n_end += 1
|
||||
if v.indegree() == 0 and v['type'] != START_TYPE:
|
||||
return False
|
||||
if v.outdegree() == 0 and v['type'] != END_TYPE:
|
||||
return False
|
||||
return res and n_start == 1 and n_end == 1
|
||||
|
||||
|
||||
class Accumulator():
|
||||
def __init__(self, *args):
|
||||
self.args = args
|
||||
self.argdict = {}
|
||||
for i, arg in enumerate(args):
|
||||
self.argdict[arg] = i
|
||||
self.sums = [0] * len(args)
|
||||
self.cnt = 0
|
||||
|
||||
def accum(self, val):
|
||||
val = [val] if type(val) is not list else val
|
||||
val = [v for v in val if v is not None]
|
||||
assert (len(val) == len(self.args))
|
||||
for i in range(len(val)):
|
||||
if torch.is_tensor(val[i]):
|
||||
val[i] = val[i].item()
|
||||
self.sums[i] += val[i]
|
||||
self.cnt += 1
|
||||
|
||||
def clear(self):
|
||||
self.sums = [0] * len(self.args)
|
||||
self.cnt = 0
|
||||
|
||||
def get(self, arg, avg=True):
|
||||
i = self.argdict.get(arg, -1)
|
||||
assert (i is not -1)
|
||||
if avg:
|
||||
return self.sums[i] / (self.cnt + 1e-8)
|
||||
else:
|
||||
return self.sums[i]
|
||||
|
||||
def print_(self, header=None, time=None,
|
||||
logfile=None, do_not_print=[], as_int=[],
|
||||
avg=True):
|
||||
msg = '' if header is None else header + ': '
|
||||
if time is not None:
|
||||
msg += ('(%.3f secs), ' % time)
|
||||
|
||||
args = [arg for arg in self.args if arg not in do_not_print]
|
||||
arg = []
|
||||
for arg in args:
|
||||
val = self.sums[self.argdict[arg]]
|
||||
if avg:
|
||||
val /= (self.cnt + 1e-8)
|
||||
if arg in as_int:
|
||||
msg += ('%s %d, ' % (arg, int(val)))
|
||||
else:
|
||||
msg += ('%s %.4f, ' % (arg, val))
|
||||
print(msg)
|
||||
|
||||
if logfile is not None:
|
||||
logfile.write(msg + '\n')
|
||||
logfile.flush()
|
||||
|
||||
def add_scalars(self, summary, header=None, tag_scalar=None,
|
||||
step=None, avg=True, args=None):
|
||||
for arg in self.args:
|
||||
val = self.sums[self.argdict[arg]]
|
||||
if avg:
|
||||
val /= (self.cnt + 1e-8)
|
||||
else:
|
||||
val = val
|
||||
tag = f'{header}/{arg}' if header is not None else arg
|
||||
if tag_scalar is not None:
|
||||
summary.add_scalars(main_tag=tag,
|
||||
tag_scalar_dict={tag_scalar: val},
|
||||
global_step=step)
|
||||
else:
|
||||
summary.add_scalar(tag=tag,
|
||||
scalar_value=val,
|
||||
global_step=step)
|
||||
|
||||
|
||||
class Log:
|
||||
def __init__(self, args, logf, summary=None):
|
||||
self.args = args
|
||||
self.logf = logf
|
||||
self.summary = summary
|
||||
self.stime = time.time()
|
||||
self.ep_sttime = None
|
||||
|
||||
def print(self, logger, epoch, tag=None, avg=True):
|
||||
if tag == 'train':
|
||||
ct = time.time() - self.ep_sttime
|
||||
tt = time.time() - self.stime
|
||||
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
|
||||
print(msg)
|
||||
self.logf.write(msg+'\n')
|
||||
logger.print_(header=tag, logfile=self.logf, avg=avg)
|
||||
|
||||
if self.summary is not None:
|
||||
logger.add_scalars(
|
||||
self.summary, header=tag, step=epoch, avg=avg)
|
||||
logger.clear()
|
||||
|
||||
def print_args(self):
|
||||
argdict = vars(self.args)
|
||||
print(argdict)
|
||||
for k, v in argdict.items():
|
||||
self.logf.write(k + ': ' + str(v) + '\n')
|
||||
self.logf.write('\n')
|
||||
|
||||
def set_time(self):
|
||||
self.stime = time.time()
|
||||
|
||||
def save_time_log(self):
|
||||
ct = time.time() - self.stime
|
||||
msg = f'({ct:6.2f}s) meta-training phase done'
|
||||
print(msg)
|
||||
self.logf.write(msg+'\n')
|
||||
|
||||
def print_pred_log(self, loss, corr, tag, epoch=None, max_corr_dict=None):
|
||||
if tag == 'train':
|
||||
ct = time.time() - self.ep_sttime
|
||||
tt = time.time() - self.stime
|
||||
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
|
||||
self.logf.write(msg+'\n')
|
||||
print(msg)
|
||||
self.logf.flush()
|
||||
# msg = f'ep {epoch:3d} ep time {time.time() - ep_sttime:8.2f} '
|
||||
# msg += f'time {time.time() - sttime:6.2f} '
|
||||
if max_corr_dict is not None:
|
||||
max_corr = max_corr_dict['corr']
|
||||
max_loss = max_corr_dict['loss']
|
||||
msg = f'{tag}: loss {loss:.6f} ({max_loss:.6f}) '
|
||||
msg += f'corr {corr:.4f} ({max_corr:.4f})'
|
||||
else:
|
||||
msg = f'{tag}: loss {loss:.6f} corr {corr:.4f}'
|
||||
self.logf.write(msg+'\n')
|
||||
print(msg)
|
||||
self.logf.flush()
|
||||
|
||||
def max_corr_log(self, max_corr_dict):
|
||||
corr = max_corr_dict['corr']
|
||||
loss = max_corr_dict['loss']
|
||||
epoch = max_corr_dict['epoch']
|
||||
msg = f'[epoch {epoch}] max correlation: {corr:.4f}, loss: {loss:.6f}'
|
||||
self.logf.write(msg+'\n')
|
||||
print(msg)
|
||||
self.logf.flush()
|
||||
|
||||
|
||||
def get_log(epoch, loss, y_pred, y, acc_std, acc_mean, tag='train'):
|
||||
msg = f'[{tag}] Ep {epoch} loss {loss.item()/len(y):0.4f} '
|
||||
if type(y_pred) == list:
|
||||
msg += f'pacc {y_pred[0]:0.4f}'
|
||||
msg += f'({y_pred[0]*100.0*acc_std+acc_mean:0.4f}) '
|
||||
else:
|
||||
msg += f'pacc {y_pred:0.4f}'
|
||||
msg += f'({y_pred*100.0*acc_std+acc_mean:0.4f}) '
|
||||
msg += f'acc {y[0]:0.4f}({y[0]*100*acc_std+acc_mean:0.4f})'
|
||||
return msg
|
||||
|
||||
|
||||
def load_model(model, ckpt_path):
|
||||
model.cpu()
|
||||
model.load_state_dict(torch.load(ckpt_path))
|
||||
|
||||
|
||||
def save_model(epoch, model, model_path, max_corr=None):
|
||||
print("==> save current model...")
|
||||
if max_corr is not None:
|
||||
torch.save(model.cpu().state_dict(),
|
||||
os.path.join(model_path, 'ckpt_max_corr.pt'))
|
||||
else:
|
||||
torch.save(model.cpu().state_dict(),
|
||||
os.path.join(model_path, f'ckpt_{epoch}.pt'))
|
||||
|
||||
|
||||
def mean_confidence_interval(data, confidence=0.95):
|
||||
a = 1.0 * np.array(data)
|
||||
n = len(a)
|
||||
m, se = np.mean(a), scipy.stats.sem(a)
|
||||
h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
|
||||
return m, h
|
||||
@@ -0,0 +1,6 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
dir_path = (Path(__file__).parent).resolve()
|
||||
if str(dir_path) not in sys.path: sys.path.insert(0, str(dir_path))
|
||||
|
||||
from .architecture import train_single_model
|
||||
@@ -0,0 +1,173 @@
|
||||
###############################################################
|
||||
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
|
||||
###############################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
###############################################################
|
||||
from functions import evaluate_for_seed
|
||||
from nas_bench_201_models import CellStructure, CellArchitectures, get_search_spaces
|
||||
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
||||
from nas_bench_201_datasets import get_datasets
|
||||
from procedures import get_machine_info
|
||||
from procedures import save_checkpoint, copy_checkpoint
|
||||
from config_utils import load_config
|
||||
from pathlib import Path
|
||||
from copy import deepcopy
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
import random
|
||||
import argparse
|
||||
from PIL import ImageFile
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
NASBENCH201_CONFIG_PATH = os.path.join(
|
||||
os.getcwd(), 'main_exp', 'transfer_nag')
|
||||
|
||||
|
||||
def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed,
|
||||
arch_config, workers, logger):
|
||||
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
|
||||
all_infos = {'info': machine_info}
|
||||
all_dataset_keys = []
|
||||
# look all the datasets
|
||||
for dataset, xpath, split in zip(datasets, xpaths, splits):
|
||||
# train valid data
|
||||
task = None
|
||||
train_data, valid_data, xshape, class_num = get_datasets(
|
||||
dataset, xpath, -1, task)
|
||||
|
||||
# load the configuration
|
||||
if dataset in ['mnist', 'svhn', 'aircraft', 'pets']:
|
||||
if use_less:
|
||||
config_path = os.path.join(
|
||||
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/LESS.config')
|
||||
else:
|
||||
config_path = os.path.join(
|
||||
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{}.config'.format(dataset))
|
||||
|
||||
p = os.path.join(
|
||||
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{:}-split.txt'.format(dataset))
|
||||
if not os.path.exists(p):
|
||||
import json
|
||||
label_list = list(range(len(train_data)))
|
||||
random.shuffle(label_list)
|
||||
strlist = [str(label_list[i]) for i in range(len(label_list))]
|
||||
splited = {'train': ["int", strlist[:len(train_data) // 2]],
|
||||
'valid': ["int", strlist[len(train_data) // 2:]]}
|
||||
with open(p, 'w') as f:
|
||||
f.write(json.dumps(splited))
|
||||
split_info = load_config(os.path.join(
|
||||
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{:}-split.txt'.format(dataset)), None, None)
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(dataset))
|
||||
|
||||
config = load_config(
|
||||
config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
# data loader
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size,
|
||||
shuffle=True, num_workers=workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
|
||||
shuffle=False, num_workers=workers, pin_memory=True)
|
||||
splits = load_config(os.path.join(
|
||||
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{}-test-split.txt'.format(dataset)), None, None)
|
||||
ValLoaders = {'ori-test': valid_loader,
|
||||
'x-valid': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||
splits.xvalid),
|
||||
num_workers=workers, pin_memory=True),
|
||||
'x-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||
splits.xtest),
|
||||
num_workers=workers, pin_memory=True)
|
||||
}
|
||||
dataset_key = '{:}'.format(dataset)
|
||||
if bool(split):
|
||||
dataset_key = dataset_key + '-valid'
|
||||
logger.log(
|
||||
'Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.
|
||||
format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(
|
||||
dataset_key, config))
|
||||
for key, value in ValLoaders.items():
|
||||
logger.log(
|
||||
'Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value)))
|
||||
|
||||
results = evaluate_for_seed(
|
||||
arch_config, config, arch, train_loader, ValLoaders, seed, logger)
|
||||
all_infos[dataset_key] = results
|
||||
all_dataset_keys.append(dataset_key)
|
||||
all_infos['all_dataset_keys'] = all_dataset_keys
|
||||
return all_infos
|
||||
|
||||
|
||||
def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less,
|
||||
seeds, model_str, arch_config):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads(workers)
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
logger = Logger(str(save_dir), 0, False)
|
||||
|
||||
if model_str in CellArchitectures:
|
||||
arch = CellArchitectures[model_str]
|
||||
logger.log(
|
||||
'The model string is found in pre-defined architecture dict : {:}'.format(model_str))
|
||||
else:
|
||||
try:
|
||||
arch = CellStructure.str2structure(model_str)
|
||||
except:
|
||||
raise ValueError(
|
||||
'Invalid model string : {:}. It can not be found or parsed.'.format(model_str))
|
||||
|
||||
assert arch.check_valid_op(get_search_spaces(
|
||||
'cell', 'nas-bench-201')), '{:} has the invalid op.'.format(arch)
|
||||
# assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch)
|
||||
logger.log('Start train-evaluate {:}'.format(arch.tostr()))
|
||||
logger.log('arch_config : {:}'.format(arch_config))
|
||||
|
||||
start_time, seed_time = time.time(), AverageMeter()
|
||||
for _is, seed in enumerate(seeds):
|
||||
logger.log(
|
||||
'\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds),
|
||||
seed))
|
||||
to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed)
|
||||
if to_save_name.exists():
|
||||
logger.log(
|
||||
'Find the existing file {:}, directly load!'.format(to_save_name))
|
||||
checkpoint = torch.load(to_save_name)
|
||||
else:
|
||||
logger.log(
|
||||
'Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
|
||||
checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less,
|
||||
seed, arch_config, workers, logger)
|
||||
torch.save(checkpoint, to_save_name)
|
||||
# log information
|
||||
logger.log('{:}'.format(checkpoint['info']))
|
||||
all_dataset_keys = checkpoint['all_dataset_keys']
|
||||
for dataset_key in all_dataset_keys:
|
||||
logger.log('\n{:} dataset : {:} {:}'.format(
|
||||
'-' * 15, dataset_key, '-' * 15))
|
||||
dataset_info = checkpoint[dataset_key]
|
||||
# logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
|
||||
logger.log('Flops = {:} MB, Params = {:} MB'.format(
|
||||
dataset_info['flop'], dataset_info['param']))
|
||||
logger.log('config : {:}'.format(dataset_info['config']))
|
||||
logger.log('Training State (finish) = {:}'.format(
|
||||
dataset_info['finish-train']))
|
||||
last_epoch = dataset_info['total_epoch'] - 1
|
||||
train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es']
|
||||
valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es']
|
||||
# measure elapsed time
|
||||
seed_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
need_time = 'Time Left: {:}'.format(convert_secs2time(
|
||||
seed_time.avg * (len(seeds) - _is - 1), True))
|
||||
logger.log(
|
||||
'\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}'.format(_is, len(seeds), seed,
|
||||
need_time))
|
||||
logger.close()
|
||||
@@ -0,0 +1,13 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .configure_utils import load_config, dict2config#, configure2str
|
||||
#from .basic_args import obtain_basic_args
|
||||
#from .attention_args import obtain_attention_args
|
||||
#from .random_baseline import obtain_RandomSearch_args
|
||||
#from .cls_kd_args import obtain_cls_kd_args
|
||||
#from .cls_init_args import obtain_cls_init_args
|
||||
#from .search_single_args import obtain_search_single_args
|
||||
#from .search_args import obtain_search_args
|
||||
# for network pruning
|
||||
#from .pruning_args import obtain_pruning_args
|
||||
@@ -0,0 +1,106 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
import os, json
|
||||
from os import path as osp
|
||||
from pathlib import Path
|
||||
from collections import namedtuple
|
||||
|
||||
support_types = ('str', 'int', 'bool', 'float', 'none')
|
||||
|
||||
|
||||
def convert_param(original_lists):
|
||||
assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists)
|
||||
ctype, value = original_lists[0], original_lists[1]
|
||||
assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types)
|
||||
is_list = isinstance(value, list)
|
||||
if not is_list: value = [value]
|
||||
outs = []
|
||||
for x in value:
|
||||
if ctype == 'int':
|
||||
x = int(x)
|
||||
elif ctype == 'str':
|
||||
x = str(x)
|
||||
elif ctype == 'bool':
|
||||
x = bool(int(x))
|
||||
elif ctype == 'float':
|
||||
x = float(x)
|
||||
elif ctype == 'none':
|
||||
if x.lower() != 'none':
|
||||
raise ValueError('For the none type, the value must be none instead of {:}'.format(x))
|
||||
x = None
|
||||
else:
|
||||
raise TypeError('Does not know this type : {:}'.format(ctype))
|
||||
outs.append(x)
|
||||
if not is_list: outs = outs[0]
|
||||
return outs
|
||||
|
||||
|
||||
def load_config(path, extra, logger):
|
||||
path = str(path)
|
||||
if hasattr(logger, 'log'): logger.log(path)
|
||||
assert os.path.exists(path), 'Can not find {:}'.format(path)
|
||||
# Reading data back
|
||||
with open(path, 'r') as f:
|
||||
data = json.load(f)
|
||||
content = { k: convert_param(v) for k,v in data.items()}
|
||||
assert extra is None or isinstance(extra, dict), 'invalid type of extra : {:}'.format(extra)
|
||||
if isinstance(extra, dict): content = {**content, **extra}
|
||||
Arguments = namedtuple('Configure', ' '.join(content.keys()))
|
||||
content = Arguments(**content)
|
||||
if hasattr(logger, 'log'): logger.log('{:}'.format(content))
|
||||
return content
|
||||
|
||||
|
||||
def configure2str(config, xpath=None):
|
||||
if not isinstance(config, dict):
|
||||
config = config._asdict()
|
||||
def cstring(x):
|
||||
return "\"{:}\"".format(x)
|
||||
def gtype(x):
|
||||
if isinstance(x, list): x = x[0]
|
||||
if isinstance(x, str) : return 'str'
|
||||
elif isinstance(x, bool) : return 'bool'
|
||||
elif isinstance(x, int): return 'int'
|
||||
elif isinstance(x, float): return 'float'
|
||||
elif x is None : return 'none'
|
||||
else: raise ValueError('invalid : {:}'.format(x))
|
||||
def cvalue(x, xtype):
|
||||
if isinstance(x, list): is_list = True
|
||||
else:
|
||||
is_list, x = False, [x]
|
||||
temps = []
|
||||
for temp in x:
|
||||
if xtype == 'bool' : temp = cstring(int(temp))
|
||||
elif xtype == 'none': temp = cstring('None')
|
||||
else : temp = cstring(temp)
|
||||
temps.append( temp )
|
||||
if is_list:
|
||||
return "[{:}]".format( ', '.join( temps ) )
|
||||
else:
|
||||
return temps[0]
|
||||
|
||||
xstrings = []
|
||||
for key, value in config.items():
|
||||
xtype = gtype(value)
|
||||
string = ' {:20s} : [{:8s}, {:}]'.format(cstring(key), cstring(xtype), cvalue(value, xtype))
|
||||
xstrings.append(string)
|
||||
Fstring = '{\n' + ',\n'.join(xstrings) + '\n}'
|
||||
if xpath is not None:
|
||||
parent = Path(xpath).resolve().parent
|
||||
parent.mkdir(parents=True, exist_ok=True)
|
||||
if osp.isfile(xpath): os.remove(xpath)
|
||||
with open(xpath, "w") as text_file:
|
||||
text_file.write('{:}'.format(Fstring))
|
||||
return Fstring
|
||||
|
||||
|
||||
def dict2config(xdict, logger):
|
||||
assert isinstance(xdict, dict), 'invalid type : {:}'.format( type(xdict) )
|
||||
Arguments = namedtuple('Configure', ' '.join(xdict.keys()))
|
||||
content = Arguments(**xdict)
|
||||
if hasattr(logger, 'log'): logger.log('{:}'.format(content))
|
||||
return content
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"scheduler": ["str", "cos"],
|
||||
"eta_min" : ["float", "0.0"],
|
||||
"epochs" : ["int", "200"],
|
||||
"warmup" : ["int", "0"],
|
||||
"optim" : ["str", "SGD"],
|
||||
"LR" : ["float", "0.1"],
|
||||
"decay" : ["float", "0.0005"],
|
||||
"momentum" : ["float", "0.9"],
|
||||
"nesterov" : ["bool", "1"],
|
||||
"criterion": ["str", "Softmax"],
|
||||
"batch_size": ["int", "256"]
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"scheduler": ["str", "cos"],
|
||||
"eta_min" : ["float", "0.0"],
|
||||
"epochs" : ["int", "50"],
|
||||
"warmup" : ["int", "0"],
|
||||
"optim" : ["str", "SGD"],
|
||||
"LR" : ["float", "0.1"],
|
||||
"decay" : ["float", "0.0005"],
|
||||
"momentum" : ["float", "0.9"],
|
||||
"nesterov" : ["bool", "1"],
|
||||
"criterion": ["str", "Softmax"],
|
||||
"batch_size": ["int", "256"]
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"scheduler": ["str", "cos"],
|
||||
"eta_min" : ["float", "0.0"],
|
||||
"epochs" : ["int", "200"],
|
||||
"warmup" : ["int", "0"],
|
||||
"optim" : ["str", "SGD"],
|
||||
"LR" : ["float", "0.1"],
|
||||
"decay" : ["float", "0.0005"],
|
||||
"momentum" : ["float", "0.9"],
|
||||
"nesterov" : ["bool", "1"],
|
||||
"criterion": ["str", "Softmax"],
|
||||
"batch_size": ["int", "256"]
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"scheduler": ["str", "cos"],
|
||||
"eta_min" : ["float", "0.0"],
|
||||
"epochs" : ["int", "200"],
|
||||
"warmup" : ["int", "0"],
|
||||
"optim" : ["str", "SGD"],
|
||||
"LR" : ["float", "0.1"],
|
||||
"decay" : ["float", "0.0005"],
|
||||
"momentum" : ["float", "0.9"],
|
||||
"nesterov" : ["bool", "1"],
|
||||
"criterion": ["str", "Softmax"],
|
||||
"batch_size": ["int", "256"]
|
||||
}
|
||||
153
NAS-Bench-201/main_exp/transfer_nag/nas_bench_201/functions.py
Normal file
153
NAS-Bench-201/main_exp/transfer_nag/nas_bench_201/functions.py
Normal file
@@ -0,0 +1,153 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
#####################################################
|
||||
import time
|
||||
import torch
|
||||
from procedures import prepare_seed, get_optim_scheduler
|
||||
from nasbench_utils import get_model_infos, obtain_accuracy
|
||||
from config_utils import dict2config
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from nas_bench_201_models import get_cell_based_tiny_net
|
||||
|
||||
|
||||
__all__ = ['evaluate_for_seed', 'pure_evaluate']
|
||||
|
||||
|
||||
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
|
||||
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
latencies = []
|
||||
network.eval()
|
||||
with torch.no_grad():
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
inputs = inputs.cuda(non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
batch_time.update(time.time() - end)
|
||||
if batch is None or batch == inputs.size(0):
|
||||
batch = inputs.size(0)
|
||||
latencies.append(batch_time.val - data_time.val)
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(
|
||||
logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
end = time.time()
|
||||
if len(latencies) > 2:
|
||||
latencies = latencies[1:]
|
||||
return losses.avg, top1.avg, top5.avg, latencies
|
||||
|
||||
|
||||
def procedure(xloader, network, criterion, scheduler, optimizer, mode):
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
if mode == 'train':
|
||||
network.train()
|
||||
elif mode == 'valid':
|
||||
network.eval()
|
||||
else:
|
||||
raise ValueError("The mode is not right : {:}".format(mode))
|
||||
|
||||
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == 'train':
|
||||
scheduler.update(None, 1.0 * i / len(xloader))
|
||||
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
if mode == 'train':
|
||||
optimizer.zero_grad()
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
# backward
|
||||
if mode == 'train':
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
# count time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
||||
|
||||
|
||||
def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, seed, logger):
|
||||
prepare_seed(seed) # random seed
|
||||
net = get_cell_based_tiny_net(dict2config({'name': 'infer.tiny',
|
||||
'C': arch_config['channel'], 'N': arch_config['num_cells'],
|
||||
'genotype': arch, 'num_classes': config.class_num}, None)
|
||||
)
|
||||
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
|
||||
if 'ckpt_path' in arch_config.keys():
|
||||
ckpt = torch.load(arch_config['ckpt_path'])
|
||||
ckpt['classifier.weight'] = net.state_dict()['classifier.weight']
|
||||
ckpt['classifier.bias'] = net.state_dict()['classifier.bias']
|
||||
net.load_state_dict(ckpt)
|
||||
|
||||
flop, param = get_model_infos(net, config.xshape)
|
||||
logger.log('Network : {:}'.format(net.get_message()), False)
|
||||
logger.log(
|
||||
'{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed))
|
||||
logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param))
|
||||
# train and valid
|
||||
optimizer, scheduler, criterion = get_optim_scheduler(
|
||||
net.parameters(), config)
|
||||
network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda()
|
||||
# network, criterion = torch.nn.DataParallel(net).to(torch.device(f"cuda:{device}")), criterion.to(torch.device(f"cuda:{device}"))
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = time.time(
|
||||
), AverageMeter(), config.epochs + config.warmup
|
||||
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {
|
||||
}, {}, {}, {}, {}, {}
|
||||
train_times, valid_times = {}, {}
|
||||
for epoch in range(total_epoch):
|
||||
scheduler.update(epoch, 0.0)
|
||||
|
||||
train_loss, train_acc1, train_acc5, train_tm = procedure(
|
||||
train_loader, network, criterion, scheduler, optimizer, 'train')
|
||||
train_losses[epoch] = train_loss
|
||||
train_acc1es[epoch] = train_acc1
|
||||
train_acc5es[epoch] = train_acc5
|
||||
train_times[epoch] = train_tm
|
||||
with torch.no_grad():
|
||||
for key, xloder in valid_loaders.items():
|
||||
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(
|
||||
xloder, network, criterion, None, None, 'valid')
|
||||
valid_losses['{:}@{:}'.format(key, epoch)] = valid_loss
|
||||
valid_acc1es['{:}@{:}'.format(key, epoch)] = valid_acc1
|
||||
valid_acc5es['{:}@{:}'.format(key, epoch)] = valid_acc5
|
||||
valid_times['{:}@{:}'.format(key, epoch)] = valid_tm
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
need_time = 'Time Left: {:}'.format(convert_secs2time(
|
||||
epoch_time.avg * (total_epoch-epoch-1), True))
|
||||
logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]'.format(
|
||||
time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5))
|
||||
info_seed = {'flop': flop,
|
||||
'param': param,
|
||||
'channel': arch_config['channel'],
|
||||
'num_cells': arch_config['num_cells'],
|
||||
'config': config._asdict(),
|
||||
'total_epoch': total_epoch,
|
||||
'train_losses': train_losses,
|
||||
'train_acc1es': train_acc1es,
|
||||
'train_acc5es': train_acc5es,
|
||||
'train_times': train_times,
|
||||
'valid_losses': valid_losses,
|
||||
'valid_acc1es': valid_acc1es,
|
||||
'valid_acc5es': valid_acc5es,
|
||||
'valid_times': valid_times,
|
||||
'net_state_dict': net.state_dict(),
|
||||
'net_string': '{:}'.format(net),
|
||||
'finish-train': True
|
||||
}
|
||||
return info_seed
|
||||
@@ -0,0 +1,9 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
# every package does not rely on pytorch or tensorflow
|
||||
# I tried to list all dependency here: os, sys, time, numpy, (possibly) matplotlib
|
||||
from .logger import Logger#, PrintLogger
|
||||
from .meter import AverageMeter
|
||||
from .time_utils import time_for_file, time_string, time_string_short, time_print, convert_secs2time
|
||||
from .time_utils import time_string, convert_secs2time
|
||||
@@ -0,0 +1,150 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from pathlib import Path
|
||||
import importlib, warnings
|
||||
import os, sys, time, numpy as np
|
||||
if sys.version_info.major == 2: # Python 2.x
|
||||
from StringIO import StringIO as BIO
|
||||
else: # Python 3.x
|
||||
from io import BytesIO as BIO
|
||||
|
||||
if importlib.util.find_spec('tensorflow'):
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class PrintLogger(object):
|
||||
|
||||
def __init__(self):
|
||||
"""Create a summary writer logging to log_dir."""
|
||||
self.name = 'PrintLogger'
|
||||
|
||||
def log(self, string):
|
||||
print (string)
|
||||
|
||||
def close(self):
|
||||
print ('-'*30 + ' close printer ' + '-'*30)
|
||||
|
||||
|
||||
class Logger(object):
|
||||
|
||||
def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False):
|
||||
"""Create a summary writer logging to log_dir."""
|
||||
self.seed = int(seed)
|
||||
self.log_dir = Path(log_dir)
|
||||
self.model_dir = Path(log_dir) / 'checkpoint'
|
||||
self.log_dir.mkdir (parents=True, exist_ok=True)
|
||||
if create_model_dir:
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
#self.meta_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
|
||||
|
||||
self.use_tf = bool(use_tf)
|
||||
self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h', time.gmtime(time.time()) )))
|
||||
#self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h-at-%H:%M:%S', time.gmtime(time.time()) )))
|
||||
self.logger_path = self.log_dir / 'seed-{:}-T-{:}.log'.format(self.seed, time.strftime('%d-%h-at-%H-%M-%S', time.gmtime(time.time())))
|
||||
self.logger_file = open(self.logger_path, 'w')
|
||||
|
||||
if self.use_tf:
|
||||
self.tensorboard_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
|
||||
self.writer = tf.summary.FileWriter(str(self.tensorboard_dir))
|
||||
else:
|
||||
self.writer = None
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(dir={log_dir}, use-tf={use_tf}, writer={writer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def path(self, mode):
|
||||
valids = ('model', 'best', 'info', 'log')
|
||||
if mode == 'model': return self.model_dir / 'seed-{:}-basic.pth'.format(self.seed)
|
||||
elif mode == 'best' : return self.model_dir / 'seed-{:}-best.pth'.format(self.seed)
|
||||
elif mode == 'info' : return self.log_dir / 'seed-{:}-last-info.pth'.format(self.seed)
|
||||
elif mode == 'log' : return self.log_dir
|
||||
else: raise TypeError('Unknow mode = {:}, valid modes = {:}'.format(mode, valids))
|
||||
|
||||
def extract_log(self):
|
||||
return self.logger_file
|
||||
|
||||
def close(self):
|
||||
self.logger_file.close()
|
||||
if self.writer is not None:
|
||||
self.writer.close()
|
||||
|
||||
def log(self, string, save=True, stdout=False):
|
||||
if stdout:
|
||||
sys.stdout.write(string); sys.stdout.flush()
|
||||
else:
|
||||
print (string)
|
||||
if save:
|
||||
self.logger_file.write('{:}\n'.format(string))
|
||||
self.logger_file.flush()
|
||||
|
||||
def scalar_summary(self, tags, values, step):
|
||||
"""Log a scalar variable."""
|
||||
if not self.use_tf:
|
||||
warnings.warn('Do set use-tensorflow installed but call scalar_summary')
|
||||
else:
|
||||
assert isinstance(tags, list) == isinstance(values, list), 'Type : {:} vs {:}'.format(type(tags), type(values))
|
||||
if not isinstance(tags, list):
|
||||
tags, values = [tags], [values]
|
||||
for tag, value in zip(tags, values):
|
||||
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
|
||||
self.writer.add_summary(summary, step)
|
||||
self.writer.flush()
|
||||
|
||||
def image_summary(self, tag, images, step):
|
||||
"""Log a list of images."""
|
||||
import scipy
|
||||
if not self.use_tf:
|
||||
warnings.warn('Do set use-tensorflow installed but call scalar_summary')
|
||||
return
|
||||
|
||||
img_summaries = []
|
||||
for i, img in enumerate(images):
|
||||
# Write the image to a string
|
||||
try:
|
||||
s = StringIO()
|
||||
except:
|
||||
s = BytesIO()
|
||||
scipy.misc.toimage(img).save(s, format="png")
|
||||
|
||||
# Create an Image object
|
||||
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
|
||||
height=img.shape[0],
|
||||
width=img.shape[1])
|
||||
# Create a Summary value
|
||||
img_summaries.append(tf.Summary.Value(tag='{}/{}'.format(tag, i), image=img_sum))
|
||||
|
||||
# Create and write Summary
|
||||
summary = tf.Summary(value=img_summaries)
|
||||
self.writer.add_summary(summary, step)
|
||||
self.writer.flush()
|
||||
|
||||
def histo_summary(self, tag, values, step, bins=1000):
|
||||
"""Log a histogram of the tensor of values."""
|
||||
if not self.use_tf: raise ValueError('Do not have tensorflow')
|
||||
import tensorflow as tf
|
||||
|
||||
# Create a histogram using numpy
|
||||
counts, bin_edges = np.histogram(values, bins=bins)
|
||||
|
||||
# Fill the fields of the histogram proto
|
||||
hist = tf.HistogramProto()
|
||||
hist.min = float(np.min(values))
|
||||
hist.max = float(np.max(values))
|
||||
hist.num = int(np.prod(values.shape))
|
||||
hist.sum = float(np.sum(values))
|
||||
hist.sum_squares = float(np.sum(values**2))
|
||||
|
||||
# Drop the start of the first bin
|
||||
bin_edges = bin_edges[1:]
|
||||
|
||||
# Add bin edges and counts
|
||||
for edge in bin_edges:
|
||||
hist.bucket_limit.append(edge)
|
||||
for c in counts:
|
||||
hist.bucket.append(c)
|
||||
|
||||
# Create and write Summary
|
||||
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
|
||||
self.writer.add_summary(summary, step)
|
||||
self.writer.flush()
|
||||
@@ -0,0 +1,98 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0.0
|
||||
self.avg = 0.0
|
||||
self.sum = 0.0
|
||||
self.count = 0.0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(val={val}, avg={avg}, count={count})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
|
||||
class RecorderMeter(object):
|
||||
"""Computes and stores the minimum loss value and its epoch index"""
|
||||
def __init__(self, total_epoch):
|
||||
self.reset(total_epoch)
|
||||
|
||||
def reset(self, total_epoch):
|
||||
assert total_epoch > 0, 'total_epoch should be greater than 0 vs {:}'.format(total_epoch)
|
||||
self.total_epoch = total_epoch
|
||||
self.current_epoch = 0
|
||||
self.epoch_losses = np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
|
||||
self.epoch_losses = self.epoch_losses - 1
|
||||
self.epoch_accuracy= np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
|
||||
self.epoch_accuracy= self.epoch_accuracy
|
||||
|
||||
def update(self, idx, train_loss, train_acc, val_loss, val_acc):
|
||||
assert idx >= 0 and idx < self.total_epoch, 'total_epoch : {} , but update with the {} index'.format(self.total_epoch, idx)
|
||||
self.epoch_losses [idx, 0] = train_loss
|
||||
self.epoch_losses [idx, 1] = val_loss
|
||||
self.epoch_accuracy[idx, 0] = train_acc
|
||||
self.epoch_accuracy[idx, 1] = val_acc
|
||||
self.current_epoch = idx + 1
|
||||
return self.max_accuracy(False) == self.epoch_accuracy[idx, 1]
|
||||
|
||||
def max_accuracy(self, istrain):
|
||||
if self.current_epoch <= 0: return 0
|
||||
if istrain: return self.epoch_accuracy[:self.current_epoch, 0].max()
|
||||
else: return self.epoch_accuracy[:self.current_epoch, 1].max()
|
||||
|
||||
def plot_curve(self, save_path):
|
||||
import matplotlib
|
||||
matplotlib.use('agg')
|
||||
import matplotlib.pyplot as plt
|
||||
title = 'the accuracy/loss curve of train/val'
|
||||
dpi = 100
|
||||
width, height = 1600, 1000
|
||||
legend_fontsize = 10
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
x_axis = np.array([i for i in range(self.total_epoch)]) # epochs
|
||||
y_axis = np.zeros(self.total_epoch)
|
||||
|
||||
plt.xlim(0, self.total_epoch)
|
||||
plt.ylim(0, 100)
|
||||
interval_y = 5
|
||||
interval_x = 5
|
||||
plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
|
||||
plt.yticks(np.arange(0, 100 + interval_y, interval_y))
|
||||
plt.grid()
|
||||
plt.title(title, fontsize=20)
|
||||
plt.xlabel('the training epoch', fontsize=16)
|
||||
plt.ylabel('accuracy', fontsize=16)
|
||||
|
||||
y_axis[:] = self.epoch_accuracy[:, 0]
|
||||
plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2)
|
||||
plt.legend(loc=4, fontsize=legend_fontsize)
|
||||
|
||||
y_axis[:] = self.epoch_accuracy[:, 1]
|
||||
plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2)
|
||||
plt.legend(loc=4, fontsize=legend_fontsize)
|
||||
|
||||
|
||||
y_axis[:] = self.epoch_losses[:, 0]
|
||||
plt.plot(x_axis, y_axis*50, color='g', linestyle=':', label='train-loss-x50', lw=2)
|
||||
plt.legend(loc=4, fontsize=legend_fontsize)
|
||||
|
||||
y_axis[:] = self.epoch_losses[:, 1]
|
||||
plt.plot(x_axis, y_axis*50, color='y', linestyle=':', label='valid-loss-x50', lw=2)
|
||||
plt.legend(loc=4, fontsize=legend_fontsize)
|
||||
|
||||
if save_path is not None:
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
|
||||
print ('---- save figure {} into {}'.format(title, save_path))
|
||||
plt.close(fig)
|
||||
@@ -0,0 +1,42 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import time, sys
|
||||
import numpy as np
|
||||
|
||||
def time_for_file():
|
||||
ISOTIMEFORMAT='%d-%h-at-%H-%M-%S'
|
||||
return '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
||||
|
||||
def time_string():
|
||||
ISOTIMEFORMAT='%Y-%m-%d %X'
|
||||
string = '[{:}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
||||
return string
|
||||
|
||||
def time_string_short():
|
||||
ISOTIMEFORMAT='%Y%m%d'
|
||||
string = '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
||||
return string
|
||||
|
||||
def time_print(string, is_print=True):
|
||||
if (is_print):
|
||||
print('{} : {}'.format(time_string(), string))
|
||||
|
||||
def convert_secs2time(epoch_time, return_str=False):
|
||||
need_hour = int(epoch_time / 3600)
|
||||
need_mins = int((epoch_time - 3600*need_hour) / 60)
|
||||
need_secs = int(epoch_time - 3600*need_hour - 60*need_mins)
|
||||
if return_str:
|
||||
str = '[{:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
|
||||
return str
|
||||
else:
|
||||
return need_hour, need_mins, need_secs
|
||||
|
||||
def print_log(print_string, log):
|
||||
#if isinstance(log, Logger): log.log('{:}'.format(print_string))
|
||||
if hasattr(log, 'log'): log.log('{:}'.format(print_string))
|
||||
else:
|
||||
print("{:}".format(print_string))
|
||||
if log is not None:
|
||||
log.write('{:}\n'.format(print_string))
|
||||
log.flush()
|
||||
@@ -0,0 +1,4 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .get_dataset_with_transform import get_datasets
|
||||
@@ -0,0 +1,179 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from __future__ import print_function
|
||||
import torch.utils.data as data
|
||||
from torchvision.datasets.folder import pil_loader, accimage_loader, default_loader
|
||||
from PIL import Image
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
def make_dataset(dir, image_ids, targets):
|
||||
assert (len(image_ids) == len(targets))
|
||||
images = []
|
||||
dir = os.path.expanduser(dir)
|
||||
for i in range(len(image_ids)):
|
||||
item = (os.path.join(dir, 'data', 'images',
|
||||
'%s.jpg' % image_ids[i]), targets[i])
|
||||
images.append(item)
|
||||
return images
|
||||
|
||||
|
||||
def find_classes(classes_file):
|
||||
# read classes file, separating out image IDs and class names
|
||||
image_ids = []
|
||||
targets = []
|
||||
f = open(classes_file, 'r')
|
||||
for line in f:
|
||||
split_line = line.split(' ')
|
||||
image_ids.append(split_line[0])
|
||||
targets.append(' '.join(split_line[1:]))
|
||||
f.close()
|
||||
|
||||
# index class names
|
||||
classes = np.unique(targets)
|
||||
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
||||
targets = [class_to_idx[c] for c in targets]
|
||||
|
||||
return (image_ids, targets, classes, class_to_idx)
|
||||
|
||||
|
||||
class FGVCAircraft(data.Dataset):
|
||||
"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset.
|
||||
Args:
|
||||
root (string): Root directory path to dataset.
|
||||
class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
|
||||
to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
|
||||
transform (callable, optional): A function/transform that takes in a PIL image
|
||||
and returns a transformed version. E.g. ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
loader (callable, optional): A function to load an image given its path.
|
||||
download (bool, optional): If true, downloads the dataset from the internet and
|
||||
puts it in the root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
"""
|
||||
url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
|
||||
class_types = ('variant', 'family', 'manufacturer')
|
||||
splits = ('train', 'val', 'trainval', 'test')
|
||||
|
||||
def __init__(self, root, class_type='variant', split='train', transform=None,
|
||||
target_transform=None, loader=default_loader, download=False):
|
||||
if split not in self.splits:
|
||||
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
|
||||
split, ', '.join(self.splits),
|
||||
))
|
||||
if class_type not in self.class_types:
|
||||
raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
|
||||
class_type, ', '.join(self.class_types),
|
||||
))
|
||||
self.root = os.path.expanduser(root)
|
||||
self.root = os.path.join(self.root, 'fgvc-aircraft-2013b')
|
||||
self.class_type = class_type
|
||||
self.split = split
|
||||
self.classes_file = os.path.join(self.root, 'data',
|
||||
'images_%s_%s.txt' % (self.class_type, self.split))
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
||||
(image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file)
|
||||
samples = make_dataset(self.root, image_ids, targets)
|
||||
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
self.loader = loader
|
||||
|
||||
self.samples = samples
|
||||
self.classes = classes
|
||||
self.class_to_idx = class_to_idx
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
Returns:
|
||||
tuple: (sample, target) where target is class_index of the target class.
|
||||
"""
|
||||
|
||||
path, target = self.samples[index]
|
||||
sample = self.loader(path)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return sample, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __repr__(self):
|
||||
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
||||
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
|
||||
fmt_str += ' Root Location: {}\n'.format(self.root)
|
||||
tmp = ' Transforms (if any): '
|
||||
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||
tmp = ' Target Transforms (if any): '
|
||||
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||
return fmt_str
|
||||
|
||||
def _check_exists(self):
|
||||
return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
|
||||
os.path.exists(self.classes_file)
|
||||
|
||||
def download(self):
|
||||
"""Download the FGVC-Aircraft data if it doesn't exist already."""
|
||||
from six.moves import urllib
|
||||
import tarfile
|
||||
|
||||
if self._check_exists():
|
||||
return
|
||||
|
||||
# prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
|
||||
print('Downloading %s ... (may take a few minutes)' % self.url)
|
||||
parent_dir = os.path.abspath(os.path.join(self.root, os.pardir))
|
||||
tar_name = self.url.rpartition('/')[-1]
|
||||
tar_path = os.path.join(parent_dir, tar_name)
|
||||
data = urllib.request.urlopen(self.url)
|
||||
|
||||
# download .tar.gz file
|
||||
with open(tar_path, 'wb') as f:
|
||||
f.write(data.read())
|
||||
|
||||
# extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b
|
||||
data_folder = tar_path.strip('.tar.gz')
|
||||
print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder))
|
||||
tar = tarfile.open(tar_path)
|
||||
tar.extractall(parent_dir)
|
||||
|
||||
# if necessary, rename data folder to self.root
|
||||
if not os.path.samefile(data_folder, self.root):
|
||||
print('Renaming %s to %s ...' % (data_folder, self.root))
|
||||
os.rename(data_folder, self.root)
|
||||
|
||||
# delete .tar.gz file
|
||||
print('Deleting %s ...' % tar_path)
|
||||
os.remove(tar_path)
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b', class_type='manufacturer', split='train', transform=None,
|
||||
target_transform=None, loader=default_loader, download=False)
|
||||
print(len(air))
|
||||
print(len(air))
|
||||
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b', class_type='manufacturer', split='val', transform=None,
|
||||
target_transform=None, loader=default_loader, download=False)
|
||||
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b', class_type='manufacturer', split='trainval', transform=None,
|
||||
target_transform=None, loader=default_loader, download=False)
|
||||
print(len(air))
|
||||
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b/', class_type='manufacturer', split='test', transform=None,
|
||||
target_transform=None, loader=default_loader, download=False)
|
||||
print(len(air))
|
||||
import pdb;
|
||||
pdb.set_trace()
|
||||
print(len(air))
|
||||
@@ -0,0 +1,304 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung 2021. 03.
|
||||
##################################################
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as transforms
|
||||
from copy import deepcopy
|
||||
# from PIL import Image
|
||||
import random
|
||||
import pdb
|
||||
from .aircraft import FGVCAircraft
|
||||
from .pets import PetDataset
|
||||
from config_utils import load_config
|
||||
|
||||
Dataset2Class = {'cifar10': 10,
|
||||
'cifar100': 100,
|
||||
'mnist': 10,
|
||||
'svhn': 10,
|
||||
'aircraft': 30,
|
||||
'pets': 37}
|
||||
|
||||
|
||||
class CUTOUT(object):
|
||||
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def __call__(self, img):
|
||||
h, w = img.size(1), img.size(2)
|
||||
mask = np.ones((h, w), np.float32)
|
||||
y = np.random.randint(h)
|
||||
x = np.random.randint(w)
|
||||
|
||||
y1 = np.clip(y - self.length // 2, 0, h)
|
||||
y2 = np.clip(y + self.length // 2, 0, h)
|
||||
x1 = np.clip(x - self.length // 2, 0, w)
|
||||
x2 = np.clip(x + self.length // 2, 0, w)
|
||||
|
||||
mask[y1: y2, x1: x2] = 0.
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.expand_as(img)
|
||||
img *= mask
|
||||
return img
|
||||
|
||||
|
||||
imagenet_pca = {
|
||||
'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
|
||||
'eigvec': np.asarray([
|
||||
[-0.5675, 0.7192, 0.4009],
|
||||
[-0.5808, -0.0045, -0.8140],
|
||||
[-0.5836, -0.6948, 0.4203],
|
||||
])
|
||||
}
|
||||
|
||||
|
||||
class Lighting(object):
|
||||
def __init__(self, alphastd,
|
||||
eigval=imagenet_pca['eigval'],
|
||||
eigvec=imagenet_pca['eigvec']):
|
||||
self.alphastd = alphastd
|
||||
assert eigval.shape == (3,)
|
||||
assert eigvec.shape == (3, 3)
|
||||
self.eigval = eigval
|
||||
self.eigvec = eigvec
|
||||
|
||||
def __call__(self, img):
|
||||
if self.alphastd == 0.:
|
||||
return img
|
||||
rnd = np.random.randn(3) * self.alphastd
|
||||
rnd = rnd.astype('float32')
|
||||
v = rnd
|
||||
old_dtype = np.asarray(img).dtype
|
||||
v = v * self.eigval
|
||||
v = v.reshape((3, 1))
|
||||
inc = np.dot(self.eigvec, v).reshape((3,))
|
||||
img = np.add(img, inc)
|
||||
if old_dtype == np.uint8:
|
||||
img = np.clip(img, 0, 255)
|
||||
img = Image.fromarray(img.astype(old_dtype), 'RGB')
|
||||
return img
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '()'
|
||||
|
||||
|
||||
def get_datasets(name, root, cutout, use_num_cls=None):
|
||||
if name == 'cifar10':
|
||||
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
|
||||
std = [x / 255 for x in [63.0, 62.1, 66.7]]
|
||||
elif name == 'cifar100':
|
||||
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
||||
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
||||
elif name.startswith('mnist'):
|
||||
mean, std = [0.1307, 0.1307, 0.1307], [0.3081, 0.3081, 0.3081]
|
||||
elif name.startswith('svhn'):
|
||||
mean, std = [0.4376821, 0.4437697, 0.47280442], [
|
||||
0.19803012, 0.20101562, 0.19703614]
|
||||
elif name.startswith('aircraft'):
|
||||
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
||||
elif name.startswith('pets'):
|
||||
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
# Data Argumentation
|
||||
if name == 'cifar10' or name == 'cifar100':
|
||||
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
|
||||
transforms.Normalize(mean, std)]
|
||||
if cutout > 0:
|
||||
lists += [CUTOUT(cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name.startswith('cub200'):
|
||||
train_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std)
|
||||
])
|
||||
test_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std)
|
||||
])
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name.startswith('mnist'):
|
||||
train_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
|
||||
transforms.Normalize(mean, std),
|
||||
])
|
||||
test_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
|
||||
transforms.Normalize(mean, std)
|
||||
])
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name.startswith('svhn'):
|
||||
train_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std)
|
||||
])
|
||||
test_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std)
|
||||
])
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name.startswith('aircraft'):
|
||||
train_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std)
|
||||
])
|
||||
test_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std),
|
||||
])
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name.startswith('pets'):
|
||||
train_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std)
|
||||
])
|
||||
test_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std),
|
||||
])
|
||||
xshape = (1, 3, 32, 32)
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
if name == 'cifar10':
|
||||
train_data = dset.CIFAR10(
|
||||
root, train=True, transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR10(
|
||||
root, train=False, transform=test_transform, download=True)
|
||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||
elif name == 'cifar100':
|
||||
train_data = dset.CIFAR100(
|
||||
root, train=True, transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR100(
|
||||
root, train=False, transform=test_transform, download=True)
|
||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||
elif name == 'mnist':
|
||||
train_data = dset.MNIST(
|
||||
root, train=True, transform=train_transform, download=True)
|
||||
test_data = dset.MNIST(
|
||||
root, train=False, transform=test_transform, download=True)
|
||||
assert len(train_data) == 60000 and len(test_data) == 10000
|
||||
elif name == 'svhn':
|
||||
train_data = dset.SVHN(root, split='train',
|
||||
transform=train_transform, download=True)
|
||||
test_data = dset.SVHN(root, split='test',
|
||||
transform=test_transform, download=True)
|
||||
assert len(train_data) == 73257 and len(test_data) == 26032
|
||||
elif name == 'aircraft':
|
||||
train_data = FGVCAircraft(root, class_type='manufacturer', split='trainval',
|
||||
transform=train_transform, download=False)
|
||||
test_data = FGVCAircraft(root, class_type='manufacturer', split='test',
|
||||
transform=test_transform, download=False)
|
||||
assert len(train_data) == 6667 and len(test_data) == 3333
|
||||
elif name == 'pets':
|
||||
train_data = PetDataset(root, train=True, num_cl=37,
|
||||
val_split=0.15, transforms=train_transform)
|
||||
test_data = PetDataset(root, train=False, num_cl=37,
|
||||
val_split=0.15, transforms=test_transform)
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
class_num = Dataset2Class[name] if use_num_cls is None else len(
|
||||
use_num_cls)
|
||||
return train_data, test_data, xshape, class_num
|
||||
|
||||
|
||||
def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers, num_cls=None):
|
||||
if isinstance(batch_size, (list, tuple)):
|
||||
batch, test_batch = batch_size
|
||||
else:
|
||||
batch, test_batch = batch_size, batch_size
|
||||
if dataset == 'cifar10':
|
||||
# split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(
|
||||
'{:}/cifar-split.txt'.format(config_root), None, None)
|
||||
# search over the proposed training and validation set
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
# logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
|
||||
# To split data
|
||||
xvalid_data = deepcopy(train_data)
|
||||
if hasattr(xvalid_data, 'transforms'): # to avoid a print issue
|
||||
xvalid_data.transforms = valid_data.transform
|
||||
xvalid_data.transform = deepcopy(valid_data.transform)
|
||||
search_data = SearchDataset(
|
||||
dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers,
|
||||
pin_memory=True)
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||
train_split),
|
||||
num_workers=workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||
valid_split),
|
||||
num_workers=workers, pin_memory=True)
|
||||
elif dataset == 'cifar100':
|
||||
cifar100_test_split = load_config(
|
||||
'{:}/cifar100-test-split.txt'.format(config_root), None, None)
|
||||
search_train_data = train_data
|
||||
search_valid_data = deepcopy(valid_data)
|
||||
search_valid_data.transform = train_data.transform
|
||||
search_data = SearchDataset(dataset, [search_train_data, search_valid_data],
|
||||
list(range(len(search_train_data))),
|
||||
cifar100_test_split.xvalid)
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers,
|
||||
pin_memory=True)
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True, num_workers=workers,
|
||||
pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||
cifar100_test_split.xvalid), num_workers=workers, pin_memory=True)
|
||||
elif dataset in ['mnist', 'svhn', 'aircraft', 'pets']:
|
||||
if not os.path.exists('{:}/{}-test-split.txt'.format(config_root, dataset)):
|
||||
import json
|
||||
label_list = list(range(len(valid_data)))
|
||||
random.shuffle(label_list)
|
||||
strlist = [str(label_list[i]) for i in range(len(label_list))]
|
||||
split = {'xvalid': ["int", strlist[:len(valid_data) // 2]],
|
||||
'xtest': ["int", strlist[len(valid_data) // 2:]]}
|
||||
with open('{:}/{}-test-split.txt'.format(config_root, dataset), 'w') as f:
|
||||
f.write(json.dumps(split))
|
||||
test_split = load_config(
|
||||
'{:}/{}-test-split.txt'.format(config_root, dataset), None, None)
|
||||
|
||||
search_train_data = train_data
|
||||
search_valid_data = deepcopy(valid_data)
|
||||
search_valid_data.transform = train_data.transform
|
||||
search_data = SearchDataset(dataset, [search_train_data, search_valid_data],
|
||||
list(range(len(search_train_data))), test_split.xvalid)
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True,
|
||||
num_workers=workers, pin_memory=True)
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True,
|
||||
num_workers=workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||
test_split.xvalid), num_workers=workers, pin_memory=True)
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(dataset))
|
||||
return search_loader, train_loader, valid_loader
|
||||
@@ -0,0 +1,45 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import torch
|
||||
from glob import glob
|
||||
from torch.utils.data.dataset import Dataset
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def load_image(filename):
|
||||
img = Image.open(filename)
|
||||
img = img.convert('RGB')
|
||||
return img
|
||||
|
||||
class PetDataset(Dataset):
|
||||
def __init__(self, root, train=True, num_cl=37, val_split=0.2, transforms=None):
|
||||
self.data = torch.load(os.path.join(root,'{}{}.pth'.format('train' if train else 'test',
|
||||
int(100*(1-val_split)) if train else int(100*val_split))))
|
||||
self.len = len(self.data)
|
||||
self.transform = transforms
|
||||
def __getitem__(self, index):
|
||||
img, label = self.data[index]
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return img, label
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Added
|
||||
import torchvision.transforms as transforms
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_transform = transforms.Compose(
|
||||
[transforms.Resize(256), transforms.RandomRotation(45), transforms.CenterCrop(224),
|
||||
transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize])
|
||||
test_transform = transforms.Compose(
|
||||
[transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
|
||||
root = '/w14/dataset/MetaGen/pets'
|
||||
train_data, test_data = get_pets(root, num_cl=37, val_split=0.2,
|
||||
tr_transform=train_transform,
|
||||
te_transform=test_transform)
|
||||
import pdb;
|
||||
pdb.set_trace()
|
||||
@@ -0,0 +1,34 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def additive_func(A, B):
|
||||
assert A.dim() == B.dim() and A.size(0) == B.size(0), '{:} vs {:}'.format(A.size(), B.size())
|
||||
C = min(A.size(1), B.size(1))
|
||||
if A.size(1) == B.size(1):
|
||||
return A + B
|
||||
elif A.size(1) < B.size(1):
|
||||
out = B.clone()
|
||||
out[:,:C] += A
|
||||
return out
|
||||
else:
|
||||
out = A.clone()
|
||||
out[:,:C] += B
|
||||
return out
|
||||
|
||||
|
||||
def change_key(key, value):
|
||||
def func(m):
|
||||
if hasattr(m, key):
|
||||
setattr(m, key, value)
|
||||
return func
|
||||
|
||||
|
||||
def parse_channel_info(xstring):
|
||||
blocks = xstring.split(' ')
|
||||
blocks = [x.split('-') for x in blocks]
|
||||
blocks = [[int(_) for _ in x] for x in blocks]
|
||||
return blocks
|
||||
@@ -0,0 +1,45 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from os import path as osp
|
||||
from typing import List, Text
|
||||
import torch
|
||||
|
||||
__all__ = ['get_cell_based_tiny_net', 'get_search_spaces', \
|
||||
'CellStructure', 'CellArchitectures'
|
||||
]
|
||||
|
||||
# useful modules
|
||||
from config_utils import dict2config
|
||||
from .SharedUtils import change_key
|
||||
from .cell_searchs import CellStructure, CellArchitectures
|
||||
|
||||
|
||||
# Cell-based NAS Models
|
||||
def get_cell_based_tiny_net(config):
|
||||
if config.name == 'infer.tiny':
|
||||
from .cell_infers import TinyNetwork
|
||||
if hasattr(config, 'genotype'):
|
||||
genotype = config.genotype
|
||||
elif hasattr(config, 'arch_str'):
|
||||
genotype = CellStructure.str2structure(config.arch_str)
|
||||
else: raise ValueError('Can not find genotype from this config : {:}'.format(config))
|
||||
return TinyNetwork(config.C, config.N, genotype, config.num_classes)
|
||||
else:
|
||||
raise ValueError('invalid network name : {:}'.format(config.name))
|
||||
|
||||
|
||||
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
|
||||
def get_search_spaces(xtype, name) -> List[Text]:
|
||||
if xtype == 'cell' or xtype == 'tss': # The topology search space.
|
||||
from .cell_operations import SearchSpaceNames
|
||||
assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())
|
||||
return SearchSpaceNames[name]
|
||||
elif xtype == 'sss': # The size search space.
|
||||
if name == 'nas-bench-301':
|
||||
return {'candidates': [8, 16, 24, 32, 40, 48, 56, 64],
|
||||
'numbers': 5}
|
||||
else:
|
||||
raise ValueError('Invalid name : {:}'.format(name))
|
||||
else:
|
||||
raise ValueError('invalid search-space type is {:}'.format(xtype))
|
||||
@@ -0,0 +1,4 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
from .tiny_network import TinyNetwork
|
||||
@@ -0,0 +1,122 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import OPS
|
||||
|
||||
|
||||
# Cell for NAS-Bench-201
|
||||
class InferCell(nn.Module):
|
||||
|
||||
def __init__(self, genotype, C_in, C_out, stride):
|
||||
super(InferCell, self).__init__()
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
self.node_IN = []
|
||||
self.node_IX = []
|
||||
self.genotype = deepcopy(genotype)
|
||||
for i in range(1, len(genotype)):
|
||||
node_info = genotype[i-1]
|
||||
cur_index = []
|
||||
cur_innod = []
|
||||
for (op_name, op_in) in node_info:
|
||||
if op_in == 0:
|
||||
layer = OPS[op_name](C_in , C_out, stride, True, True)
|
||||
else:
|
||||
layer = OPS[op_name](C_out, C_out, 1, True, True)
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
cur_index.append( len(self.layers) )
|
||||
cur_innod.append( op_in )
|
||||
self.layers.append( layer )
|
||||
self.node_IX.append( cur_index )
|
||||
self.node_IN.append( cur_innod )
|
||||
self.nodes = len(genotype)
|
||||
self.in_dim = C_in
|
||||
self.out_dim = C_out
|
||||
|
||||
def extra_repr(self):
|
||||
string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
|
||||
laystr = []
|
||||
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
|
||||
y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)]
|
||||
x = '{:}<-({:})'.format(i+1, ','.join(y))
|
||||
laystr.append( x )
|
||||
return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr())
|
||||
|
||||
def forward(self, inputs):
|
||||
nodes = [inputs]
|
||||
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
|
||||
node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) )
|
||||
nodes.append( node_feature )
|
||||
return nodes[-1]
|
||||
|
||||
|
||||
|
||||
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||
class NASNetInferCell(nn.Module):
|
||||
|
||||
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats):
|
||||
super(NASNetInferCell, self).__init__()
|
||||
self.reduction = reduction
|
||||
if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats)
|
||||
else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats)
|
||||
self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats)
|
||||
|
||||
if not reduction:
|
||||
nodes, concats = genotype['normal'], genotype['normal_concat']
|
||||
else:
|
||||
nodes, concats = genotype['reduce'], genotype['reduce_concat']
|
||||
self._multiplier = len(concats)
|
||||
self._concats = concats
|
||||
self._steps = len(nodes)
|
||||
self._nodes = nodes
|
||||
self.edges = nn.ModuleDict()
|
||||
for i, node in enumerate(nodes):
|
||||
for in_node in node:
|
||||
name, j = in_node[0], in_node[1]
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
node_str = '{:}<-{:}'.format(i+2, j)
|
||||
self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats)
|
||||
|
||||
# [TODO] to support drop_prob in this function..
|
||||
def forward(self, s0, s1, unused_drop_prob):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i, node in enumerate(self._nodes):
|
||||
clist = []
|
||||
for in_node in node:
|
||||
name, j = in_node[0], in_node[1]
|
||||
node_str = '{:}<-{:}'.format(i+2, j)
|
||||
op = self.edges[ node_str ]
|
||||
clist.append( op(states[j]) )
|
||||
states.append( sum(clist) )
|
||||
return torch.cat([states[x] for x in self._concats], dim=1)
|
||||
|
||||
|
||||
class AuxiliaryHeadCIFAR(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 8x8"""
|
||||
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
||||
@@ -0,0 +1,66 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .cells import InferCell
|
||||
|
||||
|
||||
# The macro structure for architectures in NAS-Bench-201
|
||||
class TinyNetwork(nn.Module):
|
||||
|
||||
def __init__(self, C, N, genotype, num_classes):
|
||||
super(TinyNetwork, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C))
|
||||
|
||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev = C
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2, True)
|
||||
else:
|
||||
cell = InferCell(genotype, C_prev, C_curr, 1)
|
||||
self.cells.append( cell )
|
||||
C_prev = cell.out_dim
|
||||
self._Layer= len(self.cells)
|
||||
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def forward(self, inputs):
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
feature = cell(feature)
|
||||
'''
|
||||
out2 = self.lastact(feature)
|
||||
out = self.global_pooling( out2 )
|
||||
out = out.view(out.size(0), -1)
|
||||
out2 = out2.view(out2.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
return out2, logits
|
||||
|
||||
'''
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling( out )
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
@@ -0,0 +1,308 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
|
||||
|
||||
OPS = {
|
||||
'none' : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride),
|
||||
'avg_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'avg', affine, track_running_stats),
|
||||
'max_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'max', affine, track_running_stats),
|
||||
'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats),
|
||||
'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
|
||||
'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats),
|
||||
'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
|
||||
'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats),
|
||||
'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats),
|
||||
'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats),
|
||||
'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
|
||||
}
|
||||
|
||||
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
|
||||
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3']
|
||||
|
||||
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
|
||||
'nas-bench-201': NAS_BENCH_201,
|
||||
'nas-bench-301': NAS_BENCH_201,
|
||||
'darts' : DARTS_SPACE}
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=not affine),
|
||||
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=not affine),
|
||||
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DualSepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||
super(DualSepConv, self).__init__()
|
||||
self.op_a = SepConv(C_in, C_in , kernel_size, stride, padding, dilation, affine, track_running_stats)
|
||||
self.op_b = SepConv(C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.op_a(x)
|
||||
x = self.op_b(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
|
||||
def __init__(self, inplanes, planes, stride, affine=True):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine)
|
||||
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine)
|
||||
if stride == 2:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
|
||||
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
|
||||
elif inplanes != planes:
|
||||
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.in_dim = inplanes
|
||||
self.out_dim = planes
|
||||
self.stride = stride
|
||||
self.num_conv = 2
|
||||
|
||||
def extra_repr(self):
|
||||
string = '{name}(inC={in_dim}, outC={out_dim}, stride={stride})'.format(name=self.__class__.__name__, **self.__dict__)
|
||||
return string
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
return residual + basicblock
|
||||
|
||||
|
||||
class POOLING(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True):
|
||||
super(POOLING, self).__init__()
|
||||
if C_in == C_out:
|
||||
self.preprocess = None
|
||||
else:
|
||||
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 1, affine, track_running_stats)
|
||||
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
|
||||
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
|
||||
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.preprocess: x = self.preprocess(inputs)
|
||||
else : x = inputs
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Zero(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride):
|
||||
super(Zero, self).__init__()
|
||||
self.C_in = C_in
|
||||
self.C_out = C_out
|
||||
self.stride = stride
|
||||
self.is_zero = True
|
||||
|
||||
def forward(self, x):
|
||||
if self.C_in == self.C_out:
|
||||
if self.stride == 1: return x.mul(0.)
|
||||
else : return x[:,:,::self.stride,::self.stride].mul(0.)
|
||||
else:
|
||||
shape = list(x.shape)
|
||||
shape[1] = self.C_out
|
||||
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
|
||||
return zeros
|
||||
|
||||
def extra_repr(self):
|
||||
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, affine, track_running_stats):
|
||||
super(FactorizedReduce, self).__init__()
|
||||
self.stride = stride
|
||||
self.C_in = C_in
|
||||
self.C_out = C_out
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
if stride == 2:
|
||||
#assert C_out % 2 == 0, 'C_out : {:}'.format(C_out)
|
||||
C_outs = [C_out // 2, C_out - C_out // 2]
|
||||
self.convs = nn.ModuleList()
|
||||
for i in range(2):
|
||||
self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=not affine))
|
||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||
elif stride == 1:
|
||||
self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False)
|
||||
else:
|
||||
raise ValueError('Invalid stride : {:}'.format(stride))
|
||||
self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 2:
|
||||
x = self.relu(x)
|
||||
y = self.pad(x)
|
||||
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1)
|
||||
else:
|
||||
out = self.conv(x)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
def extra_repr(self):
|
||||
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
|
||||
|
||||
|
||||
# Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019
|
||||
class PartAwareOp(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, part=4):
|
||||
super().__init__()
|
||||
self.part = 4
|
||||
self.hidden = C_in // 3
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.local_conv_list = nn.ModuleList()
|
||||
for i in range(self.part):
|
||||
self.local_conv_list.append(
|
||||
nn.Sequential(nn.ReLU(), nn.Conv2d(C_in, self.hidden, 1), nn.BatchNorm2d(self.hidden, affine=True))
|
||||
)
|
||||
self.W_K = nn.Linear(self.hidden, self.hidden)
|
||||
self.W_Q = nn.Linear(self.hidden, self.hidden)
|
||||
|
||||
if stride == 2 : self.last = FactorizedReduce(C_in + self.hidden, C_out, 2)
|
||||
elif stride == 1: self.last = FactorizedReduce(C_in + self.hidden, C_out, 1)
|
||||
else: raise ValueError('Invalid Stride : {:}'.format(stride))
|
||||
|
||||
def forward(self, x):
|
||||
batch, C, H, W = x.size()
|
||||
assert H >= self.part, 'input size too small : {:} vs {:}'.format(x.shape, self.part)
|
||||
IHs = [0]
|
||||
for i in range(self.part): IHs.append( min(H, int((i+1)*(float(H)/self.part))) )
|
||||
local_feat_list = []
|
||||
for i in range(self.part):
|
||||
feature = x[:, :, IHs[i]:IHs[i+1], :]
|
||||
xfeax = self.avg_pool(feature)
|
||||
xfea = self.local_conv_list[i]( xfeax )
|
||||
local_feat_list.append( xfea )
|
||||
part_feature = torch.cat(local_feat_list, dim=2).view(batch, -1, self.part)
|
||||
part_feature = part_feature.transpose(1,2).contiguous()
|
||||
part_K = self.W_K(part_feature)
|
||||
part_Q = self.W_Q(part_feature).transpose(1,2).contiguous()
|
||||
weight_att = torch.bmm(part_K, part_Q)
|
||||
attention = torch.softmax(weight_att, dim=2)
|
||||
aggreateF = torch.bmm(attention, part_feature).transpose(1,2).contiguous()
|
||||
features = []
|
||||
for i in range(self.part):
|
||||
feature = aggreateF[:, :, i:i+1].expand(batch, self.hidden, IHs[i+1]-IHs[i])
|
||||
feature = feature.view(batch, self.hidden, IHs[i+1]-IHs[i], 1)
|
||||
features.append( feature )
|
||||
features = torch.cat(features, dim=2).expand(batch, self.hidden, H, W)
|
||||
final_fea = torch.cat((x,features), dim=1)
|
||||
outputs = self.last( final_fea )
|
||||
return outputs
|
||||
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
if drop_prob > 0.:
|
||||
keep_prob = 1. - drop_prob
|
||||
mask = x.new_zeros(x.size(0), 1, 1, 1)
|
||||
mask = mask.bernoulli_(keep_prob)
|
||||
x = torch.div(x, keep_prob)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours
|
||||
class GDAS_Reduction_Cell(nn.Module):
|
||||
|
||||
def __init__(self, C_prev_prev, C_prev, C, reduction_prev, multiplier, affine, track_running_stats):
|
||||
super(GDAS_Reduction_Cell, self).__init__()
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine, track_running_stats)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, 1, affine, track_running_stats)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, 1, affine, track_running_stats)
|
||||
self.multiplier = multiplier
|
||||
|
||||
self.reduction = True
|
||||
self.ops1 = nn.ModuleList(
|
||||
[nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
|
||||
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
|
||||
nn.BatchNorm2d(C, affine=True),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C, affine=True)),
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
|
||||
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
|
||||
nn.BatchNorm2d(C, affine=True),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C, affine=True))])
|
||||
|
||||
self.ops2 = nn.ModuleList(
|
||||
[nn.Sequential(
|
||||
nn.MaxPool2d(3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(C, affine=True)),
|
||||
nn.Sequential(
|
||||
nn.MaxPool2d(3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(C, affine=True))])
|
||||
|
||||
def forward(self, s0, s1, drop_prob = -1):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
X0 = self.ops1[0] (s0)
|
||||
X1 = self.ops1[1] (s1)
|
||||
if self.training and drop_prob > 0.:
|
||||
X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob)
|
||||
|
||||
#X2 = self.ops2[0] (X0+X1)
|
||||
X2 = self.ops2[0] (s0)
|
||||
X3 = self.ops2[1] (s1)
|
||||
if self.training and drop_prob > 0.:
|
||||
X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob)
|
||||
return torch.cat([X0, X1, X2, X3], dim=1)
|
||||
@@ -0,0 +1,26 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
# The macro structure is defined in NAS-Bench-201
|
||||
# from .search_model_darts import TinyNetworkDarts
|
||||
# from .search_model_gdas import TinyNetworkGDAS
|
||||
# from .search_model_setn import TinyNetworkSETN
|
||||
# from .search_model_enas import TinyNetworkENAS
|
||||
# from .search_model_random import TinyNetworkRANDOM
|
||||
# from .generic_model import GenericNAS201Model
|
||||
from .genotypes import Structure as CellStructure, architectures as CellArchitectures
|
||||
# NASNet-based macro structure
|
||||
# from .search_model_gdas_nasnet import NASNetworkGDAS
|
||||
# from .search_model_darts_nasnet import NASNetworkDARTS
|
||||
|
||||
|
||||
# nas201_super_nets = {'DARTS-V1': TinyNetworkDarts,
|
||||
# "DARTS-V2": TinyNetworkDarts,
|
||||
# "GDAS": TinyNetworkGDAS,
|
||||
# "SETN": TinyNetworkSETN,
|
||||
# "ENAS": TinyNetworkENAS,
|
||||
# "RANDOM": TinyNetworkRANDOM,
|
||||
# "generic": GenericNAS201Model}
|
||||
|
||||
# nasnet_super_nets = {"GDAS": NASNetworkGDAS,
|
||||
# "DARTS": NASNetworkDARTS}
|
||||
@@ -0,0 +1,198 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def get_combination(space, num):
|
||||
combs = []
|
||||
for i in range(num):
|
||||
if i == 0:
|
||||
for func in space:
|
||||
combs.append( [(func, i)] )
|
||||
else:
|
||||
new_combs = []
|
||||
for string in combs:
|
||||
for func in space:
|
||||
xstring = string + [(func, i)]
|
||||
new_combs.append( xstring )
|
||||
combs = new_combs
|
||||
return combs
|
||||
|
||||
|
||||
class Structure:
|
||||
|
||||
def __init__(self, genotype):
|
||||
assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype))
|
||||
self.node_num = len(genotype) + 1
|
||||
self.nodes = []
|
||||
self.node_N = []
|
||||
for idx, node_info in enumerate(genotype):
|
||||
assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info))
|
||||
assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info))
|
||||
for node_in in node_info:
|
||||
assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in))
|
||||
assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in)
|
||||
self.node_N.append( len(node_info) )
|
||||
self.nodes.append( tuple(deepcopy(node_info)) )
|
||||
|
||||
def tolist(self, remove_str):
|
||||
# convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
|
||||
# note that we re-order the input node in this function
|
||||
# return the-genotype-list and success [if unsuccess, it is not a connectivity]
|
||||
genotypes = []
|
||||
for node_info in self.nodes:
|
||||
node_info = list( node_info )
|
||||
node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
|
||||
node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
|
||||
if len(node_info) == 0: return None, False
|
||||
genotypes.append( node_info )
|
||||
return genotypes, True
|
||||
|
||||
def node(self, index):
|
||||
assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self))
|
||||
return self.nodes[index]
|
||||
|
||||
def tostr(self):
|
||||
strings = []
|
||||
for node_info in self.nodes:
|
||||
string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info])
|
||||
string = '|{:}|'.format(string)
|
||||
strings.append( string )
|
||||
return '+'.join(strings)
|
||||
|
||||
def check_valid(self):
|
||||
nodes = {0: True}
|
||||
for i, node_info in enumerate(self.nodes):
|
||||
sums = []
|
||||
for op, xin in node_info:
|
||||
if op == 'none' or nodes[xin] is False: x = False
|
||||
else: x = True
|
||||
sums.append( x )
|
||||
nodes[i+1] = sum(sums) > 0
|
||||
return nodes[len(self.nodes)]
|
||||
|
||||
def to_unique_str(self, consider_zero=False):
|
||||
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
|
||||
# two operations are special, i.e., none and skip_connect
|
||||
nodes = {0: '0'}
|
||||
for i_node, node_info in enumerate(self.nodes):
|
||||
cur_node = []
|
||||
for op, xin in node_info:
|
||||
if consider_zero is None:
|
||||
x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||
elif consider_zero:
|
||||
if op == 'none' or nodes[xin] == '#': x = '#' # zero
|
||||
elif op == 'skip_connect': x = nodes[xin]
|
||||
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||
else:
|
||||
if op == 'skip_connect': x = nodes[xin]
|
||||
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||
cur_node.append(x)
|
||||
nodes[i_node+1] = '+'.join( sorted(cur_node) )
|
||||
return nodes[ len(self.nodes) ]
|
||||
|
||||
def check_valid_op(self, op_names):
|
||||
for node_info in self.nodes:
|
||||
for inode_edge in node_info:
|
||||
#assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
|
||||
if inode_edge[0] not in op_names: return False
|
||||
return True
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.nodes) + 1
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.nodes[index]
|
||||
|
||||
@staticmethod
|
||||
def str2structure(xstr):
|
||||
if isinstance(xstr, Structure): return xstr
|
||||
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
|
||||
nodestrs = xstr.split('+')
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(nodestrs):
|
||||
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||
inputs = ( xi.split('~') for xi in inputs )
|
||||
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
|
||||
genotypes.append( input_infos )
|
||||
return Structure( genotypes )
|
||||
|
||||
@staticmethod
|
||||
def str2fullstructure(xstr, default_name='none'):
|
||||
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
|
||||
nodestrs = xstr.split('+')
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(nodestrs):
|
||||
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||
inputs = ( xi.split('~') for xi in inputs )
|
||||
input_infos = list( (op, int(IDX)) for (op, IDX) in inputs)
|
||||
all_in_nodes= list(x[1] for x in input_infos)
|
||||
for j in range(i):
|
||||
if j not in all_in_nodes: input_infos.append((default_name, j))
|
||||
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
|
||||
genotypes.append( tuple(node_info) )
|
||||
return Structure( genotypes )
|
||||
|
||||
@staticmethod
|
||||
def gen_all(search_space, num, return_ori):
|
||||
assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space))
|
||||
assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num)
|
||||
all_archs = get_combination(search_space, 1)
|
||||
for i, arch in enumerate(all_archs):
|
||||
all_archs[i] = [ tuple(arch) ]
|
||||
|
||||
for inode in range(2, num):
|
||||
cur_nodes = get_combination(search_space, inode)
|
||||
new_all_archs = []
|
||||
for previous_arch in all_archs:
|
||||
for cur_node in cur_nodes:
|
||||
new_all_archs.append( previous_arch + [tuple(cur_node)] )
|
||||
all_archs = new_all_archs
|
||||
if return_ori:
|
||||
return all_archs
|
||||
else:
|
||||
return [Structure(x) for x in all_archs]
|
||||
|
||||
|
||||
|
||||
ResNet_CODE = Structure(
|
||||
[(('nor_conv_3x3', 0), ), # node-1
|
||||
(('nor_conv_3x3', 1), ), # node-2
|
||||
(('skip_connect', 0), ('skip_connect', 2))] # node-3
|
||||
)
|
||||
|
||||
AllConv3x3_CODE = Structure(
|
||||
[(('nor_conv_3x3', 0), ), # node-1
|
||||
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2
|
||||
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3
|
||||
)
|
||||
|
||||
AllFull_CODE = Structure(
|
||||
[(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1
|
||||
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2
|
||||
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3
|
||||
)
|
||||
|
||||
AllConv1x1_CODE = Structure(
|
||||
[(('nor_conv_1x1', 0), ), # node-1
|
||||
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2
|
||||
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3
|
||||
)
|
||||
|
||||
AllIdentity_CODE = Structure(
|
||||
[(('skip_connect', 0), ), # node-1
|
||||
(('skip_connect', 0), ('skip_connect', 1)), # node-2
|
||||
(('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3
|
||||
)
|
||||
|
||||
architectures = {'resnet' : ResNet_CODE,
|
||||
'all_c3x3': AllConv3x3_CODE,
|
||||
'all_c1x1': AllConv1x1_CODE,
|
||||
'all_idnt': AllIdentity_CODE,
|
||||
'all_full': AllFull_CODE}
|
||||
@@ -0,0 +1,167 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
|
||||
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else : self.avg = None
|
||||
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
|
||||
if has_bn : self.bn = nn.BatchNorm2d(nOut)
|
||||
else : self.bn = None
|
||||
if has_relu: self.relu = nn.ReLU(inplace=True)
|
||||
else : self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg : out = self.avg( inputs )
|
||||
else : out = inputs
|
||||
conv = self.conv( out )
|
||||
if self.bn : out = self.bn( conv )
|
||||
else : out = conv
|
||||
if self.relu: out = self.relu( out )
|
||||
else : out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
|
||||
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
|
||||
|
||||
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
|
||||
residual_in = iCs[2]
|
||||
elif iCs[0] != iCs[2]:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
|
||||
else:
|
||||
self.downsample = None
|
||||
#self.out_dim = max(residual_in, iCs[2])
|
||||
self.out_dim = iCs[2]
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
|
||||
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
|
||||
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
|
||||
residual_in = iCs[3]
|
||||
elif iCs[0] != iCs[3]:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
|
||||
residual_in = iCs[3]
|
||||
else:
|
||||
self.downsample = None
|
||||
#self.out_dim = max(residual_in, iCs[3])
|
||||
self.out_dim = iCs[3]
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
|
||||
class InferCifarResNet(nn.Module):
|
||||
|
||||
def __init__(self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual):
|
||||
super(InferCifarResNet, self).__init__()
|
||||
|
||||
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == 'ResNetBasicblock':
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == 'ResNetBottleneck':
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError('invalid block : {:}'.format(block_name))
|
||||
assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks)
|
||||
|
||||
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
|
||||
self.num_classes = num_classes
|
||||
self.xchannels = xchannels
|
||||
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
|
||||
last_channel_idx = 1
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
num_conv = block.num_conv
|
||||
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iCs, stride)
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
self.layers.append ( module )
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
|
||||
if iL + 1 == xblocks[stage]: # reach the maximum depth
|
||||
out_channel = module.out_dim
|
||||
for iiL in range(iL+1, layer_blocks):
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
break
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer( x )
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -0,0 +1,150 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
|
||||
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else : self.avg = None
|
||||
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
|
||||
if has_bn : self.bn = nn.BatchNorm2d(nOut)
|
||||
else : self.bn = None
|
||||
if has_relu: self.relu = nn.ReLU(inplace=True)
|
||||
else : self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg : out = self.avg( inputs )
|
||||
else : out = inputs
|
||||
conv = self.conv( out )
|
||||
if self.bn : out = self.bn( conv )
|
||||
else : out = conv
|
||||
if self.relu: out = self.relu( out )
|
||||
else : out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
|
||||
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
|
||||
elif inplanes != planes*self.expansion:
|
||||
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes*self.expansion
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
|
||||
class InferDepthCifarResNet(nn.Module):
|
||||
|
||||
def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual):
|
||||
super(InferDepthCifarResNet, self).__init__()
|
||||
|
||||
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == 'ResNetBasicblock':
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == 'ResNetBottleneck':
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError('invalid block : {:}'.format(block_name))
|
||||
assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks)
|
||||
|
||||
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
|
||||
self.num_classes = num_classes
|
||||
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
|
||||
self.channels = [16]
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
iC = self.channels[-1]
|
||||
planes = 16 * (2**stage)
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iC, planes, stride)
|
||||
self.channels.append( module.out_dim )
|
||||
self.layers.append ( module )
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, planes, module.out_dim, stride)
|
||||
if iL + 1 == xblocks[stage]: # reach the maximum depth
|
||||
break
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(self.channels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer( x )
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -0,0 +1,160 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
|
||||
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else : self.avg = None
|
||||
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
|
||||
if has_bn : self.bn = nn.BatchNorm2d(nOut)
|
||||
else : self.bn = None
|
||||
if has_relu: self.relu = nn.ReLU(inplace=True)
|
||||
else : self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg : out = self.avg( inputs )
|
||||
else : out = inputs
|
||||
conv = self.conv( out )
|
||||
if self.bn : out = self.bn( conv )
|
||||
else : out = conv
|
||||
if self.relu: out = self.relu( out )
|
||||
else : out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
|
||||
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
|
||||
|
||||
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
|
||||
residual_in = iCs[2]
|
||||
elif iCs[0] != iCs[2]:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
|
||||
else:
|
||||
self.downsample = None
|
||||
#self.out_dim = max(residual_in, iCs[2])
|
||||
self.out_dim = iCs[2]
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
|
||||
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
|
||||
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
|
||||
residual_in = iCs[3]
|
||||
elif iCs[0] != iCs[3]:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
|
||||
residual_in = iCs[3]
|
||||
else:
|
||||
self.downsample = None
|
||||
#self.out_dim = max(residual_in, iCs[3])
|
||||
self.out_dim = iCs[3]
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
|
||||
class InferWidthCifarResNet(nn.Module):
|
||||
|
||||
def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual):
|
||||
super(InferWidthCifarResNet, self).__init__()
|
||||
|
||||
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == 'ResNetBasicblock':
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == 'ResNetBottleneck':
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError('invalid block : {:}'.format(block_name))
|
||||
|
||||
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
|
||||
self.num_classes = num_classes
|
||||
self.xchannels = xchannels
|
||||
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
|
||||
last_channel_idx = 1
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
num_conv = block.num_conv
|
||||
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iCs, stride)
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
self.layers.append ( module )
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer( x )
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -0,0 +1,170 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
|
||||
num_conv = 1
|
||||
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else : self.avg = None
|
||||
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
|
||||
if has_bn : self.bn = nn.BatchNorm2d(nOut)
|
||||
else : self.bn = None
|
||||
if has_relu: self.relu = nn.ReLU(inplace=True)
|
||||
else : self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg : out = self.avg( inputs )
|
||||
else : out = inputs
|
||||
conv = self.conv( out )
|
||||
if self.bn : out = self.bn( conv )
|
||||
else : out = conv
|
||||
if self.relu: out = self.relu( out )
|
||||
else : out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
|
||||
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
|
||||
|
||||
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=True, has_relu=False)
|
||||
residual_in = iCs[2]
|
||||
elif iCs[0] != iCs[2]:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
|
||||
else:
|
||||
self.downsample = None
|
||||
#self.out_dim = max(residual_in, iCs[2])
|
||||
self.out_dim = iCs[2]
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
|
||||
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
|
||||
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=True, has_relu=False)
|
||||
residual_in = iCs[3]
|
||||
elif iCs[0] != iCs[3]:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
|
||||
residual_in = iCs[3]
|
||||
else:
|
||||
self.downsample = None
|
||||
#self.out_dim = max(residual_in, iCs[3])
|
||||
self.out_dim = iCs[3]
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
|
||||
class InferImagenetResNet(nn.Module):
|
||||
|
||||
def __init__(self, block_name, layers, xblocks, xchannels, deep_stem, num_classes, zero_init_residual):
|
||||
super(InferImagenetResNet, self).__init__()
|
||||
|
||||
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == 'BasicBlock':
|
||||
block = ResNetBasicblock
|
||||
elif block_name == 'Bottleneck':
|
||||
block = ResNetBottleneck
|
||||
else:
|
||||
raise ValueError('invalid block : {:}'.format(block_name))
|
||||
assert len(xblocks) == len(layers), 'invalid layers : {:} vs xblocks : {:}'.format(layers, xblocks)
|
||||
|
||||
self.message = 'InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}'.format(sum(layers)*block.num_conv, sum(xblocks)*block.num_conv, xblocks)
|
||||
self.num_classes = num_classes
|
||||
self.xchannels = xchannels
|
||||
if not deep_stem:
|
||||
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 7, 2, 3, False, has_avg=False, has_bn=True, has_relu=True) ] )
|
||||
last_channel_idx = 1
|
||||
else:
|
||||
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
,ConvBNReLU(xchannels[1], xchannels[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
|
||||
last_channel_idx = 2
|
||||
self.layers.append( nn.MaxPool2d(kernel_size=3, stride=2, padding=1) )
|
||||
for stage, layer_blocks in enumerate(layers):
|
||||
for iL in range(layer_blocks):
|
||||
num_conv = block.num_conv
|
||||
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iCs, stride)
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
self.layers.append ( module )
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
|
||||
if iL + 1 == xblocks[stage]: # reach the maximum depth
|
||||
out_channel = module.out_dim
|
||||
for iiL in range(iL+1, layer_blocks):
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
break
|
||||
assert last_channel_idx + 1 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels))
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
|
||||
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer( x )
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -0,0 +1,122 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
|
||||
from torch import nn
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import parse_channel_info
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, kernel_size, stride, groups, has_bn=True, has_relu=True):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
padding = (kernel_size - 1) // 2
|
||||
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False)
|
||||
if has_bn: self.bn = nn.BatchNorm2d(out_planes)
|
||||
else : self.bn = None
|
||||
if has_relu: self.relu = nn.ReLU6(inplace=True)
|
||||
else : self.relu = None
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv( x )
|
||||
if self.bn: out = self.bn ( out )
|
||||
if self.relu: out = self.relu( out )
|
||||
return out
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, channels, stride, expand_ratio, additive):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2], 'invalid stride : {:}'.format(stride)
|
||||
assert len(channels) in [2, 3], 'invalid channels : {:}'.format(channels)
|
||||
|
||||
if len(channels) == 2:
|
||||
layers = []
|
||||
else:
|
||||
layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)]
|
||||
layers.extend([
|
||||
# dw
|
||||
ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]),
|
||||
# pw-linear
|
||||
ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False),
|
||||
])
|
||||
self.conv = nn.Sequential(*layers)
|
||||
self.additive = additive
|
||||
if self.additive and channels[0] != channels[-1]:
|
||||
self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False)
|
||||
else:
|
||||
self.shortcut = None
|
||||
self.out_dim = channels[-1]
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
# if self.additive: return additive_func(out, x)
|
||||
if self.shortcut: return out + self.shortcut(x)
|
||||
else : return out
|
||||
|
||||
|
||||
class InferMobileNetV2(nn.Module):
|
||||
def __init__(self, num_classes, xchannels, xblocks, dropout):
|
||||
super(InferMobileNetV2, self).__init__()
|
||||
block = InvertedResidual
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16 , 1, 1],
|
||||
[6, 24 , 2, 2],
|
||||
[6, 32 , 3, 2],
|
||||
[6, 64 , 4, 2],
|
||||
[6, 96 , 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
assert len(inverted_residual_setting) == len(xblocks), 'invalid number of layers : {:} vs {:}'.format(len(inverted_residual_setting), len(xblocks))
|
||||
for block_num, ir_setting in zip(xblocks, inverted_residual_setting):
|
||||
assert block_num <= ir_setting[2], '{:} vs {:}'.format(block_num, ir_setting)
|
||||
xchannels = parse_channel_info(xchannels)
|
||||
#for i, chs in enumerate(xchannels):
|
||||
# if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs)
|
||||
self.xchannels = xchannels
|
||||
self.message = 'InferMobileNetV2 : xblocks={:}'.format(xblocks)
|
||||
# building first layer
|
||||
features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)]
|
||||
last_channel_idx = 1
|
||||
|
||||
# building inverted residual blocks
|
||||
for stage, (t, c, n, s) in enumerate(inverted_residual_setting):
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
additv = True if i > 0 else False
|
||||
module = block(self.xchannels[last_channel_idx], stride, t, additv)
|
||||
features.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format(stage, i, n, len(features), self.xchannels[last_channel_idx], stride, t, c)
|
||||
last_channel_idx += 1
|
||||
if i + 1 == xblocks[stage]:
|
||||
out_channel = module.out_dim
|
||||
for iiL in range(i+1, n):
|
||||
last_channel_idx += 1
|
||||
self.xchannels[last_channel_idx][0] = module.out_dim
|
||||
break
|
||||
# building last several layers
|
||||
features.append(ConvBNReLU(self.xchannels[last_channel_idx][0], self.xchannels[last_channel_idx][1], 1, 1, 1))
|
||||
assert last_channel_idx + 2 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels))
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(self.xchannels[last_channel_idx][1], num_classes),
|
||||
)
|
||||
|
||||
# weight initialization
|
||||
self.apply( initialize_resnet )
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
features = self.features(inputs)
|
||||
vectors = features.mean([2, 3])
|
||||
predicts = self.classifier(vectors)
|
||||
return features, predicts
|
||||
@@ -0,0 +1,58 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
from typing import List, Text, Any
|
||||
import torch.nn as nn
|
||||
from models.cell_operations import ResNetBasicblock
|
||||
from models.cell_infers.cells import InferCell
|
||||
|
||||
|
||||
class DynamicShapeTinyNet(nn.Module):
|
||||
|
||||
def __init__(self, channels: List[int], genotype: Any, num_classes: int):
|
||||
super(DynamicShapeTinyNet, self).__init__()
|
||||
self._channels = channels
|
||||
if len(channels) % 3 != 2:
|
||||
raise ValueError('invalid number of layers : {:}'.format(len(channels)))
|
||||
self._num_stage = N = len(channels) // 3
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(channels[0]))
|
||||
|
||||
# layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
c_prev = channels[0]
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)):
|
||||
if reduction : cell = ResNetBasicblock(c_prev, c_curr, 2, True)
|
||||
else : cell = InferCell(genotype, c_prev, c_curr, 1)
|
||||
self.cells.append( cell )
|
||||
c_prev = cell.out_dim
|
||||
self._num_layer = len(self.cells)
|
||||
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(c_prev, num_classes)
|
||||
|
||||
def get_message(self) -> Text:
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(C={_channels}, N={_num_stage}, L={_num_layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def forward(self, inputs):
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling( out )
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
@@ -0,0 +1,9 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
from .InferCifarResNet_width import InferWidthCifarResNet
|
||||
from .InferImagenetResNet import InferImagenetResNet
|
||||
from .InferCifarResNet_depth import InferDepthCifarResNet
|
||||
from .InferCifarResNet import InferCifarResNet
|
||||
from .InferMobileNetV2 import InferMobileNetV2
|
||||
from .InferTinyCellNet import DynamicShapeTinyNet
|
||||
@@ -0,0 +1,5 @@
|
||||
def parse_channel_info(xstring):
|
||||
blocks = xstring.split(' ')
|
||||
blocks = [x.split('-') for x in blocks]
|
||||
blocks = [[int(_) for _ in x] for x in blocks]
|
||||
return blocks
|
||||
@@ -0,0 +1,2 @@
|
||||
from .evaluation_utils import obtain_accuracy
|
||||
from .flop_benchmark import get_model_infos
|
||||
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
def obtain_accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
# correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
@@ -0,0 +1,181 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
def count_parameters_in_MB(model):
|
||||
if isinstance(model, nn.Module):
|
||||
return np.sum(np.prod(v.size()) for v in model.parameters())/1e6
|
||||
else:
|
||||
return np.sum(np.prod(v.size()) for v in model)/1e6
|
||||
|
||||
|
||||
def get_model_infos(model, shape):
|
||||
#model = copy.deepcopy( model )
|
||||
|
||||
model = add_flops_counting_methods(model)
|
||||
#model = model.cuda()
|
||||
model.eval()
|
||||
|
||||
#cache_inputs = torch.zeros(*shape).cuda()
|
||||
#cache_inputs = torch.zeros(*shape)
|
||||
cache_inputs = torch.rand(*shape)
|
||||
if next(model.parameters()).is_cuda: cache_inputs = cache_inputs.cuda()
|
||||
#print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log)
|
||||
with torch.no_grad():
|
||||
_____ = model(cache_inputs)
|
||||
FLOPs = compute_average_flops_cost( model ) / 1e6
|
||||
Param = count_parameters_in_MB(model)
|
||||
|
||||
if hasattr(model, 'auxiliary_param'):
|
||||
aux_params = count_parameters_in_MB(model.auxiliary_param())
|
||||
print ('The auxiliary params of this model is : {:}'.format(aux_params))
|
||||
print ('We remove the auxiliary params from the total params ({:}) when counting'.format(Param))
|
||||
Param = Param - aux_params
|
||||
|
||||
#print_log('FLOPs : {:} MB'.format(FLOPs), log)
|
||||
torch.cuda.empty_cache()
|
||||
model.apply( remove_hook_function )
|
||||
return FLOPs, Param
|
||||
|
||||
|
||||
# ---- Public functions
|
||||
def add_flops_counting_methods( model ):
|
||||
model.__batch_counter__ = 0
|
||||
add_batch_counter_hook_function( model )
|
||||
model.apply( add_flops_counter_variable_or_reset )
|
||||
model.apply( add_flops_counter_hook_function )
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def compute_average_flops_cost(model):
|
||||
"""
|
||||
A method that will be available after add_flops_counting_methods() is called on a desired net object.
|
||||
Returns current mean flops consumption per image.
|
||||
"""
|
||||
batches_count = model.__batch_counter__
|
||||
flops_sum = 0
|
||||
#or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
|
||||
or isinstance(module, torch.nn.Conv1d) \
|
||||
or hasattr(module, 'calculate_flop_self'):
|
||||
flops_sum += module.__flops__
|
||||
return flops_sum / batches_count
|
||||
|
||||
|
||||
# ---- Internal functions
|
||||
def pool_flops_counter_hook(pool_module, inputs, output):
|
||||
batch_size = inputs[0].size(0)
|
||||
kernel_size = pool_module.kernel_size
|
||||
out_C, output_height, output_width = output.shape[1:]
|
||||
assert out_C == inputs[0].size(1), '{:} vs. {:}'.format(out_C, inputs[0].size())
|
||||
|
||||
overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size
|
||||
pool_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def self_calculate_flops_counter_hook(self_module, inputs, output):
|
||||
overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape)
|
||||
self_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def fc_flops_counter_hook(fc_module, inputs, output):
|
||||
batch_size = inputs[0].size(0)
|
||||
xin, xout = fc_module.in_features, fc_module.out_features
|
||||
assert xin == inputs[0].size(1) and xout == output.size(1), 'IO=({:}, {:})'.format(xin, xout)
|
||||
overall_flops = batch_size * xin * xout
|
||||
if fc_module.bias is not None:
|
||||
overall_flops += batch_size * xout
|
||||
fc_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def conv1d_flops_counter_hook(conv_module, inputs, outputs):
|
||||
batch_size = inputs[0].size(0)
|
||||
outL = outputs.shape[-1]
|
||||
[kernel] = conv_module.kernel_size
|
||||
in_channels = conv_module.in_channels
|
||||
out_channels = conv_module.out_channels
|
||||
groups = conv_module.groups
|
||||
conv_per_position_flops = kernel * in_channels * out_channels / groups
|
||||
|
||||
active_elements_count = batch_size * outL
|
||||
overall_flops = conv_per_position_flops * active_elements_count
|
||||
|
||||
if conv_module.bias is not None:
|
||||
overall_flops += out_channels * active_elements_count
|
||||
conv_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def conv2d_flops_counter_hook(conv_module, inputs, output):
|
||||
batch_size = inputs[0].size(0)
|
||||
output_height, output_width = output.shape[2:]
|
||||
|
||||
kernel_height, kernel_width = conv_module.kernel_size
|
||||
in_channels = conv_module.in_channels
|
||||
out_channels = conv_module.out_channels
|
||||
groups = conv_module.groups
|
||||
conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups
|
||||
|
||||
active_elements_count = batch_size * output_height * output_width
|
||||
overall_flops = conv_per_position_flops * active_elements_count
|
||||
|
||||
if conv_module.bias is not None:
|
||||
overall_flops += out_channels * active_elements_count
|
||||
conv_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def batch_counter_hook(module, inputs, output):
|
||||
# Can have multiple inputs, getting the first one
|
||||
inputs = inputs[0]
|
||||
batch_size = inputs.shape[0]
|
||||
module.__batch_counter__ += batch_size
|
||||
|
||||
|
||||
def add_batch_counter_hook_function(module):
|
||||
if not hasattr(module, '__batch_counter_handle__'):
|
||||
handle = module.register_forward_hook(batch_counter_hook)
|
||||
module.__batch_counter_handle__ = handle
|
||||
|
||||
|
||||
def add_flops_counter_variable_or_reset(module):
|
||||
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
|
||||
or isinstance(module, torch.nn.Conv1d) \
|
||||
or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
|
||||
or hasattr(module, 'calculate_flop_self'):
|
||||
module.__flops__ = 0
|
||||
|
||||
|
||||
def add_flops_counter_hook_function(module):
|
||||
if isinstance(module, torch.nn.Conv2d):
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(conv2d_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif isinstance(module, torch.nn.Conv1d):
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(conv1d_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif isinstance(module, torch.nn.Linear):
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(fc_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d):
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(pool_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif hasattr(module, 'calculate_flop_self'): # self-defined module
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(self_calculate_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
|
||||
|
||||
def remove_hook_function(module):
|
||||
hookers = ['__batch_counter_handle__', '__flops_handle__']
|
||||
for hooker in hookers:
|
||||
if hasattr(module, hooker):
|
||||
handle = getattr(module, hooker)
|
||||
handle.remove()
|
||||
keys = ['__flops__', '__batch_counter__', '__flops__'] + hookers
|
||||
for ckey in keys:
|
||||
if hasattr(module, ckey): delattr(module, ckey)
|
||||
@@ -0,0 +1,28 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .starts import get_machine_info, save_checkpoint, copy_checkpoint
|
||||
from .optimizers import get_optim_scheduler
|
||||
from .starts import prepare_seed #, prepare_logger, get_machine_info, save_checkpoint, copy_checkpoint
|
||||
'''
|
||||
from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed
|
||||
from .funcs_nasbench import pure_evaluate as bench_pure_evaluate
|
||||
from .funcs_nasbench import get_nas_bench_loaders
|
||||
|
||||
def get_procedures(procedure):
|
||||
from .basic_main import basic_train, basic_valid
|
||||
from .search_main import search_train, search_valid
|
||||
from .search_main_v2 import search_train_v2
|
||||
from .simple_KD_main import simple_KD_train, simple_KD_valid
|
||||
|
||||
train_funcs = {'basic' : basic_train, \
|
||||
'search': search_train,'Simple-KD': simple_KD_train, \
|
||||
'search-v2': search_train_v2}
|
||||
valid_funcs = {'basic' : basic_valid, \
|
||||
'search': search_valid,'Simple-KD': simple_KD_valid, \
|
||||
'search-v2': search_valid}
|
||||
|
||||
train_func = train_funcs[procedure]
|
||||
valid_func = valid_funcs[procedure]
|
||||
return train_func, valid_func
|
||||
'''
|
||||
@@ -0,0 +1,204 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import math, torch
|
||||
import torch.nn as nn
|
||||
from bisect import bisect_right
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class _LRScheduler(object):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs):
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
raise TypeError('{:} is not an Optimizer'.format(type(optimizer).__name__))
|
||||
self.optimizer = optimizer
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault('initial_lr', group['lr'])
|
||||
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
|
||||
self.max_epochs = epochs
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.current_epoch = 0
|
||||
self.current_iter = 0
|
||||
|
||||
def extra_repr(self):
|
||||
return ''
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}'.format(name=self.__class__.__name__, **self.__dict__)
|
||||
+ ', {:})'.format(self.extra_repr()))
|
||||
|
||||
def state_dict(self):
|
||||
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def get_lr(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_min_info(self):
|
||||
lrs = self.get_lr()
|
||||
return '#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#'.format(min(lrs), max(lrs), self.current_epoch, self.current_iter)
|
||||
|
||||
def get_min_lr(self):
|
||||
return min( self.get_lr() )
|
||||
|
||||
def update(self, cur_epoch, cur_iter):
|
||||
if cur_epoch is not None:
|
||||
assert isinstance(cur_epoch, int) and cur_epoch>=0, 'invalid cur-epoch : {:}'.format(cur_epoch)
|
||||
self.current_epoch = cur_epoch
|
||||
if cur_iter is not None:
|
||||
assert isinstance(cur_iter, float) and cur_iter>=0, 'invalid cur-iter : {:}'.format(cur_iter)
|
||||
self.current_iter = cur_iter
|
||||
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||
param_group['lr'] = lr
|
||||
|
||||
|
||||
|
||||
class CosineAnnealingLR(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min):
|
||||
self.T_max = T_max
|
||||
self.eta_min = eta_min
|
||||
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, T-max={:}, eta-min={:}'.format('cosine', self.T_max, self.eta_min)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
#if last_epoch < self.T_max:
|
||||
#if last_epoch < self.max_epochs:
|
||||
lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2
|
||||
#else:
|
||||
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
|
||||
elif self.current_epoch >= self.max_epochs:
|
||||
lr = self.eta_min
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
|
||||
|
||||
|
||||
class MultiStepLR(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
|
||||
assert len(milestones) == len(gammas), 'invalid {:} vs {:}'.format(len(milestones), len(gammas))
|
||||
self.milestones = milestones
|
||||
self.gammas = gammas
|
||||
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, milestones={:}, gammas={:}, base-lrs={:}'.format('multistep', self.milestones, self.gammas, self.base_lrs)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
idx = bisect_right(self.milestones, last_epoch)
|
||||
lr = base_lr
|
||||
for x in self.gammas[:idx]: lr *= x
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
|
||||
|
||||
class ExponentialLR(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, gamma):
|
||||
self.gamma = gamma
|
||||
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, gamma={:}, base-lrs={:}'.format('exponential', self.gamma, self.base_lrs)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
|
||||
lr = base_lr * (self.gamma ** last_epoch)
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
|
||||
|
||||
class LinearLR(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR):
|
||||
self.max_LR = max_LR
|
||||
self.min_LR = min_LR
|
||||
super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, max_LR={:}, min_LR={:}, base-lrs={:}'.format('LinearLR', self.max_LR, self.min_LR, self.base_lrs)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
|
||||
ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR
|
||||
lr = base_lr * (1-ratio)
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
|
||||
|
||||
|
||||
class CrossEntropyLabelSmooth(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, epsilon):
|
||||
super(CrossEntropyLabelSmooth, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.epsilon = epsilon
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
log_probs = self.logsoftmax(inputs)
|
||||
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
|
||||
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
loss = (-targets * log_probs).mean(0).sum()
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
def get_optim_scheduler(parameters, config):
|
||||
assert hasattr(config, 'optim') and hasattr(config, 'scheduler') and hasattr(config, 'criterion'), 'config must have optim / scheduler / criterion keys instead of {:}'.format(config)
|
||||
if config.optim == 'SGD':
|
||||
optim = torch.optim.SGD(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov)
|
||||
elif config.optim == 'RMSprop':
|
||||
optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay)
|
||||
else:
|
||||
raise ValueError('invalid optim : {:}'.format(config.optim))
|
||||
|
||||
if config.scheduler == 'cos':
|
||||
T_max = getattr(config, 'T_max', config.epochs)
|
||||
scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min)
|
||||
elif config.scheduler == 'multistep':
|
||||
scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas)
|
||||
elif config.scheduler == 'exponential':
|
||||
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
|
||||
elif config.scheduler == 'linear':
|
||||
scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min)
|
||||
else:
|
||||
raise ValueError('invalid scheduler : {:}'.format(config.scheduler))
|
||||
|
||||
if config.criterion == 'Softmax':
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
elif config.criterion == 'SmoothSoftmax':
|
||||
criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth)
|
||||
else:
|
||||
raise ValueError('invalid criterion : {:}'.format(config.criterion))
|
||||
return optim, scheduler, criterion
|
||||
@@ -0,0 +1,64 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, torch, random, PIL, copy, numpy as np
|
||||
from os import path as osp
|
||||
from shutil import copyfile
|
||||
|
||||
|
||||
def prepare_seed(rand_seed):
|
||||
random.seed(rand_seed)
|
||||
np.random.seed(rand_seed)
|
||||
torch.manual_seed(rand_seed)
|
||||
torch.cuda.manual_seed(rand_seed)
|
||||
torch.cuda.manual_seed_all(rand_seed)
|
||||
|
||||
|
||||
def prepare_logger(xargs):
|
||||
args = copy.deepcopy( xargs )
|
||||
from log_utils import Logger
|
||||
logger = Logger(args.save_dir, args.rand_seed)
|
||||
logger.log('Main Function with logger : {:}'.format(logger))
|
||||
logger.log('Arguments : -------------------------------')
|
||||
for name, value in args._get_kwargs():
|
||||
logger.log('{:16} : {:}'.format(name, value))
|
||||
logger.log("Python Version : {:}".format(sys.version.replace('\n', ' ')))
|
||||
logger.log("Pillow Version : {:}".format(PIL.__version__))
|
||||
logger.log("PyTorch Version : {:}".format(torch.__version__))
|
||||
logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
|
||||
logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
|
||||
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
|
||||
logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None'))
|
||||
return logger
|
||||
|
||||
|
||||
def get_machine_info():
|
||||
info = "Python Version : {:}".format(sys.version.replace('\n', ' '))
|
||||
info+= "\nPillow Version : {:}".format(PIL.__version__)
|
||||
info+= "\nPyTorch Version : {:}".format(torch.__version__)
|
||||
info+= "\ncuDNN Version : {:}".format(torch.backends.cudnn.version())
|
||||
info+= "\nCUDA available : {:}".format(torch.cuda.is_available())
|
||||
info+= "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count())
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ:
|
||||
info+= "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ['CUDA_VISIBLE_DEVICES'])
|
||||
else:
|
||||
info+= "\nDoes not set CUDA_VISIBLE_DEVICES"
|
||||
return info
|
||||
|
||||
|
||||
def save_checkpoint(state, filename, logger):
|
||||
if osp.isfile(filename):
|
||||
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(filename))
|
||||
os.remove(filename)
|
||||
torch.save(state, filename)
|
||||
assert osp.isfile(filename), 'save filename : {:} failed, which is not found.'.format(filename)
|
||||
if hasattr(logger, 'log'): logger.log('save checkpoint into {:}'.format(filename))
|
||||
return filename
|
||||
|
||||
|
||||
def copy_checkpoint(src, dst, logger):
|
||||
if osp.isfile(dst):
|
||||
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(dst))
|
||||
os.remove(dst)
|
||||
copyfile(src, dst)
|
||||
if hasattr(logger, 'log'): logger.log('copy the file from {:} into {:}'.format(src, dst))
|
||||
83
NAS-Bench-201/main_exp/transfer_nag/run_multi_proc.py
Normal file
83
NAS-Bench-201/main_exp/transfer_nag/run_multi_proc.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from torch.multiprocessing import Process
|
||||
import os
|
||||
from absl import app, flags
|
||||
import sys
|
||||
import torch
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
|
||||
from nas_bench_201 import train_single_model
|
||||
from all_path import NASBENCH201
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_split", 15, "The number of splits")
|
||||
flags.DEFINE_list("arch_idx_lst", None, "arch index list")
|
||||
flags.DEFINE_list("arch_str_lst", None, "arch str list")
|
||||
flags.DEFINE_string("meta_test_path", None, "meta test path")
|
||||
flags.DEFINE_string("data_name", None, "data_name")
|
||||
flags.DEFINE_string("raw_data_path", None, "raw_data_path")
|
||||
|
||||
|
||||
def run_single_process(rank, seed, arch_idx, meta_test_path, data_name,
|
||||
raw_data_path, num_split=15, backend="nccl"):
|
||||
# 8 GPUs
|
||||
device = ['0', '1', '2', '3', '4', '5', '6', '7', '0', '1', '2', '3', '4', '5', '6', '7',
|
||||
'0', '1', '2', '3', '4', '5', '6', '7', '0', '1', '2', '3', '4', '5', '6', '7'][rank]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = device
|
||||
|
||||
save_path = os.path.join(meta_test_path, str(arch_idx))
|
||||
if type(seed) == int:
|
||||
seeds = [seed]
|
||||
elif type(seed) in [list, tuple]:
|
||||
seeds = seed
|
||||
|
||||
nasbench201 = torch.load(NASBENCH201)
|
||||
arch_str = nasbench201['arch']['str'][arch_idx]
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
train_single_model(save_dir=save_path,
|
||||
workers=24,
|
||||
datasets=[data_name],
|
||||
xpaths=[f'{raw_data_path}/{data_name}'],
|
||||
splits=[0],
|
||||
use_less=False,
|
||||
seeds=seeds,
|
||||
model_str=arch_str,
|
||||
arch_config={'channel': 16, 'num_cells': 5})
|
||||
|
||||
|
||||
def run_multi_process(argv):
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "1234"
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
processes = []
|
||||
|
||||
arch_idx_lst = [int(i) for i in FLAGS.arch_idx_lst]
|
||||
seeds = [777, 888, 999] * len(arch_idx_lst)
|
||||
arch_idx_lst_ = []
|
||||
for i in arch_idx_lst:
|
||||
arch_idx_lst_ += [i] * 3
|
||||
|
||||
for arch_idx in arch_idx_lst:
|
||||
os.makedirs(os.path.join(FLAGS.meta_test_path, str(arch_idx)), exist_ok=True)
|
||||
|
||||
for rank in range(FLAGS.num_split):
|
||||
arch_idx = arch_idx_lst_[rank]
|
||||
seed = seeds[rank]
|
||||
p = Process(target=run_single_process, args=(rank,
|
||||
seed,
|
||||
arch_idx,
|
||||
FLAGS.meta_test_path,
|
||||
FLAGS.data_name,
|
||||
FLAGS.raw_data_path))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
while any(p.is_alive() for p in processes):
|
||||
continue
|
||||
print("All processes have completed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(run_multi_process)
|
||||
@@ -0,0 +1,38 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from set_encoder.setenc_modules import *
|
||||
|
||||
|
||||
class SetPool(nn.Module):
|
||||
def __init__(self, dim_input, num_outputs, dim_output,
|
||||
num_inds=32, dim_hidden=128, num_heads=4, ln=False, mode=None):
|
||||
super(SetPool, self).__init__()
|
||||
if 'sab' in mode: # [32, 400, 128]
|
||||
self.enc = nn.Sequential(
|
||||
SAB(dim_input, dim_hidden, num_heads, ln=ln), # SAB?
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln))
|
||||
else: # [32, 400, 128]
|
||||
self.enc = nn.Sequential(
|
||||
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln), # SAB?
|
||||
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
|
||||
if 'PF' in mode: # [32, 1, 501]
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
|
||||
nn.Linear(dim_hidden, dim_output))
|
||||
elif 'P' in mode:
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln))
|
||||
else: # torch.Size([32, 1, 501])
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln), # 32 1 128
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
|
||||
nn.Linear(dim_hidden, dim_output))
|
||||
# "", sm, sab, sabsm
|
||||
|
||||
def forward(self, X):
|
||||
x1 = self.enc(X)
|
||||
x2 = self.dec(x1)
|
||||
return x2
|
||||
@@ -0,0 +1,67 @@
|
||||
#####################################################################################
|
||||
# Copyright (c) Juho Lee SetTransformer, ICML 2019 [GitHub set_transformer]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
class MAB(nn.Module):
|
||||
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
|
||||
super(MAB, self).__init__()
|
||||
self.dim_V = dim_V
|
||||
self.num_heads = num_heads
|
||||
self.fc_q = nn.Linear(dim_Q, dim_V)
|
||||
self.fc_k = nn.Linear(dim_K, dim_V)
|
||||
self.fc_v = nn.Linear(dim_K, dim_V)
|
||||
if ln:
|
||||
self.ln0 = nn.LayerNorm(dim_V)
|
||||
self.ln1 = nn.LayerNorm(dim_V)
|
||||
self.fc_o = nn.Linear(dim_V, dim_V)
|
||||
|
||||
def forward(self, Q, K):
|
||||
Q = self.fc_q(Q)
|
||||
K, V = self.fc_k(K), self.fc_v(K)
|
||||
|
||||
dim_split = self.dim_V // self.num_heads
|
||||
Q_ = torch.cat(Q.split(dim_split, 2), 0)
|
||||
K_ = torch.cat(K.split(dim_split, 2), 0)
|
||||
V_ = torch.cat(V.split(dim_split, 2), 0)
|
||||
|
||||
A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
|
||||
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
|
||||
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
|
||||
O = O + F.relu(self.fc_o(O))
|
||||
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
|
||||
return O
|
||||
|
||||
class SAB(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, num_heads, ln=False):
|
||||
super(SAB, self).__init__()
|
||||
self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)
|
||||
|
||||
def forward(self, X):
|
||||
return self.mab(X, X)
|
||||
|
||||
class ISAB(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
|
||||
super(ISAB, self).__init__()
|
||||
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
|
||||
nn.init.xavier_uniform_(self.I)
|
||||
self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
|
||||
self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)
|
||||
|
||||
def forward(self, X):
|
||||
H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
|
||||
return self.mab1(X, H)
|
||||
|
||||
class PMA(nn.Module):
|
||||
def __init__(self, dim, num_heads, num_seeds, ln=False):
|
||||
super(PMA, self).__init__()
|
||||
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
|
||||
nn.init.xavier_uniform_(self.S)
|
||||
self.mab = MAB(dim, dim, dim, num_heads, ln=ln)
|
||||
|
||||
def forward(self, X):
|
||||
return self.mab(self.S.repeat(X.size(0), 1, 1), X)
|
||||
243
NAS-Bench-201/main_exp/transfer_nag/unnoised_model.py
Normal file
243
NAS-Bench-201/main_exp/transfer_nag/unnoised_model.py
Normal file
@@ -0,0 +1,243 @@
|
||||
######################################################################################
|
||||
# Copyright (c) muhanzhang, D-VAE, NeurIPS 2019 [GitHub D-VAE]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
import torch
|
||||
from torch import nn
|
||||
from set_encoder.setenc_models import SetPool
|
||||
|
||||
|
||||
class MetaSurrogateUnnoisedModel(nn.Module):
|
||||
def __init__(self, args, graph_config):
|
||||
super(MetaSurrogateUnnoisedModel, self).__init__()
|
||||
self.max_n = graph_config['max_n'] # maximum number of vertices
|
||||
self.nvt = args.nvt # number of vertex types
|
||||
self.START_TYPE = graph_config['START_TYPE']
|
||||
self.END_TYPE = graph_config['END_TYPE']
|
||||
self.hs = args.hs # hidden state size of each vertex
|
||||
self.nz = args.nz # size of latent representation z
|
||||
self.gs = args.hs # size of graph state
|
||||
self.bidir = True # whether to use bidirectional encoding
|
||||
self.vid = True
|
||||
self.device = None
|
||||
self.input_type = 'DG'
|
||||
self.num_sample = args.num_sample
|
||||
|
||||
if self.vid:
|
||||
self.vs = self.hs + self.max_n # vertex state size = hidden state + vid
|
||||
else:
|
||||
self.vs = self.hs
|
||||
|
||||
# 0. encoding-related
|
||||
self.grue_forward = nn.GRUCell(self.nvt, self.hs) # encoder GRU
|
||||
self.grue_backward = nn.GRUCell(
|
||||
self.nvt, self.hs) # backward encoder GRU
|
||||
self.fc1 = nn.Linear(self.gs, self.nz) # latent mean
|
||||
self.fc2 = nn.Linear(self.gs, self.nz) # latent logvar
|
||||
|
||||
# 2. gate-related
|
||||
self.gate_forward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
self.gate_backward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
self.mapper_forward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs, bias=False),
|
||||
) # disable bias to ensure padded zeros also mapped to zeros
|
||||
self.mapper_backward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs, bias=False),
|
||||
)
|
||||
|
||||
# 3. bidir-related, to unify sizes
|
||||
if self.bidir:
|
||||
self.hv_unify = nn.Sequential(
|
||||
nn.Linear(self.hs * 2, self.hs),
|
||||
)
|
||||
self.hg_unify = nn.Sequential(
|
||||
nn.Linear(self.gs * 2, self.gs),
|
||||
)
|
||||
|
||||
# 4. other
|
||||
self.relu = nn.ReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.tanh = nn.Tanh()
|
||||
self.logsoftmax1 = nn.LogSoftmax(1)
|
||||
|
||||
# 6. predictor
|
||||
np = self.gs
|
||||
self.intra_setpool = SetPool(dim_input=512,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.inter_setpool = SetPool(dim_input=self.nz,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.set_fc = nn.Sequential(
|
||||
nn.Linear(512, self.nz),
|
||||
nn.ReLU())
|
||||
|
||||
input_dim = 0
|
||||
if 'D' in self.input_type:
|
||||
input_dim += self.nz
|
||||
if 'G' in self.input_type:
|
||||
input_dim += self.nz
|
||||
|
||||
self.pred_fc = nn.Sequential(
|
||||
nn.Linear(input_dim, self.hs),
|
||||
nn.Tanh(),
|
||||
nn.Linear(self.hs, 1)
|
||||
)
|
||||
self.mseloss = nn.MSELoss(reduction='sum')
|
||||
|
||||
def predict(self, D_mu, G_mu):
|
||||
input_vec = []
|
||||
if 'D' in self.input_type:
|
||||
input_vec.append(D_mu)
|
||||
if 'G' in self.input_type:
|
||||
input_vec.append(G_mu)
|
||||
input_vec = torch.cat(input_vec, dim=1)
|
||||
return self.pred_fc(input_vec)
|
||||
|
||||
def get_device(self):
|
||||
if self.device is None:
|
||||
self.device = next(self.parameters()).device
|
||||
return self.device
|
||||
|
||||
def _get_zeros(self, n, length):
|
||||
# get a zero hidden state
|
||||
return torch.zeros(n, length).to(self.get_device())
|
||||
|
||||
def _get_zero_hidden(self, n=1):
|
||||
return self._get_zeros(n, self.hs) # get a zero hidden state
|
||||
|
||||
def _one_hot(self, idx, length):
|
||||
if type(idx) in [list, range]:
|
||||
if idx == []:
|
||||
return None
|
||||
idx = torch.LongTensor(idx).unsqueeze(0).t()
|
||||
x = torch.zeros((len(idx), length)).scatter_(
|
||||
1, idx, 1).to(self.get_device())
|
||||
else:
|
||||
idx = torch.LongTensor([idx]).unsqueeze(0)
|
||||
x = torch.zeros((1, length)).scatter_(
|
||||
1, idx, 1).to(self.get_device())
|
||||
return x
|
||||
|
||||
def _gated(self, h, gate, mapper):
|
||||
return gate(h) * mapper(h)
|
||||
|
||||
def _collate_fn(self, G):
|
||||
return [g.copy() for g in G]
|
||||
|
||||
def _propagate_to(self, G, v, propagator, H=None, reverse=False, gate=None, mapper=None):
|
||||
# propagate messages to vertex index v for all graphs in G
|
||||
# return the new messages (states) at v
|
||||
G = [g for g in G if g.vcount() > v]
|
||||
if len(G) == 0:
|
||||
return
|
||||
if H is not None:
|
||||
idx = [i for i, g in enumerate(G) if g.vcount() > v]
|
||||
H = H[idx]
|
||||
v_types = [g.vs[v]['type'] for g in G]
|
||||
X = self._one_hot(v_types, self.nvt)
|
||||
if reverse:
|
||||
H_name = 'H_backward' # name of the hidden states attribute
|
||||
H_pred = [[g.vs[x][H_name] for x in g.successors(v)] for g in G]
|
||||
if self.vid:
|
||||
vids = [self._one_hot(g.successors(v), self.max_n) for g in G]
|
||||
gate, mapper = self.gate_backward, self.mapper_backward
|
||||
else:
|
||||
H_name = 'H_forward' # name of the hidden states attribute
|
||||
H_pred = [[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
|
||||
if self.vid:
|
||||
vids = [self._one_hot(g.predecessors(v), self.max_n)
|
||||
for g in G]
|
||||
if gate is None:
|
||||
gate, mapper = self.gate_forward, self.mapper_forward
|
||||
if self.vid:
|
||||
H_pred = [[torch.cat([x[i], y[i:i + 1]], 1)
|
||||
for i in range(len(x))] for x, y in zip(H_pred, vids)]
|
||||
# if h is not provided, use gated sum of v's predecessors' states as the input hidden state
|
||||
if H is None:
|
||||
# maximum number of predecessors
|
||||
max_n_pred = max([len(x) for x in H_pred])
|
||||
if max_n_pred == 0:
|
||||
H = self._get_zero_hidden(len(G))
|
||||
else:
|
||||
H_pred = [torch.cat(h_pred +
|
||||
[self._get_zeros(max_n_pred - len(h_pred), self.vs)], 0).unsqueeze(0)
|
||||
for h_pred in H_pred] # pad all to same length
|
||||
H_pred = torch.cat(H_pred, 0) # batch * max_n_pred * vs
|
||||
H = self._gated(H_pred, gate, mapper).sum(1) # batch * hs
|
||||
Hv = propagator(X, H)
|
||||
for i, g in enumerate(G):
|
||||
g.vs[v][H_name] = Hv[i:i + 1]
|
||||
return Hv
|
||||
|
||||
def _propagate_from(self, G, v, propagator, H0=None, reverse=False):
|
||||
# perform a series of propagation_to steps starting from v following a topo order
|
||||
# assume the original vertex indices are in a topological order
|
||||
if reverse:
|
||||
prop_order = range(v, -1, -1)
|
||||
else:
|
||||
prop_order = range(v, self.max_n)
|
||||
Hv = self._propagate_to(G, v, propagator, H0,
|
||||
reverse=reverse) # the initial vertex
|
||||
for v_ in prop_order[1:]:
|
||||
self._propagate_to(G, v_, propagator, reverse=reverse)
|
||||
return Hv
|
||||
|
||||
def _get_graph_state(self, G, decode=False):
|
||||
# get the graph states
|
||||
# when decoding, use the last generated vertex's state as the graph state
|
||||
# when encoding, use the ending vertex state or unify the starting and ending vertex states
|
||||
Hg = []
|
||||
for g in G:
|
||||
hg = g.vs[g.vcount() - 1]['H_forward']
|
||||
if self.bidir and not decode: # decoding never uses backward propagation
|
||||
hg_b = g.vs[0]['H_backward']
|
||||
hg = torch.cat([hg, hg_b], 1)
|
||||
Hg.append(hg)
|
||||
Hg = torch.cat(Hg, 0)
|
||||
if self.bidir and not decode:
|
||||
Hg = self.hg_unify(Hg)
|
||||
return Hg
|
||||
|
||||
def set_encode(self, X):
|
||||
proto_batch = []
|
||||
for x in X:
|
||||
cls_protos = self.intra_setpool(
|
||||
x.view(-1, self.num_sample, 512)).squeeze(1)
|
||||
proto_batch.append(
|
||||
self.inter_setpool(cls_protos.unsqueeze(0)))
|
||||
v = torch.stack(proto_batch).squeeze()
|
||||
return v
|
||||
|
||||
def graph_encode(self, G):
|
||||
# encode graphs G into latent vectors
|
||||
if type(G) != list:
|
||||
G = [G]
|
||||
self._propagate_from(G, 0, self.grue_forward, H0=self._get_zero_hidden(len(G)),
|
||||
reverse=False)
|
||||
if self.bidir:
|
||||
self._propagate_from(G, self.max_n - 1, self.grue_backward,
|
||||
H0=self._get_zero_hidden(len(G)), reverse=True)
|
||||
Hg = self._get_graph_state(G)
|
||||
mu = self.fc1(Hg)
|
||||
# logvar = self.fc2(Hg)
|
||||
return mu # , logvar
|
||||
|
||||
def reparameterize(self, mu, logvar, eps_scale=0.01):
|
||||
# return z ~ N(mu, std)
|
||||
if self.training:
|
||||
std = logvar.mul(0.5).exp_()
|
||||
eps = torch.randn_like(std) * eps_scale
|
||||
return eps.mul(std).add_(mu)
|
||||
else:
|
||||
return mu
|
||||
33
NAS-Bench-201/main_exp/utils.py
Normal file
33
NAS-Bench-201/main_exp/utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
def restore_checkpoint(ckpt_dir, state, device, resume=False):
|
||||
if not resume:
|
||||
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
|
||||
return state
|
||||
elif not os.path.exists(ckpt_dir):
|
||||
if not os.path.exists(os.path.dirname(ckpt_dir)):
|
||||
os.makedirs(os.path.dirname(ckpt_dir))
|
||||
logging.warning(f"No checkpoint found at {ckpt_dir}. "
|
||||
f"Returned the same state as input")
|
||||
return state
|
||||
else:
|
||||
loaded_state = torch.load(ckpt_dir, map_location=device)
|
||||
for k in state:
|
||||
if k in ['optimizer', 'model', 'ema']:
|
||||
state[k].load_state_dict(loaded_state[k])
|
||||
else:
|
||||
state[k] = loaded_state[k]
|
||||
return state
|
||||
|
||||
|
||||
def reset_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
0
NAS-Bench-201/models/__init__.py
Executable file
0
NAS-Bench-201/models/__init__.py
Executable file
391
NAS-Bench-201/models/cate.py
Normal file
391
NAS-Bench-201/models/cate.py
Normal file
@@ -0,0 +1,391 @@
|
||||
# Most of this code is from https://github.com/AIoT-MLSys-Lab/CATE.git
|
||||
# which was authored by Shen Yan, Kaiqiang Song, Fei Liu, Mi Zhang, 2021
|
||||
|
||||
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from . import utils
|
||||
from .transformer import Encoder, SemanticEmbedding
|
||||
from .set_encoder.setenc_models import SetPool
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False, activate_func=F.relu):
|
||||
"""
|
||||
num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model.
|
||||
input_dim: dimensionality of input features
|
||||
hidden_dim: dimensionality of hidden units at ALL layers
|
||||
output_dim: number of classes for prediction
|
||||
num_classes: the number of classes of input, to be treated with different gains and biases,
|
||||
(see the definition of class `ConditionalLayer1d`)
|
||||
"""
|
||||
|
||||
super(MLP, self).__init__()
|
||||
|
||||
self.linear_or_not = True # default is linear model
|
||||
self.num_layers = num_layers
|
||||
self.use_bn = use_bn
|
||||
self.activate_func = activate_func
|
||||
|
||||
if num_layers < 1:
|
||||
raise ValueError("number of layers should be positive!")
|
||||
elif num_layers == 1:
|
||||
# Linear model
|
||||
self.linear = torch.nn.Linear(input_dim, output_dim)
|
||||
else:
|
||||
# Multi-layer model
|
||||
self.linear_or_not = False
|
||||
self.linears = torch.nn.ModuleList()
|
||||
|
||||
self.linears.append(torch.nn.Linear(input_dim, hidden_dim))
|
||||
for layer in range(num_layers - 2):
|
||||
self.linears.append(torch.nn.Linear(hidden_dim, hidden_dim))
|
||||
self.linears.append(torch.nn.Linear(hidden_dim, output_dim))
|
||||
|
||||
if self.use_bn:
|
||||
self.batch_norms = torch.nn.ModuleList()
|
||||
for layer in range(num_layers - 1):
|
||||
self.batch_norms.append(torch.nn.BatchNorm1d(hidden_dim))
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
:param x: [num_classes * batch_size, N, F_i], batch of node features
|
||||
note that in self.cond_layers[layer],
|
||||
`x` is splited into `num_classes` groups in dim=0,
|
||||
and then treated with different gains and biases
|
||||
"""
|
||||
if self.linear_or_not:
|
||||
# If linear model
|
||||
return self.linear(x)
|
||||
else:
|
||||
# If MLP
|
||||
h = x
|
||||
for layer in range(self.num_layers - 1):
|
||||
h = self.linears[layer](h)
|
||||
if self.use_bn:
|
||||
h = self.batch_norms[layer](h)
|
||||
h = self.activate_func(h)
|
||||
return self.linears[self.num_layers - 1](h)
|
||||
|
||||
|
||||
""" Transformer Encoder """
|
||||
class GraphEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(GraphEncoder, self).__init__()
|
||||
# Forward Transformers
|
||||
self.encoder_f = Encoder(config)
|
||||
|
||||
def forward(self, x, mask):
|
||||
h_f, hs_f, attns_f = self.encoder_f(x, mask)
|
||||
h = torch.cat(hs_f, dim=-1)
|
||||
return h
|
||||
|
||||
@staticmethod
|
||||
def get_embeddings(h_x):
|
||||
h_x = h_x.cpu()
|
||||
return h_x[:, -1]
|
||||
|
||||
|
||||
class CLSHead(nn.Module):
|
||||
def __init__(self, config, init_weights=None):
|
||||
super(CLSHead, self).__init__()
|
||||
self.layer_1 = nn.Linear(config.d_model, config.d_model)
|
||||
self.dropout = nn.Dropout(p=config.dropout)
|
||||
self.layer_2 = nn.Linear(config.d_model, config.n_vocab)
|
||||
if init_weights is not None:
|
||||
self.layer_2.weight = init_weights
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dropout(torch.tanh(self.layer_1(x)))
|
||||
return F.log_softmax(self.layer_2(x), dim=-1)
|
||||
|
||||
|
||||
@utils.register_model(name='CATE')
|
||||
class CATE(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CATE, self).__init__()
|
||||
# Shared Embedding Layer
|
||||
self.opEmb = SemanticEmbedding(config.model.graph_encoder)
|
||||
self.dropout_op = nn.Dropout(p=config.model.dropout)
|
||||
self.d_model = config.model.graph_encoder.d_model
|
||||
self.act = act = get_act(config)
|
||||
# Time
|
||||
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
|
||||
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
|
||||
|
||||
# 2 GraphEncoder for X and Y
|
||||
self.graph_encoder = GraphEncoder(config.model.graph_encoder)
|
||||
|
||||
self.fdim = int(config.model.graph_encoder.n_layers * config.model.graph_encoder.d_model)
|
||||
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=config.data.n_vocab,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
if 'pos_enc_type' in config.model:
|
||||
self.pos_enc_type = config.model.pos_enc_type
|
||||
if self.pos_enc_type == 1:
|
||||
raise NotImplementedError
|
||||
elif self.pos_enc_type == 2:
|
||||
if config.data.name == 'NASBench201':
|
||||
self.pos_encoder = PositionalEncoding_Cell(d_model=self.d_model, max_len=config.data.max_node)
|
||||
else:
|
||||
self.pos_encoder = PositionalEncoding_StageWise(d_model=self.d_model, max_len=config.data.max_node)
|
||||
elif self.pos_enc_type == 3:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
self.pos_encoder = None
|
||||
else:
|
||||
self.pos_encoder = None
|
||||
|
||||
|
||||
def forward(self, X, time_cond, maskX):
|
||||
|
||||
emb_x = self.dropout_op(self.opEmb(X))
|
||||
|
||||
if self.pos_encoder is not None:
|
||||
emb_p = self.pos_encoder(emb_x)
|
||||
emb_x = emb_x + emb_p
|
||||
|
||||
# Time embedding
|
||||
timesteps = time_cond
|
||||
emb_t = get_timestep_embedding(timesteps, self.d_model)
|
||||
emb_t = self.timeEmb1(emb_t) # [32, 512]
|
||||
emb_t = self.timeEmb2(self.act(emb_t)) # [32, 64]
|
||||
emb_t = emb_t.unsqueeze(1)
|
||||
emb = emb_x + emb_t
|
||||
|
||||
h_x = self.graph_encoder(emb, maskX)
|
||||
h_x = self.final(h_x)
|
||||
|
||||
return h_x
|
||||
|
||||
|
||||
@utils.register_model(name='PredictorCATE')
|
||||
class PredictorCATE(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(PredictorCATE, self).__init__()
|
||||
# Shared Embedding Layer
|
||||
self.opEmb = SemanticEmbedding(config.model.graph_encoder)
|
||||
self.dropout_op = nn.Dropout(p=config.model.dropout)
|
||||
self.d_model = config.model.graph_encoder.d_model
|
||||
self.act = act = get_act(config)
|
||||
# Time
|
||||
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
|
||||
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
|
||||
|
||||
# 2 GraphEncoder for X and Y
|
||||
self.graph_encoder = GraphEncoder(config.model.graph_encoder)
|
||||
|
||||
self.fdim = int(config.model.graph_encoder.n_layers * config.model.graph_encoder.d_model)
|
||||
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=config.data.n_vocab,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
self.rdim = int(config.data.max_node * config.data.n_vocab)
|
||||
self.regeress = MLP(num_layers=2, input_dim=self.rdim, hidden_dim=2*self.rdim, output_dim=1,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
def forward(self, X, time_cond, maskX):
|
||||
|
||||
emb_x = self.dropout_op(self.opEmb(X))
|
||||
|
||||
# Time embedding
|
||||
timesteps = time_cond
|
||||
emb_t = get_timestep_embedding(timesteps, self.d_model)
|
||||
emb_t = self.timeEmb1(emb_t)
|
||||
emb_t = self.timeEmb2(self.act(emb_t))
|
||||
emb_t = emb_t.unsqueeze(1)
|
||||
emb = emb_x + emb_t
|
||||
|
||||
h_x = self.graph_encoder(emb, maskX)
|
||||
h_x = self.final(h_x)
|
||||
h_x = h_x.reshape(h_x.size(0), -1)
|
||||
h_x = self.regeress(h_x)
|
||||
|
||||
return h_x
|
||||
|
||||
|
||||
class PositionalEncoding_StageWise(nn.Module):
|
||||
|
||||
def __init__(self, d_model, max_len):
|
||||
super(PositionalEncoding_StageWise, self).__init__()
|
||||
NUM_STAGE = 5
|
||||
max_len = int(max_len / NUM_STAGE)
|
||||
self.encoding = torch.zeros(max_len, d_model)
|
||||
self.encoding.requires_grad = False
|
||||
pos = torch.arange(0, max_len)
|
||||
pos = pos.float().unsqueeze(dim=1)
|
||||
_2i = torch.arange(0, d_model, step=2).float()
|
||||
self.encoding[:, ::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
|
||||
self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
|
||||
self.encoding = torch.cat([self.encoding] * NUM_STAGE, dim=0)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, seq_len, _ = x.size()
|
||||
return self.encoding[:seq_len, :].to(x.device)
|
||||
|
||||
|
||||
class PositionalEncoding_Cell(nn.Module):
|
||||
|
||||
def __init__(self, d_model, max_len):
|
||||
super(PositionalEncoding_Cell, self).__init__()
|
||||
NUM_STAGE = 1
|
||||
max_len = int(max_len / NUM_STAGE)
|
||||
self.encoding = torch.zeros(max_len, d_model)
|
||||
self.encoding.requires_grad = False
|
||||
pos = torch.arange(0, max_len)
|
||||
pos = pos.float().unsqueeze(dim=1)
|
||||
_2i = torch.arange(0, d_model, step=2).float()
|
||||
self.encoding[:, ::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
|
||||
self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
|
||||
self.encoding = torch.cat([self.encoding] * NUM_STAGE, dim=0)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, seq_len, _ = x.size()
|
||||
return self.encoding[:seq_len, :].to(x.device)
|
||||
|
||||
|
||||
@utils.register_model(name='MetaPredictorCATE')
|
||||
class MetaPredictorCATE(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(MetaPredictorCATE, self).__init__()
|
||||
|
||||
self.input_type= config.model.input_type
|
||||
self.hs = config.model.hs
|
||||
|
||||
self.opEmb = SemanticEmbedding(config.model.graph_encoder)
|
||||
self.dropout_op = nn.Dropout(p=config.model.dropout)
|
||||
self.d_model = config.model.graph_encoder.d_model
|
||||
self.act = act = get_act(config)
|
||||
|
||||
# Time
|
||||
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
|
||||
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
|
||||
|
||||
self.graph_encoder = GraphEncoder(config.model.graph_encoder)
|
||||
|
||||
self.fdim = int(config.model.graph_encoder.n_layers * config.model.graph_encoder.d_model)
|
||||
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=config.data.n_vocab,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
self.rdim = int(config.data.max_node * config.data.n_vocab)
|
||||
self.regeress = MLP(num_layers=2, input_dim=self.rdim, hidden_dim=2*self.rdim, output_dim=2*self.rdim,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
# Set
|
||||
self.nz = config.model.nz
|
||||
self.num_sample = config.model.num_sample
|
||||
self.intra_setpool = SetPool(dim_input=512,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.inter_setpool = SetPool(dim_input=self.nz,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.set_fc = nn.Sequential(
|
||||
nn.Linear(512, self.nz),
|
||||
nn.ReLU())
|
||||
|
||||
input_dim = 0
|
||||
if 'D' in self.input_type:
|
||||
input_dim += self.nz
|
||||
if 'A' in self.input_type:
|
||||
input_dim += 2*self.rdim
|
||||
|
||||
self.pred_fc = nn.Sequential(
|
||||
nn.Linear(input_dim, self.hs),
|
||||
nn.Tanh(),
|
||||
nn.Linear(self.hs, 1)
|
||||
)
|
||||
|
||||
self.sample_state = False
|
||||
self.D_mu = None
|
||||
|
||||
|
||||
def arch_encode(self, X, time_cond, maskX):
|
||||
emb_x = self.dropout_op(self.opEmb(X))
|
||||
|
||||
# Time embedding
|
||||
timesteps = time_cond
|
||||
emb_t = get_timestep_embedding(timesteps, self.d_model)# time embedding
|
||||
emb_t = self.timeEmb1(emb_t) # [32, 512]
|
||||
emb_t = self.timeEmb2(self.act(emb_t)) # [32, 64]
|
||||
emb_t = emb_t.unsqueeze(1)
|
||||
emb = emb_x + emb_t
|
||||
|
||||
h_x = self.graph_encoder(emb, maskX)
|
||||
h_x = self.final(h_x)
|
||||
|
||||
h_x = h_x.reshape(h_x.size(0), -1)
|
||||
h_x = self.regeress(h_x)
|
||||
return h_x
|
||||
|
||||
|
||||
def set_encode(self, task):
|
||||
proto_batch = []
|
||||
for x in task:
|
||||
cls_protos = self.intra_setpool(
|
||||
x.view(-1, self.num_sample, 512)).squeeze(1)
|
||||
proto_batch.append(
|
||||
self.inter_setpool(cls_protos.unsqueeze(0)))
|
||||
v = torch.stack(proto_batch).squeeze()
|
||||
return v
|
||||
|
||||
|
||||
def predict(self, D_mu, A_mu):
|
||||
input_vec = []
|
||||
if 'D' in self.input_type:
|
||||
input_vec.append(D_mu)
|
||||
if 'A' in self.input_type:
|
||||
input_vec.append(A_mu)
|
||||
input_vec = torch.cat(input_vec, dim=1)
|
||||
return self.pred_fc(input_vec)
|
||||
|
||||
|
||||
def forward(self, X, time_cond, maskX, task):
|
||||
if self.sample_state:
|
||||
if self.D_mu is None:
|
||||
self.D_mu = self.set_encode(task)
|
||||
D_mu = self.D_mu
|
||||
else:
|
||||
D_mu = self.set_encode(task)
|
||||
A_mu = self.arch_encode(X, time_cond, maskX)
|
||||
y_pred = self.predict(D_mu, A_mu)
|
||||
return y_pred
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||
assert len(timesteps.shape) == 1
|
||||
half_dim = embedding_dim // 2
|
||||
# magic number 10000 is from transformers
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = F.pad(emb, (0, 1), mode='constant')
|
||||
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
|
||||
def get_act(config):
|
||||
"""Get actiuvation functions from the config file."""
|
||||
|
||||
if config.model.nonlinearity.lower() == 'elu':
|
||||
return nn.ELU()
|
||||
elif config.model.nonlinearity.lower() == 'relu':
|
||||
return nn.ReLU()
|
||||
elif config.model.nonlinearity.lower() == 'lrelu':
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif config.model.nonlinearity.lower() == 'swish':
|
||||
return nn.SiLU()
|
||||
elif config.model.nonlinearity.lower() == 'tanh':
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise NotImplementedError('activation function does not exist!')
|
||||
125
NAS-Bench-201/models/digcn.py
Normal file
125
NAS-Bench-201/models/digcn.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Most of this code is from https://github.com/ultmaster/neuralpredictor.pytorch
|
||||
# which was authored by Yuge Zhang, 2020
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
from . import utils
|
||||
from models.cate import PositionalEncoding_StageWise
|
||||
|
||||
|
||||
def normalize_adj(adj):
|
||||
# Row-normalize matrix
|
||||
last_dim = adj.size(-1)
|
||||
rowsum = adj.sum(2, keepdim=True).repeat(1, 1, last_dim)
|
||||
return torch.div(adj, rowsum)
|
||||
|
||||
|
||||
def graph_pooling(inputs, num_vertices):
|
||||
num_vertices = num_vertices.to(inputs.device)
|
||||
out = inputs.sum(1)
|
||||
return torch.div(out, num_vertices.unsqueeze(-1).expand_as(out))
|
||||
|
||||
|
||||
class DirectedGraphConvolution(nn.Module):
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight1 = nn.Parameter(torch.zeros((in_features, out_features)))
|
||||
self.weight2 = nn.Parameter(torch.zeros((in_features, out_features)))
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.xavier_uniform_(self.weight1.data)
|
||||
nn.init.xavier_uniform_(self.weight2.data)
|
||||
|
||||
def forward(self, inputs, adj):
|
||||
inputs = inputs.to(self.weight1.device)
|
||||
adj = adj.to(self.weight1.device)
|
||||
norm_adj = normalize_adj(adj)
|
||||
output1 = F.relu(torch.matmul(norm_adj, torch.matmul(inputs, self.weight1)))
|
||||
inv_norm_adj = normalize_adj(adj.transpose(1, 2))
|
||||
output2 = F.relu(torch.matmul(inv_norm_adj, torch.matmul(inputs, self.weight2)))
|
||||
out = (output1 + output2) / 2
|
||||
out = self.dropout(out)
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + ' (' \
|
||||
+ str(self.in_features) + ' -> ' \
|
||||
+ str(self.out_features) + ')'
|
||||
|
||||
|
||||
@utils.register_model(name='NeuralPredictor')
|
||||
class NeuralPredictor(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.gcn = [DirectedGraphConvolution(config.model.graph_encoder.initial_hidden if i == 0 else config.model.graph_encoder.gcn_hidden,
|
||||
config.model.graph_encoder.gcn_hidden)
|
||||
for i in range(config.model.graph_encoder.gcn_layers)]
|
||||
self.gcn = nn.ModuleList(self.gcn)
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.fc1 = nn.Linear(config.model.graph_encoder.gcn_hidden, config.model.graph_encoder.linear_hidden, bias=False)
|
||||
self.fc2 = nn.Linear(config.model.graph_encoder.linear_hidden, 1, bias=False)
|
||||
# Time
|
||||
self.d_model = config.model.graph_encoder.gcn_hidden
|
||||
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
|
||||
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
|
||||
self.act = act = get_act(config)
|
||||
|
||||
def forward(self, X, time_cond, maskX):
|
||||
out = X
|
||||
adj = maskX
|
||||
|
||||
numv = torch.tensor([adj.size(1)] * adj.size(0)).to(out.device) # 20
|
||||
gs = adj.size(1) # graph node number
|
||||
|
||||
timesteps = time_cond
|
||||
emb_t = get_timestep_embedding(timesteps, self.d_model)# time embedding
|
||||
emb_t = self.timeEmb1(emb_t)
|
||||
emb_t = self.timeEmb2(self.act(emb_t)) # (5, 144)
|
||||
|
||||
adj_with_diag = normalize_adj(adj + torch.eye(gs, device=adj.device)) # assuming diagonal is not 1
|
||||
for layer in self.gcn:
|
||||
out = layer(out, adj_with_diag)
|
||||
out = graph_pooling(out, numv)
|
||||
# time
|
||||
out = out + emb_t
|
||||
out = self.fc1(out)
|
||||
out = self.dropout(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||
assert len(timesteps.shape) == 1
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = F.pad(emb, (0, 1), mode='constant')
|
||||
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
|
||||
def get_act(config):
|
||||
"""Get actiuvation functions from the config file."""
|
||||
|
||||
if config.model.nonlinearity.lower() == 'elu':
|
||||
return nn.ELU()
|
||||
elif config.model.nonlinearity.lower() == 'relu':
|
||||
return nn.ReLU()
|
||||
elif config.model.nonlinearity.lower() == 'lrelu':
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif config.model.nonlinearity.lower() == 'swish':
|
||||
return nn.SiLU()
|
||||
elif config.model.nonlinearity.lower() == 'tanh':
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise NotImplementedError('activation function does not exist!')
|
||||
190
NAS-Bench-201/models/digcn_meta.py
Normal file
190
NAS-Bench-201/models/digcn_meta.py
Normal file
@@ -0,0 +1,190 @@
|
||||
# Most of this code is from https://github.com/ultmaster/neuralpredictor.pytorch
|
||||
# which was authored by Yuge Zhang, 2020
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from . import utils
|
||||
from .set_encoder.setenc_models import SetPool
|
||||
|
||||
|
||||
def normalize_adj(adj):
|
||||
# Row-normalize matrix
|
||||
last_dim = adj.size(-1)
|
||||
rowsum = adj.sum(2, keepdim=True).repeat(1, 1, last_dim)
|
||||
return torch.div(adj, rowsum)
|
||||
|
||||
|
||||
def graph_pooling(inputs, num_vertices):
|
||||
num_vertices = num_vertices.to(inputs.device)
|
||||
out = inputs.sum(1)
|
||||
return torch.div(out, num_vertices.unsqueeze(-1).expand_as(out))
|
||||
|
||||
|
||||
class DirectedGraphConvolution(nn.Module):
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight1 = nn.Parameter(torch.zeros((in_features, out_features)))
|
||||
self.weight2 = nn.Parameter(torch.zeros((in_features, out_features)))
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.xavier_uniform_(self.weight1.data)
|
||||
nn.init.xavier_uniform_(self.weight2.data)
|
||||
|
||||
def forward(self, inputs, adj):
|
||||
inputs = inputs.to(self.weight1.device)
|
||||
adj = adj.to(self.weight1.device)
|
||||
norm_adj = normalize_adj(adj)
|
||||
output1 = F.relu(torch.matmul(norm_adj, torch.matmul(inputs, self.weight1)))
|
||||
inv_norm_adj = normalize_adj(adj.transpose(1, 2))
|
||||
output2 = F.relu(torch.matmul(inv_norm_adj, torch.matmul(inputs, self.weight2)))
|
||||
out = (output1 + output2) / 2
|
||||
out = self.dropout(out)
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + ' (' \
|
||||
+ str(self.in_features) + ' -> ' \
|
||||
+ str(self.out_features) + ')'
|
||||
|
||||
|
||||
@utils.register_model(name='MetaNeuralPredictor')
|
||||
class MetaeuralPredictor(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
# Arch
|
||||
self.gcn = [DirectedGraphConvolution(config.model.graph_encoder.initial_hidden if i == 0 else config.model.graph_encoder.gcn_hidden,
|
||||
config.model.graph_encoder.gcn_hidden)
|
||||
for i in range(config.model.graph_encoder.gcn_layers)]
|
||||
self.gcn = nn.ModuleList(self.gcn)
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.fc1 = nn.Linear(config.model.graph_encoder.gcn_hidden, config.model.graph_encoder.linear_hidden, bias=False)
|
||||
|
||||
# Time
|
||||
self.d_model = config.model.graph_encoder.gcn_hidden
|
||||
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
|
||||
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
|
||||
|
||||
self.act = act = get_act(config)
|
||||
self.input_type = config.model.input_type
|
||||
self.hs = config.model.hs
|
||||
|
||||
# Set
|
||||
self.nz = config.model.nz
|
||||
self.num_sample = config.model.num_sample
|
||||
self.intra_setpool = SetPool(dim_input=512,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.inter_setpool = SetPool(dim_input=self.nz,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.set_fc = nn.Sequential(
|
||||
nn.Linear(512, self.nz),
|
||||
nn.ReLU())
|
||||
|
||||
input_dim = 0
|
||||
if 'D' in self.input_type:
|
||||
input_dim += self.nz
|
||||
if 'A' in self.input_type:
|
||||
input_dim += config.model.graph_encoder.linear_hidden
|
||||
|
||||
self.pred_fc = nn.Sequential(
|
||||
nn.Linear(input_dim, self.hs),
|
||||
nn.Tanh(),
|
||||
nn.Linear(self.hs, 1)
|
||||
)
|
||||
|
||||
self.sample_state = False
|
||||
self.D_mu = None
|
||||
|
||||
def arch_encode(self, X, time_cond, maskX):
|
||||
out = X
|
||||
adj = maskX
|
||||
numv = torch.tensor([adj.size(1)] * adj.size(0)).to(out.device)
|
||||
gs = adj.size(1) # graph node number
|
||||
|
||||
timesteps = time_cond
|
||||
emb_t = get_timestep_embedding(timesteps, self.d_model)
|
||||
emb_t = self.timeEmb1(emb_t)
|
||||
emb_t = self.timeEmb2(self.act(emb_t))
|
||||
|
||||
adj_with_diag = normalize_adj(adj + torch.eye(gs, device=adj.device))
|
||||
for layer in self.gcn:
|
||||
out = layer(out, adj_with_diag)
|
||||
out = graph_pooling(out, numv)
|
||||
# time
|
||||
out = out + emb_t
|
||||
out = self.fc1(out)
|
||||
out = self.dropout(out)
|
||||
|
||||
return out
|
||||
|
||||
def set_encode(self, task):
|
||||
proto_batch = []
|
||||
for x in task:
|
||||
cls_protos = self.intra_setpool(
|
||||
x.view(-1, self.num_sample, 512)).squeeze(1)
|
||||
proto_batch.append(
|
||||
self.inter_setpool(cls_protos.unsqueeze(0)))
|
||||
v = torch.stack(proto_batch).squeeze()
|
||||
return v
|
||||
|
||||
def predict(self, D_mu, A_mu):
|
||||
input_vec = []
|
||||
if 'D' in self.input_type:
|
||||
input_vec.append(D_mu)
|
||||
if 'A' in self.input_type:
|
||||
input_vec.append(A_mu)
|
||||
input_vec = torch.cat(input_vec, dim=1)
|
||||
return self.pred_fc(input_vec)
|
||||
|
||||
def forward(self, X, time_cond, maskX, task):
|
||||
if self.sample_state:
|
||||
if self.D_mu is None:
|
||||
self.D_mu = self.set_encode(task)
|
||||
D_mu = self.D_mu
|
||||
else:
|
||||
D_mu = self.set_encode(task)
|
||||
A_mu = self.arch_encode(X, time_cond, maskX)
|
||||
y_pred = self.predict(D_mu, A_mu)
|
||||
return y_pred
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||
assert len(timesteps.shape) == 1
|
||||
half_dim = embedding_dim // 2
|
||||
# magic number 10000 is from transformers
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = F.pad(emb, (0, 1), mode='constant')
|
||||
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
|
||||
def get_act(config):
|
||||
"""Get actiuvation functions from the config file."""
|
||||
|
||||
if config.model.nonlinearity.lower() == 'elu':
|
||||
return nn.ELU()
|
||||
elif config.model.nonlinearity.lower() == 'relu':
|
||||
return nn.ReLU()
|
||||
elif config.model.nonlinearity.lower() == 'lrelu':
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif config.model.nonlinearity.lower() == 'swish':
|
||||
return nn.SiLU()
|
||||
elif config.model.nonlinearity.lower() == 'tanh':
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise NotImplementedError('activation function does not exist!')
|
||||
85
NAS-Bench-201/models/ema.py
Normal file
85
NAS-Bench-201/models/ema.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import torch
|
||||
|
||||
|
||||
class ExponentialMovingAverage:
|
||||
"""
|
||||
Maintains (exponential) moving average of a set of parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, parameters, decay, use_num_updates=True):
|
||||
"""
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; usually the result of `model.parameters()`.
|
||||
decay: The exponential decay.
|
||||
use_num_updates: Whether to use number of updates when computing averages.
|
||||
"""
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
self.decay = decay
|
||||
self.num_updates = 0 if use_num_updates else None
|
||||
self.shadow_params = [p.clone().detach()
|
||||
for p in parameters if p.requires_grad]
|
||||
self.collected_params = []
|
||||
|
||||
def update(self, parameters):
|
||||
"""
|
||||
Update currently maintained parameters.
|
||||
|
||||
Call this every time the parameters are updated, such as the result of the `optimizer.step()` call.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; usually the same set of parameters used to
|
||||
initialize this object.
|
||||
"""
|
||||
decay = self.decay
|
||||
if self.num_updates is not None:
|
||||
self.num_updates += 1
|
||||
decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
|
||||
one_minus_decay = 1.0 - decay
|
||||
with torch.no_grad():
|
||||
parameters = [p for p in parameters if p.requires_grad]
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
s_param.sub_(one_minus_decay * (s_param - param))
|
||||
|
||||
def copy_to(self, parameters):
|
||||
"""
|
||||
Copy current parameters into given collection of parameters.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored moving averages.
|
||||
"""
|
||||
parameters = [p for p in parameters if p.requires_grad]
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
if param.requires_grad:
|
||||
param.data.copy_(s_param.data)
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
Save the current parameters for restoring later.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored.
|
||||
"""
|
||||
self.collected_params = [param.clone() for param in parameters]
|
||||
|
||||
def restore(self, parameters):
|
||||
"""
|
||||
Restore the parameters stored with the `store` method.
|
||||
Useful to validate the model with EMA parameters without affecting the original optimization process.
|
||||
Store the parameters before the `copy_to` method.
|
||||
After validation (or model saving), use this to restore the former parameters.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters.
|
||||
"""
|
||||
for c_param, param in zip(self.collected_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
||||
|
||||
def state_dict(self):
|
||||
return dict(decay=self.decay, num_updates=self.num_updates, shadow_params=self.shadow_params)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.decay = state_dict['decay']
|
||||
self.num_updates = state_dict['num_updates']
|
||||
self.shadow_params = state_dict['shadow_params']
|
||||
82
NAS-Bench-201/models/gnns.py
Normal file
82
NAS-Bench-201/models/gnns.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from .trans_layers import *
|
||||
|
||||
|
||||
class pos_gnn(nn.Module):
|
||||
def __init__(self, act, x_ch, pos_ch, out_ch, max_node, graph_layer, n_layers=3, edge_dim=None, heads=4,
|
||||
temb_dim=None, dropout=0.1, attn_clamp=False):
|
||||
super().__init__()
|
||||
self.out_ch = out_ch
|
||||
self.Dropout_0 = nn.Dropout(dropout)
|
||||
self.act = act
|
||||
self.max_node = max_node
|
||||
self.n_layers = n_layers
|
||||
|
||||
if temb_dim is not None:
|
||||
self.Dense_node0 = nn.Linear(temb_dim, x_ch)
|
||||
self.Dense_node1 = nn.Linear(temb_dim, pos_ch)
|
||||
self.Dense_edge0 = nn.Linear(temb_dim, edge_dim)
|
||||
self.Dense_edge1 = nn.Linear(temb_dim, edge_dim)
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
self.edge_convs = nn.ModuleList()
|
||||
self.edge_layer = nn.Linear(edge_dim * 2 + self.out_ch, edge_dim)
|
||||
|
||||
for i in range(n_layers):
|
||||
if i == 0:
|
||||
self.convs.append(eval(graph_layer)(x_ch, pos_ch, self.out_ch//heads, heads, edge_dim=edge_dim*2,
|
||||
act=act, attn_clamp=attn_clamp))
|
||||
else:
|
||||
self.convs.append(eval(graph_layer)
|
||||
(self.out_ch, pos_ch, self.out_ch//heads, heads, edge_dim=edge_dim*2, act=act,
|
||||
attn_clamp=attn_clamp))
|
||||
self.edge_convs.append(nn.Linear(self.out_ch, edge_dim*2))
|
||||
|
||||
def forward(self, x_degree, x_pos, edge_index, dense_ori, dense_spd, dense_index, temb=None):
|
||||
"""
|
||||
Args:
|
||||
x_degree: node degree feature [B*N, x_ch]
|
||||
x_pos: node rwpe feature [B*N, pos_ch]
|
||||
edge_index: [2, edge_length]
|
||||
dense_ori: edge feature [B, N, N, nf//2]
|
||||
dense_spd: edge shortest path distance feature [B, N, N, nf//2] # Do we need this part? # TODO
|
||||
dense_index
|
||||
temb: [B, temb_dim]
|
||||
"""
|
||||
|
||||
B, N, _, _ = dense_ori.shape
|
||||
|
||||
if temb is not None:
|
||||
dense_ori = dense_ori + self.Dense_edge0(self.act(temb))[:, None, None, :]
|
||||
dense_spd = dense_spd + self.Dense_edge1(self.act(temb))[:, None, None, :]
|
||||
|
||||
temb = temb.unsqueeze(1).repeat(1, self.max_node, 1)
|
||||
temb = temb.reshape(-1, temb.shape[-1])
|
||||
x_degree = x_degree + self.Dense_node0(self.act(temb))
|
||||
x_pos = x_pos + self.Dense_node1(self.act(temb))
|
||||
|
||||
dense_edge = torch.cat([dense_ori, dense_spd], dim=-1)
|
||||
|
||||
ori_edge_attr = dense_edge
|
||||
h = x_degree
|
||||
h_pos = x_pos
|
||||
|
||||
for i_layer in range(self.n_layers):
|
||||
h_edge = dense_edge[dense_index]
|
||||
# update node feature
|
||||
h, h_pos = self.convs[i_layer](h, h_pos, edge_index, h_edge)
|
||||
h = self.Dropout_0(h)
|
||||
h_pos = self.Dropout_0(h_pos)
|
||||
|
||||
# update dense edge feature
|
||||
h_dense_node = h.reshape(B, N, -1)
|
||||
cur_edge_attr = h_dense_node.unsqueeze(1) + h_dense_node.unsqueeze(2) # [B, N, N, nf]
|
||||
dense_edge = (dense_edge + self.act(self.edge_convs[i_layer](cur_edge_attr))) / math.sqrt(2.)
|
||||
dense_edge = self.Dropout_0(dense_edge)
|
||||
|
||||
# Concat edge attribute
|
||||
h_dense_edge = torch.cat([ori_edge_attr, dense_edge], dim=-1)
|
||||
h_dense_edge = self.edge_layer(h_dense_edge).permute(0, 3, 1, 2)
|
||||
|
||||
return h_dense_edge
|
||||
44
NAS-Bench-201/models/layers.py
Normal file
44
NAS-Bench-201/models/layers.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Common layers"""
|
||||
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
def get_act(config):
|
||||
"""Get actiuvation functions from the config file."""
|
||||
|
||||
if config.model.nonlinearity.lower() == 'elu':
|
||||
return nn.ELU()
|
||||
elif config.model.nonlinearity.lower() == 'relu':
|
||||
return nn.ReLU()
|
||||
elif config.model.nonlinearity.lower() == 'lrelu':
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif config.model.nonlinearity.lower() == 'swish':
|
||||
return nn.SiLU()
|
||||
elif config.model.nonlinearity.lower() == 'tanh':
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise NotImplementedError('activation function does not exist!')
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, padding=0):
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
|
||||
padding=padding)
|
||||
return conv
|
||||
|
||||
|
||||
# from DDPM
|
||||
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||
assert len(timesteps.shape) == 1
|
||||
half_dim = embedding_dim // 2
|
||||
# magic number 10000 is from transformers
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = F.pad(emb, (0, 1), mode='constant')
|
||||
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
||||
return emb
|
||||
38
NAS-Bench-201/models/set_encoder/setenc_models.py
Normal file
38
NAS-Bench-201/models/set_encoder/setenc_models.py
Normal file
@@ -0,0 +1,38 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from .setenc_modules import *
|
||||
|
||||
|
||||
class SetPool(nn.Module):
|
||||
def __init__(self, dim_input, num_outputs, dim_output,
|
||||
num_inds=32, dim_hidden=128, num_heads=4, ln=False, mode=None):
|
||||
super(SetPool, self).__init__()
|
||||
if 'sab' in mode: # [32, 400, 128]
|
||||
self.enc = nn.Sequential(
|
||||
SAB(dim_input, dim_hidden, num_heads, ln=ln), # SAB?
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln))
|
||||
else: # [32, 400, 128]
|
||||
self.enc = nn.Sequential(
|
||||
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln), # SAB?
|
||||
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
|
||||
if 'PF' in mode: # [32, 1, 501]
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
|
||||
nn.Linear(dim_hidden, dim_output))
|
||||
elif 'P' in mode:
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln))
|
||||
else: # torch.Size([32, 1, 501])
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln), # 32 1 128
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
|
||||
nn.Linear(dim_hidden, dim_output))
|
||||
# "", sm, sab, sabsm
|
||||
|
||||
def forward(self, X):
|
||||
x1 = self.enc(X)
|
||||
x2 = self.dec(x1)
|
||||
return x2
|
||||
67
NAS-Bench-201/models/set_encoder/setenc_modules.py
Normal file
67
NAS-Bench-201/models/set_encoder/setenc_modules.py
Normal file
@@ -0,0 +1,67 @@
|
||||
#####################################################################################
|
||||
# Copyright (c) Juho Lee SetTransformer, ICML 2019 [GitHub set_transformer]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
class MAB(nn.Module):
|
||||
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
|
||||
super(MAB, self).__init__()
|
||||
self.dim_V = dim_V
|
||||
self.num_heads = num_heads
|
||||
self.fc_q = nn.Linear(dim_Q, dim_V)
|
||||
self.fc_k = nn.Linear(dim_K, dim_V)
|
||||
self.fc_v = nn.Linear(dim_K, dim_V)
|
||||
if ln:
|
||||
self.ln0 = nn.LayerNorm(dim_V)
|
||||
self.ln1 = nn.LayerNorm(dim_V)
|
||||
self.fc_o = nn.Linear(dim_V, dim_V)
|
||||
|
||||
def forward(self, Q, K):
|
||||
Q = self.fc_q(Q)
|
||||
K, V = self.fc_k(K), self.fc_v(K)
|
||||
|
||||
dim_split = self.dim_V // self.num_heads
|
||||
Q_ = torch.cat(Q.split(dim_split, 2), 0)
|
||||
K_ = torch.cat(K.split(dim_split, 2), 0)
|
||||
V_ = torch.cat(V.split(dim_split, 2), 0)
|
||||
|
||||
A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
|
||||
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
|
||||
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
|
||||
O = O + F.relu(self.fc_o(O))
|
||||
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
|
||||
return O
|
||||
|
||||
class SAB(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, num_heads, ln=False):
|
||||
super(SAB, self).__init__()
|
||||
self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)
|
||||
|
||||
def forward(self, X):
|
||||
return self.mab(X, X)
|
||||
|
||||
class ISAB(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
|
||||
super(ISAB, self).__init__()
|
||||
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
|
||||
nn.init.xavier_uniform_(self.I)
|
||||
self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
|
||||
self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)
|
||||
|
||||
def forward(self, X):
|
||||
H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
|
||||
return self.mab1(X, H)
|
||||
|
||||
class PMA(nn.Module):
|
||||
def __init__(self, dim, num_heads, num_seeds, ln=False):
|
||||
super(PMA, self).__init__()
|
||||
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
|
||||
nn.init.xavier_uniform_(self.S)
|
||||
self.mab = MAB(dim, dim, dim, num_heads, ln=ln)
|
||||
|
||||
def forward(self, X):
|
||||
return self.mab(self.S.repeat(X.size(0), 1, 1), X)
|
||||
144
NAS-Bench-201/models/trans_layers.py
Normal file
144
NAS-Bench-201/models/trans_layers.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import math
|
||||
from typing import Union, Tuple, Optional
|
||||
from torch_geometric.typing import PairTensor, Adj, OptTensor
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Linear
|
||||
from torch_scatter import scatter
|
||||
from torch_geometric.nn.conv import MessagePassing
|
||||
from torch_geometric.utils import softmax
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PosTransLayer(MessagePassing):
|
||||
"""Involving the edge feature and updating position feature. Multiply Msg."""
|
||||
|
||||
_alpha: OptTensor
|
||||
|
||||
def __init__(self, x_channels: int, pos_channels: int, out_channels: int,
|
||||
heads: int = 1, dropout: float = 0., edge_dim: Optional[int] = None,
|
||||
bias: bool = True, act=None, attn_clamp: bool = False, **kwargs):
|
||||
kwargs.setdefault('aggr', 'add')
|
||||
super(PosTransLayer, self).__init__(node_dim=0, **kwargs)
|
||||
|
||||
self.x_channels = x_channels
|
||||
self.pos_channels = pos_channels
|
||||
self.in_channels = in_channels = x_channels + pos_channels
|
||||
self.out_channels = out_channels
|
||||
self.heads = heads
|
||||
self.dropout = dropout
|
||||
self.edge_dim = edge_dim
|
||||
self.attn_clamp = attn_clamp
|
||||
|
||||
if act is None:
|
||||
self.act = nn.LeakyReLU(negative_slope=0.2)
|
||||
else:
|
||||
self.act = act
|
||||
|
||||
self.lin_key = Linear(in_channels, heads * out_channels)
|
||||
self.lin_query = Linear(in_channels, heads * out_channels)
|
||||
self.lin_value = Linear(in_channels, heads * out_channels)
|
||||
|
||||
self.lin_edge0 = Linear(edge_dim, heads * out_channels, bias=False)
|
||||
self.lin_edge1 = Linear(edge_dim, heads * out_channels, bias=False)
|
||||
|
||||
self.lin_pos = Linear(heads * out_channels, pos_channels, bias=False)
|
||||
|
||||
self.lin_skip = Linear(x_channels, heads * out_channels, bias=bias)
|
||||
self.norm1 = nn.GroupNorm(num_groups=min(heads * out_channels // 4, 32),
|
||||
num_channels=heads * out_channels, eps=1e-6)
|
||||
self.norm2 = nn.GroupNorm(num_groups=min(heads * out_channels // 4, 32),
|
||||
num_channels=heads * out_channels, eps=1e-6)
|
||||
# FFN
|
||||
self.FFN = nn.Sequential(Linear(heads * out_channels, heads * out_channels),
|
||||
self.act,
|
||||
Linear(heads * out_channels, heads * out_channels))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
self.lin_key.reset_parameters()
|
||||
self.lin_query.reset_parameters()
|
||||
self.lin_value.reset_parameters()
|
||||
self.lin_skip.reset_parameters()
|
||||
self.lin_edge0.reset_parameters()
|
||||
self.lin_edge1.reset_parameters()
|
||||
self.lin_pos.reset_parameters()
|
||||
|
||||
def forward(self, x: OptTensor,
|
||||
pos: Tensor,
|
||||
edge_index: Adj,
|
||||
edge_attr: OptTensor = None
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
""""""
|
||||
|
||||
H, C = self.heads, self.out_channels
|
||||
|
||||
x_feat = torch.cat([x, pos], -1)
|
||||
query = self.lin_query(x_feat).view(-1, H, C)
|
||||
key = self.lin_key(x_feat).view(-1, H, C)
|
||||
value = self.lin_value(x_feat).view(-1, H, C)
|
||||
|
||||
# propagate_type: (x: PairTensor, edge_attr: OptTensor)
|
||||
out_x, out_pos = self.propagate(edge_index, query=query, key=key, value=value, pos=pos, edge_attr=edge_attr,
|
||||
size=None)
|
||||
|
||||
out_x = out_x.view(-1, self.heads * self.out_channels)
|
||||
|
||||
# skip connection for x
|
||||
x_r = self.lin_skip(x)
|
||||
out_x = (out_x + x_r) / math.sqrt(2)
|
||||
out_x = self.norm1(out_x)
|
||||
|
||||
# FFN
|
||||
out_x = (out_x + self.FFN(out_x)) / math.sqrt(2)
|
||||
out_x = self.norm2(out_x)
|
||||
|
||||
# skip connection for pos
|
||||
out_pos = pos + torch.tanh(pos + out_pos)
|
||||
|
||||
return out_x, out_pos
|
||||
|
||||
def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
|
||||
pos_j: Tensor,
|
||||
edge_attr: OptTensor,
|
||||
index: Tensor, ptr: OptTensor,
|
||||
size_i: Optional[int]) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
edge_attn = self.lin_edge0(edge_attr).view(-1, self.heads, self.out_channels)
|
||||
alpha = (query_i * key_j * edge_attn).sum(dim=-1) / math.sqrt(self.out_channels)
|
||||
if self.attn_clamp:
|
||||
alpha = alpha.clamp(min=-5., max=5.)
|
||||
|
||||
alpha = softmax(alpha, index, ptr, size_i)
|
||||
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
|
||||
|
||||
# node feature message
|
||||
msg = value_j
|
||||
msg = msg * self.lin_edge1(edge_attr).view(-1, self.heads, self.out_channels)
|
||||
msg = msg * alpha.view(-1, self.heads, 1)
|
||||
|
||||
# node position message
|
||||
pos_msg = pos_j * self.lin_pos(msg.reshape(-1, self.heads * self.out_channels))
|
||||
|
||||
return msg, pos_msg
|
||||
|
||||
def aggregate(self, inputs: Tuple[Tensor, Tensor], index: Tensor,
|
||||
ptr: Optional[Tensor] = None,
|
||||
dim_size: Optional[int] = None) -> Tuple[Tensor, Tensor]:
|
||||
if ptr is not None:
|
||||
raise NotImplementedError("Not implement Ptr in aggregate")
|
||||
else:
|
||||
return (scatter(inputs[0], index, 0, dim_size=dim_size, reduce=self.aggr),
|
||||
scatter(inputs[1], index, 0, dim_size=dim_size, reduce="mean"))
|
||||
|
||||
def update(self, inputs: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
|
||||
return inputs
|
||||
|
||||
def __repr__(self):
|
||||
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
|
||||
self.in_channels,
|
||||
self.out_channels, self.heads)
|
||||
255
NAS-Bench-201/models/transformer.py
Executable file
255
NAS-Bench-201/models/transformer.py
Executable file
@@ -0,0 +1,255 @@
|
||||
from copy import deepcopy as cp
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def clones(module, N):
|
||||
return nn.ModuleList([cp(module) for _ in range(N)])
|
||||
|
||||
|
||||
def attention(query, key, value, mask = None, dropout = None):
|
||||
d_k = query.size(-1)
|
||||
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, -1e9)
|
||||
attn = F.softmax(scores, dim = -1)
|
||||
if dropout is not None:
|
||||
attn = dropout(attn)
|
||||
return torch.matmul(attn, value), attn
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
|
||||
self.d_model = config.d_model
|
||||
self.n_head = config.n_head
|
||||
self.d_k = config.d_model // config.n_head
|
||||
|
||||
self.linears = clones(nn.Linear(self.d_model, self.d_model), 4)
|
||||
self.dropout = nn.Dropout(p=config.dropout)
|
||||
|
||||
def forward(self, query, key, value, mask = None):
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1)
|
||||
batch_size = query.size(0)
|
||||
|
||||
query, key , value = [l(x).view(batch_size, -1, self.n_head, self.d_k).transpose(1,2) for l, x in zip(self.linears, (query, key, value))]
|
||||
x, attn = attention(query, key, value, mask = mask, dropout = self.dropout)
|
||||
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.n_head * self.d_k)
|
||||
return self.linears[3](x), attn
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
|
||||
self.w_1 = nn.Linear(config.d_model, config.d_ff)
|
||||
self.w_2 = nn.Linear(config.d_ff, config.d_model)
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
||||
|
||||
|
||||
class PositionwiseFeedForwardLast(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(PositionwiseFeedForwardLast, self).__init__()
|
||||
|
||||
self.w_1 = nn.Linear(config.d_model, config.d_ff)
|
||||
self.w_2 = nn.Linear(config.d_ff, config.n_vocab)
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
||||
|
||||
|
||||
class SelfAttentionBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(SelfAttentionBlock, self).__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(config.d_model)
|
||||
self.attn = MultiHeadAttention(config)
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
|
||||
def forward(self, x, mask):
|
||||
x_ = self.norm(x)
|
||||
x_ , attn = self.attn(x_, x_, x_, mask)
|
||||
return self.dropout(x_) + x, attn
|
||||
|
||||
|
||||
class SourceAttentionBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(SourceAttentionBlock, self).__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(config.d_model)
|
||||
self.attn = MultiHeadAttention(config)
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
|
||||
def forward(self, x, m, mask):
|
||||
x_ = self.norm(x)
|
||||
x_, attn = self.attn(x_, m, m, mask)
|
||||
return self.dropout(x_) + x, attn
|
||||
|
||||
|
||||
class FeedForwardBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(FeedForwardBlock, self).__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(config.d_model)
|
||||
self.feed_forward = PositionwiseFeedForward(config)
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x_ = self.norm(x)
|
||||
x_ = self.feed_forward(x_)
|
||||
return self.dropout(x_) + x
|
||||
|
||||
|
||||
class FeedForwardBlockLast(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(FeedForwardBlockLast, self).__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(config.d_model)
|
||||
self.feed_forward = PositionwiseFeedForwardLast(config)
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
# Only for the last layer
|
||||
self.proj_fc = nn.Linear(config.d_model, config.n_vocab)
|
||||
|
||||
def forward(self, x):
|
||||
x_ = self.norm(x)
|
||||
x_ = self.feed_forward(x_)
|
||||
return self.dropout(x_) + self.proj_fc(x)
|
||||
|
||||
|
||||
class EncoderBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(EncoderBlock, self).__init__()
|
||||
self.self_attn = SelfAttentionBlock(config)
|
||||
self.feed_forward = FeedForwardBlock(config)
|
||||
|
||||
def forward(self, x, mask):
|
||||
x, attn = self.self_attn(x, mask)
|
||||
x = self.feed_forward(x)
|
||||
return x, attn
|
||||
|
||||
|
||||
class EncoderBlockLast(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(EncoderBlockLast, self).__init__()
|
||||
self.self_attn = SelfAttentionBlock(config)
|
||||
self.feed_forward = FeedForwardBlockLast(config)
|
||||
|
||||
def forward(self, x, mask):
|
||||
x, attn = self.self_attn(x, mask)
|
||||
x = self.feed_forward(x)
|
||||
return x, attn
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(DecoderBlock, self).__init__()
|
||||
|
||||
self.self_attn = SelfAttentionBlock(config)
|
||||
self.src_attn = SourceAttentionBlock(config)
|
||||
self.feed_forward = FeedForwardBlock(config)
|
||||
|
||||
def forward(self, x, m, src_mask, tgt_mask):
|
||||
x, attn_tgt = self.self_attn(x, tgt_mask)
|
||||
x, attn_src = self.src_attn(x, m, src_mask)
|
||||
x = self.feed_forward(x)
|
||||
return x, attn_src, attn_tgt
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(Encoder, self).__init__()
|
||||
self.layers = clones(EncoderBlock(config), config.n_layers)
|
||||
self.norms = clones(nn.LayerNorm(config.d_model), config.n_layers)
|
||||
|
||||
def forward(self, x, mask):
|
||||
outputs = []
|
||||
attns = []
|
||||
for layer, norm in zip(self.layers, self.norms):
|
||||
x, attn = layer(x, mask)
|
||||
outputs.append(norm(x))
|
||||
attns.append(attn)
|
||||
return outputs[-1], outputs, attns
|
||||
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(PositionalEmbedding, self).__init__()
|
||||
|
||||
p2e = torch.zeros(config.max_len, config.d_model)
|
||||
position = torch.arange(0.0, config.max_len).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0.0, config.d_model, 2) * (- math.log(10000.0) / config.d_model))
|
||||
p2e[:, 0::2] = torch.sin(position * div_term)
|
||||
p2e[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
self.register_buffer('p2e', p2e)
|
||||
|
||||
def forward(self, x):
|
||||
shp = x.size()
|
||||
with torch.no_grad():
|
||||
emb = torch.index_select(self.p2e, 0, x.view(-1)).view(shp + (-1,))
|
||||
return emb
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(Transformer, self).__init__()
|
||||
self.p2e = PositionalEmbedding(config)
|
||||
self.encoder = Encoder(config)
|
||||
|
||||
def forward(self, input_emb, position_ids, attention_mask):
|
||||
# position embedding projection
|
||||
projection = self.p2e(position_ids) + input_emb
|
||||
return self.encoder(projection, attention_mask)
|
||||
|
||||
|
||||
class TokenTypeEmbedding(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(TokenTypeEmbedding, self).__init__()
|
||||
self.t2e = nn.Embedding(config.n_token_type, config.d_model)
|
||||
self.d_model = config.d_model
|
||||
|
||||
def forward(self, x):
|
||||
return self.t2e(x) * math.sqrt(self.d_model)
|
||||
|
||||
|
||||
class SemanticEmbedding(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(SemanticEmbedding, self).__init__()
|
||||
self.d_model = config.d_model
|
||||
self.fc = nn.Linear(config.n_vocab, config.d_model)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc(x) * math.sqrt(self.d_model)
|
||||
|
||||
|
||||
class Embeddings(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(Embeddings, self).__init__()
|
||||
|
||||
self.w2e = SemanticEmbedding(config)
|
||||
self.p2e = PositionalEmbedding(config)
|
||||
self.t2e = TokenTypeEmbedding(config)
|
||||
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
|
||||
def forward(self, input_ids, position_ids = None, token_type_ids = None):
|
||||
if position_ids is None:
|
||||
batch_size, length = input_ids.size()
|
||||
with torch.no_grad():
|
||||
position_ids = torch.arange(0, length).repeat(batch_size, 1)
|
||||
if torch.cuda.is_available():
|
||||
position_ids = position_ids.cuda(device=input_ids.device)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
embeddings = self.w2e(input_ids) + self.p2e(position_ids) + self.t2e(token_type_ids)
|
||||
return self.dropout(embeddings)
|
||||
289
NAS-Bench-201/models/utils.py
Normal file
289
NAS-Bench-201/models/utils.py
Normal file
@@ -0,0 +1,289 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import sde_lib
|
||||
|
||||
_MODELS = {}
|
||||
|
||||
|
||||
def register_model(cls=None, *, name=None):
|
||||
"""A decorator for registering model classes."""
|
||||
|
||||
def _register(cls):
|
||||
if name is None:
|
||||
local_name = cls.__name__
|
||||
else:
|
||||
local_name = name
|
||||
if local_name in _MODELS:
|
||||
raise ValueError(
|
||||
f'Already registered model with name: {local_name}')
|
||||
_MODELS[local_name] = cls
|
||||
return cls
|
||||
|
||||
if cls is None:
|
||||
return _register
|
||||
else:
|
||||
return _register(cls)
|
||||
|
||||
|
||||
def get_model(name):
|
||||
return _MODELS[name]
|
||||
|
||||
|
||||
def create_model(config):
|
||||
"""Create the model."""
|
||||
model_name = config.model.name
|
||||
model = get_model(model_name)(config)
|
||||
model = model.to(config.device)
|
||||
return model
|
||||
|
||||
|
||||
def get_model_fn(model, train=False):
|
||||
"""Create a function to give the output of the score-based model.
|
||||
|
||||
Args:
|
||||
model: The score model.
|
||||
train: `True` for training and `False` for evaluation.
|
||||
|
||||
Returns:
|
||||
A model function.
|
||||
"""
|
||||
|
||||
def model_fn(x, labels, *args, **kwargs):
|
||||
"""Compute the output of the score-based model.
|
||||
|
||||
Args:
|
||||
x: A mini-batch of input data (Adjacency matrices).
|
||||
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
|
||||
for different models.
|
||||
mask: Mask for adjacency matrices.
|
||||
|
||||
Returns:
|
||||
A tuple of (model output, new mutable states)
|
||||
"""
|
||||
if not train:
|
||||
model.eval()
|
||||
return model(x, labels, *args, **kwargs)
|
||||
else:
|
||||
model.train()
|
||||
return model(x, labels, *args, **kwargs)
|
||||
|
||||
return model_fn
|
||||
|
||||
|
||||
def get_score_fn(sde, model, train=False, continuous=False):
|
||||
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
model: A score model.
|
||||
train: `True` for training and `False` for evaluation.
|
||||
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
|
||||
|
||||
Returns:
|
||||
A score function.
|
||||
"""
|
||||
model_fn = get_model_fn(model, train=train)
|
||||
|
||||
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
||||
def score_fn(x, t, *args, **kwargs):
|
||||
# Scale neural network output by standard deviation and flip sign
|
||||
if continuous or isinstance(sde, sde_lib.subVPSDE):
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
|
||||
labels = t * 999
|
||||
score = model_fn(x, labels, *args, **kwargs)
|
||||
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
labels = t * (sde.N - 1)
|
||||
score = model_fn(x, labels, *args, **kwargs)
|
||||
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
|
||||
labels.long()]
|
||||
|
||||
score = -score / std[:, None, None]
|
||||
return score
|
||||
|
||||
elif isinstance(sde, sde_lib.VESDE):
|
||||
def score_fn(x, t, *args, **kwargs):
|
||||
if continuous:
|
||||
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VE-trained models, t=0 corresponds to the highest noise level
|
||||
labels = sde.T - t
|
||||
labels *= sde.N - 1
|
||||
labels = torch.round(labels).long()
|
||||
|
||||
score = model_fn(x, labels, *args, **kwargs)
|
||||
return score
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"SDE class {sde.__class__.__name__} not yet supported.")
|
||||
|
||||
return score_fn
|
||||
|
||||
|
||||
def get_classifier_grad_fn(sde, classifier, train=False, continuous=False,
|
||||
regress=True, labels='max'):
|
||||
logit_fn = get_logit_fn(sde, classifier, train, continuous)
|
||||
|
||||
def classifier_grad_fn(x, t, *args, **kwargs):
|
||||
with torch.enable_grad():
|
||||
x_in = x.detach().requires_grad_(True)
|
||||
if regress:
|
||||
assert labels in ['max', 'min']
|
||||
logit = logit_fn(x_in, t, *args, **kwargs)
|
||||
if labels == 'max':
|
||||
prob = logit.sum()
|
||||
elif labels == 'min':
|
||||
prob = -logit.sum()
|
||||
else:
|
||||
logit = logit_fn(x_in, t, *args, **kwargs)
|
||||
log_prob = F.log_softmax(logit, dim=-1)
|
||||
prob = log_prob[range(len(logit)), labels.view(-1)].sum()
|
||||
classifier_grad = torch.autograd.grad(prob, x_in)[0]
|
||||
return classifier_grad
|
||||
|
||||
return classifier_grad_fn
|
||||
|
||||
|
||||
def get_logit_fn(sde, classifier, train=False, continuous=False):
|
||||
classifier_fn = get_model_fn(classifier, train=train)
|
||||
|
||||
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
||||
def logit_fn(x, t, *args, **kwargs):
|
||||
# Scale neural network output by standard deviation and flip sign
|
||||
if continuous or isinstance(sde, sde_lib.subVPSDE):
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
|
||||
labels = t * 999
|
||||
logit = classifier_fn(x, labels, *args, **kwargs)
|
||||
else:
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
labels = t * (sde.N - 1)
|
||||
logit = classifier_fn(x, labels, *args, **kwargs)
|
||||
return logit
|
||||
|
||||
elif isinstance(sde, sde_lib.VESDE):
|
||||
def logit_fn(x, t, *args, **kwargs):
|
||||
if continuous:
|
||||
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VE-trained models, t=0 corresponds to the highest noise level
|
||||
labels = sde.T - t
|
||||
labels *= sde.N - 1
|
||||
labels = torch.round(labels).long()
|
||||
logit = classifier_fn(x, labels, *args, **kwargs)
|
||||
return logit
|
||||
|
||||
return logit_fn
|
||||
|
||||
|
||||
def get_predictor_fn(sde, model, train=False, continuous=False):
|
||||
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
model: A predictor model.
|
||||
train: `True` for training and `False` for evaluation.
|
||||
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
|
||||
|
||||
Returns:
|
||||
A score function.
|
||||
"""
|
||||
model_fn = get_model_fn(model, train=train)
|
||||
|
||||
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
||||
def predictor_fn(x, t, *args, **kwargs):
|
||||
# Scale neural network output by standard deviation and flip sign
|
||||
if continuous or isinstance(sde, sde_lib.subVPSDE):
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
|
||||
labels = t * 999
|
||||
pred = model_fn(x, labels, *args, **kwargs)
|
||||
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
labels = t * (sde.N - 1)
|
||||
pred = model_fn(x, labels, *args, **kwargs)
|
||||
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
|
||||
labels.long()]
|
||||
|
||||
return pred
|
||||
|
||||
elif isinstance(sde, sde_lib.VESDE):
|
||||
def predictor_fn(x, t, *args, **kwargs):
|
||||
if continuous:
|
||||
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VE-trained models, t=0 corresponds to the highest noise level
|
||||
labels = sde.T - t
|
||||
labels *= sde.N - 1
|
||||
labels = torch.round(labels).long()
|
||||
|
||||
pred = model_fn(x, labels, *args, **kwargs)
|
||||
return pred
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"SDE class {sde.__class__.__name__} not yet supported.")
|
||||
|
||||
return predictor_fn
|
||||
|
||||
|
||||
def to_flattened_numpy(x):
|
||||
"""Flatten a torch tensor `x` and convert it to numpy."""
|
||||
return x.detach().cpu().numpy().reshape((-1,))
|
||||
|
||||
|
||||
def from_flattened_numpy(x, shape):
|
||||
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
|
||||
return torch.from_numpy(x.reshape(shape))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def mask_adj2node(adj_mask):
|
||||
"""Convert batched adjacency mask matrices to batched node mask matrices.
|
||||
|
||||
Args:
|
||||
adj_mask: [B, N, N] Batched adjacency mask matrices without self-loop edge.
|
||||
|
||||
Output:
|
||||
node_mask: [B, N] Batched node mask matrices indicating the valid nodes.
|
||||
"""
|
||||
|
||||
batch_size, max_num_nodes, _ = adj_mask.shape
|
||||
|
||||
node_mask = adj_mask[:, 0, :].clone()
|
||||
node_mask[:, 0] = 1
|
||||
|
||||
return node_mask
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_rw_feat(k_step, dense_adj):
|
||||
"""Compute k_step Random Walk for given dense adjacency matrix."""
|
||||
|
||||
rw_list = []
|
||||
deg = dense_adj.sum(-1, keepdims=True)
|
||||
AD = dense_adj / (deg + 1e-8)
|
||||
rw_list.append(AD)
|
||||
|
||||
for _ in range(k_step):
|
||||
rw = torch.bmm(rw_list[-1], AD)
|
||||
rw_list.append(rw)
|
||||
rw_map = torch.stack(rw_list[1:], dim=1) # [B, k_step, N, N]
|
||||
|
||||
rw_landing = torch.diagonal(
|
||||
rw_map, offset=0, dim1=2, dim2=3) # [B, k_step, N]
|
||||
rw_landing = rw_landing.permute(0, 2, 1) # [B, N, rw_depth]
|
||||
|
||||
# get the shortest path distance indices
|
||||
tmp_rw = rw_map.sort(dim=1)[0]
|
||||
spd_ind = (tmp_rw <= 0).sum(dim=1) # [B, N, N]
|
||||
|
||||
spd_onehot = torch.nn.functional.one_hot(
|
||||
spd_ind, num_classes=k_step+1).to(torch.float)
|
||||
spd_onehot = spd_onehot.permute(0, 3, 1, 2) # [B, kstep, N, N]
|
||||
|
||||
return rw_landing, spd_onehot
|
||||
520
NAS-Bench-201/run_lib.py
Normal file
520
NAS-Bench-201/run_lib.py
Normal file
@@ -0,0 +1,520 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import logging
|
||||
from absl import flags
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
import torch
|
||||
|
||||
from models import cate
|
||||
from models import digcn
|
||||
from models import digcn_meta
|
||||
import losses
|
||||
import sampling
|
||||
from models import utils as mutils
|
||||
from models.ema import ExponentialMovingAverage
|
||||
import datasets_nas
|
||||
import sde_lib
|
||||
from utils import *
|
||||
from logger import Logger
|
||||
from analysis.arch_metrics import SamplingArchMetrics, SamplingArchMetricsMeta
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def set_exp_name(config):
|
||||
if config.task == 'tr_scorenet':
|
||||
exp_name = f'./results/{config.task}/{config.folder_name}'
|
||||
data = config.data
|
||||
|
||||
elif config.task == 'tr_meta_surrogate':
|
||||
exp_name = f'./results/{config.task}/{config.folder_name}'
|
||||
|
||||
os.makedirs(exp_name, exist_ok=True)
|
||||
config.exp_name = exp_name
|
||||
set_random_seed(config)
|
||||
|
||||
return exp_name
|
||||
|
||||
|
||||
def set_random_seed(config):
|
||||
seed = config.seed
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def scorenet_train(config):
|
||||
"""Runs the score network training pipeline.
|
||||
Args:
|
||||
config: Configuration to use.
|
||||
"""
|
||||
|
||||
## Set logger
|
||||
exp_name = set_exp_name(config)
|
||||
logger = Logger(
|
||||
log_dir=exp_name,
|
||||
write_textfile=True)
|
||||
logger.update_config(config, is_args=True)
|
||||
logger.write_str(str(vars(config)))
|
||||
logger.write_str('-' * 100)
|
||||
|
||||
## Create directories for experimental logs
|
||||
sample_dir = os.path.join(exp_name, "samples")
|
||||
os.makedirs(sample_dir, exist_ok=True)
|
||||
|
||||
## Initialize model and optimizer
|
||||
score_model = mutils.create_model(config)
|
||||
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
|
||||
optimizer = losses.get_optimizer(config, score_model.parameters())
|
||||
state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0, config=config)
|
||||
|
||||
## Create checkpoints directory
|
||||
checkpoint_dir = os.path.join(exp_name, "checkpoints")
|
||||
|
||||
## Intermediate checkpoints to resume training
|
||||
checkpoint_meta_dir = os.path.join(exp_name, "checkpoints-meta", "checkpoint.pth")
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True)
|
||||
|
||||
## Resume training when intermediate checkpoints are detected
|
||||
if config.resume:
|
||||
state = restore_checkpoint(config.resume_ckpt_path, state, config.device, resume=config.resume)
|
||||
initial_step = int(state['step'])
|
||||
|
||||
## Build dataloader and iterators
|
||||
train_ds, eval_ds, test_ds = datasets_nas.get_dataset(config)
|
||||
train_loader, eval_loader, test_loader = datasets_nas.get_dataloader(config, train_ds, eval_ds, test_ds)
|
||||
train_iter = iter(train_loader)
|
||||
|
||||
# Create data normalizer and its inverse
|
||||
scaler = datasets_nas.get_data_scaler(config)
|
||||
inverse_scaler = datasets_nas.get_data_inverse_scaler(config)
|
||||
|
||||
## Setup SDEs
|
||||
if config.training.sde.lower() == 'vpsde':
|
||||
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
|
||||
sampling_eps = 1e-3
|
||||
elif config.training.sde.lower() == 'vesde':
|
||||
sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
|
||||
sampling_eps = 1e-5
|
||||
else:
|
||||
raise NotImplementedError(f"SDE {config.training.sde} unknown.")
|
||||
|
||||
# Build one-step training and evaluation functions
|
||||
optimize_fn = losses.optimization_manager(config)
|
||||
continuous = config.training.continuous
|
||||
reduce_mean = config.training.reduce_mean
|
||||
likelihood_weighting = config.training.likelihood_weighting
|
||||
train_step_fn = losses.get_step_fn(sde=sde,
|
||||
train=True,
|
||||
optimize_fn=optimize_fn,
|
||||
reduce_mean=reduce_mean,
|
||||
continuous=continuous,
|
||||
likelihood_weighting=likelihood_weighting,
|
||||
data=config.data.name)
|
||||
eval_step_fn = losses.get_step_fn(sde=sde,
|
||||
train=False,
|
||||
optimize_fn=optimize_fn,
|
||||
reduce_mean=reduce_mean,
|
||||
continuous=continuous,
|
||||
likelihood_weighting=likelihood_weighting,
|
||||
data=config.data.name)
|
||||
|
||||
## Build sampling functions
|
||||
if config.training.snapshot_sampling:
|
||||
sampling_shape = (config.training.eval_batch_size, config.data.max_node, config.data.n_vocab)
|
||||
sampling_fn = sampling.get_sampling_fn(config=config,
|
||||
sde=sde,
|
||||
shape=sampling_shape,
|
||||
inverse_scaler=inverse_scaler,
|
||||
eps=sampling_eps)
|
||||
|
||||
## Build analysis tools
|
||||
sampling_metrics = SamplingArchMetrics(config, train_ds, exp_name)
|
||||
|
||||
## Start training the score network
|
||||
logging.info("Starting training loop at step %d." % (initial_step,))
|
||||
element = {'train': ['training_loss'],
|
||||
'eval': ['eval_loss'],
|
||||
'test': ['test_loss'],
|
||||
'sample': ['r_valid', 'r_unique', 'r_novel']}
|
||||
|
||||
num_train_steps = config.training.n_iters
|
||||
is_best = False
|
||||
min_test_loss = 1e05
|
||||
for step in range(initial_step, num_train_steps+1):
|
||||
try:
|
||||
x, adj, extra = next(train_iter)
|
||||
except StopIteration:
|
||||
train_iter = train_loader.__iter__()
|
||||
x, adj, extra = next(train_iter)
|
||||
mask = aug_mask(adj, algo=config.data.aug_mask_algo, data=config.data.name)
|
||||
x, adj, mask = scaler(x.to(config.device)), adj.to(config.device), mask.to(config.device)
|
||||
batch = (x, adj, mask)
|
||||
|
||||
## Execute one training step
|
||||
loss = train_step_fn(state, batch)
|
||||
logger.update(key="training_loss", v=loss.item())
|
||||
if step % config.training.log_freq == 0:
|
||||
logging.info("step: %d, training_loss: %.5e" % (step, loss.item()))
|
||||
|
||||
## Report the loss on evaluation dataset periodically
|
||||
if step % config.training.eval_freq == 0:
|
||||
for eval_x, eval_adj, eval_extra in eval_loader:
|
||||
eval_mask = aug_mask(eval_adj, algo=config.data.aug_mask_algo, data=config.data.name)
|
||||
eval_x, eval_adj, eval_mask = scaler(eval_x.to(config.device)), eval_adj.to(config.device), eval_mask.to(config.device)
|
||||
eval_batch = (eval_x, eval_adj, eval_mask)
|
||||
eval_loss = eval_step_fn(state, eval_batch)
|
||||
logging.info("step: %d, eval_loss: %.5e" % (step, eval_loss.item()))
|
||||
logger.update(key="eval_loss", v=eval_loss.item())
|
||||
for test_x, test_adj, test_extra in test_loader:
|
||||
test_mask = aug_mask(test_adj, algo=config.data.aug_mask_algo, data=config.data.name)
|
||||
test_x, test_adj, test_mask = scaler(test_x.to(config.device)), test_adj.to(config.device), test_mask.to(config.device)
|
||||
test_batch = (test_x, test_adj, test_mask)
|
||||
test_loss = eval_step_fn(state, test_batch)
|
||||
logging.info("step: %d, test_loss: %.5e" % (step, test_loss.item()))
|
||||
logger.update(key="test_loss", v=test_loss.item())
|
||||
if logger.logs['test_loss'].avg < min_test_loss:
|
||||
is_best = True
|
||||
|
||||
## Save the checkpoint
|
||||
if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
|
||||
save_step = step // config.training.snapshot_freq
|
||||
save_checkpoint(checkpoint_dir, state, step, save_step, is_best)
|
||||
|
||||
## Generate samples
|
||||
if config.training.snapshot_sampling:
|
||||
ema.store(score_model.parameters())
|
||||
ema.copy_to(score_model.parameters())
|
||||
sample, sample_steps, _ = sampling_fn(score_model, mask)
|
||||
quantized_sample = quantize(sample)
|
||||
this_sample_dir = os.path.join(sample_dir, "iter_{}".format(step))
|
||||
os.makedirs(this_sample_dir, exist_ok=True)
|
||||
|
||||
## Evaluate samples
|
||||
arch_metric = sampling_metrics(arch_list=quantized_sample, this_sample_dir=this_sample_dir)
|
||||
r_valid, r_unique, r_novel = arch_metric[0][0], arch_metric[0][1], arch_metric[0][2]
|
||||
logger.update(key="r_valid", v=r_valid)
|
||||
logger.update(key="r_unique", v=r_unique)
|
||||
logger.update(key="r_novel", v=r_novel)
|
||||
logging.info("r_valid: %.5e" % (r_valid))
|
||||
logging.info("r_unique: %.5e" % (r_unique))
|
||||
logging.info("r_novel: %.5e" % (r_novel))
|
||||
|
||||
if step % config.training.eval_freq == 0:
|
||||
logger.write_log(element=element, step=step)
|
||||
else:
|
||||
logger.write_log(element={'train': ['training_loss']}, step=step)
|
||||
|
||||
logger.reset()
|
||||
|
||||
logger.save_log()
|
||||
|
||||
|
||||
def scorenet_evaluate(config):
|
||||
"""Evaluate trained score network.
|
||||
Args:
|
||||
config: Configuration to use.
|
||||
"""
|
||||
|
||||
## Set logger
|
||||
exp_name = set_exp_name(config)
|
||||
logger = Logger(
|
||||
log_dir=exp_name,
|
||||
write_textfile=True)
|
||||
logger.update_config(config, is_args=True)
|
||||
logger.write_str(str(vars(config)))
|
||||
logger.write_str('-' * 100)
|
||||
|
||||
## Load the config of pre-trained score network
|
||||
score_config = torch.load(config.scorenet_ckpt_path)['config']
|
||||
|
||||
## Setup SDEs
|
||||
if score_config.training.sde.lower() == 'vpsde':
|
||||
sde = sde_lib.VPSDE(beta_min=score_config.model.beta_min, beta_max=score_config.model.beta_max, N=score_config.model.num_scales)
|
||||
sampling_eps = 1e-3
|
||||
elif score_config.training.sde.lower() == 'vesde':
|
||||
sde = sde_lib.VESDE(sigma_min=score_config.model.sigma_min, sigma_max=score_config.model.sigma_max, N=score_config.model.num_scales)
|
||||
sampling_eps = 1e-5
|
||||
else:
|
||||
raise NotImplementedError(f"SDE {config.training.sde} unknown.")
|
||||
|
||||
## Creat data normalizer and its inverse
|
||||
inverse_scaler = datasets_nas.get_data_inverse_scaler(config)
|
||||
|
||||
# Build the sampling function
|
||||
sampling_shape = (config.eval.batch_size, score_config.data.max_node, score_config.data.n_vocab)
|
||||
sampling_fn = sampling.get_sampling_fn(config=config,
|
||||
sde=sde,
|
||||
shape=sampling_shape,
|
||||
inverse_scaler=inverse_scaler,
|
||||
eps=sampling_eps)
|
||||
|
||||
## Load pre-trained score network
|
||||
score_model = mutils.create_model(score_config)
|
||||
ema = ExponentialMovingAverage(score_model.parameters(), decay=score_config.model.ema_rate)
|
||||
state = dict(model=score_model, ema=ema, step=0, config=score_config)
|
||||
state = restore_checkpoint(config.scorenet_ckpt_path, state, device=config.device, resume=True)
|
||||
ema.store(score_model.parameters())
|
||||
ema.copy_to(score_model.parameters())
|
||||
|
||||
## Build dataset
|
||||
train_ds, eval_ds, test_ds = datasets_nas.get_dataset(score_config)
|
||||
|
||||
## Build analysis tools
|
||||
sampling_metrics = SamplingArchMetrics(config, train_ds, exp_name)
|
||||
|
||||
## Create directories for experimental logs
|
||||
sample_dir = os.path.join(exp_name, "samples")
|
||||
os.makedirs(sample_dir, exist_ok=True)
|
||||
|
||||
## Start sampling
|
||||
logging.info("Starting sampling")
|
||||
element = {'sample': ['r_valid', 'r_unique', 'r_novel']}
|
||||
|
||||
num_sampling_rounds = int(np.ceil(config.eval.num_samples / config.eval.batch_size))
|
||||
print(f'>>> Sampling for {num_sampling_rounds} rounds...')
|
||||
|
||||
all_samples = []
|
||||
adj = train_ds.adj.to(config.device)
|
||||
mask = train_ds.mask(algo=score_config.data.aug_mask_algo).to(config.device)
|
||||
if len(adj.shape) == 2: adj = adj.unsqueeze(0)
|
||||
if len(mask.shape) == 2: mask = mask.unsqueeze(0)
|
||||
|
||||
for _ in range(num_sampling_rounds):
|
||||
sample, sample_steps, _ = sampling_fn(score_model, mask)
|
||||
quantized_sample = quantize(sample)
|
||||
all_samples += quantized_sample
|
||||
|
||||
## Evaluate samples
|
||||
all_samples = all_samples[:config.eval.num_samples]
|
||||
arch_metric = sampling_metrics(arch_list=all_samples, this_sample_dir=sample_dir)
|
||||
r_valid, r_unique, r_novel = arch_metric[0][0], arch_metric[0][1], arch_metric[0][2]
|
||||
logger.update(key="r_valid", v=r_valid)
|
||||
logger.update(key="r_unique", v=r_unique)
|
||||
logger.update(key="r_novel", v=r_novel)
|
||||
logger.write_log(element=element, step=1)
|
||||
logger.save_log()
|
||||
|
||||
|
||||
def meta_surrogate_train(config):
|
||||
"""Runs the meta-predictor model training pipeline.
|
||||
Args:
|
||||
config: Configuration to use.
|
||||
"""
|
||||
## Set logger
|
||||
exp_name = set_exp_name(config)
|
||||
logger = Logger(
|
||||
log_dir=exp_name,
|
||||
write_textfile=True)
|
||||
logger.update_config(config, is_args=True)
|
||||
logger.write_str(str(vars(config)))
|
||||
logger.write_str('-' * 100)
|
||||
|
||||
## Create directories for experimental logs
|
||||
sample_dir = os.path.join(exp_name, "samples")
|
||||
os.makedirs(sample_dir, exist_ok=True)
|
||||
|
||||
## Initialize model and optimizer
|
||||
surrogate_model = mutils.create_model(config)
|
||||
optimizer = losses.get_optimizer(config, surrogate_model.parameters())
|
||||
state = dict(optimizer=optimizer, model=surrogate_model, step=0, config=config)
|
||||
|
||||
## Create checkpoints directory
|
||||
checkpoint_dir = os.path.join(exp_name, "checkpoints")
|
||||
|
||||
## Intermediate checkpoints to resume training
|
||||
checkpoint_meta_dir = os.path.join(exp_name, "checkpoints-meta", "checkpoint.pth")
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True)
|
||||
|
||||
## Resume training when intermediate checkpoints are detected and resume=True
|
||||
state = restore_checkpoint(checkpoint_meta_dir, state, config.device, resume=config.resume)
|
||||
initial_step = int(state['step'])
|
||||
|
||||
## Build dataloader and iterators
|
||||
train_ds, eval_ds, test_ds = datasets_nas.get_meta_dataset(config)
|
||||
train_loader, eval_loader, _ = datasets_nas.get_dataloader(config, train_ds, eval_ds, test_ds)
|
||||
train_iter = iter(train_loader)
|
||||
|
||||
## Create data normalizer and its inverse
|
||||
scaler = datasets_nas.get_data_scaler(config)
|
||||
inverse_scaler = datasets_nas.get_data_inverse_scaler(config)
|
||||
|
||||
## Setup SDEs
|
||||
if config.training.sde.lower() == 'vpsde':
|
||||
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
|
||||
sampling_eps = 1e-3
|
||||
elif config.training.sde.lower() == 'vesde':
|
||||
sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
|
||||
sampling_eps = 1e-5
|
||||
else:
|
||||
raise NotImplementedError(f"SDE {config.training.sde} unknown.")
|
||||
|
||||
## Build one-step training and evaluation functions
|
||||
optimize_fn = losses.optimization_manager(config)
|
||||
continuous = config.training.continuous
|
||||
reduce_mean = config.training.reduce_mean
|
||||
likelihood_weighting = config.training.likelihood_weighting
|
||||
train_step_fn = losses.get_step_fn_predictor(sde=sde,
|
||||
train=True,
|
||||
optimize_fn=optimize_fn,
|
||||
reduce_mean=reduce_mean,
|
||||
continuous=continuous,
|
||||
likelihood_weighting=likelihood_weighting,
|
||||
data=config.data.name,
|
||||
label_list=config.data.label_list,
|
||||
noised=config.training.noised)
|
||||
eval_step_fn = losses.get_step_fn_predictor(sde,
|
||||
train=False,
|
||||
optimize_fn=optimize_fn,
|
||||
reduce_mean=reduce_mean,
|
||||
continuous=continuous,
|
||||
likelihood_weighting=likelihood_weighting,
|
||||
data=config.data.name,
|
||||
label_list=config.data.label_list,
|
||||
noised=config.training.noised)
|
||||
|
||||
## Build sampling functions
|
||||
if config.training.snapshot_sampling:
|
||||
sampling_shape = (config.training.eval_batch_size, config.data.max_node, config.data.n_vocab)
|
||||
sampling_fn = sampling.get_sampling_fn(config=config,
|
||||
sde=sde,
|
||||
shape=sampling_shape,
|
||||
inverse_scaler=inverse_scaler,
|
||||
eps=sampling_eps,
|
||||
conditional=True,
|
||||
data_name=config.sampling.check_dataname, # for sanity check
|
||||
num_sample=config.model.num_sample)
|
||||
## Load pre-trained score network
|
||||
score_config = torch.load(config.scorenet_ckpt_path)['config']
|
||||
check_config(score_config, config)
|
||||
score_model = mutils.create_model(score_config)
|
||||
score_ema = ExponentialMovingAverage(score_model.parameters(), decay=score_config.model.ema_rate)
|
||||
score_state = dict(model=score_model, ema=score_ema, step=0, config=score_config)
|
||||
score_state = restore_checkpoint(config.scorenet_ckpt_path, score_state, device=config.device, resume=True)
|
||||
score_ema.copy_to(score_model.parameters())
|
||||
|
||||
## Build analysis tools
|
||||
sampling_metrics = SamplingArchMetricsMeta(config, train_ds, exp_name)
|
||||
|
||||
## Start training
|
||||
logging.info("Starting training loop at step %d." % (initial_step,))
|
||||
element = {'train': ['training_loss'],
|
||||
'eval': ['eval_loss', 'eval_p_corr', 'eval_s_corr'],
|
||||
'sample': ['r_valid', 'r_unique', 'r_novel']}
|
||||
num_train_steps = config.training.n_iters
|
||||
is_best = False
|
||||
max_eval_p_corr = -1
|
||||
for step in range(initial_step, num_train_steps + 1):
|
||||
try:
|
||||
x, adj, extra, task = next(train_iter)
|
||||
except StopIteration:
|
||||
train_iter = train_loader.__iter__()
|
||||
x, adj, extra, task = next(train_iter)
|
||||
mask = aug_mask(adj, algo=config.data.aug_mask_algo, data=config.data.name)
|
||||
x, adj, mask, task = scaler(x.to(config.device)), adj.to(config.device), mask.to(config.device), task.to(config.device)
|
||||
batch = (x, adj, mask, extra, task)
|
||||
|
||||
## Execute one training step
|
||||
loss, pred, labels = train_step_fn(state, batch)
|
||||
logger.update(key="training_loss", v=loss.item())
|
||||
if step % config.training.log_freq == 0:
|
||||
logging.info("step: %d, training_loss: %.5e" % (step, loss.item()))
|
||||
|
||||
## Report the loss on evaluation dataset periodically
|
||||
if step % config.training.eval_freq == 0:
|
||||
eval_pred_list, eval_labels_list = list(), list()
|
||||
for eval_x, eval_adj, eval_extra, eval_task in eval_loader:
|
||||
eval_mask = aug_mask(eval_adj, algo=config.data.aug_mask_algo, data=config.data.name)
|
||||
eval_x, eval_adj, eval_mask, eval_task = scaler(eval_x.to(config.device)), eval_adj.to(config.device), eval_mask.to(config.device), eval_task.to(config.device)
|
||||
eval_batch = (eval_x, eval_adj, eval_mask, eval_extra, eval_task)
|
||||
eval_loss, eval_pred, eval_labels = eval_step_fn(state, eval_batch)
|
||||
eval_pred_list += [v.detach().item() for v in eval_pred.squeeze()]
|
||||
eval_labels_list += [v.detach().item() for v in eval_labels.squeeze()]
|
||||
logging.info("step: %d, eval_loss: %.5e" % (step, eval_loss.item()))
|
||||
logger.update(key="eval_loss", v=eval_loss.item())
|
||||
eval_p_corr = pearsonr(np.array(eval_pred_list), np.array(eval_labels_list))[0]
|
||||
eval_s_corr = spearmanr(np.array(eval_pred_list), np.array(eval_labels_list))[0]
|
||||
logging.info("step: %d, eval_p_corr: %.5e" % (step, eval_p_corr))
|
||||
logging.info("step: %d, eval_s_corr: %.5e" % (step, eval_s_corr))
|
||||
logger.update(key="eval_p_corr", v=eval_p_corr)
|
||||
logger.update(key="eval_s_corr", v=eval_s_corr)
|
||||
if eval_p_corr > max_eval_p_corr:
|
||||
is_best = True
|
||||
max_eval_p_corr = eval_p_corr
|
||||
|
||||
## Save a checkpoint periodically and generate samples
|
||||
if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
|
||||
## Save the checkpoint.
|
||||
save_step = step // config.training.snapshot_freq
|
||||
save_checkpoint(checkpoint_dir, state, step, save_step, is_best)
|
||||
## Generate and save samples
|
||||
if config.training.snapshot_sampling:
|
||||
score_ema.store(score_model.parameters())
|
||||
score_ema.copy_to(score_model.parameters())
|
||||
sample = sampling_fn(score_model=score_model,
|
||||
mask=mask,
|
||||
classifier=surrogate_model,
|
||||
classifier_scale=config.sampling.classifier_scale)
|
||||
quantized_sample = quantize(sample) # quantization
|
||||
this_sample_dir = os.path.join(sample_dir, "iter_{}".format(step))
|
||||
os.makedirs(this_sample_dir, exist_ok=True)
|
||||
## Evaluate samples
|
||||
arch_metric = sampling_metrics(arch_list=quantized_sample,
|
||||
this_sample_dir=this_sample_dir,
|
||||
check_dataname=config.sampling.check_dataname)
|
||||
r_valid, r_unique, r_novel = arch_metric[0][0], arch_metric[0][1], arch_metric[0][2]
|
||||
logging.info("step: %d, r_valid: %.5e" % (step, r_valid))
|
||||
logging.info("step: %d, r_unique: %.5e" % (step, r_unique))
|
||||
logging.info("step: %d, r_novel: %.5e" % (step, r_novel))
|
||||
logger.update(key="r_valid", v=r_valid)
|
||||
logger.update(key="r_unique", v=r_unique)
|
||||
logger.update(key="r_novel", v=r_novel)
|
||||
|
||||
if step % config.training.eval_freq == 0:
|
||||
logger.write_log(element=element, step=step)
|
||||
else:
|
||||
logger.write_log(element={'train': ['training_loss']}, step=step)
|
||||
|
||||
logger.reset()
|
||||
|
||||
|
||||
def check_config(config1, config2):
|
||||
assert config1.model.sigma_min == config2.model.sigma_min
|
||||
assert config1.model.sigma_max == config2.model.sigma_max
|
||||
assert config1.training.sde == config2.training.sde
|
||||
assert config1.training.continuous == config2.training.continuous
|
||||
assert config1.data.centered == config2.data.centered
|
||||
assert config1.data.max_node == config2.data.max_node
|
||||
assert config1.data.n_vocab == config2.data.n_vocab
|
||||
|
||||
|
||||
run_train_dict = {
|
||||
'scorenet': scorenet_train,
|
||||
'meta_surrogate': meta_surrogate_train
|
||||
}
|
||||
|
||||
|
||||
run_eval_dict = {
|
||||
'scorenet': scorenet_evaluate,
|
||||
}
|
||||
|
||||
|
||||
def train(config):
|
||||
run_train_dict[config.model_type](config)
|
||||
|
||||
|
||||
def evaluate(config):
|
||||
run_eval_dict[config.model_type](config)
|
||||
|
||||
579
NAS-Bench-201/sampling.py
Normal file
579
NAS-Bench-201/sampling.py
Normal file
@@ -0,0 +1,579 @@
|
||||
"""Various sampling methods."""
|
||||
|
||||
import functools
|
||||
import torch
|
||||
import numpy as np
|
||||
import abc
|
||||
from tqdm import trange
|
||||
import sde_lib
|
||||
from models import utils as mutils
|
||||
from datasets_nas import MetaTestDataset
|
||||
from all_path import DATA_PATH
|
||||
|
||||
|
||||
_CORRECTORS = {}
|
||||
_PREDICTORS = {}
|
||||
|
||||
|
||||
def register_predictor(cls=None, *, name=None):
|
||||
"""A decorator for registering predictor classes."""
|
||||
|
||||
def _register(cls):
|
||||
if name is None:
|
||||
local_name = cls.__name__
|
||||
else:
|
||||
local_name = name
|
||||
if local_name in _PREDICTORS:
|
||||
raise ValueError(f'Already registered predictor with name: {local_name}')
|
||||
_PREDICTORS[local_name] = cls
|
||||
return cls
|
||||
|
||||
if cls is None:
|
||||
return _register
|
||||
else:
|
||||
return _register(cls)
|
||||
|
||||
|
||||
def register_corrector(cls=None, *, name=None):
|
||||
"""A decorator for registering corrector classes."""
|
||||
|
||||
def _register(cls):
|
||||
if name is None:
|
||||
local_name = cls.__name__
|
||||
else:
|
||||
local_name = name
|
||||
if local_name in _CORRECTORS:
|
||||
raise ValueError(f'Already registered corrector with name: {local_name}')
|
||||
_CORRECTORS[local_name] = cls
|
||||
return cls
|
||||
|
||||
if cls is None:
|
||||
return _register
|
||||
else:
|
||||
return _register(cls)
|
||||
|
||||
|
||||
def get_predictor(name):
|
||||
return _PREDICTORS[name]
|
||||
|
||||
|
||||
def get_corrector(name):
|
||||
return _CORRECTORS[name]
|
||||
|
||||
|
||||
def get_sampling_fn(
|
||||
config,
|
||||
sde,
|
||||
shape,
|
||||
inverse_scaler,
|
||||
eps,
|
||||
conditional=False,
|
||||
data_name='cifar10',
|
||||
num_sample=20):
|
||||
"""Create a sampling function.
|
||||
|
||||
Args:
|
||||
config: A `ml_collections.ConfigDict` object that contains all configuration information.
|
||||
sde: A `sde_lib.SDE` object that represents the forward SDE.
|
||||
shape: A sequence of integers representing the expected shape of a single sample.
|
||||
inverse_scaler: The inverse data normalizer function.
|
||||
eps: A `float` number. The reverse-time SDE is only integrated to `eps` for numerical stability.
|
||||
conditional: If `True`, the sampling function is conditional
|
||||
data_name: A `str` name of the dataset.
|
||||
num_sample: An `int` number of samples for each class of the dataset.
|
||||
|
||||
Returns:
|
||||
A function that takes random states and a replicated training state and outputs samples with the
|
||||
trailing dimensions matching `shape`.
|
||||
"""
|
||||
|
||||
sampler_name = config.sampling.method
|
||||
|
||||
# Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases.
|
||||
if sampler_name.lower() == 'pc':
|
||||
predictor = get_predictor(config.sampling.predictor.lower())
|
||||
corrector = get_corrector(config.sampling.corrector.lower())
|
||||
|
||||
if not conditional:
|
||||
print('>>> Get pc_sampler...')
|
||||
sampling_fn = get_pc_sampler_nas(sde=sde,
|
||||
shape=shape,
|
||||
predictor=predictor,
|
||||
corrector=corrector,
|
||||
inverse_scaler=inverse_scaler,
|
||||
snr=config.sampling.snr,
|
||||
n_steps=config.sampling.n_steps_each,
|
||||
probability_flow=config.sampling.probability_flow,
|
||||
continuous=config.training.continuous,
|
||||
denoise=config.sampling.noise_removal,
|
||||
eps=eps,
|
||||
device=config.device)
|
||||
else:
|
||||
print('>>> Get pc_conditional_sampler...')
|
||||
sampling_fn = get_pc_conditional_sampler_meta_nas(sde=sde,
|
||||
shape=shape,
|
||||
predictor=predictor,
|
||||
corrector=corrector,
|
||||
inverse_scaler=inverse_scaler,
|
||||
snr=config.sampling.snr,
|
||||
n_steps=config.sampling.n_steps_each,
|
||||
probability_flow=config.sampling.probability_flow,
|
||||
continuous=config.training.continuous,
|
||||
denoise=config.sampling.noise_removal,
|
||||
eps=eps,
|
||||
device=config.device,
|
||||
regress=config.sampling.regress,
|
||||
labels=config.sampling.labels,
|
||||
data_name=data_name,
|
||||
num_sample=num_sample)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Sampler name {sampler_name} unknown.")
|
||||
|
||||
return sampling_fn
|
||||
|
||||
|
||||
class Predictor(abc.ABC):
|
||||
"""The abstract class for a predictor algorithm."""
|
||||
|
||||
def __init__(self, sde, score_fn, probability_flow=False):
|
||||
super().__init__()
|
||||
self.sde = sde
|
||||
# Compute the reverse SDE/ODE
|
||||
if isinstance(sde, tuple):
|
||||
self.rsde = (sde[0].reverse(score_fn, probability_flow), sde[1].reverse(score_fn, probability_flow))
|
||||
else:
|
||||
self.rsde = sde.reverse(score_fn, probability_flow)
|
||||
self.score_fn = score_fn
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_fn(self, x, t, *args, **kwargs):
|
||||
"""One update of the predictor.
|
||||
|
||||
Args:
|
||||
x: A PyTorch tensor representing the current state.
|
||||
t: A PyTorch tensor representing the current time step.
|
||||
|
||||
Returns:
|
||||
x: A PyTorch tensor of the next state.
|
||||
x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Corrector(abc.ABC):
|
||||
"""The abstract class for a corrector algorithm."""
|
||||
|
||||
def __init__(self, sde, score_fn, snr, n_steps):
|
||||
super().__init__()
|
||||
self.sde = sde
|
||||
self.score_fn = score_fn
|
||||
self.snr = snr
|
||||
self.n_steps = n_steps
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_fn(self, x, t, *args, **kwargs):
|
||||
"""One update of the corrector.
|
||||
|
||||
Args:
|
||||
x: A PyTorch tensor representing the current state.
|
||||
t: A PyTorch tensor representing the current time step.
|
||||
|
||||
Returns:
|
||||
x: A PyTorch tensor of the next state.
|
||||
x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@register_predictor(name='euler_maruyama')
|
||||
class EulerMaruyamaPredictor(Predictor):
|
||||
def __init__(self, sde, score_fn, probability_flow=False):
|
||||
super().__init__(sde, score_fn, probability_flow)
|
||||
|
||||
def update_fn(self, x, t, *args, **kwargs):
|
||||
dt = -1. / self.rsde.N
|
||||
z = torch.randn_like(x)
|
||||
drift, diffusion = self.rsde.sde(x, t, *args, **kwargs)
|
||||
x_mean = x + drift * dt
|
||||
x = x_mean + diffusion[:, None, None] * np.sqrt(-dt) * z
|
||||
return x, x_mean
|
||||
|
||||
|
||||
@register_predictor(name='reverse_diffusion')
|
||||
class ReverseDiffusionPredictor(Predictor):
|
||||
def __init__(self, sde, score_fn, probability_flow=False):
|
||||
super().__init__(sde, score_fn, probability_flow)
|
||||
|
||||
def update_fn(self, x, t, *args, **kwargs):
|
||||
f, G = self.rsde.discretize(x, t, *args, **kwargs)
|
||||
z = torch.randn_like(x)
|
||||
x_mean = x - f
|
||||
x = x_mean + G[:, None, None] * z
|
||||
return x, x_mean
|
||||
|
||||
|
||||
@register_predictor(name='none')
|
||||
class NonePredictor(Predictor):
|
||||
"""An empty predictor that does nothing."""
|
||||
|
||||
def __init__(self, sde, score_fn, probability_flow=False):
|
||||
pass
|
||||
|
||||
def update_fn(self, x, t, *args, **kwargs):
|
||||
return x, x
|
||||
|
||||
|
||||
@register_corrector(name='langevin')
|
||||
class LangevinCorrector(Corrector):
|
||||
def __init__(self, sde, score_fn, snr, n_steps):
|
||||
super().__init__(sde, score_fn, snr, n_steps)
|
||||
|
||||
def update_fn(self, x, t, *args, **kwargs):
|
||||
sde = self.sde
|
||||
score_fn = self.score_fn
|
||||
n_steps = self.n_steps
|
||||
target_snr = self.snr
|
||||
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
||||
timestep = (t * (sde.N - 1) / sde.T).long()
|
||||
# Note: it seems that subVPSDE doesn't set alphas
|
||||
alpha = sde.alphas.to(t.device)[timestep]
|
||||
else:
|
||||
alpha = torch.ones_like(t)
|
||||
|
||||
for i in range(n_steps):
|
||||
|
||||
grad = score_fn(x, t, *args, **kwargs)
|
||||
noise = torch.randn_like(x)
|
||||
|
||||
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
|
||||
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
|
||||
|
||||
step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
|
||||
x_mean = x + step_size[:, None, None] * grad
|
||||
x = x_mean + torch.sqrt(step_size * 2)[:, None, None] * noise
|
||||
|
||||
return x, x_mean
|
||||
|
||||
|
||||
@register_corrector(name='none')
|
||||
class NoneCorrector(Corrector):
|
||||
"""An empty corrector that does nothing."""
|
||||
|
||||
def __init__(self, sde, score_fn, snr, n_steps):
|
||||
pass
|
||||
|
||||
def update_fn(self, x, t, *args, **kwargs):
|
||||
return x, x
|
||||
|
||||
|
||||
def shared_predictor_update_fn(x, t, sde, model,
|
||||
predictor, probability_flow, continuous, *args, **kwargs):
|
||||
"""A wrapper that configures and returns the update function of predictors."""
|
||||
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
|
||||
if predictor is None:
|
||||
# Corrector-only sampler
|
||||
predictor_obj = NonePredictor(sde, score_fn, probability_flow)
|
||||
else:
|
||||
predictor_obj = predictor(sde, score_fn, probability_flow)
|
||||
|
||||
return predictor_obj.update_fn(x, t, *args, **kwargs)
|
||||
|
||||
|
||||
def shared_corrector_update_fn(x, t, sde, model,
|
||||
corrector, continuous, snr, n_steps, *args, **kwargs):
|
||||
"""A wrapper that configures and returns the update function of correctors."""
|
||||
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
|
||||
|
||||
if corrector is None:
|
||||
# Predictor-only sampler
|
||||
corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps)
|
||||
else:
|
||||
corrector_obj = corrector(sde, score_fn, snr, n_steps)
|
||||
|
||||
return corrector_obj.update_fn(x, t, *args, **kwargs)
|
||||
|
||||
|
||||
def get_pc_sampler(sde,
|
||||
shape,
|
||||
predictor,
|
||||
corrector,
|
||||
inverse_scaler,
|
||||
snr,
|
||||
n_steps=1,
|
||||
probability_flow=False,
|
||||
continuous=False,
|
||||
denoise=True,
|
||||
eps=1e-3,
|
||||
device='cuda'):
|
||||
"""Create a Predictor-Corrector (PC) sampler.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object representing the forward SDE.
|
||||
shape: A sequence of integers. The expected shape of a single sample.
|
||||
predictor: A subclass of `sampling.Predictor` representing the predictor algorithm.
|
||||
corrector: A subclass of `sampling.Corrector` representing the corrector algorithm.
|
||||
inverse_scaler: The inverse data normalizer.
|
||||
snr: A `float` number. The signal-to-noise ratio for configuring correctors.
|
||||
n_steps: An integer. The number of corrector steps per predictor update.
|
||||
probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor.
|
||||
continuous: `True` indicates that the score model was continuously trained.
|
||||
denoise: If `True`, add one-step denoising to the final samples.
|
||||
eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
|
||||
device: PyTorch device.
|
||||
|
||||
Returns:
|
||||
A sampling function that returns samples and the number of function evaluations during sampling.
|
||||
"""
|
||||
# Create predictor & corrector update functions
|
||||
predictor_update_fn = functools.partial(shared_predictor_update_fn,
|
||||
sde=sde,
|
||||
predictor=predictor,
|
||||
probability_flow=probability_flow,
|
||||
continuous=continuous)
|
||||
corrector_update_fn = functools.partial(shared_corrector_update_fn,
|
||||
sde=sde,
|
||||
corrector=corrector,
|
||||
continuous=continuous,
|
||||
snr=snr,
|
||||
n_steps=n_steps)
|
||||
|
||||
def pc_sampler(model, n_nodes_pmf):
|
||||
"""The PC sampler function.
|
||||
|
||||
Args:
|
||||
model: A score model.
|
||||
n_nodes_pmf: Probability mass function of graph nodes.
|
||||
|
||||
Returns:
|
||||
Samples, number of function evaluations.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# Initial sample
|
||||
x = sde.prior_sampling(shape).to(device)
|
||||
timesteps = torch.linspace(sde.T, eps, sde.N, device=device)
|
||||
|
||||
# Sample the number of nodes
|
||||
n_nodes = torch.multinomial(n_nodes_pmf, shape[0], replacement=True)
|
||||
mask = torch.zeros((shape[0], shape[-1]), device=device)
|
||||
for i in range(shape[0]):
|
||||
mask[i][:n_nodes[i]] = 1.
|
||||
mask = (mask[:, None, :] * mask[:, :, None]).unsqueeze(1)
|
||||
mask = torch.tril(mask, -1)
|
||||
mask = mask + mask.transpose(-1, -2)
|
||||
|
||||
x = x * mask
|
||||
|
||||
for i in range(sde.N):
|
||||
t = timesteps[i]
|
||||
vec_t = torch.ones(shape[0], device=t.device) * t
|
||||
x, x_mean = corrector_update_fn(x, vec_t, model=model, mask=mask)
|
||||
x = x * mask
|
||||
x, x_mean = predictor_update_fn(x, vec_t, model=model, mask=mask)
|
||||
x = x * mask
|
||||
|
||||
return inverse_scaler(x_mean if denoise else x) * mask, sde.N * (n_steps + 1), n_nodes
|
||||
|
||||
return pc_sampler
|
||||
|
||||
|
||||
def get_pc_sampler_nas(sde,
|
||||
shape,
|
||||
predictor,
|
||||
corrector,
|
||||
inverse_scaler,
|
||||
snr,
|
||||
n_steps=1,
|
||||
probability_flow=False,
|
||||
continuous=False,
|
||||
denoise=True,
|
||||
eps=1e-3,
|
||||
device='cuda'):
|
||||
"""Create a Predictor-Corrector (PC) sampler.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object representing the forward SDE.
|
||||
shape: A sequence of integers. The expected shape of a single sample.
|
||||
predictor: A subclass of `sampling.Predictor` representing the predictor algorithm.
|
||||
corrector: A subclass of `sampling.Corrector` representing the corrector algorithm.
|
||||
inverse_scaler: The inverse data normalizer.
|
||||
snr: A `float` number. The signal-to-noise ratio for configuring correctors.
|
||||
n_steps: An integer. The number of corrector steps per predictor update.
|
||||
probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor.
|
||||
continuous: `True` indicates that the score model was continuously trained.
|
||||
denoise: If `True`, add one-step denoising to the final samples.
|
||||
eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
|
||||
device: PyTorch device.
|
||||
|
||||
Returns:
|
||||
A sampling function that returns samples and the number of function evaluations during sampling.
|
||||
"""
|
||||
# Create predictor & corrector update functions
|
||||
predictor_update_fn = functools.partial(shared_predictor_update_fn,
|
||||
sde=sde,
|
||||
predictor=predictor,
|
||||
probability_flow=probability_flow,
|
||||
continuous=continuous)
|
||||
corrector_update_fn = functools.partial(shared_corrector_update_fn,
|
||||
sde=sde,
|
||||
corrector=corrector,
|
||||
continuous=continuous,
|
||||
snr=snr,
|
||||
n_steps=n_steps)
|
||||
|
||||
def pc_sampler(model, mask):
|
||||
"""The PC sampler function.
|
||||
|
||||
Args:
|
||||
model: A score model.
|
||||
n_nodes_pmf: Probability mass function of graph nodes.
|
||||
|
||||
Returns:
|
||||
Samples, number of function evaluations.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# Initial sample
|
||||
x = sde.prior_sampling(shape).to(device)
|
||||
timesteps = torch.linspace(sde.T, eps, sde.N, device=device)
|
||||
mask = mask[0].unsqueeze(0).repeat(x.size(0), 1, 1)
|
||||
|
||||
for i in trange(sde.N, desc='[PC sampling]', position=1, leave=False):
|
||||
t = timesteps[i]
|
||||
vec_t = torch.ones(shape[0], device=t.device) * t
|
||||
x, x_mean = corrector_update_fn(x, vec_t, model=model, maskX=mask)
|
||||
x, x_mean = predictor_update_fn(x, vec_t, model=model, maskX=mask)
|
||||
return inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1), None
|
||||
|
||||
return pc_sampler
|
||||
|
||||
|
||||
def get_pc_conditional_sampler_meta_nas(
|
||||
sde,
|
||||
shape,
|
||||
predictor,
|
||||
corrector,
|
||||
inverse_scaler,
|
||||
snr,
|
||||
n_steps=1,
|
||||
probability_flow=False,
|
||||
continuous=False,
|
||||
denoise=True,
|
||||
eps=1e-5,
|
||||
device='cuda',
|
||||
regress=True,
|
||||
labels='max',
|
||||
data_name='cifar10',
|
||||
num_sample=20):
|
||||
|
||||
"""Class-conditional sampling with Predictor-Corrector (PC) samplers.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
score_model: A `torch.nn.Module` object that represents the architecture of the score-based model.
|
||||
classifier: A `torch.nn.Module` object that represents the architecture of the noise-dependent classifier.
|
||||
# classifier_params: A dictionary that contains the weights of the classifier.
|
||||
shape: A sequence of integers. The expected shape of a single sample.
|
||||
predictor: A subclass of `sampling.predictor` that represents a predictor algorithm.
|
||||
corrector: A subclass of `sampling.corrector` that represents a corrector algorithm.
|
||||
inverse_scaler: The inverse data normalizer.
|
||||
snr: A `float` number. The signal-to-noise ratio for correctors.
|
||||
n_steps: An integer. The number of corrector steps per update of the predictor.
|
||||
probability_flow: If `True`, solve the probability flow ODE for sampling with the predictor.
|
||||
continuous: `True` indicates the score-based model was trained with continuous time.
|
||||
denoise: If `True`, add one-step denoising to final samples.
|
||||
eps: A `float` number. The SDE/ODE will be integrated to `eps` to avoid numerical issues.
|
||||
|
||||
Returns: A pmapped class-conditional image sampler.
|
||||
"""
|
||||
|
||||
# --------- Meta-NAS ---------- #
|
||||
test_dataset = MetaTestDataset(
|
||||
data_path=DATA_PATH,
|
||||
data_name=data_name,
|
||||
num_sample=num_sample)
|
||||
|
||||
|
||||
def conditional_predictor_update_fn(score_model, classifier, x, t, labels, maskX, classifier_scale, *args, **kwargs):
|
||||
"""The predictor update function for class-conditional sampling."""
|
||||
score_fn = mutils.get_score_fn(sde, score_model, train=False, continuous=continuous)
|
||||
classifier_grad_fn = mutils.get_classifier_grad_fn(sde, classifier, train=False, continuous=continuous,
|
||||
regress=regress, labels=labels)
|
||||
|
||||
def total_grad_fn(x, t, *args, **kwargs):
|
||||
score = score_fn(x, t, maskX)
|
||||
classifier_grad = classifier_grad_fn(x, t, maskX, *args, **kwargs)
|
||||
return score + classifier_scale * classifier_grad
|
||||
|
||||
if predictor is None:
|
||||
predictor_obj = NonePredictor(sde, total_grad_fn, probability_flow)
|
||||
else:
|
||||
predictor_obj = predictor(sde, total_grad_fn, probability_flow)
|
||||
|
||||
return predictor_obj.update_fn(x, t, *args, **kwargs)
|
||||
|
||||
|
||||
def conditional_corrector_update_fn(score_model, classifier, x, t, labels, maskX, classifier_scale, *args, **kwargs):
|
||||
"""The corrector update function for class-conditional sampling."""
|
||||
score_fn = mutils.get_score_fn(sde, score_model, train=False, continuous=continuous)
|
||||
classifier_grad_fn = mutils.get_classifier_grad_fn(sde, classifier, train=False, continuous=continuous,
|
||||
regress=regress, labels=labels)
|
||||
|
||||
def total_grad_fn(x, t, *args, **kwargs):
|
||||
score = score_fn(x, t, maskX)
|
||||
classifier_grad = classifier_grad_fn(x, t, maskX, *args, **kwargs)
|
||||
return score + classifier_scale * classifier_grad
|
||||
|
||||
if corrector is None:
|
||||
corrector_obj = NoneCorrector(sde, total_grad_fn, snr, n_steps)
|
||||
else:
|
||||
corrector_obj = corrector(sde, total_grad_fn, snr, n_steps)
|
||||
|
||||
return corrector_obj.update_fn(x, t, *args, **kwargs)
|
||||
|
||||
|
||||
def pc_conditional_sampler(
|
||||
score_model,
|
||||
mask,
|
||||
classifier,
|
||||
classifier_scale=None,
|
||||
task=None):
|
||||
|
||||
"""Generate class-conditional samples with Predictor-Corrector (PC) samplers.
|
||||
|
||||
Args:
|
||||
score_model: A `torch.nn.Module` object that represents the training state
|
||||
of the score-based model.
|
||||
labels: A JAX array of integers that represent the target label of each sample.
|
||||
|
||||
Returns:
|
||||
Class-conditional samples.
|
||||
"""
|
||||
|
||||
# to accerlerating sampling
|
||||
with torch.no_grad():
|
||||
if task is None:
|
||||
task = test_dataset[0]
|
||||
task = task.repeat(shape[0], 1, 1)
|
||||
task = task.to(device)
|
||||
else:
|
||||
task = task.repeat(shape[0], 1, 1)
|
||||
task = task.to(device)
|
||||
classifier.sample_state = True
|
||||
classifier.D_mu = None
|
||||
|
||||
# initial sample
|
||||
x = sde.prior_sampling(shape).to(device)
|
||||
timesteps = torch.linspace(sde.T, eps, sde.N, device=device)
|
||||
|
||||
if len(mask.shape) == 3: mask = mask[0]
|
||||
mask = mask.unsqueeze(0).repeat(x.size(0), 1, 1) # adj
|
||||
|
||||
for i in trange(sde.N, desc='[PC conditional sampling]', position=1, leave=False):
|
||||
t = timesteps[i]
|
||||
vec_t = torch.ones(shape[0], device=t.device) * t
|
||||
x, x_mean = conditional_corrector_update_fn(score_model, classifier, x, vec_t, labels=labels, maskX=mask, task=task, classifier_scale=classifier_scale)
|
||||
x, x_mean = conditional_predictor_update_fn(score_model, classifier, x, vec_t, labels=labels, maskX=mask, task=task, classifier_scale=classifier_scale)
|
||||
classifier.sample_state = False
|
||||
return inverse_scaler(x_mean if denoise else x)
|
||||
|
||||
return pc_conditional_sampler
|
||||
4
NAS-Bench-201/script/download_preprocessed_dataset.sh
Normal file
4
NAS-Bench-201/script/download_preprocessed_dataset.sh
Normal file
@@ -0,0 +1,4 @@
|
||||
export LD_LIBRARY_PATH=/opt/conda/envs/gtctnz_2/lib/python3.7/site-packages/nvidia/cublas/lib/
|
||||
|
||||
echo '[Downloading processed]'
|
||||
python main_exp/transfer_nag/get_files/get_preprocessed_data.py
|
||||
15
NAS-Bench-201/script/download_raw_dataset.sh
Normal file
15
NAS-Bench-201/script/download_raw_dataset.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
export LD_LIBRARY_PATH=/opt/conda/envs/gtctnz_2/lib/python3.7/site-packages/nvidia/cublas/lib/
|
||||
|
||||
DATANAME=$1
|
||||
|
||||
if [[ $DATANAME = 'aircraft' ]]; then
|
||||
echo '[Downloading aircraft]'
|
||||
python main_exp/transfer_nag/get_files/get_aircraft.py
|
||||
|
||||
elif [[ $DATANAME = 'pets' ]]; then
|
||||
echo '[Downloading pets]'
|
||||
python main_exp/transfer_nag/get_files/get_pets.py
|
||||
|
||||
else
|
||||
echo 'Not Implemeted'
|
||||
fi
|
||||
6
NAS-Bench-201/script/tr_meta_surrogate.sh
Normal file
6
NAS-Bench-201/script/tr_meta_surrogate.sh
Normal file
@@ -0,0 +1,6 @@
|
||||
FOLDER_NAME='tr_meta_surrogate_nb201'
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$1 python main.py --config configs/tr_meta_surrogate.py \
|
||||
--mode train \
|
||||
--config.folder_name $FOLDER_NAME
|
||||
|
||||
5
NAS-Bench-201/script/tr_scorenet.sh
Normal file
5
NAS-Bench-201/script/tr_scorenet.sh
Normal file
@@ -0,0 +1,5 @@
|
||||
FOLDER_NAME='tr_scorenet_nb201'
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$1 python main.py --config configs/tr_scorenet.py \
|
||||
--mode train \
|
||||
--config.folder_name $FOLDER_NAME
|
||||
10
NAS-Bench-201/script/transfer_nag.sh
Normal file
10
NAS-Bench-201/script/transfer_nag.sh
Normal file
@@ -0,0 +1,10 @@
|
||||
FOLDER_NAME='transfer_nag_nb201'
|
||||
|
||||
GPU=$1
|
||||
DATANAME=$2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$GPU python main_exp/transfer_nag/main.py \
|
||||
--gpu $GPU \
|
||||
--test \
|
||||
--folder_name $FOLDER_NAME \
|
||||
--data-name $DATANAME
|
||||
300
NAS-Bench-201/sde_lib.py
Normal file
300
NAS-Bench-201/sde_lib.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""Abstract SDE classes, Reverse SDE, and VP SDEs."""
|
||||
|
||||
import abc
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SDE(abc.ABC):
|
||||
"""SDE abstract class. Functions are designed for a mini-batch of inputs."""
|
||||
|
||||
def __init__(self, N):
|
||||
"""Construct an SDE.
|
||||
|
||||
Args:
|
||||
N: number of discretization time steps.
|
||||
"""
|
||||
super().__init__()
|
||||
self.N = N
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def T(self):
|
||||
"""End time of the SDE."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def sde(self, x, t):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def marginal_prob(self, x, t):
|
||||
"""Parameters to determine the marginal distribution of the SDE, $p_t(x)$"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def prior_sampling(self, shape):
|
||||
"""Generate one sample from the prior distribution, $p_T(x)$."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def prior_logp(self, z, mask):
|
||||
"""Compute log-density of the prior distribution.
|
||||
|
||||
Useful for computing the log-likelihood via probability flow ODE.
|
||||
|
||||
Args:
|
||||
z: latent code
|
||||
Returns:
|
||||
log probability density
|
||||
"""
|
||||
pass
|
||||
|
||||
def discretize(self, x, t):
|
||||
"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
|
||||
|
||||
Useful for reverse diffusion sampling and probability flow sampling.
|
||||
Defaults to Euler-Maruyama discretization.
|
||||
|
||||
Args:
|
||||
x: a torch tensor
|
||||
t: a torch float representing the time step (from 0 to `self.T`)
|
||||
|
||||
Returns:
|
||||
f, G
|
||||
"""
|
||||
dt = 1 / self.N
|
||||
drift, diffusion = self.sde(x, t)
|
||||
f = drift * dt
|
||||
G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
|
||||
return f, G
|
||||
|
||||
def reverse(self, score_fn, probability_flow=False):
|
||||
"""Create the reverse-time SDE/ODE.
|
||||
|
||||
Args:
|
||||
score_fn: A time-dependent score-based model that takes x and t and returns the score.
|
||||
probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
|
||||
"""
|
||||
|
||||
N = self.N
|
||||
T = self.T
|
||||
sde_fn = self.sde
|
||||
discretize_fn = self.discretize
|
||||
|
||||
# Build the class for reverse-time SDE.
|
||||
class RSDE(self.__class__):
|
||||
def __init__(self):
|
||||
self.N = N
|
||||
self.probability_flow = probability_flow
|
||||
|
||||
@property
|
||||
def T(self):
|
||||
return T
|
||||
|
||||
def sde(self, x, t, *args, **kwargs):
|
||||
"""Create the drift and diffusion functions for the reverse SDE/ODE."""
|
||||
|
||||
drift, diffusion = sde_fn(x, t)
|
||||
score = score_fn(x, t, *args, **kwargs)
|
||||
drift = drift - diffusion[:, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)
|
||||
# Set the diffusion function to zero for ODEs.
|
||||
diffusion = 0. if self.probability_flow else diffusion
|
||||
return drift, diffusion
|
||||
|
||||
'''
|
||||
def sde_score(self, x, t, score):
|
||||
"""Create the drift and diffusion functions for the reverse SDE/ODE, given score values."""
|
||||
drift, diffusion = sde_fn(x, t)
|
||||
if len(score.shape) == 4:
|
||||
drift = drift - diffusion[:, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)
|
||||
elif len(score.shape) == 3:
|
||||
drift = drift - diffusion[:, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)
|
||||
else:
|
||||
raise ValueError
|
||||
diffusion = 0. if self.probability_flow else diffusion
|
||||
return drift, diffusion
|
||||
'''
|
||||
|
||||
def discretize(self, x, t, *args, **kwargs):
|
||||
"""Create discretized iteration rules for the reverse diffusion sampler."""
|
||||
f, G = discretize_fn(x, t)
|
||||
rev_f = f - G[:, None, None] ** 2 * score_fn(x, t, *args, **kwargs) * \
|
||||
(0.5 if self.probability_flow else 1.)
|
||||
rev_G = torch.zeros_like(G) if self.probability_flow else G
|
||||
return rev_f, rev_G
|
||||
|
||||
'''
|
||||
def discretize_score(self, x, t, score):
|
||||
"""Create discretized iteration rules for the reverse diffusion sampler, given score values."""
|
||||
f, G = discretize_fn(x, t)
|
||||
if len(score.shape) == 4:
|
||||
rev_f = f - G[:, None, None, None] ** 2 * score * \
|
||||
(0.5 if self.probability_flow else 1.)
|
||||
elif len(score.shape) == 3:
|
||||
rev_f = f - G[:, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)
|
||||
else:
|
||||
raise ValueError
|
||||
rev_G = torch.zeros_like(G) if self.probability_flow else G
|
||||
return rev_f, rev_G
|
||||
'''
|
||||
|
||||
return RSDE()
|
||||
|
||||
|
||||
class VPSDE(SDE):
|
||||
def __init__(self, beta_min=0.1, beta_max=20, N=1000):
|
||||
"""Construct a Variance Preserving SDE.
|
||||
|
||||
Args:
|
||||
beta_min: value of beta(0)
|
||||
beta_max: value of beta(1)
|
||||
N: number of discretization steps
|
||||
"""
|
||||
super().__init__(N)
|
||||
self.beta_0 = beta_min
|
||||
self.beta_1 = beta_max
|
||||
self.N = N
|
||||
self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N)
|
||||
self.alphas = 1. - self.discrete_betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
||||
self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
|
||||
|
||||
@property
|
||||
def T(self):
|
||||
return 1
|
||||
|
||||
def sde(self, x, t):
|
||||
beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
|
||||
if len(x.shape) == 4:
|
||||
drift = -0.5 * beta_t[:, None, None, None] * x
|
||||
elif len(x.shape) == 3:
|
||||
drift = -0.5 * beta_t[:, None, None] * x
|
||||
else:
|
||||
raise NotImplementedError
|
||||
diffusion = torch.sqrt(beta_t)
|
||||
return drift, diffusion
|
||||
|
||||
def marginal_prob(self, x, t):
|
||||
log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
||||
if len(x.shape) == 4:
|
||||
mean = torch.exp(log_mean_coeff[:, None, None, None]) * x
|
||||
elif len(x.shape) == 3:
|
||||
mean = torch.exp(log_mean_coeff[:, None, None]) * x
|
||||
else:
|
||||
raise ValueError("The shape of x in marginal_prob is not correct.")
|
||||
std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
|
||||
return mean, std
|
||||
|
||||
def prior_sampling(self, shape):
|
||||
return torch.randn(*shape)
|
||||
|
||||
def prior_logp(self, z, mask):
|
||||
N = torch.sum(mask, dim=tuple(range(1, len(mask.shape))))
|
||||
logps = -N / 2. * np.log(2 * np.pi) - torch.sum((z * mask) ** 2, dim=(1, 2, 3)) / 2.
|
||||
return logps
|
||||
|
||||
def discretize(self, x, t):
|
||||
"""DDPM discretization."""
|
||||
timestep = (t * (self.N - 1) / self.T).long()
|
||||
beta = self.discrete_betas.to(x.device)[timestep]
|
||||
alpha = self.alphas.to(x.device)[timestep]
|
||||
sqrt_beta = torch.sqrt(beta)
|
||||
if len(x.shape) == 4:
|
||||
f = torch.sqrt(alpha)[:, None, None, None] * x - x
|
||||
elif len(x.shape) == 3:
|
||||
f = torch.sqrt(alpha)[:, None, None] * x - x
|
||||
else:
|
||||
NotImplementedError
|
||||
G = sqrt_beta
|
||||
return f, G
|
||||
|
||||
|
||||
class subVPSDE(SDE):
|
||||
def __init__(self, beta_min=0.1, beta_max=20, N=1000):
|
||||
"""Construct the sub-VP SDE that excels at likelihoods.
|
||||
Args:
|
||||
beta_min: value of beta(0)
|
||||
beta_max: value of beta(1)
|
||||
N: number of discretization steps
|
||||
"""
|
||||
super().__init__(N)
|
||||
self.beta_0 = beta_min
|
||||
self.beta_1 = beta_max
|
||||
self.N = N
|
||||
|
||||
@property
|
||||
def T(self):
|
||||
return 1
|
||||
|
||||
def sde(self, x, t):
|
||||
beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
|
||||
drift = -0.5 * beta_t[:, None, None] * x
|
||||
discount = 1. - torch.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t ** 2)
|
||||
diffusion = torch.sqrt(beta_t * discount)
|
||||
return drift, diffusion
|
||||
|
||||
def marginal_prob(self, x, t):
|
||||
log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
||||
mean = torch.exp(log_mean_coeff)[:, None, None] * x
|
||||
std = 1 - torch.exp(2. * log_mean_coeff)
|
||||
return mean, std
|
||||
|
||||
def prior_sampling(self, shape):
|
||||
return torch.randn(*shape)
|
||||
|
||||
def prior_logp(self, z):
|
||||
shape = z.shape
|
||||
N = np.prod(shape[1:])
|
||||
return -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2.
|
||||
|
||||
|
||||
class VESDE(SDE):
|
||||
def __init__(self, sigma_min=0.01, sigma_max=50, N=1000):
|
||||
"""Construct a Variance Exploding SDE.
|
||||
|
||||
Args:
|
||||
sigma_min: smallest sigma.
|
||||
sigma_max: largest sigma.
|
||||
N: number of discretization steps
|
||||
"""
|
||||
super().__init__(N)
|
||||
self.sigma_min = sigma_min
|
||||
self.sigma_max = sigma_max
|
||||
self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
|
||||
self.N = N
|
||||
|
||||
@property
|
||||
def T(self):
|
||||
return 1
|
||||
|
||||
def sde(self, x, t):
|
||||
sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
|
||||
drift = torch.zeros_like(x)
|
||||
diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)),
|
||||
device=t.device))
|
||||
return drift, diffusion
|
||||
|
||||
def marginal_prob(self, x, t):
|
||||
std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
|
||||
mean = x
|
||||
return mean, std
|
||||
|
||||
def prior_sampling(self, shape):
|
||||
return torch.randn(*shape) * self.sigma_max
|
||||
|
||||
def prior_logp(self, z):
|
||||
shape = z.shape
|
||||
N = np.prod(shape[1:])
|
||||
return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2)
|
||||
|
||||
def discretize(self, x, t):
|
||||
"""SMLD(NCSN) discretization."""
|
||||
timestep = (t * (self.N - 1) / self.T).long()
|
||||
sigma = self.discrete_sigmas.to(t.device)[timestep]
|
||||
adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
|
||||
self.discrete_sigmas[timestep.cpu() - 1].to(t.device))
|
||||
f = torch.zeros_like(x)
|
||||
G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
|
||||
return f, G
|
||||
262
NAS-Bench-201/utils.py
Normal file
262
NAS-Bench-201/utils.py
Normal file
@@ -0,0 +1,262 @@
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
from torch_scatter import scatter
|
||||
import shutil
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def to_dense_adj(edge_index, batch=None, edge_attr=None, max_num_nodes=None):
|
||||
"""Converts batched sparse adjacency matrices given by edge indices and
|
||||
edge attributes to a single dense batched adjacency matrix.
|
||||
|
||||
Args:
|
||||
edge_index (LongTensor): The edge indices.
|
||||
batch (LongTensor, optional): Batch vector
|
||||
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
|
||||
node to a specific example. (default: :obj:`None`)
|
||||
edge_attr (Tensor, optional): Edge weights or multi-dimensional edge
|
||||
features. (default: :obj:`None`)
|
||||
max_num_nodes (int, optional): The size of the output node dimension.
|
||||
(default: :obj:`None`)
|
||||
|
||||
Returns:
|
||||
adj: [batch_size, max_num_nodes, max_num_nodes] Dense adjacency matrices.
|
||||
mask: Mask for dense adjacency matrices.
|
||||
"""
|
||||
if batch is None:
|
||||
batch = edge_index.new_zeros(edge_index.max().item() + 1)
|
||||
|
||||
batch_size = batch.max().item() + 1
|
||||
one = batch.new_ones(batch.size(0))
|
||||
num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce='add')
|
||||
cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])
|
||||
|
||||
idx0 = batch[edge_index[0]]
|
||||
idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]]
|
||||
idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]]
|
||||
|
||||
if max_num_nodes is None:
|
||||
max_num_nodes = num_nodes.max().item()
|
||||
|
||||
elif idx1.max() >= max_num_nodes or idx2.max() >= max_num_nodes:
|
||||
mask = (idx1 < max_num_nodes) & (idx2 < max_num_nodes)
|
||||
idx0 = idx0[mask]
|
||||
idx1 = idx1[mask]
|
||||
idx2 = idx2[mask]
|
||||
edge_attr = None if edge_attr is None else edge_attr[mask]
|
||||
|
||||
if edge_attr is None:
|
||||
edge_attr = torch.ones(idx0.numel(), device=edge_index.device)
|
||||
|
||||
size = [batch_size, max_num_nodes, max_num_nodes]
|
||||
size += list(edge_attr.size())[1:]
|
||||
adj = torch.zeros(size, dtype=edge_attr.dtype, device=edge_index.device)
|
||||
|
||||
flattened_size = batch_size * max_num_nodes * max_num_nodes
|
||||
adj = adj.view([flattened_size] + list(adj.size())[3:])
|
||||
idx = idx0 * max_num_nodes * max_num_nodes + idx1 * max_num_nodes + idx2
|
||||
scatter(edge_attr, idx, dim=0, out=adj, reduce='add')
|
||||
adj = adj.view(size)
|
||||
|
||||
node_idx = torch.arange(batch.size(0), dtype=torch.long, device=edge_index.device)
|
||||
node_idx = (node_idx - cum_nodes[batch]) + (batch * max_num_nodes)
|
||||
mask = torch.zeros(batch_size * max_num_nodes, dtype=adj.dtype, device=adj.device)
|
||||
mask[node_idx] = 1
|
||||
mask = mask.view(batch_size, max_num_nodes)
|
||||
|
||||
mask = mask[:, None, :] * mask[:, :, None]
|
||||
|
||||
return adj, mask
|
||||
|
||||
|
||||
def restore_checkpoint_partial(model, pretrained_stdict):
|
||||
model_dict = model.state_dict()
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {k: v for k, v in pretrained_stdict.items() if k in model_dict}
|
||||
# 2. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
# 3. load the new state dict
|
||||
model.load_state_dict(model_dict)
|
||||
return model
|
||||
|
||||
|
||||
def restore_checkpoint(ckpt_dir, state, device, resume=False):
|
||||
if not resume:
|
||||
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
|
||||
return state
|
||||
elif not os.path.exists(ckpt_dir):
|
||||
if not os.path.exists(os.path.dirname(ckpt_dir)):
|
||||
os.makedirs(os.path.dirname(ckpt_dir))
|
||||
logging.warning(f"No checkpoint found at {ckpt_dir}. "
|
||||
f"Returned the same state as input")
|
||||
return state
|
||||
else:
|
||||
loaded_state = torch.load(ckpt_dir, map_location=device)
|
||||
for k in state:
|
||||
if k in ['optimizer', 'model', 'ema']:
|
||||
state[k].load_state_dict(loaded_state[k])
|
||||
else:
|
||||
state[k] = loaded_state[k]
|
||||
return state
|
||||
|
||||
|
||||
def save_checkpoint(ckpt_dir, state, step, save_step, is_best, remove_except_best=False):
|
||||
saved_state = {}
|
||||
for k in state:
|
||||
if k in ['optimizer', 'model', 'ema']:
|
||||
saved_state.update({k: state[k].state_dict()})
|
||||
else:
|
||||
saved_state.update({k: state[k]})
|
||||
os.makedirs(ckpt_dir, exist_ok=True)
|
||||
torch.save(saved_state, os.path.join(ckpt_dir, f'checkpoint_{step}_{save_step}.pth.tar'))
|
||||
if is_best:
|
||||
shutil.copy(os.path.join(ckpt_dir, f'checkpoint_{step}_{save_step}.pth.tar'), os.path.join(ckpt_dir, 'model_best.pth.tar'))
|
||||
# remove the ckpt except is_best state
|
||||
if remove_except_best:
|
||||
for ckpt_file in sorted(os.listdir(ckpt_dir)):
|
||||
if not ckpt_file.startswith('checkpoint'):
|
||||
continue
|
||||
if os.path.join(ckpt_dir, ckpt_file) != os.path.join(ckpt_dir, 'model_best.pth.tar'):
|
||||
os.remove(os.path.join(ckpt_dir, ckpt_file))
|
||||
|
||||
|
||||
def floyed(r):
|
||||
"""
|
||||
:param r: a numpy NxN matrix with float 0,1
|
||||
:return: a numpy NxN matrix with float 0,1
|
||||
"""
|
||||
# r = np.array(r)
|
||||
if type(r) == torch.Tensor:
|
||||
r = r.cpu().numpy()
|
||||
N = r.shape[0]
|
||||
# import pdb; pdb.set_trace()
|
||||
for k in range(N):
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
if r[i, k] > 0 and r[k, j] > 0:
|
||||
r[i, j] = 1
|
||||
return r
|
||||
|
||||
|
||||
def aug_mask(adj, algo='floyed', data='NASBench201'):
|
||||
if len(adj.shape) == 2:
|
||||
adj = adj.unsqueeze(0)
|
||||
|
||||
if data.lower() in ['nasbench201', 'ofa']:
|
||||
assert len(adj.shape) == 3
|
||||
r = adj[0].clone().detach()
|
||||
if algo == 'long_range':
|
||||
mask_i = torch.from_numpy(long_range(r)).float().to(adj.device)
|
||||
elif algo == 'floyed':
|
||||
mask_i = torch.from_numpy(floyed(r)).float().to(adj.device)
|
||||
else:
|
||||
mask_i = r
|
||||
masks = [mask_i] * adj.size(0)
|
||||
return torch.stack(masks)
|
||||
else:
|
||||
masks = []
|
||||
for r in adj:
|
||||
if algo == 'long_range':
|
||||
mask_i = torch.from_numpy(long_range(r)).float().to(adj.device)
|
||||
elif algo == 'floyed':
|
||||
mask_i = torch.from_numpy(floyed(r)).float().to(adj.device)
|
||||
else:
|
||||
mask_i = r
|
||||
masks.append(mask_i)
|
||||
return torch.stack(masks)
|
||||
|
||||
|
||||
def long_range(r):
|
||||
"""
|
||||
:param r: a numpy NxN matrix with float 0,1
|
||||
:return: a numpy NxN matrix with float 0,1
|
||||
"""
|
||||
# r = np.array(r)
|
||||
if type(r) == torch.Tensor:
|
||||
r = r.cpu().numpy()
|
||||
N = r.shape[0]
|
||||
for j in range(1, N):
|
||||
col_j = r[:, j][:j]
|
||||
in_to_j = [i for i, val in enumerate(col_j) if val > 0]
|
||||
if len(in_to_j) > 0:
|
||||
for i in in_to_j:
|
||||
col_i = r[:, i][:i]
|
||||
in_to_i = [i for i, val in enumerate(col_i) if val > 0]
|
||||
if len(in_to_i) > 0:
|
||||
for k in in_to_i:
|
||||
r[k, j] = 1
|
||||
return r
|
||||
|
||||
|
||||
def dense_adj(graph_data, max_num_nodes, scaler=None, dequantization=False):
|
||||
"""Convert PyG DataBatch to dense adjacency matrices.
|
||||
|
||||
Args:
|
||||
graph_data: DataBatch object.
|
||||
max_num_nodes: The size of the output node dimension.
|
||||
scaler: Data normalizer.
|
||||
dequantization: uniform dequantization.
|
||||
|
||||
Returns:
|
||||
adj: Dense adjacency matrices.
|
||||
mask: Mask for adjacency matrices.
|
||||
"""
|
||||
|
||||
adj, adj_mask = to_dense_adj(graph_data.edge_index, graph_data.batch, max_num_nodes=max_num_nodes) # [B, N, N]
|
||||
# adj: [32, 20, 20] / adj_mask: [32, 20, 20]
|
||||
if dequantization:
|
||||
noise = torch.rand_like(adj)
|
||||
noise = torch.tril(noise, -1)
|
||||
noise = noise + noise.transpose(1, 2)
|
||||
adj = (noise + adj) / 2.
|
||||
adj = scaler(adj[:, None, :, :]) # [32, 1, 20, 20]
|
||||
# set diag = 0 in adj_mask
|
||||
adj_mask = torch.tril(adj_mask, -1) # [32, 20, 20]
|
||||
adj_mask = adj_mask + adj_mask.transpose(1, 2)
|
||||
|
||||
return adj, adj_mask[:, None, :, :]
|
||||
|
||||
|
||||
def adj2graph(adj, sample_nodes):
|
||||
"""Covert the PyTorch tensor adjacency matrices to numpy array.
|
||||
|
||||
Args:
|
||||
adj: [Batch_size, channel, Max_node, Max_node], assume channel=1
|
||||
sample_nodes: [Batch_size]
|
||||
"""
|
||||
adj_list = []
|
||||
# discretization
|
||||
adj[adj >= 0.5] = 1.
|
||||
adj[adj < 0.5] = 0.
|
||||
for i in range(adj.shape[0]):
|
||||
adj_tmp = adj[i, 0]
|
||||
# symmetric
|
||||
adj_tmp = torch.tril(adj_tmp, -1)
|
||||
adj_tmp = adj_tmp + adj_tmp.transpose(0, 1)
|
||||
# truncate
|
||||
adj_tmp = adj_tmp.cpu().numpy()[:sample_nodes[i], :sample_nodes[i]]
|
||||
adj_list.append(adj_tmp)
|
||||
|
||||
return adj_list
|
||||
|
||||
|
||||
def quantize(x):
|
||||
"""Covert the PyTorch tensor x, adj matrices to numpy array.
|
||||
|
||||
Args:
|
||||
x: [Batch_size, Max_node, N_vocab]
|
||||
adj: [Batch_size, Max_node, Max_node]
|
||||
"""
|
||||
x_list = []
|
||||
|
||||
# discretization
|
||||
x[x >= 0.5] = 1.
|
||||
x[x < 0.5] = 0.
|
||||
|
||||
for i in range(x.shape[0]):
|
||||
x_tmp = x[i]
|
||||
x_tmp = x_tmp.cpu().numpy()
|
||||
x_list.append(x_tmp)
|
||||
|
||||
return x_list
|
||||
Reference in New Issue
Block a user