update
This commit is contained in:
parent
5a1dc89756
commit
2410fe9f5e
133
correlation/NAS-Bench-101.py
Normal file
133
correlation/NAS-Bench-101.py
Normal file
@ -0,0 +1,133 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import pickle
|
||||
import torch
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
from thop import profile
|
||||
|
||||
from foresight.models import *
|
||||
from foresight.pruners import *
|
||||
from foresight.dataset import *
|
||||
|
||||
|
||||
def get_num_classes(args):
|
||||
return 100 if args.dataset == 'cifar100' else 10 if args.dataset == 'cifar10' else 120
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='Zero-cost Metrics for NAS-Bench-101')
|
||||
parser.add_argument('--api_loc', default='../data/nasbench_only108.tfrecord',
|
||||
type=str, help='path to API')
|
||||
parser.add_argument('--json_loc', default='data/all_graphs.json',
|
||||
type=str, help='path to JSON database')
|
||||
parser.add_argument('--outdir', default='./',
|
||||
type=str, help='output directory')
|
||||
parser.add_argument('--outfname', default='test',
|
||||
type=str, help='output filename')
|
||||
parser.add_argument('--batch_size', default=256, type=int)
|
||||
parser.add_argument('--dataset', type=str, default='cifar10',
|
||||
help='dataset to use [cifar10, cifar100, ImageNet16-120]')
|
||||
parser.add_argument('--gpu', type=int, default=0, help='GPU index to work on')
|
||||
parser.add_argument('--num_data_workers', type=int, default=2, help='number of workers for dataloaders')
|
||||
parser.add_argument('--dataload', type=str, default='random', help='random or grasp supported')
|
||||
parser.add_argument('--dataload_info', type=int, default=1,
|
||||
help='number of batches to use for random dataload or number of samples per class for grasp dataload')
|
||||
parser.add_argument('--start', type=int, default=5, help='start index')
|
||||
parser.add_argument('--end', type=int, default=10, help='end index')
|
||||
parser.add_argument('--write_freq', type=int, default=100, help='frequency of write to file')
|
||||
args = parser.parse_args()
|
||||
args.device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
|
||||
return args
|
||||
|
||||
|
||||
def get_op_names(v):
|
||||
o = []
|
||||
for op in v:
|
||||
if op == -1:
|
||||
o.append('input')
|
||||
elif op == -2:
|
||||
o.append('output')
|
||||
elif op == 0:
|
||||
o.append('conv3x3-bn-relu')
|
||||
elif op == 1:
|
||||
o.append('conv1x1-bn-relu')
|
||||
elif op == 2:
|
||||
o.append('maxpool3x3')
|
||||
return o
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
# nasbench = api.NASBench(args.api_loc)
|
||||
models = json.load(open(args.json_loc))
|
||||
|
||||
print(f'Running models {args.start} to {args.end} out of {len(models.keys())}')
|
||||
|
||||
train_loader, val_loader = get_cifar_dataloaders(args.batch_size, args.batch_size, args.dataset,
|
||||
args.num_data_workers)
|
||||
|
||||
all_points = []
|
||||
pre = 'cf' if 'cifar' in args.dataset else 'im'
|
||||
|
||||
if args.outfname == 'test':
|
||||
fn = f'nb1_{pre}{get_num_classes(args)}.p'
|
||||
else:
|
||||
fn = f'{args.outfname}.p'
|
||||
op = os.path.join(args.outdir, fn)
|
||||
|
||||
print('outfile =', op)
|
||||
first = True
|
||||
|
||||
# loop over nasbench1 archs (k=hash, v=[adj_matrix, ops])
|
||||
idx = 0
|
||||
cached_res = []
|
||||
for k, v in models.items():
|
||||
|
||||
if idx < args.start:
|
||||
idx += 1
|
||||
continue
|
||||
if idx >= args.end:
|
||||
break
|
||||
print(f'idx = {idx}')
|
||||
idx += 1
|
||||
|
||||
res = {}
|
||||
res['hash'] = k
|
||||
|
||||
# model
|
||||
spec = nasbench1_spec._ToModelSpec(v[0], get_op_names(v[1]))
|
||||
net = nasbench1.Network(spec, stem_out=128, num_stacks=3, num_mods=3, num_classes=get_num_classes(args))
|
||||
net.to(args.device)
|
||||
|
||||
measures = predictive.find_measures(net,
|
||||
train_loader,
|
||||
(args.dataload, args.dataload_info, get_num_classes(args)),
|
||||
args.device)
|
||||
res['logmeasures'] = measures
|
||||
|
||||
print(res)
|
||||
cached_res.append(res)
|
||||
|
||||
# write to file
|
||||
if idx % args.write_freq == 0 or idx == args.end or idx == args.start + 10:
|
||||
print(f'writing {len(cached_res)} results to {op}')
|
||||
pf = open(op, 'ab')
|
||||
for cr in cached_res:
|
||||
pickle.dump(cr, pf)
|
||||
pf.close()
|
||||
cached_res = []
|
128
correlation/NAS-Bench-201.py
Normal file
128
correlation/NAS-Bench-201.py
Normal file
@ -0,0 +1,128 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import time
|
||||
|
||||
from foresight.dataset import *
|
||||
from foresight.models import nasbench2
|
||||
from foresight.pruners import predictive
|
||||
from foresight.weight_initializers import init_net
|
||||
from models import get_cell_based_tiny_net
|
||||
import pickle
|
||||
|
||||
|
||||
def get_num_classes(args):
|
||||
return 100 if args.dataset == 'cifar100' else 10 if args.dataset == 'cifar10' else 120
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='Zero-cost Metrics for NAS-Bench-201')
|
||||
parser.add_argument('--api_loc', default='../data/NAS-Bench-201-v1_0-e61699.pth',
|
||||
type=str, help='path to API')
|
||||
parser.add_argument('--outdir', default='./',
|
||||
type=str, help='output directory')
|
||||
parser.add_argument('--init_w_type', type=str, default='none',
|
||||
help='weight initialization (before pruning) type [none, xavier, kaiming, zero, one]')
|
||||
parser.add_argument('--init_b_type', type=str, default='none',
|
||||
help='bias initialization (before pruning) type [none, xavier, kaiming, zero, one]')
|
||||
parser.add_argument('--batch_size', default=64, type=int)
|
||||
parser.add_argument('--dataset', type=str, default='ImageNet16-120',
|
||||
help='dataset to use [cifar10, cifar100, ImageNet16-120]')
|
||||
parser.add_argument('--gpu', type=int, default=5, help='GPU index to work on')
|
||||
parser.add_argument('--data_size', type=int, default=32, help='data_size')
|
||||
parser.add_argument('--num_data_workers', type=int, default=2, help='number of workers for dataloaders')
|
||||
parser.add_argument('--dataload', type=str, default='appoint', help='random, grasp, appoint supported')
|
||||
parser.add_argument('--dataload_info', type=int, default=1,
|
||||
help='number of batches to use for random dataload or number of samples per class for grasp dataload')
|
||||
parser.add_argument('--seed', type=int, default=42, help='pytorch manual seed')
|
||||
parser.add_argument('--write_freq', type=int, default=1, help='frequency of write to file')
|
||||
parser.add_argument('--start', type=int, default=0, help='start index')
|
||||
parser.add_argument('--end', type=int, default=0, help='end index')
|
||||
parser.add_argument('--noacc', default=False, action='store_true',
|
||||
help='avoid loading NASBench2 api an instead load a pickle file with tuple (index, arch_str)')
|
||||
args = parser.parse_args()
|
||||
args.device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
print(args.device)
|
||||
|
||||
if args.noacc:
|
||||
api = pickle.load(open(args.api_loc,'rb'))
|
||||
else:
|
||||
from nas_201_api import NASBench201API as API
|
||||
api = API(args.api_loc)
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
train_loader, val_loader = get_cifar_dataloaders(args.batch_size, args.batch_size, args.dataset, args.num_data_workers, resize=args.data_size)
|
||||
x, y = next(iter(train_loader))
|
||||
# random data
|
||||
# x = torch.rand((args.batch_size, 3, args.data_size, args.data_size))
|
||||
# y = 0
|
||||
|
||||
cached_res = []
|
||||
pre = 'cf' if 'cifar' in args.dataset else 'im'
|
||||
pfn = f'nb2_{args.search_space}_{pre}{get_num_classes(args)}_seed{args.seed}_dl{args.dataload}_dlinfo{args.dataload_info}_initw{args.init_w_type}_initb{args.init_b_type}_{args.batch_size}.p'
|
||||
op = os.path.join(args.outdir, pfn)
|
||||
|
||||
end = len(api) if args.end == 0 else args.end
|
||||
|
||||
# loop over nasbench2 archs
|
||||
for i, arch_str in enumerate(api):
|
||||
|
||||
if i < args.start:
|
||||
continue
|
||||
if i >= end:
|
||||
break
|
||||
|
||||
res = {'i': i, 'arch': arch_str}
|
||||
# print(arch_str)
|
||||
if args.search_space == 'tss':
|
||||
net = nasbench2.get_model_from_arch_str(arch_str, get_num_classes(args))
|
||||
arch_str2 = nasbench2.get_arch_str_from_model(net)
|
||||
if arch_str != arch_str2:
|
||||
print(arch_str)
|
||||
print(arch_str2)
|
||||
raise ValueError
|
||||
elif args.search_space == 'sss':
|
||||
config = api.get_net_config(i, args.dataset)
|
||||
# print(config)
|
||||
net = get_cell_based_tiny_net(config)
|
||||
net.to(args.device)
|
||||
# print(net)
|
||||
|
||||
init_net(net, args.init_w_type, args.init_b_type)
|
||||
|
||||
# print(x.size(), y)
|
||||
measures = get_score(net, x, i, args.device)
|
||||
|
||||
res['meco'] = measures
|
||||
|
||||
if not args.noacc:
|
||||
info = api.get_more_info(i, 'cifar10-valid' if args.dataset == 'cifar10' else args.dataset, iepoch=None,
|
||||
hp='200', is_random=False)
|
||||
|
||||
trainacc = info['train-accuracy']
|
||||
valacc = info['valid-accuracy']
|
||||
testacc = info['test-accuracy']
|
||||
|
||||
res['trainacc'] = trainacc
|
||||
res['valacc'] = valacc
|
||||
res['testacc'] = testacc
|
||||
|
||||
print(res)
|
||||
cached_res.append(res)
|
||||
|
||||
# write to file
|
||||
if i % args.write_freq == 0 or i == len(api) - 1 or i == 10:
|
||||
print(f'writing {len(cached_res)} results to {op}')
|
||||
pf = open(op, 'ab')
|
||||
for cr in cached_res:
|
||||
pickle.dump(cr, pf)
|
||||
pf.close()
|
||||
cached_res = []
|
38
exp_scripts/zerocostpt_nb201_pipeline.sh
Normal file
38
exp_scripts/zerocostpt_nb201_pipeline.sh
Normal file
@ -0,0 +1,38 @@
|
||||
#!/bin/bash
|
||||
script_name=`basename "$0"`
|
||||
id=${script_name%.*}
|
||||
dataset=${dataset:-cifar10}
|
||||
seed=${seed:-0}
|
||||
gpu=${gpu:-"auto"}
|
||||
pool_size=${pool_size:-10}
|
||||
batch_size=${batch_size:-256}
|
||||
edge_decision=${edge_decision:-'random'}
|
||||
validate_rounds=${validate_rounds:-100}
|
||||
metric=${metric:-'jacob'}
|
||||
while [ $# -gt 0 ]; do
|
||||
if [[ $1 == *"--"* ]]; then
|
||||
param="${1/--/}"
|
||||
declare $param="$2"
|
||||
# echo $1 $2 // Optional to see the parameter:value result
|
||||
fi
|
||||
shift
|
||||
done
|
||||
|
||||
echo 'id:' $id 'seed:' $seed 'dataset:' $dataset
|
||||
echo 'gpu:' $gpu
|
||||
|
||||
cd ../nasbench201/
|
||||
python3 networks_proposal.py \
|
||||
--dataset $dataset \
|
||||
--save $id --gpu $gpu --seed $seed \
|
||||
--edge_decision $edge_decision --proj_crit $metric \
|
||||
--batch_size $batch_size\
|
||||
--pool_size $pool_size \
|
||||
|
||||
cd ../zerocostnas/
|
||||
python3 post_validate.py\
|
||||
--ckpt_path ../experiments/nas-bench-201/prop-$id-$seed-$pool_size-$metric\
|
||||
--save $id --seed $seed --gpu $gpu\
|
||||
--edge_decision $edge_decision --proj_crit $metric \
|
||||
--batch_size $batch_size\
|
||||
--validate_rounds $validate_rounds\
|
110
nasbench201/DownsampledImageNet.py
Normal file
110
nasbench201/DownsampledImageNet.py
Normal file
@ -0,0 +1,110 @@
|
||||
import os, sys, hashlib, torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch.utils.data as data
|
||||
import pickle
|
||||
|
||||
|
||||
def calculate_md5(fpath, chunk_size=1024 * 1024):
|
||||
md5 = hashlib.md5()
|
||||
with open(fpath, 'rb') as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b''):
|
||||
md5.update(chunk)
|
||||
return md5.hexdigest()
|
||||
|
||||
|
||||
def check_md5(fpath, md5, **kwargs):
|
||||
return md5 == calculate_md5(fpath, **kwargs)
|
||||
|
||||
|
||||
def check_integrity(fpath, md5=None):
|
||||
print(fpath)
|
||||
if not os.path.isfile(fpath): return False
|
||||
if md5 is None: return True
|
||||
else : return check_md5(fpath, md5)
|
||||
|
||||
|
||||
class ImageNet16(data.Dataset):
|
||||
# http://image-net.org/download-images
|
||||
# A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets
|
||||
# https://arxiv.org/pdf/1707.08819.pdf
|
||||
|
||||
train_list = [
|
||||
['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'],
|
||||
['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'],
|
||||
['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'],
|
||||
['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'],
|
||||
['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'],
|
||||
['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'],
|
||||
['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'],
|
||||
['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'],
|
||||
['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'],
|
||||
['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'],
|
||||
]
|
||||
valid_list = [
|
||||
['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'],
|
||||
]
|
||||
|
||||
def __init__(self, root, train, transform, use_num_of_class_only=None):
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.train = train # training set or valid set
|
||||
if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.')
|
||||
|
||||
if self.train: downloaded_list = self.train_list
|
||||
else : downloaded_list = self.valid_list
|
||||
self.data = []
|
||||
self.targets = []
|
||||
|
||||
# now load the picked numpy arrays
|
||||
for i, (file_name, checksum) in enumerate(downloaded_list):
|
||||
file_path = os.path.join(self.root, file_name)
|
||||
#print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path))
|
||||
with open(file_path, 'rb') as f:
|
||||
if sys.version_info[0] == 2:
|
||||
entry = pickle.load(f)
|
||||
else:
|
||||
entry = pickle.load(f, encoding='latin1')
|
||||
self.data.append(entry['data'])
|
||||
self.targets.extend(entry['labels'])
|
||||
self.data = np.vstack(self.data).reshape(-1, 3, 16, 16)
|
||||
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
|
||||
if use_num_of_class_only is not None:
|
||||
assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only)
|
||||
new_data, new_targets = [], []
|
||||
for I, L in zip(self.data, self.targets):
|
||||
if 1 <= L <= use_num_of_class_only:
|
||||
new_data.append( I )
|
||||
new_targets.append( L )
|
||||
self.data = new_data
|
||||
self.targets = new_targets
|
||||
# self.mean.append(entry['mean'])
|
||||
#self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16)
|
||||
#self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1)
|
||||
#print ('Mean : {:}'.format(self.mean))
|
||||
#temp = self.data - np.reshape(self.mean, (1, 1, 1, 3))
|
||||
#std_data = np.std(temp, axis=0)
|
||||
#std_data = np.mean(np.mean(std_data, axis=0), axis=0)
|
||||
#print ('Std : {:}'.format(std_data))
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, target = self.data[index], self.targets[index] - 1
|
||||
|
||||
img = Image.fromarray(img)
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def _check_integrity(self):
|
||||
root = self.root
|
||||
for fentry in (self.train_list + self.valid_list):
|
||||
filename, md5 = fentry[0], fentry[1]
|
||||
fpath = os.path.join(root, filename)
|
||||
if not check_integrity(fpath, md5):
|
||||
return False
|
||||
return True
|
52
nasbench201/architect_ig.py
Normal file
52
nasbench201/architect_ig.py
Normal file
@ -0,0 +1,52 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Architect(object):
|
||||
def __init__(self, model, args):
|
||||
self.network_momentum = args.momentum
|
||||
self.network_weight_decay = args.weight_decay
|
||||
self.model = model
|
||||
self.optimizer = torch.optim.Adam(self.model.arch_parameters(),
|
||||
lr=args.arch_learning_rate, betas=(0.5, 0.999),
|
||||
weight_decay=args.arch_weight_decay)
|
||||
|
||||
self._init_arch_parameters = []
|
||||
for alpha in self.model.arch_parameters():
|
||||
alpha_init = torch.zeros_like(alpha)
|
||||
alpha_init.data.copy_(alpha)
|
||||
self._init_arch_parameters.append(alpha_init)
|
||||
|
||||
#### mode
|
||||
if args.method in ['darts', 'darts-proj', 'sdarts', 'sdarts-proj']:
|
||||
self.method = 'fo' # first order update
|
||||
elif 'so' in args.method:
|
||||
print('ERROR: PLEASE USE architect.py for second order darts')
|
||||
elif args.method in ['blank', 'blank-proj']:
|
||||
self.method = 'blank'
|
||||
else:
|
||||
print('ERROR: WRONG ARCH UPDATE METHOD', args.method); exit(0)
|
||||
|
||||
def reset_arch_parameters(self):
|
||||
for alpha, alpha_init in zip(self.model.arch_parameters(), self._init_arch_parameters):
|
||||
alpha.data.copy_(alpha_init.data)
|
||||
|
||||
def step(self, input_train, target_train, input_valid, target_valid, *args, **kwargs):
|
||||
if self.method == 'fo':
|
||||
shared = self._step_fo(input_train, target_train, input_valid, target_valid)
|
||||
elif self.method == 'so':
|
||||
raise NotImplementedError
|
||||
elif self.method == 'blank': ## do not update alpha
|
||||
shared = None
|
||||
|
||||
return shared
|
||||
|
||||
#### first order
|
||||
def _step_fo(self, input_train, target_train, input_valid, target_valid):
|
||||
loss = self.model._loss(input_valid, target_valid)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
return None
|
||||
|
||||
#### darts 2nd order
|
||||
def _step_darts_so(self, input_train, target_train, input_valid, target_valid, eta, model_optimizer):
|
||||
raise NotImplementedError
|
120
nasbench201/cell_infers/cells.py
Normal file
120
nasbench201/cell_infers/cells.py
Normal file
@ -0,0 +1,120 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import OPS
|
||||
|
||||
|
||||
# Cell for NAS-Bench-201
|
||||
class InferCell(nn.Module):
|
||||
|
||||
def __init__(self, genotype, C_in, C_out, stride):
|
||||
super(InferCell, self).__init__()
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
self.node_IN = []
|
||||
self.node_IX = []
|
||||
self.genotype = deepcopy(genotype)
|
||||
for i in range(1, len(genotype)):
|
||||
node_info = genotype[i-1]
|
||||
cur_index = []
|
||||
cur_innod = []
|
||||
for (op_name, op_in) in node_info:
|
||||
if op_in == 0:
|
||||
layer = OPS[op_name](C_in , C_out, stride, True, True)
|
||||
else:
|
||||
layer = OPS[op_name](C_out, C_out, 1, True, True)
|
||||
cur_index.append( len(self.layers) )
|
||||
cur_innod.append( op_in )
|
||||
self.layers.append( layer )
|
||||
self.node_IX.append( cur_index )
|
||||
self.node_IN.append( cur_innod )
|
||||
self.nodes = len(genotype)
|
||||
self.in_dim = C_in
|
||||
self.out_dim = C_out
|
||||
|
||||
def extra_repr(self):
|
||||
string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
|
||||
laystr = []
|
||||
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
|
||||
y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)]
|
||||
x = '{:}<-({:})'.format(i+1, ','.join(y))
|
||||
laystr.append( x )
|
||||
return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr())
|
||||
|
||||
def forward(self, inputs):
|
||||
nodes = [inputs]
|
||||
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
|
||||
node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) )
|
||||
nodes.append( node_feature )
|
||||
return nodes[-1]
|
||||
|
||||
|
||||
|
||||
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||
class NASNetInferCell(nn.Module):
|
||||
|
||||
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats):
|
||||
super(NASNetInferCell, self).__init__()
|
||||
self.reduction = reduction
|
||||
if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats)
|
||||
else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats)
|
||||
self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats)
|
||||
|
||||
if not reduction:
|
||||
nodes, concats = genotype['normal'], genotype['normal_concat']
|
||||
else:
|
||||
nodes, concats = genotype['reduce'], genotype['reduce_concat']
|
||||
self._multiplier = len(concats)
|
||||
self._concats = concats
|
||||
self._steps = len(nodes)
|
||||
self._nodes = nodes
|
||||
self.edges = nn.ModuleDict()
|
||||
for i, node in enumerate(nodes):
|
||||
for in_node in node:
|
||||
name, j = in_node[0], in_node[1]
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
node_str = '{:}<-{:}'.format(i+2, j)
|
||||
self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats)
|
||||
|
||||
# [TODO] to support drop_prob in this function..
|
||||
def forward(self, s0, s1, unused_drop_prob):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i, node in enumerate(self._nodes):
|
||||
clist = []
|
||||
for in_node in node:
|
||||
name, j = in_node[0], in_node[1]
|
||||
node_str = '{:}<-{:}'.format(i+2, j)
|
||||
op = self.edges[ node_str ]
|
||||
clist.append( op(states[j]) )
|
||||
states.append( sum(clist) )
|
||||
return torch.cat([states[x] for x in self._concats], dim=1)
|
||||
|
||||
|
||||
class AuxiliaryHeadCIFAR(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 8x8"""
|
||||
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
82
nasbench201/cell_infers/tiny_network.py
Normal file
82
nasbench201/cell_infers/tiny_network.py
Normal file
@ -0,0 +1,82 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .cells import InferCell
|
||||
|
||||
|
||||
# The macro structure for architectures in NAS-Bench-201
|
||||
class TinyNetwork(nn.Module):
|
||||
|
||||
def __init__(self, C, N, genotype, num_classes):
|
||||
super(TinyNetwork, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C))
|
||||
|
||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev = C
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2, True)
|
||||
else:
|
||||
cell = InferCell(genotype, C_prev, C_curr, 1)
|
||||
self.cells.append( cell )
|
||||
C_prev = cell.out_dim
|
||||
self._Layer= len(self.cells)
|
||||
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
|
||||
self.requires_feature = True
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def forward(self, inputs):
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling( out )
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
if self.requires_feature:
|
||||
return logits, out
|
||||
else:
|
||||
return logits
|
||||
|
||||
def _loss(self, input, target, return_logits=False):
|
||||
logits, _ = self(input)
|
||||
loss = self._criterion(logits, target)
|
||||
|
||||
return (loss, logits) if return_logits else loss
|
||||
|
||||
def step(self, input, target, args, shared=None, return_grad=False):
|
||||
Lt, logit_t = self._loss(input, target, return_logits=True)
|
||||
Lt.backward()
|
||||
if args.grad_clip != 0:
|
||||
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
|
||||
self.optimizer.step()
|
||||
|
||||
if return_grad:
|
||||
grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()])
|
||||
return logit_t, Lt, grad
|
||||
else:
|
||||
return logit_t, Lt
|
289
nasbench201/cell_operations.py
Normal file
289
nasbench201/cell_operations.py
Normal file
@ -0,0 +1,289 @@
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
sys.path.insert(0, '../')
|
||||
from Layers import layers
|
||||
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
|
||||
|
||||
OPS = {
|
||||
'noise' : lambda C_in, C_out, stride, affine, track_running_stats: NoiseOp(stride, 0., 1.), # C_in, C_out not needed
|
||||
'none' : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride),
|
||||
'avg_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'avg', affine, track_running_stats),
|
||||
'max_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'max', affine, track_running_stats),
|
||||
'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats),
|
||||
'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
|
||||
'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats),
|
||||
'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
|
||||
'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats),
|
||||
'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats),
|
||||
'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats),
|
||||
'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
|
||||
}
|
||||
|
||||
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
|
||||
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3']
|
||||
#### wrc modified
|
||||
NAS_BENCH_201_SKIP = ['none', 'skip_connect', 'nor_conv_1x1_skip', 'nor_conv_3x3_skip', 'avg_pool_3x3']
|
||||
NAS_BENCH_201_SIMPLE = ['skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
NAS_BENCH_201_S2 = ['skip_connect', 'nor_conv_3x3']
|
||||
NAS_BENCH_201_S4 = ['noise', 'nor_conv_3x3']
|
||||
NAS_BENCH_201_S10 = ['none', 'nor_conv_3x3']
|
||||
|
||||
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
|
||||
'nas-bench-201': NAS_BENCH_201,
|
||||
'nas-bench-201-simple': NAS_BENCH_201_SIMPLE,
|
||||
'nas-bench-201-s2': NAS_BENCH_201_S2,
|
||||
'nas-bench-201-s4': NAS_BENCH_201_S4,
|
||||
'nas-bench-201-s10': NAS_BENCH_201_S10,
|
||||
'darts' : DARTS_SPACE}
|
||||
|
||||
class NoiseOp(nn.Module):
|
||||
def __init__(self, stride, mean, std):
|
||||
super(NoiseOp, self).__init__()
|
||||
self.stride = stride
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x * 0
|
||||
if self.stride != 1:
|
||||
x_new = x[:,:,::self.stride,::self.stride]
|
||||
else:
|
||||
x_new = x
|
||||
noise = x_new.data.new(x_new.size()).normal_(self.mean, self.std)
|
||||
return noise
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
layers.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
|
||||
)
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x * 0
|
||||
return self.op(x)
|
||||
|
||||
def score(self):
|
||||
score = 0
|
||||
for l in self.op:
|
||||
if hasattr(l, 'score'):
|
||||
score += torch.sum(l.score).cpu().numpy()
|
||||
return score
|
||||
|
||||
#### wrc modified
|
||||
class ReLUConvBNSkip(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||
super(ReLUConvBNSkip, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
layers.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
|
||||
)
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x * 0
|
||||
return self.op(x) + x
|
||||
|
||||
def score(self):
|
||||
score = 0
|
||||
for l in self.op:
|
||||
if hasattr(l, 'score'):
|
||||
score += torch.sum(l.score).cpu().numpy()
|
||||
return score
|
||||
####
|
||||
|
||||
class SepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
layers.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
|
||||
layers.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats),
|
||||
)
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x * 0
|
||||
return self.op(x)
|
||||
|
||||
def score(self):
|
||||
score = 0
|
||||
for l in self.op:
|
||||
if hasattr(l, 'score'):
|
||||
score += torch.sum(l.score).cpu().numpy()
|
||||
return score
|
||||
|
||||
|
||||
class DualSepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||
super(DualSepConv, self).__init__()
|
||||
self.op_a = SepConv(C_in, C_in , kernel_size, stride, padding, dilation, affine, track_running_stats)
|
||||
self.op_b = SepConv(C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats)
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x * 0
|
||||
x = self.op_a(x)
|
||||
x = self.op_b(x)
|
||||
return x
|
||||
|
||||
def score(self):
|
||||
score = self.op_a.score() + self.op_b.score()
|
||||
return score
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
|
||||
def __init__(self, inplanes, planes, stride, affine=True):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine)
|
||||
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine)
|
||||
if stride == 2:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
|
||||
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
|
||||
elif inplanes != planes:
|
||||
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.in_dim = inplanes
|
||||
self.out_dim = planes
|
||||
self.stride = stride
|
||||
self.num_conv = 2
|
||||
|
||||
def extra_repr(self):
|
||||
string = '{name}(inC={in_dim}, outC={out_dim}, stride={stride})'.format(name=self.__class__.__name__, **self.__dict__)
|
||||
return string
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
return residual + basicblock
|
||||
|
||||
def score(self):
|
||||
return self.conv_a.score() + self.conv_b.score()
|
||||
|
||||
|
||||
|
||||
|
||||
class POOLING(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True):
|
||||
super(POOLING, self).__init__()
|
||||
if C_in == C_out:
|
||||
self.preprocess = None
|
||||
else:
|
||||
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, affine, track_running_stats)
|
||||
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
|
||||
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
|
||||
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
|
||||
|
||||
def forward(self, inputs, block_input=False):
|
||||
if block_input:
|
||||
inputs = inputs * 0
|
||||
if self.preprocess: x = self.preprocess(inputs)
|
||||
else : x = inputs
|
||||
return self.op(x)
|
||||
|
||||
def score(self):
|
||||
if self.preprocess :
|
||||
return self.preprocess.score()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x * 0
|
||||
return x
|
||||
|
||||
|
||||
class Zero(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride):
|
||||
super(Zero, self).__init__()
|
||||
self.C_in = C_in
|
||||
self.C_out = C_out
|
||||
self.stride = stride
|
||||
self.is_zero = True
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
if self.C_in == self.C_out:
|
||||
if self.stride == 1: return x.mul(0.)
|
||||
else : return x[:,:,::self.stride,::self.stride].mul(0.)
|
||||
else: ## this is never called in nasbench201
|
||||
shape = list(x.shape)
|
||||
shape[1] = self.C_out
|
||||
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
|
||||
return zeros
|
||||
|
||||
def extra_repr(self):
|
||||
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, affine, track_running_stats):
|
||||
super(FactorizedReduce, self).__init__()
|
||||
self.stride = stride
|
||||
self.C_in = C_in
|
||||
self.C_out = C_out
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
if stride == 2:
|
||||
#assert C_out % 2 == 0, 'C_out : {:}'.format(C_out)
|
||||
C_outs = [C_out // 2, C_out - C_out // 2]
|
||||
self.convs = nn.ModuleList()
|
||||
for i in range(2):
|
||||
self.convs.append(layers.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) )
|
||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||
elif stride == 1:
|
||||
self.conv = layers.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False)
|
||||
else:
|
||||
raise ValueError('Invalid stride : {:}'.format(stride))
|
||||
self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x * 0
|
||||
if self.stride == 2:
|
||||
x = self.relu(x)
|
||||
y = self.pad(x)
|
||||
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1)
|
||||
else:
|
||||
out = self.conv(x)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
def extra_repr(self):
|
||||
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
|
||||
|
||||
def score(self):
|
||||
if self.stride == 1:
|
||||
return self.conv.score()
|
||||
else:
|
||||
return self.convs[0].score()+self.convs[1].score()
|
194
nasbench201/genotypes.py
Normal file
194
nasbench201/genotypes.py
Normal file
@ -0,0 +1,194 @@
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def get_combination(space, num):
|
||||
combs = []
|
||||
for i in range(num):
|
||||
if i == 0:
|
||||
for func in space:
|
||||
combs.append( [(func, i)] )
|
||||
else:
|
||||
new_combs = []
|
||||
for string in combs:
|
||||
for func in space:
|
||||
xstring = string + [(func, i)]
|
||||
new_combs.append( xstring )
|
||||
combs = new_combs
|
||||
return combs
|
||||
|
||||
|
||||
class Structure:
|
||||
|
||||
def __init__(self, genotype):
|
||||
assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype))
|
||||
self.node_num = len(genotype) + 1
|
||||
self.nodes = []
|
||||
self.node_N = []
|
||||
for idx, node_info in enumerate(genotype):
|
||||
assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info))
|
||||
assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info))
|
||||
for node_in in node_info:
|
||||
assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in))
|
||||
assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in)
|
||||
self.node_N.append( len(node_info) )
|
||||
self.nodes.append( tuple(deepcopy(node_info)) )
|
||||
|
||||
def tolist(self, remove_str):
|
||||
# convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
|
||||
# note that we re-order the input node in this function
|
||||
# return the-genotype-list and success [if unsuccess, it is not a connectivity]
|
||||
genotypes = []
|
||||
for node_info in self.nodes:
|
||||
node_info = list( node_info )
|
||||
node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
|
||||
node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
|
||||
if len(node_info) == 0: return None, False
|
||||
genotypes.append( node_info )
|
||||
return genotypes, True
|
||||
|
||||
def node(self, index):
|
||||
assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self))
|
||||
return self.nodes[index]
|
||||
|
||||
def tostr(self):
|
||||
strings = []
|
||||
for node_info in self.nodes:
|
||||
string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info])
|
||||
string = '|{:}|'.format(string)
|
||||
strings.append( string )
|
||||
return '+'.join(strings)
|
||||
|
||||
def check_valid(self):
|
||||
nodes = {0: True}
|
||||
for i, node_info in enumerate(self.nodes):
|
||||
sums = []
|
||||
for op, xin in node_info:
|
||||
if op == 'none' or nodes[xin] is False: x = False
|
||||
else: x = True
|
||||
sums.append( x )
|
||||
nodes[i+1] = sum(sums) > 0
|
||||
return nodes[len(self.nodes)]
|
||||
|
||||
def to_unique_str(self, consider_zero=False):
|
||||
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
|
||||
# two operations are special, i.e., none and skip_connect
|
||||
nodes = {0: '0'}
|
||||
for i_node, node_info in enumerate(self.nodes):
|
||||
cur_node = []
|
||||
for op, xin in node_info:
|
||||
if consider_zero is None:
|
||||
x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||
elif consider_zero:
|
||||
if op == 'none' or nodes[xin] == '#': x = '#' # zero
|
||||
elif op == 'skip_connect': x = nodes[xin]
|
||||
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||
else:
|
||||
if op == 'skip_connect': x = nodes[xin]
|
||||
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||
cur_node.append(x)
|
||||
nodes[i_node+1] = '+'.join( sorted(cur_node) )
|
||||
return nodes[ len(self.nodes) ]
|
||||
|
||||
def check_valid_op(self, op_names):
|
||||
for node_info in self.nodes:
|
||||
for inode_edge in node_info:
|
||||
#assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
|
||||
if inode_edge[0] not in op_names: return False
|
||||
return True
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.nodes) + 1
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.nodes[index]
|
||||
|
||||
@staticmethod
|
||||
def str2structure(xstr):
|
||||
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
|
||||
nodestrs = xstr.split('+')
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(nodestrs):
|
||||
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||
inputs = ( xi.split('~') for xi in inputs )
|
||||
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
|
||||
genotypes.append( input_infos )
|
||||
return Structure( genotypes )
|
||||
|
||||
@staticmethod
|
||||
def str2fullstructure(xstr, default_name='none'):
|
||||
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
|
||||
nodestrs = xstr.split('+')
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(nodestrs):
|
||||
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||
inputs = ( xi.split('~') for xi in inputs )
|
||||
input_infos = list( (op, int(IDX)) for (op, IDX) in inputs)
|
||||
all_in_nodes= list(x[1] for x in input_infos)
|
||||
for j in range(i):
|
||||
if j not in all_in_nodes: input_infos.append((default_name, j))
|
||||
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
|
||||
genotypes.append( tuple(node_info) )
|
||||
return Structure( genotypes )
|
||||
|
||||
@staticmethod
|
||||
def gen_all(search_space, num, return_ori):
|
||||
assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space))
|
||||
assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num)
|
||||
all_archs = get_combination(search_space, 1)
|
||||
for i, arch in enumerate(all_archs):
|
||||
all_archs[i] = [ tuple(arch) ]
|
||||
|
||||
for inode in range(2, num):
|
||||
cur_nodes = get_combination(search_space, inode)
|
||||
new_all_archs = []
|
||||
for previous_arch in all_archs:
|
||||
for cur_node in cur_nodes:
|
||||
new_all_archs.append( previous_arch + [tuple(cur_node)] )
|
||||
all_archs = new_all_archs
|
||||
if return_ori:
|
||||
return all_archs
|
||||
else:
|
||||
return [Structure(x) for x in all_archs]
|
||||
|
||||
|
||||
|
||||
ResNet_CODE = Structure(
|
||||
[(('nor_conv_3x3', 0), ), # node-1
|
||||
(('nor_conv_3x3', 1), ), # node-2
|
||||
(('skip_connect', 0), ('skip_connect', 2))] # node-3
|
||||
)
|
||||
|
||||
AllConv3x3_CODE = Structure(
|
||||
[(('nor_conv_3x3', 0), ), # node-1
|
||||
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2
|
||||
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3
|
||||
)
|
||||
|
||||
AllFull_CODE = Structure(
|
||||
[(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1
|
||||
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2
|
||||
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3
|
||||
)
|
||||
|
||||
AllConv1x1_CODE = Structure(
|
||||
[(('nor_conv_1x1', 0), ), # node-1
|
||||
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2
|
||||
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3
|
||||
)
|
||||
|
||||
AllIdentity_CODE = Structure(
|
||||
[(('skip_connect', 0), ), # node-1
|
||||
(('skip_connect', 0), ('skip_connect', 1)), # node-2
|
||||
(('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3
|
||||
)
|
||||
|
||||
architectures = {'resnet' : ResNet_CODE,
|
||||
'all_c3x3': AllConv3x3_CODE,
|
||||
'all_c1x1': AllConv1x1_CODE,
|
||||
'all_idnt': AllIdentity_CODE,
|
||||
'all_full': AllFull_CODE}
|
619
nasbench201/init_projection.py
Normal file
619
nasbench201/init_projection.py
Normal file
@ -0,0 +1,619 @@
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as f
|
||||
sys.path.insert(0, '../')
|
||||
import nasbench201.utils as ig_utils
|
||||
import logging
|
||||
import torch.utils
|
||||
import copy
|
||||
import scipy.stats as ss
|
||||
from collections import OrderedDict
|
||||
from foresight.pruners import *
|
||||
from op_score import Jocab_Score, get_ntk_n
|
||||
import gc
|
||||
from nasbench201.linear_region import Linear_Region_Collector
|
||||
|
||||
torch.set_printoptions(precision=4, sci_mode=False)
|
||||
np.set_printoptions(precision=4, suppress=True)
|
||||
|
||||
# global-edge-iter: similar toglobal-op-iterbut iteratively selects edge e from E based on the average score of all operations on each edge
|
||||
def global_op_greedy_pt_project(proj_queue, model, args):
|
||||
def project(model, args):
|
||||
## macros
|
||||
num_edge, num_op = model.num_edge, model.num_op
|
||||
|
||||
##get remain eid numbers
|
||||
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
|
||||
compare = lambda x, y : x < y
|
||||
|
||||
crit_extrema = None
|
||||
best_eid = None
|
||||
input, target = next(iter(proj_queue))
|
||||
for eid in remain_eids:
|
||||
for opid in range(num_op):
|
||||
# projection
|
||||
weights = model.get_projected_weights()
|
||||
proj_mask = torch.ones_like(weights[eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[eid] = weights[eid] * proj_mask
|
||||
|
||||
## proj evaluation
|
||||
if args.proj_crit == 'jacob':
|
||||
valid_stats = Jocab_Score(model, input, target, weights=weights)
|
||||
crit = valid_stats
|
||||
|
||||
if crit_extrema is None or compare(crit, crit_extrema):
|
||||
crit_extrema = crit
|
||||
best_opid = opid
|
||||
best_eid = eid
|
||||
|
||||
logging.info('best opid %d', best_opid)
|
||||
return best_eid, best_opid
|
||||
|
||||
tune_epochs = model.arch_parameters()[0].shape[0]
|
||||
|
||||
for epoch in range(tune_epochs):
|
||||
logging.info('epoch %d', epoch)
|
||||
logging.info('project')
|
||||
selected_eid, best_opid = project(model, args)
|
||||
model.project_op(selected_eid, best_opid)
|
||||
|
||||
return
|
||||
|
||||
# global-edge-iter: similar toglobal-op-oncebut uses the average score of operations on edges to obtain the edge discretization order
|
||||
def global_edge_greedy_pt_project(proj_queue, model, args):
|
||||
def select_eid(model, args):
|
||||
## macros
|
||||
num_edge, num_op = model.num_edge, model.num_op
|
||||
|
||||
##get remain eid numbers
|
||||
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
|
||||
compare = lambda x, y : x < y
|
||||
|
||||
crit_extrema = None
|
||||
best_eid = None
|
||||
input, target = next(iter(proj_queue))
|
||||
for eid in remain_eids:
|
||||
eid_score = []
|
||||
for opid in range(num_op):
|
||||
# projection
|
||||
weights = model.get_projected_weights()
|
||||
proj_mask = torch.ones_like(weights[eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[eid] = weights[eid] * proj_mask
|
||||
|
||||
## proj evaluation
|
||||
if args.proj_crit == 'jacob':
|
||||
valid_stats = Jocab_Score(model, input, target, weights=weights)
|
||||
crit = valid_stats
|
||||
eid_score.append(crit)
|
||||
eid_score = np.mean(eid_score)
|
||||
|
||||
if crit_extrema is None or compare(eid_score, crit_extrema):
|
||||
crit_extrema = eid_score
|
||||
best_eid = eid
|
||||
return best_eid
|
||||
|
||||
def project(model, args, selected_eid):
|
||||
## macros
|
||||
num_edge, num_op = model.num_edge, model.num_op
|
||||
|
||||
## select the best operation
|
||||
if args.proj_crit == 'jacob':
|
||||
crit_idx = 3
|
||||
compare = lambda x, y: x < y
|
||||
else:
|
||||
crit_idx = 4
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
best_opid = 0
|
||||
crit_list = []
|
||||
op_ids = []
|
||||
input, target = next(iter(proj_queue))
|
||||
for opid in range(num_op):
|
||||
## projection
|
||||
weights = model.get_projected_weights()
|
||||
proj_mask = torch.ones_like(weights[selected_eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[selected_eid] = weights[selected_eid] * proj_mask
|
||||
|
||||
## proj evaluation
|
||||
if args.proj_crit == 'jacob':
|
||||
valid_stats = Jocab_Score(model, input, target, weights=weights)
|
||||
crit = valid_stats
|
||||
|
||||
crit_list.append(crit)
|
||||
op_ids.append(opid)
|
||||
|
||||
best_opid = op_ids[np.nanargmin(crit_list)]
|
||||
|
||||
logging.info('best opid %d', best_opid)
|
||||
logging.info(crit_list)
|
||||
return selected_eid, best_opid
|
||||
|
||||
num_edges = model.arch_parameters()[0].shape[0]
|
||||
|
||||
for epoch in range(num_edges):
|
||||
logging.info('epoch %d', epoch)
|
||||
|
||||
logging.info('project')
|
||||
selected_eid = select_eid(model, args)
|
||||
selected_eid, best_opid = project(model, args, selected_eid)
|
||||
model.project_op(selected_eid, best_opid)
|
||||
return
|
||||
|
||||
# global-op-once: only evaluates S(A−(e,o)) for all operations once to obtain a ranking order of the operations, and discretizes the edgesEaccording to this order
|
||||
def global_op_once_pt_project(proj_queue, model, args):
|
||||
def order(model, args):
|
||||
## macros
|
||||
num_edge, num_op = model.num_edge, model.num_op
|
||||
|
||||
##get remain eid numbers
|
||||
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
|
||||
compare = lambda x, y : x < y
|
||||
|
||||
edge_score = OrderedDict()
|
||||
input, target = next(iter(proj_queue))
|
||||
for eid in remain_eids:
|
||||
crit_list = []
|
||||
for opid in range(num_op):
|
||||
# projection
|
||||
weights = model.get_projected_weights()
|
||||
proj_mask = torch.ones_like(weights[eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[eid] = weights[eid] * proj_mask
|
||||
|
||||
## proj evaluation
|
||||
if args.proj_crit == 'jacob':
|
||||
valid_stats = Jocab_Score(model, input, target, weights=weights)
|
||||
crit = valid_stats
|
||||
|
||||
crit_list.append(crit)
|
||||
edge_score[eid] = np.nanargmin(crit_list)
|
||||
return edge_score
|
||||
|
||||
def project(model, args, selected_eid):
|
||||
## macros
|
||||
num_edge, num_op = model.num_edge, model.num_op
|
||||
## select the best operation
|
||||
if args.proj_crit == 'jacob':
|
||||
crit_idx = 3
|
||||
compare = lambda x, y: x < y
|
||||
else:
|
||||
crit_idx = 4
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
best_opid = 0
|
||||
crit_list = []
|
||||
op_ids = []
|
||||
input, target = next(iter(proj_queue))
|
||||
for opid in range(num_op):
|
||||
## projection
|
||||
weights = model.get_projected_weights()
|
||||
proj_mask = torch.ones_like(weights[selected_eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[selected_eid] = weights[selected_eid] * proj_mask
|
||||
|
||||
## proj evaluation
|
||||
if args.proj_crit == 'jacob':
|
||||
crit = Jocab_Score(model, input, target, weights=weights)
|
||||
crit_list.append(crit)
|
||||
op_ids.append(opid)
|
||||
|
||||
best_opid = op_ids[np.nanargmin(crit_list)]
|
||||
|
||||
logging.info('best opid %d', best_opid)
|
||||
logging.info(crit_list)
|
||||
return selected_eid, best_opid
|
||||
|
||||
num_edges = model.arch_parameters()[0].shape[0]
|
||||
|
||||
eid_order = order(model, args)
|
||||
for epoch in range(num_edges):
|
||||
logging.info('epoch %d', epoch)
|
||||
logging.info('project')
|
||||
selected_eid, _ = eid_order.popitem()
|
||||
selected_eid, best_opid = project(model, args, selected_eid)
|
||||
model.project_op(selected_eid, best_opid)
|
||||
|
||||
return
|
||||
|
||||
# global-edge-once: similar toglobal-op-oncebut uses the average score of operations on dges to obtain the edge discretization order
|
||||
def global_edge_once_pt_project(proj_queue, model, args):
|
||||
def order(model, args):
|
||||
## macros
|
||||
num_edge, num_op = model.num_edge, model.num_op
|
||||
|
||||
##get remain eid numbers
|
||||
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
|
||||
compare = lambda x, y : x < y
|
||||
|
||||
edge_score = OrderedDict()
|
||||
crit_extrema = None
|
||||
best_eid = None
|
||||
input, target = next(iter(proj_queue))
|
||||
for eid in remain_eids:
|
||||
crit_list = []
|
||||
for opid in range(num_op):
|
||||
# projection
|
||||
weights = model.get_projected_weights()
|
||||
proj_mask = torch.ones_like(weights[eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[eid] = weights[eid] * proj_mask
|
||||
|
||||
## proj evaluation
|
||||
if args.proj_crit == 'jacob':
|
||||
crit = Jocab_Score(model, input, target, weights=weights)
|
||||
|
||||
crit_list.append(crit)
|
||||
edge_score[eid] = np.mean(crit_list)
|
||||
return edge_score
|
||||
|
||||
def project(model, args, selected_eid):
|
||||
## macros
|
||||
num_edge, num_op = model.num_edge, model.num_op
|
||||
## select the best operation
|
||||
if args.proj_crit == 'jacob':
|
||||
crit_idx = 3
|
||||
compare = lambda x, y: x < y
|
||||
else:
|
||||
crit_idx = 4
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
best_opid = 0
|
||||
crit_extrema = None
|
||||
crit_list = []
|
||||
op_ids = []
|
||||
input, target = next(iter(proj_queue))
|
||||
for opid in range(num_op):
|
||||
## projection
|
||||
weights = model.get_projected_weights()
|
||||
proj_mask = torch.ones_like(weights[selected_eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[selected_eid] = weights[selected_eid] * proj_mask
|
||||
|
||||
## proj evaluation
|
||||
if args.proj_crit == 'jacob':
|
||||
crit = Jocab_Score(model, input, target, weights=weights)
|
||||
crit_list.append(crit)
|
||||
op_ids.append(opid)
|
||||
|
||||
best_opid = op_ids[np.nanargmin(crit_list)]
|
||||
|
||||
logging.info('best opid %d', best_opid)
|
||||
logging.info(crit_list)
|
||||
return selected_eid, best_opid
|
||||
|
||||
num_edges = model.arch_parameters()[0].shape[0]
|
||||
|
||||
eid_order = order(model, args)
|
||||
for epoch in range(num_edges):
|
||||
logging.info('epoch %d', epoch)
|
||||
logging.info('project')
|
||||
selected_eid, _ = eid_order.popitem()
|
||||
selected_eid, best_opid = project(model, args, selected_eid)
|
||||
model.project_op(selected_eid, best_opid)
|
||||
|
||||
return
|
||||
|
||||
# fixed [reverse, order]: discretizes the edges in a fixed order, where in our experiments we discretize from the222input towards the output of the cell struct
|
||||
# random: discretizes the edges in a random order (DARTS-PT)
|
||||
# NOTE: Only this methods allows use other zero-cost proxy metrics
|
||||
def pt_project(proj_queue, model, args):
|
||||
def project(model, args):
|
||||
## macros,一共6条边,每条边有5个操作
|
||||
num_edge, num_op = model.num_edge, model.num_op
|
||||
|
||||
## select an edge
|
||||
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
|
||||
# print('candidate_flags:', model.candidate_flags)
|
||||
# print(model.candidate_flags)
|
||||
# 选边的方法
|
||||
if args.edge_decision == "random":
|
||||
# 选出来了一个数组,取其中的一个元素
|
||||
selected_eid = np.random.choice(remain_eids, size=1)[0]
|
||||
elif args.edge_decision == "reverse":
|
||||
selected_eid = remain_eids[-1]
|
||||
else:
|
||||
selected_eid = remain_eids[0]
|
||||
|
||||
## select the best operation
|
||||
if args.proj_crit == 'jacob':
|
||||
crit_idx = 3
|
||||
compare = lambda x, y: x < y
|
||||
else:
|
||||
crit_idx = 4
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
if args.dataset == 'cifar100':
|
||||
n_classes = 100
|
||||
elif args.dataset == 'imagenet16-120':
|
||||
n_classes = 120
|
||||
else:
|
||||
n_classes = 10
|
||||
|
||||
best_opid = 0
|
||||
crit_extrema = None
|
||||
crit_list = []
|
||||
op_ids = []
|
||||
input, target = next(iter(proj_queue))
|
||||
for opid in range(num_op):
|
||||
## projection
|
||||
weights = model.get_projected_weights()
|
||||
proj_mask = torch.ones_like(weights[selected_eid])
|
||||
# print(selected_eid, weights[selected_eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[selected_eid] = weights[selected_eid] * proj_mask
|
||||
|
||||
|
||||
## proj evaluation
|
||||
if args.proj_crit == 'jacob':
|
||||
crit = Jocab_Score(model, input, target, weights=weights)
|
||||
else:
|
||||
cache_weight = model.proj_weights[selected_eid]
|
||||
cache_flag = model.candidate_flags[selected_eid]
|
||||
|
||||
|
||||
for idx in range(num_op):
|
||||
if idx == opid:
|
||||
model.proj_weights[selected_eid][opid] = 0
|
||||
else:
|
||||
model.proj_weights[selected_eid][idx] = 1.0/num_op
|
||||
|
||||
|
||||
model.candidate_flags[selected_eid] = False
|
||||
# print(model.get_projected_weights())
|
||||
|
||||
if args.proj_crit == 'comb':
|
||||
synflow = predictive.find_measures(model,
|
||||
proj_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=['synflow'])
|
||||
var = predictive.find_measures(model,
|
||||
proj_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=['var'])
|
||||
# print(synflow, var)
|
||||
comb = np.log(synflow['synflow'] + 1) / (var['var'] + 0.1)
|
||||
measures = {'comb': comb}
|
||||
else:
|
||||
measures = predictive.find_measures(model,
|
||||
proj_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=[args.proj_crit])
|
||||
|
||||
# print(measures)
|
||||
for idx in range(num_op):
|
||||
model.proj_weights[selected_eid][idx] = 0
|
||||
model.candidate_flags[selected_eid] = cache_flag
|
||||
crit = measures[args.proj_crit]
|
||||
|
||||
crit_list.append(crit)
|
||||
op_ids.append(opid)
|
||||
|
||||
|
||||
best_opid = op_ids[np.nanargmin(crit_list)]
|
||||
# best_opid = op_ids[np.nanargmax(crit_list)]
|
||||
|
||||
logging.info('best opid %d', best_opid)
|
||||
logging.info('current edge id %d', selected_eid)
|
||||
logging.info(crit_list)
|
||||
return selected_eid, best_opid
|
||||
|
||||
num_edges = model.arch_parameters()[0].shape[0]
|
||||
|
||||
for epoch in range(num_edges):
|
||||
logging.info('epoch %d', epoch)
|
||||
logging.info('project')
|
||||
selected_eid, best_opid = project(model, args)
|
||||
model.project_op(selected_eid, best_opid)
|
||||
|
||||
return
|
||||
|
||||
def tenas_project(proj_queue, model, model_thin, args):
|
||||
def project(model, args):
|
||||
## macros
|
||||
num_edge, num_op = model.num_edge, model.num_op
|
||||
|
||||
##get remain eid numbers
|
||||
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
|
||||
compare = lambda x, y : x < y
|
||||
|
||||
ntks = []
|
||||
lrs = []
|
||||
edge_op_id = []
|
||||
best_eid = None
|
||||
|
||||
if args.proj_crit == 'tenas':
|
||||
lrc_model = Linear_Region_Collector(input_size=(1000, 1, 3, 3), sample_batch=3, dataset=args.dataset, data_path=args.data, seed=args.seed)
|
||||
for eid in remain_eids:
|
||||
for opid in range(num_op):
|
||||
# projection
|
||||
weights = model.get_projected_weights()
|
||||
proj_mask = torch.ones_like(weights[eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[eid] = weights[eid] * proj_mask
|
||||
|
||||
## proj evaluation
|
||||
if args.proj_crit == 'tenas':
|
||||
lrc_model.reinit(ori_models=[model_thin], seed=args.seed, weights=weights)
|
||||
lr = lrc_model.forward_batch_sample()
|
||||
lrc_model.clear()
|
||||
ntk = get_ntk_n(proj_queue, [model], recalbn=0, train_mode=True, num_batch=1, weights=weights)
|
||||
ntks.append(ntk)
|
||||
lrs.append(lr)
|
||||
edge_op_id.append('{}:{}'.format(eid, opid))
|
||||
print('ntls', ntks)
|
||||
print('lrs', lrs)
|
||||
ntks_ranks = ss.rankdata(ntks)
|
||||
lrs_ranks = ss.rankdata(lrs)
|
||||
ntks_ranks = len(ntks_ranks) - ntks_ranks.astype(int)
|
||||
op_ranks = []
|
||||
for i in range(len(edge_op_id)):
|
||||
op_ranks.append(ntks_ranks[i]+lrs_ranks[i])
|
||||
|
||||
best_op_index = edge_op_id[np.nanargmin(op_ranks[0:num_op])]
|
||||
best_eid, best_opid = [int(x) for x in best_op_index.split(':')]
|
||||
|
||||
logging.info(op_ranks)
|
||||
logging.info('best eid %d', best_eid)
|
||||
logging.info('best opid %d', best_opid)
|
||||
return best_eid, best_opid
|
||||
num_edges = model.arch_parameters()[0].shape[0]
|
||||
|
||||
for epoch in range(num_edges):
|
||||
logging.info('epoch %d', epoch)
|
||||
logging.info('project')
|
||||
selected_eid, best_opid = project(model, args)
|
||||
model.project_op(selected_eid, best_opid)
|
||||
|
||||
return
|
||||
|
||||
#new methods
|
||||
#Randomly propose candidate of networks and transfer it to supernet, then perform global op selection in this subspace
|
||||
def shrink_pt_project(proj_queue, model, args):
|
||||
def project(model, args):
|
||||
## macros
|
||||
num_edge, num_op = model.num_edge, model.num_op
|
||||
|
||||
## select an edge
|
||||
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
|
||||
selected_eid = np.random.choice(remain_eids, size=1)[0]
|
||||
|
||||
|
||||
## select the best operation
|
||||
if args.proj_crit == 'jacob':
|
||||
crit_idx = 3
|
||||
compare = lambda x, y: x < y
|
||||
else:
|
||||
crit_idx = 4
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
if args.dataset == 'cifar100':
|
||||
n_classes = 100
|
||||
elif args.dataset == 'imagenet16-120':
|
||||
n_classes = 120
|
||||
else:
|
||||
n_classes = 10
|
||||
|
||||
best_opid = 0
|
||||
crit_extrema = None
|
||||
crit_list = []
|
||||
op_ids = []
|
||||
input, target = next(iter(proj_queue))
|
||||
for opid in range(num_op):
|
||||
## projection
|
||||
weights = model.get_projected_weights()
|
||||
proj_mask = torch.ones_like(weights[selected_eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[selected_eid] = weights[selected_eid] * proj_mask
|
||||
|
||||
## proj evaluation
|
||||
if args.proj_crit == 'jacob':
|
||||
crit = Jocab_Score(model, input, target, weights=weights)
|
||||
else:
|
||||
cache_weight = model.proj_weights[selected_eid]
|
||||
cache_flag = model.candidate_flags[selected_eid]
|
||||
|
||||
for idx in range(num_op):
|
||||
if idx == opid:
|
||||
model.proj_weights[selected_eid][opid] = 0
|
||||
else:
|
||||
model.proj_weights[selected_eid][idx] = 1.0/num_op
|
||||
model.candidate_flags[selected_eid] = False
|
||||
|
||||
measures = predictive.find_measures(model,
|
||||
train_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=[args.proj_crit])
|
||||
for idx in range(num_op):
|
||||
model.proj_weights[selected_eid][idx] = 0
|
||||
model.candidate_flags[selected_eid] = cache_flag
|
||||
crit = measures[args.proj_crit]
|
||||
|
||||
crit_list.append(crit)
|
||||
op_ids.append(opid)
|
||||
|
||||
best_opid = op_ids[np.nanargmin(crit_list)]
|
||||
|
||||
logging.info('best opid %d', best_opid)
|
||||
logging.info('current edge id %d', selected_eid)
|
||||
logging.info(crit_list)
|
||||
return selected_eid, best_opid
|
||||
|
||||
def global_project(model, args):
|
||||
## macros
|
||||
num_edge, num_op = model.num_edge, model.num_op
|
||||
|
||||
##get remain eid numbers
|
||||
remain_eids = torch.nonzero(model.subspace_candidate_flags).cpu().numpy().T[0]
|
||||
compare = lambda x, y : x < y
|
||||
|
||||
crit_extrema = None
|
||||
best_eid = None
|
||||
best_opid = None
|
||||
input, target = next(iter(proj_queue))
|
||||
for eid in remain_eids:
|
||||
remain_oids = torch.nonzero(model.proj_weights[eid]).cpu().numpy().T[0]
|
||||
for opid in remain_oids:
|
||||
# projection
|
||||
weights = model.get_projected_weights()
|
||||
proj_mask = torch.ones_like(weights[eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[eid] = weights[eid] * proj_mask
|
||||
## proj evaluation
|
||||
if args.proj_crit == 'jacob':
|
||||
valid_stats = Jocab_Score(model, input, target, weights=weights)
|
||||
crit = valid_stats
|
||||
|
||||
if crit_extrema is None or compare(crit, crit_extrema):
|
||||
crit_extrema = crit
|
||||
best_opid = opid
|
||||
best_eid = eid
|
||||
|
||||
|
||||
logging.info('best eid %d', best_eid)
|
||||
logging.info('best opid %d', best_opid)
|
||||
model.subspace_candidate_flags[best_eid] = False
|
||||
proj_mask = torch.zeros_like(model.proj_weights[best_eid])
|
||||
model.proj_weights[best_eid] = model.proj_weights[best_eid] * proj_mask
|
||||
model.proj_weights[best_eid][best_opid] = 1
|
||||
return best_eid, best_opid
|
||||
|
||||
num_edges = model.arch_parameters()[0].shape[0]
|
||||
|
||||
#subspace
|
||||
logging.info('Start subspace proposal')
|
||||
subspace = copy.deepcopy(model.proj_weights)
|
||||
for i in range(20):
|
||||
model.reset_arch_parameters()
|
||||
for epoch in range(num_edges):
|
||||
logging.info('epoch %d', epoch)
|
||||
logging.info('project')
|
||||
selected_eid, best_opid = project(model, args)
|
||||
model.project_op(selected_eid, best_opid)
|
||||
subspace += model.proj_weights
|
||||
|
||||
model.reset_arch_parameters()
|
||||
subspace = torch.gt(subspace, 0).int().float()
|
||||
subspace = f.normalize(subspace, p=1, dim=1)
|
||||
model.proj_weights += subspace
|
||||
for i in range(num_edges):
|
||||
model.candidate_flags[i] = False
|
||||
logging.info('Start final search in subspace')
|
||||
logging.info(subspace)
|
||||
|
||||
model.subspace_candidate_flags = torch.tensor(len(model._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda()
|
||||
for epoch in range(num_edges):
|
||||
logging.info('epoch %d', epoch)
|
||||
logging.info('project')
|
||||
selected_eid, best_opid = global_project(model, args)
|
||||
model.printing(logging)
|
||||
#model.project_op(selected_eid, best_opid)
|
||||
return
|
270
nasbench201/linear_region.py
Normal file
270
nasbench201/linear_region.py
Normal file
@ -0,0 +1,270 @@
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as dset
|
||||
from pdb import set_trace as bp
|
||||
from operator import mul
|
||||
from functools import reduce
|
||||
import copy
|
||||
Dataset2Class = {'cifar10': 10,
|
||||
'cifar100': 100,
|
||||
'imagenet-1k-s': 1000,
|
||||
'imagenet-1k': 1000,
|
||||
}
|
||||
|
||||
|
||||
class CUTOUT(object):
|
||||
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def __call__(self, img):
|
||||
h, w = img.size(1), img.size(2)
|
||||
mask = np.ones((h, w), np.float32)
|
||||
y = np.random.randint(h)
|
||||
x = np.random.randint(w)
|
||||
|
||||
y1 = np.clip(y - self.length // 2, 0, h)
|
||||
y2 = np.clip(y + self.length // 2, 0, h)
|
||||
x1 = np.clip(x - self.length // 2, 0, w)
|
||||
x2 = np.clip(x + self.length // 2, 0, w)
|
||||
|
||||
mask[y1: y2, x1: x2] = 0.
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.expand_as(img)
|
||||
img *= mask
|
||||
return img
|
||||
|
||||
|
||||
imagenet_pca = {
|
||||
'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
|
||||
'eigvec': np.asarray([
|
||||
[-0.5675, 0.7192, 0.4009],
|
||||
[-0.5808, -0.0045, -0.8140],
|
||||
[-0.5836, -0.6948, 0.4203],
|
||||
])
|
||||
}
|
||||
|
||||
|
||||
class RandChannel(object):
|
||||
# randomly pick channels from input
|
||||
def __init__(self, num_channel):
|
||||
self.num_channel = num_channel
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(num_channel={num_channel})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def __call__(self, img):
|
||||
channel = img.size(0)
|
||||
channel_choice = sorted(np.random.choice(list(range(channel)), size=self.num_channel, replace=False))
|
||||
return torch.index_select(img, 0, torch.Tensor(channel_choice).long())
|
||||
|
||||
|
||||
def get_datasets(name, root, input_size, cutout=-1):
|
||||
assert len(input_size) in [3, 4]
|
||||
if len(input_size) == 4:
|
||||
input_size = input_size[1:]
|
||||
assert input_size[1] == input_size[2]
|
||||
|
||||
if name == 'cifar10':
|
||||
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
|
||||
std = [x / 255 for x in [63.0, 62.1, 66.7]]
|
||||
elif name == 'cifar100':
|
||||
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
||||
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
||||
elif name.startswith('imagenet-1k'):
|
||||
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
||||
elif name.startswith('ImageNet16'):
|
||||
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
|
||||
std = [x / 255 for x in [63.22, 61.26 , 65.09]]
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(name))
|
||||
#ßprint(input_size)
|
||||
# Data Argumentation
|
||||
if name == 'cifar10' or name == 'cifar100':
|
||||
lists = [transforms.RandomCrop(input_size[1], padding=4), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])]
|
||||
if cutout > 0 : lists += [CUTOUT(cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
elif name.startswith('ImageNet16'):
|
||||
lists = [transforms.RandomCrop(input_size[1], padding=4), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])]
|
||||
if cutout > 0 : lists += [CUTOUT(cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
elif name.startswith('imagenet-1k'):
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
if name == 'imagenet-1k':
|
||||
xlists = []
|
||||
xlists.append(transforms.Resize((32, 32), interpolation=2))
|
||||
xlists.append(transforms.RandomCrop(input_size[1], padding=0))
|
||||
elif name == 'imagenet-1k-s':
|
||||
xlists = [transforms.RandomResizedCrop(32, scale=(0.2, 1.0))]
|
||||
xlists = []
|
||||
else: raise ValueError('invalid name : {:}'.format(name))
|
||||
xlists.append(transforms.ToTensor())
|
||||
xlists.append(normalize)
|
||||
xlists.append(RandChannel(input_size[0]))
|
||||
train_transform = transforms.Compose(xlists)
|
||||
test_transform = transforms.Compose([transforms.Resize(40), transforms.CenterCrop(32), transforms.ToTensor(), normalize])
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
if name == 'cifar10':
|
||||
train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR10 (root, train=False, transform=test_transform , download=True)
|
||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||
elif name == 'cifar100':
|
||||
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
|
||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||
elif name.startswith('imagenet-1k'):
|
||||
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
|
||||
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
|
||||
else: raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
class_num = Dataset2Class[name]
|
||||
return train_data, test_data, class_num
|
||||
|
||||
|
||||
class LinearRegionCount(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self, n_samples):
|
||||
self.ActPattern = {}
|
||||
self.n_LR = -1
|
||||
self.n_samples = n_samples
|
||||
self.ptr = 0
|
||||
self.activations = None
|
||||
|
||||
@torch.no_grad()
|
||||
def update2D(self, activations):
|
||||
n_batch = activations.size()[0]
|
||||
n_neuron = activations.size()[1]
|
||||
self.n_neuron = n_neuron
|
||||
if self.activations is None:
|
||||
self.activations = torch.zeros(self.n_samples, n_neuron).cuda()
|
||||
self.activations[self.ptr:self.ptr+n_batch] = torch.sign(activations) # after ReLU
|
||||
self.ptr += n_batch
|
||||
|
||||
@torch.no_grad()
|
||||
def calc_LR(self):
|
||||
res = torch.matmul(self.activations.half(), (1-self.activations).T.half()) # each element in res: A * (1 - B)
|
||||
res += res.T # make symmetric, each element in res: A * (1 - B) + (1 - A) * B, a non-zero element indicate a pair of two different linear regions
|
||||
res = 1 - torch.sign(res) # a non-zero element now indicate two linear regions are identical
|
||||
res = res.sum(1) # for each sample's linear region: how many identical regions from other samples
|
||||
res = 1. / res.float() # contribution of each redudant (repeated) linear region
|
||||
self.n_LR = res.sum().item() # sum of unique regions (by aggregating contribution of all regions)
|
||||
del self.activations, res
|
||||
self.activations = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@torch.no_grad()
|
||||
def update1D(self, activationList):
|
||||
code_string = ''
|
||||
for key, value in activationList.items():
|
||||
n_neuron = value.size()[0]
|
||||
for i in range(n_neuron):
|
||||
if value[i] > 0:
|
||||
code_string += '1'
|
||||
else:
|
||||
code_string += '0'
|
||||
if code_string not in self.ActPattern:
|
||||
self.ActPattern[code_string] = 1
|
||||
|
||||
def getLinearReginCount(self):
|
||||
if self.n_LR == -1:
|
||||
self.calc_LR()
|
||||
return self.n_LR
|
||||
|
||||
|
||||
class Linear_Region_Collector:
|
||||
def __init__(self, models=[], input_size=(64, 3, 32, 32), sample_batch=100, dataset='cifar100', data_path=None, seed=0):
|
||||
self.models = []
|
||||
self.input_size = input_size # BCHW
|
||||
self.sample_batch = sample_batch
|
||||
self.input_numel = reduce(mul, self.input_size, 1)
|
||||
self.interFeature = []
|
||||
self.dataset = dataset
|
||||
self.data_path = data_path
|
||||
self.seed = seed
|
||||
self.reinit(models, input_size, sample_batch, seed)
|
||||
|
||||
def reinit(self, ori_models=None, input_size=None, sample_batch=None, seed=None, weights=None):
|
||||
models = []
|
||||
for network in ori_models:
|
||||
network = network.cuda()
|
||||
net = copy.deepcopy(network)
|
||||
net.proj_weights = weights
|
||||
num_edge, num_op = net.num_edge, net.num_op
|
||||
for i in range(num_edge):
|
||||
net.candidate_flags[i] = False
|
||||
net.eval()
|
||||
models.append(net)
|
||||
|
||||
if models is not None:
|
||||
assert isinstance(models, list)
|
||||
del self.models
|
||||
self.models = models
|
||||
for model in self.models:
|
||||
self.register_hook(model)
|
||||
device = torch.cuda.current_device()
|
||||
model = model.cuda(device=device)
|
||||
self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(models))]
|
||||
if input_size is not None or sample_batch is not None:
|
||||
if input_size is not None:
|
||||
self.input_size = input_size # BCHW
|
||||
self.input_numel = reduce(mul, self.input_size, 1)
|
||||
if sample_batch is not None:
|
||||
self.sample_batch = sample_batch
|
||||
if self.data_path is not None:
|
||||
self.train_data, _, class_num = get_datasets(self.dataset, self.data_path, self.input_size, -1)
|
||||
self.train_loader = torch.utils.data.DataLoader(self.train_data, batch_size=self.input_size[0], num_workers=16, pin_memory=True, drop_last=True, shuffle=True)
|
||||
self.loader = iter(self.train_loader)
|
||||
if seed is not None and seed != self.seed:
|
||||
self.seed = seed
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
del self.interFeature
|
||||
self.interFeature = []
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def clear(self):
|
||||
self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(self.models))]
|
||||
del self.interFeature
|
||||
self.interFeature = []
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def register_hook(self, model):
|
||||
for m in model.modules():
|
||||
if isinstance(m, nn.ReLU):
|
||||
m.register_forward_hook(hook=self.hook_in_forward)
|
||||
|
||||
def hook_in_forward(self, module, input, output):
|
||||
if isinstance(input, tuple) and len(input[0].size()) == 4:
|
||||
self.interFeature.append(output.detach()) # for ReLU
|
||||
|
||||
def forward_batch_sample(self):
|
||||
for _ in range(self.sample_batch):
|
||||
try:
|
||||
inputs, targets = self.loader.next()
|
||||
except Exception:
|
||||
del self.loader
|
||||
self.loader = iter(self.train_loader)
|
||||
inputs, targets = self.loader.next()
|
||||
for model, LRCount in zip(self.models, self.LRCounts):
|
||||
self.forward(model, LRCount, inputs)
|
||||
output = [LRCount.getLinearReginCount() for LRCount in self.LRCounts]
|
||||
return output
|
||||
|
||||
def forward(self, model, LRCount, input_data):
|
||||
self.interFeature = []
|
||||
with torch.no_grad():
|
||||
model.forward(input_data.cuda())
|
||||
if len(self.interFeature) == 0: return
|
||||
feature_data = torch.cat([f.view(input_data.size(0), -1) for f in self.interFeature], 1)
|
||||
LRCount.update2D(feature_data)
|
245
nasbench201/networks_proposal.py
Normal file
245
nasbench201/networks_proposal.py
Normal file
@ -0,0 +1,245 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, '../')
|
||||
import time
|
||||
import glob
|
||||
import json
|
||||
import shutil
|
||||
import logging
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils
|
||||
import torchvision.datasets as dset
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.autograd import Variable
|
||||
|
||||
import nasbench201.utils as ig_utils
|
||||
from nasbench201.search_model_darts_proj import TinyNetworkDartsProj
|
||||
from nasbench201.cell_operations import SearchSpaceNames
|
||||
from nasbench201.init_projection import pt_project, global_op_greedy_pt_project, global_op_once_pt_project, global_edge_greedy_pt_project, global_edge_once_pt_project, shrink_pt_project, tenas_project
|
||||
from nas_201_api import NASBench201API as API
|
||||
|
||||
torch.set_printoptions(precision=4, sci_mode=False)
|
||||
np.set_printoptions(precision=4, suppress=True)
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser("sota")
|
||||
# data related
|
||||
parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
|
||||
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100', 'imagenet16-120'], help='choose dataset')
|
||||
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
|
||||
parser.add_argument('--batch_size', type=int, default=64, help='batch size for alpha')
|
||||
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
|
||||
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
|
||||
parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability')
|
||||
parser.add_argument('--seed', type=int, default=2, help='random seed')
|
||||
|
||||
#search space setting
|
||||
parser.add_argument('--search_space', type=str, default='nas-bench-201')
|
||||
|
||||
parser.add_argument('--pool_size', type=int, default=100, help='number of model to proposed')
|
||||
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
|
||||
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
|
||||
|
||||
#system configurations
|
||||
parser.add_argument('--gpu', type=str, default='auto', help='gpu device id')
|
||||
parser.add_argument('--save', type=str, default='exp', help='experiment name')
|
||||
|
||||
#default opt setting for model
|
||||
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
|
||||
parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
|
||||
parser.add_argument('--nesterov', action='store_true', default=True, help='using nestrov momentum for SGD')
|
||||
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
|
||||
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
|
||||
|
||||
#### common
|
||||
parser.add_argument('--fast', action='store_true', default=True, help='skip loading api which is slow')
|
||||
|
||||
#### projection
|
||||
parser.add_argument('--edge_decision', type=str, default='random', choices=['random','reverse', 'order', 'global_op_greedy', 'global_op_once', 'global_edge_greedy', 'global_edge_once', 'shrink_pt_project'], help='which edge to be projected next')
|
||||
parser.add_argument('--proj_crit', type=str, default="comb", choices=['loss', 'acc', 'jacob', 'snip', 'fisher', 'synflow', 'grad_norm', 'grasp', 'jacob_cov','tenas', 'var', 'cor', 'norm', 'comb', 'meco'], help='criteria for projection')
|
||||
args = parser.parse_args()
|
||||
|
||||
#### args augment
|
||||
expid = args.save
|
||||
args.save = '../experiments/nas-bench-201/prop-{}-{}-{}'.format(args.save, args.seed, args.pool_size)
|
||||
if not args.dataset == 'cifar10':
|
||||
args.save += '-' + args.dataset
|
||||
if not args.edge_decision == 'random':
|
||||
args.save += '-' + args.edge_decision
|
||||
if not args.proj_crit == 'jacob':
|
||||
args.save += '-' + args.proj_crit
|
||||
|
||||
#### logging
|
||||
scripts_to_save = glob.glob('*.py') \
|
||||
# + ['../exp_scripts/{}.sh'.format(expid)]
|
||||
if os.path.exists(args.save):
|
||||
if input("WARNING: {} exists, override?[y/n]".format(args.save)) == 'y':
|
||||
print('proceed to override saving directory')
|
||||
shutil.rmtree(args.save)
|
||||
else:
|
||||
exit(0)
|
||||
ig_utils.create_exp_dir(args.save, scripts_to_save=scripts_to_save)
|
||||
|
||||
log_format = '%(asctime)s %(message)s'
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
|
||||
format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
||||
|
||||
log_file = 'log.txt'
|
||||
log_path = os.path.join(args.save, log_file)
|
||||
logging.info('======> log filename: %s', log_file)
|
||||
|
||||
if os.path.exists(log_path):
|
||||
if input("WARNING: {} exists, override?[y/n]".format(log_file)) == 'y':
|
||||
print('proceed to override log file directory')
|
||||
else:
|
||||
exit(0)
|
||||
|
||||
fh = logging.FileHandler(log_path, mode='w')
|
||||
fh.setFormatter(logging.Formatter(log_format))
|
||||
logging.getLogger().addHandler(fh)
|
||||
writer = SummaryWriter(args.save + '/runs')
|
||||
|
||||
#### macros
|
||||
if args.dataset == 'cifar100':
|
||||
n_classes = 100
|
||||
elif args.dataset == 'imagenet16-120':
|
||||
n_classes = 120
|
||||
else:
|
||||
n_classes = 10
|
||||
|
||||
def main():
|
||||
torch.set_num_threads(3)
|
||||
if not torch.cuda.is_available():
|
||||
logging.info('no gpu device available')
|
||||
sys.exit(1)
|
||||
|
||||
np.random.seed(args.seed)
|
||||
gpu = ig_utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu)
|
||||
torch.cuda.set_device(gpu)
|
||||
cudnn.benchmark = True
|
||||
torch.manual_seed(args.seed)
|
||||
cudnn.enabled = True
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
logging.info("args = %s", args)
|
||||
logging.info('gpu device = %d' % gpu)
|
||||
|
||||
#### model
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
search_space = SearchSpaceNames[args.search_space]
|
||||
|
||||
# 初始化超网络
|
||||
model = TinyNetworkDartsProj(C=args.init_channels, N=5, max_nodes=4, num_classes=n_classes, criterion=criterion, search_space=search_space, args=args)
|
||||
model_thin = TinyNetworkDartsProj(C=args.init_channels, N=5, max_nodes=4, num_classes=n_classes, criterion=criterion, search_space=search_space, args=args, stem_channels=1)
|
||||
model = model.cuda()
|
||||
model_thin = model_thin.cuda()
|
||||
logging.info("param size = %fMB", ig_utils.count_parameters_in_MB(model))
|
||||
|
||||
#### data
|
||||
if args.dataset == 'cifar10':
|
||||
train_transform, valid_transform = ig_utils._data_transforms_cifar10(args)
|
||||
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
|
||||
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
|
||||
elif args.dataset == 'cifar100':
|
||||
train_transform, valid_transform = ig_utils._data_transforms_cifar100(args)
|
||||
train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
|
||||
valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)
|
||||
elif args.dataset == 'imagenet16-120':
|
||||
import torchvision.transforms as transforms
|
||||
from nasbench201.DownsampledImageNet import ImageNet16
|
||||
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
|
||||
std = [x / 255 for x in [63.22, 61.26, 65.09]]
|
||||
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
train_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=True, transform=train_transform, use_num_of_class_only=120)
|
||||
valid_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=False, transform=train_transform, use_num_of_class_only=120)
|
||||
assert len(train_data) == 151700
|
||||
|
||||
num_train = len(train_data)
|
||||
indices = list(range(num_train))
|
||||
split = int(np.floor(args.train_portion * num_train))
|
||||
|
||||
train_queue = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=args.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
|
||||
pin_memory=True)
|
||||
|
||||
|
||||
#format network pool diction
|
||||
networks_pool={}
|
||||
networks_pool['search_space'] = args.search_space
|
||||
networks_pool['dataset'] = args.dataset
|
||||
networks_pool['networks'] = []
|
||||
networks_pool['pool_size'] = args.pool_size
|
||||
#### architecture selection / projection
|
||||
for i in range(args.pool_size):
|
||||
network_info={}
|
||||
logging.info('{} MODEL HAS SEARCHED'.format(i+1))
|
||||
if args.edge_decision == 'global_op_greedy':
|
||||
global_op_greedy_pt_project(train_queue, model, args)
|
||||
elif args.edge_decision == 'global_op_once':
|
||||
global_op_once_pt_project(train_queue, model, args)
|
||||
elif args.edge_decision == 'global_edge_greedy':
|
||||
global_edge_greedy_pt_project(train_queue, model, args)
|
||||
elif args.edge_decision == 'global_edge_once':
|
||||
global_edge_once_pt_project(train_queue, model, args)
|
||||
elif args.edge_decision == 'shrink_pt_project':
|
||||
shrink_pt_project(train_queue, model, args)
|
||||
api = API('../data/NAS-Bench-201-v1_0-e61699.pth')
|
||||
cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
|
||||
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = query(api, model.genotype().tostr(), logging)
|
||||
else:
|
||||
if args.proj_crit == 'jacob':
|
||||
pt_project(train_queue, model, args)
|
||||
else:
|
||||
pt_project(train_queue, model, args)
|
||||
# tenas_project(train_queue, model, model_thin, args)
|
||||
|
||||
network_info['id'] = str(i)
|
||||
network_info['genotype'] = model.genotype().tostr()
|
||||
networks_pool['networks'].append(network_info)
|
||||
model.reset_arch_parameters()
|
||||
|
||||
with open(os.path.join(args.save,'networks_pool.json'), 'w') as save_file:
|
||||
json.dump(networks_pool, save_file)
|
||||
|
||||
|
||||
#### util functions
|
||||
def distill(result):
|
||||
result = result.split('\n')
|
||||
cifar10 = result[5].replace(' ', '').split(':')
|
||||
cifar100 = result[7].replace(' ', '').split(':')
|
||||
imagenet16 = result[9].replace(' ', '').split(':')
|
||||
|
||||
cifar10_train = float(cifar10[1].strip(',test')[-7:-2].strip('='))
|
||||
cifar10_test = float(cifar10[2][-7:-2].strip('='))
|
||||
cifar100_train = float(cifar100[1].strip(',valid')[-7:-2].strip('='))
|
||||
cifar100_valid = float(cifar100[2].strip(',test')[-7:-2].strip('='))
|
||||
cifar100_test = float(cifar100[3][-7:-2].strip('='))
|
||||
imagenet16_train = float(imagenet16[1].strip(',valid')[-7:-2].strip('='))
|
||||
imagenet16_valid = float(imagenet16[2].strip(',test')[-7:-2].strip('='))
|
||||
imagenet16_test = float(imagenet16[3][-7:-2].strip('='))
|
||||
|
||||
return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
|
||||
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test
|
||||
|
||||
|
||||
def query(api, genotype, logging):
|
||||
result = api.query_by_arch(genotype, hp='200')
|
||||
logging.info('{:}'.format(result))
|
||||
cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
|
||||
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = distill(result)
|
||||
logging.info('cifar10 train %f test %f', cifar10_train, cifar10_test)
|
||||
logging.info('cifar100 train %f valid %f test %f', cifar100_train, cifar100_valid, cifar100_test)
|
||||
logging.info('imagenet16 train %f valid %f test %f', imagenet16_train, imagenet16_valid, imagenet16_test)
|
||||
return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
|
||||
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
113
nasbench201/op_score.py
Normal file
113
nasbench201/op_score.py
Normal file
@ -0,0 +1,113 @@
|
||||
import gc
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn.functional as f
|
||||
from operator import mul
|
||||
from functools import reduce
|
||||
import copy
|
||||
sys.path.insert(0, '../')
|
||||
|
||||
def Jocab_Score(ori_model, input, target, weights=None):
|
||||
model = copy.deepcopy(ori_model)
|
||||
model.eval()
|
||||
model.proj_weights = weights
|
||||
num_edge, num_op = model.num_edge, model.num_op
|
||||
for i in range(num_edge):
|
||||
model.candidate_flags[i] = False
|
||||
batch_size = input.shape[0]
|
||||
model.K = torch.zeros(batch_size, batch_size).cuda()
|
||||
|
||||
def counting_forward_hook(module, inp, out):
|
||||
try:
|
||||
if isinstance(inp, tuple):
|
||||
inp = inp[0]
|
||||
inp = inp.view(inp.size(0), -1)
|
||||
x = (inp > 0).float()
|
||||
K = x @ x.t()
|
||||
K2 = (1.-x) @ (1.-x.t())
|
||||
model.K = model.K + K + K2
|
||||
except:
|
||||
pass
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if 'ReLU' in str(type(module)):
|
||||
module.register_forward_hook(counting_forward_hook)
|
||||
|
||||
input = input.cuda()
|
||||
model(input)
|
||||
score = hooklogdet(model.K.cpu().numpy())
|
||||
del model
|
||||
del input
|
||||
return score
|
||||
|
||||
def hooklogdet(K, labels=None):
|
||||
s, ld = np.linalg.slogdet(K)
|
||||
return ld
|
||||
|
||||
# NTK
|
||||
#------------------------------------------------------------
|
||||
#https://github.com/VITA-Group/TENAS/blob/main/lib/procedures/ntk.py
|
||||
#
|
||||
def recal_bn(network, xloader, recalbn, device):
|
||||
for m in network.modules():
|
||||
if isinstance(m, torch.nn.BatchNorm2d):
|
||||
m.running_mean.data.fill_(0)
|
||||
m.running_var.data.fill_(0)
|
||||
m.num_batches_tracked.data.zero_()
|
||||
m.momentum = None
|
||||
network.train()
|
||||
with torch.no_grad():
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if i >= recalbn: break
|
||||
inputs = inputs.cuda(device=device, non_blocking=True)
|
||||
_, _ = network(inputs)
|
||||
return network
|
||||
|
||||
def get_ntk_n(xloader, networks, recalbn=0, train_mode=False, num_batch=-1, weights=None):
|
||||
device = torch.cuda.current_device()
|
||||
ntks = []
|
||||
copied_networks = []
|
||||
for network in networks:
|
||||
network = network.cuda(device=device)
|
||||
net = copy.deepcopy(network)
|
||||
net.proj_weights = weights
|
||||
num_edge, num_op = net.num_edge, net.num_op
|
||||
for i in range(num_edge):
|
||||
net.candidate_flags[i] = False
|
||||
if train_mode:
|
||||
net.train()
|
||||
else:
|
||||
net.eval()
|
||||
copied_networks.append(net)
|
||||
######
|
||||
grads = [[] for _ in range(len(copied_networks))]
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if num_batch > 0 and i >= num_batch: break
|
||||
inputs = inputs.cuda(device=device, non_blocking=True)
|
||||
for net_idx, network in enumerate(copied_networks):
|
||||
network.zero_grad()
|
||||
inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
|
||||
logit = network(inputs_)
|
||||
if isinstance(logit, tuple):
|
||||
logit = logit[1] # 201 networks: return features and logits
|
||||
for _idx in range(len(inputs_)):
|
||||
logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
|
||||
grad = []
|
||||
for name, W in network.named_parameters():
|
||||
if 'weight' in name and W.grad is not None:
|
||||
grad.append(W.grad.view(-1).detach())
|
||||
grads[net_idx].append(torch.cat(grad, -1))
|
||||
network.zero_grad()
|
||||
torch.cuda.empty_cache()
|
||||
######
|
||||
grads = [torch.stack(_grads, 0) for _grads in grads]
|
||||
ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads]
|
||||
conds = []
|
||||
for ntk in ntks:
|
||||
eigenvalues, _ = torch.symeig(ntk) # ascending
|
||||
conds.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0]).item(), copy=True, nan=100000.0))
|
||||
|
||||
del copied_networks
|
||||
return conds
|
182
nasbench201/search_cells.py
Normal file
182
nasbench201/search_cells.py
Normal file
@ -0,0 +1,182 @@
|
||||
import math, random, torch
|
||||
import warnings
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
import sys
|
||||
sys.path.insert(0, '../')
|
||||
from nasbench201.cell_operations import OPS
|
||||
|
||||
|
||||
# This module is used for NAS-Bench-201, represents a small search space with a complete DAG
|
||||
class NAS201SearchCell(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True):
|
||||
super(NAS201SearchCell, self).__init__()
|
||||
|
||||
self.op_names = deepcopy(op_names)
|
||||
self.edges = nn.ModuleDict()
|
||||
self.max_nodes = max_nodes
|
||||
self.in_dim = C_in
|
||||
self.out_dim = C_out
|
||||
for i in range(1, max_nodes):
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
if j == 0:
|
||||
xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names]
|
||||
else:
|
||||
xlists = [OPS[op_name](C_in , C_out, 1, affine, track_running_stats) for op_name in op_names]
|
||||
self.edges[ node_str ] = nn.ModuleList( xlists )
|
||||
self.edge_keys = sorted(list(self.edges.keys()))
|
||||
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
|
||||
self.num_edges = len(self.edges)
|
||||
|
||||
def extra_repr(self):
|
||||
string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
|
||||
return string
|
||||
|
||||
def forward(self, inputs, weightss):
|
||||
return self._forward(inputs, weightss)
|
||||
|
||||
def _forward(self, inputs, weightss):
|
||||
with torch.autograd.set_detect_anomaly(True):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
weights = weightss[ self.edge2index[node_str] ]
|
||||
inter_nodes.append(sum(layer(nodes[j], block_input=True)*w if w==0 else layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights)) )
|
||||
nodes.append( sum(inter_nodes) )
|
||||
return nodes[-1]
|
||||
|
||||
# GDAS
|
||||
def forward_gdas(self, inputs, hardwts, index):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
weights = hardwts[ self.edge2index[node_str] ]
|
||||
argmaxs = index[ self.edge2index[node_str] ].item()
|
||||
weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) )
|
||||
inter_nodes.append( weigsum )
|
||||
nodes.append( sum(inter_nodes) )
|
||||
return nodes[-1]
|
||||
|
||||
# joint
|
||||
def forward_joint(self, inputs, weightss):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
weights = weightss[ self.edge2index[node_str] ]
|
||||
#aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel()
|
||||
aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) )
|
||||
inter_nodes.append( aggregation )
|
||||
nodes.append( sum(inter_nodes) )
|
||||
return nodes[-1]
|
||||
|
||||
# uniform random sampling per iteration, SETN
|
||||
def forward_urs(self, inputs):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
while True: # to avoid select zero for all ops
|
||||
sops, has_non_zero = [], False
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
candidates = self.edges[node_str]
|
||||
select_op = random.choice(candidates)
|
||||
sops.append( select_op )
|
||||
if not hasattr(select_op, 'is_zero') or select_op.is_zero is False: has_non_zero=True
|
||||
if has_non_zero: break
|
||||
inter_nodes = []
|
||||
for j, select_op in enumerate(sops):
|
||||
inter_nodes.append( select_op(nodes[j]) )
|
||||
nodes.append( sum(inter_nodes) )
|
||||
return nodes[-1]
|
||||
|
||||
# select the argmax
|
||||
def forward_select(self, inputs, weightss):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
weights = weightss[ self.edge2index[node_str] ]
|
||||
inter_nodes.append( self.edges[node_str][ weights.argmax().item() ]( nodes[j] ) )
|
||||
#inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
|
||||
nodes.append( sum(inter_nodes) )
|
||||
return nodes[-1]
|
||||
|
||||
# forward with a specific structure
|
||||
def forward_dynamic(self, inputs, structure):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
cur_op_node = structure.nodes[i-1]
|
||||
inter_nodes = []
|
||||
for op_name, j in cur_op_node:
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
op_index = self.op_names.index( op_name )
|
||||
inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) )
|
||||
nodes.append( sum(inter_nodes) )
|
||||
return nodes[-1]
|
||||
|
||||
|
||||
def channel_shuffle(x, groups):
|
||||
batchsize, num_channels, height, width = x.data.size()
|
||||
channels_per_group = num_channels // groups
|
||||
# reshape
|
||||
x = x.view(batchsize, groups,
|
||||
channels_per_group, height, width)
|
||||
x = torch.transpose(x, 1, 2).contiguous()
|
||||
# flatten
|
||||
x = x.view(batchsize, -1, height, width)
|
||||
return x
|
||||
|
||||
|
||||
class NAS201SearchCell_PartialChannel(NAS201SearchCell):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True, k=4):
|
||||
super(NAS201SearchCell, self).__init__()
|
||||
|
||||
self.k = k
|
||||
self.op_names = deepcopy(op_names)
|
||||
self.edges = nn.ModuleDict()
|
||||
self.max_nodes = max_nodes
|
||||
self.in_dim = C_in
|
||||
self.out_dim = C_out
|
||||
for i in range(1, max_nodes):
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
if j == 0:
|
||||
xlists = [OPS[op_name](C_in//self.k , C_out//self.k, stride, affine, track_running_stats) for op_name in op_names]
|
||||
else:
|
||||
xlists = [OPS[op_name](C_in//self.k , C_out//self.k, 1, affine, track_running_stats) for op_name in op_names]
|
||||
self.edges[ node_str ] = nn.ModuleList( xlists )
|
||||
self.edge_keys = sorted(list(self.edges.keys()))
|
||||
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
|
||||
self.num_edges = len(self.edges)
|
||||
|
||||
def MixedOp(self, x, ops, weights):
|
||||
dim_2 = x.shape[1]
|
||||
xtemp = x[ : , : dim_2//self.k, :, :]
|
||||
xtemp2 = x[ : , dim_2//self.k:, :, :]
|
||||
temp1 = sum(w * op(xtemp) for w, op in zip(weights, ops))
|
||||
ans = torch.cat([temp1,xtemp2],dim=1)
|
||||
ans = channel_shuffle(ans,self.k)
|
||||
return ans
|
||||
|
||||
def forward(self, inputs, weightss):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
weights = weightss[ self.edge2index[node_str] ]
|
||||
# inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
|
||||
inter_nodes.append(self.MixedOp(x=nodes[j], ops=self.edges[node_str], weights=weights))
|
||||
nodes.append( sum(inter_nodes) )
|
||||
return nodes[-1]
|
||||
|
202
nasbench201/search_model.py
Normal file
202
nasbench201/search_model.py
Normal file
@ -0,0 +1,202 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
from .cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
from torch.autograd import Variable
|
||||
|
||||
class TinyNetwork(nn.Module):
|
||||
|
||||
def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args, affine=False, track_running_stats=True, stem_channels=3):
|
||||
super(TinyNetwork, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self._num_classes = num_classes
|
||||
self._criterion = criterion
|
||||
self._args = args
|
||||
self._affine = affine
|
||||
self._track_running_stats = track_running_stats
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(stem_channels, C, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C))
|
||||
|
||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
|
||||
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
||||
self.cells.append( cell )
|
||||
C_prev = cell.out_dim
|
||||
self.num_edge = num_edge
|
||||
self.num_op = len(search_space)
|
||||
self.op_names = deepcopy( search_space )
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
# self._arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
||||
self._arch_parameters = Variable(1e-3*torch.randn(num_edge, len(search_space)).cuda(), requires_grad=True)
|
||||
|
||||
## optimizer
|
||||
## 记录的是m在内存中的地址,以示区分
|
||||
arch_params = set(id(m) for m in self.arch_parameters())
|
||||
self._model_params = [m for m in self.parameters() if id(m) not in arch_params]
|
||||
|
||||
# 模型参数优化器
|
||||
self.optimizer = torch.optim.SGD(
|
||||
self._model_params,
|
||||
args.learning_rate,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
nesterov= args.nesterov)
|
||||
|
||||
|
||||
def entropy_y_x(self, p_logit):
|
||||
p = F.softmax(p_logit, dim=1)
|
||||
return - torch.sum(p * F.log_softmax(p_logit, dim=1)) / p_logit.shape[0]
|
||||
|
||||
def _loss(self, input, target, return_logits=False):
|
||||
logits = self(input)
|
||||
loss = self._criterion(logits, target)
|
||||
|
||||
return (loss, logits) if return_logits else loss
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
|
||||
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
|
||||
xlist+= list( self.classifier.parameters() )
|
||||
return xlist
|
||||
|
||||
def arch_parameters(self):
|
||||
return [self._arch_parameters]
|
||||
|
||||
def get_theta(self):
|
||||
return nn.functional.softmax(self._arch_parameters, dim=-1).cpu()
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def genotype(self):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
with torch.no_grad():
|
||||
weights = self._arch_parameters[ self.edge2index[node_str] ]
|
||||
op_name = self.op_names[ weights.argmax().item() ]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append( tuple(xlist) )
|
||||
return Structure( genotypes )
|
||||
|
||||
def forward(self, inputs, weights=None):
|
||||
sim_nn = []
|
||||
|
||||
weights = nn.functional.softmax(self._arch_parameters, dim=-1) if weights is None else weights
|
||||
|
||||
if self.slim:
|
||||
weights[1].data.fill_(0)
|
||||
weights[3].data.fill_(0)
|
||||
weights[4].data.fill_(0)
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell(feature, weights)
|
||||
else:
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling( out )
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return logits
|
||||
|
||||
def _save_arch_parameters(self):
|
||||
self._saved_arch_parameters = [p.clone() for p in self._arch_parameters]
|
||||
|
||||
def project_arch(self):
|
||||
self._save_arch_parameters()
|
||||
for p in self.arch_parameters():
|
||||
m, n = p.size()
|
||||
maxIndexs = p.data.cpu().numpy().argmax(axis=1)
|
||||
p.data = self.proximal_step(p, maxIndexs)
|
||||
|
||||
def proximal_step(self, var, maxIndexs=None):
|
||||
values = var.data.cpu().numpy()
|
||||
m, n = values.shape
|
||||
alphas = []
|
||||
for i in range(m):
|
||||
for j in range(n):
|
||||
if j == maxIndexs[i]:
|
||||
alphas.append(values[i][j].copy())
|
||||
values[i][j] = 1
|
||||
else:
|
||||
values[i][j] = 0
|
||||
return torch.Tensor(values).cuda()
|
||||
|
||||
def restore_arch_parameters(self):
|
||||
for i, p in enumerate(self._arch_parameters):
|
||||
p.data.copy_(self._saved_arch_parameters[i])
|
||||
del self._saved_arch_parameters
|
||||
|
||||
def new(self):
|
||||
model_new = TinyNetwork(self._C, self._layerN, self.max_nodes, self._num_classes, self._criterion,
|
||||
self.op_names, self._args, self._affine, self._track_running_stats).cuda()
|
||||
for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
|
||||
x.data.copy_(y.data)
|
||||
|
||||
return model_new
|
||||
|
||||
def step(self, input, target, args, shared=None, return_grad=False):
|
||||
Lt, logit_t = self._loss(input, target, return_logits=True)
|
||||
Lt.backward()
|
||||
if args.grad_clip != 0:
|
||||
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
|
||||
self.optimizer.step()
|
||||
|
||||
if return_grad:
|
||||
grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()])
|
||||
return logit_t, Lt, grad
|
||||
else:
|
||||
return logit_t, Lt
|
||||
|
||||
def printing(self, logging):
|
||||
logging.info(self.get_theta())
|
||||
|
||||
def set_arch_parameters(self, new_alphas):
|
||||
for alpha, new_alpha in zip(self.arch_parameters(), new_alphas):
|
||||
alpha.data.copy_(new_alpha.data)
|
||||
|
||||
def save_arch_parameters(self):
|
||||
self._saved_arch_parameters = self._arch_parameters.clone()
|
||||
|
||||
def restore_arch_parameters(self):
|
||||
self.set_arch_parameters(self._saved_arch_parameters)
|
||||
|
||||
def reset_optimizer(self, lr, momentum, weight_decay):
|
||||
del self.optimizer
|
||||
self.optimizer = torch.optim.SGD(
|
||||
self.get_weights(),
|
||||
lr,
|
||||
momentum=momentum,
|
||||
weight_decay=weight_decay,
|
||||
nesterov= args.nesterov)
|
33
nasbench201/search_model_darts.py
Normal file
33
nasbench201/search_model_darts.py
Normal file
@ -0,0 +1,33 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .search_model import TinyNetwork as TinyNetwork
|
||||
|
||||
|
||||
class TinyNetworkDarts(TinyNetwork):
|
||||
def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args,
|
||||
affine=False, track_running_stats=True, stem_channels=3):
|
||||
super(TinyNetworkDarts, self).__init__(C, N, max_nodes, num_classes, criterion, search_space, args,
|
||||
affine=affine, track_running_stats=track_running_stats, stem_channels=stem_channels)
|
||||
|
||||
self.theta_map = lambda x: torch.softmax(x, dim=-1)
|
||||
|
||||
def get_theta(self):
|
||||
return self.theta_map(self._arch_parameters).cpu()
|
||||
|
||||
def forward(self, inputs):
|
||||
weights = self.theta_map(self._arch_parameters)
|
||||
feature = self.stem(inputs)
|
||||
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell(feature, weights)
|
||||
else:
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling( out )
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return logits
|
80
nasbench201/search_model_darts_proj.py
Normal file
80
nasbench201/search_model_darts_proj.py
Normal file
@ -0,0 +1,80 @@
|
||||
import torch
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .search_model import TinyNetwork as TinyNetwork
|
||||
from .genotypes import Structure
|
||||
from torch.autograd import Variable
|
||||
|
||||
class TinyNetworkDartsProj(TinyNetwork):
|
||||
def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args,
|
||||
affine=False, track_running_stats=True, stem_channels=3):
|
||||
super(TinyNetworkDartsProj, self).__init__(C, N, max_nodes, num_classes, criterion, search_space, args,
|
||||
affine=affine, track_running_stats=track_running_stats, stem_channels=stem_channels)
|
||||
self.theta_map = lambda x: torch.softmax(x, dim=-1)
|
||||
|
||||
#### for edgewise projection
|
||||
self.candidate_flags = torch.tensor(len(self._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda()
|
||||
self.proj_weights = torch.zeros_like(self._arch_parameters)
|
||||
|
||||
def project_op(self, eid, opid):
|
||||
self.proj_weights[eid][opid] = 1 ## hard by default
|
||||
self.candidate_flags[eid] = False
|
||||
|
||||
def get_projected_weights(self):
|
||||
weights = self.theta_map(self._arch_parameters)
|
||||
|
||||
## proj op
|
||||
for eid in range(len(self._arch_parameters)):
|
||||
if not self.candidate_flags[eid]:
|
||||
weights[eid].data.copy_(self.proj_weights[eid])
|
||||
|
||||
return weights
|
||||
|
||||
def forward(self, inputs, weights=None):
|
||||
with torch.autograd.set_detect_anomaly(True):
|
||||
if weights is None:
|
||||
weights = self.get_projected_weights()
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell(feature, weights)
|
||||
else:
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling( out )
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return logits
|
||||
|
||||
#### utils
|
||||
def get_theta(self):
|
||||
return self.get_projected_weights()
|
||||
|
||||
def arch_parameters(self):
|
||||
return [self._arch_parameters]
|
||||
|
||||
def set_arch_parameters(self, new_alphas):
|
||||
for eid, alpha in enumerate(self.arch_parameters()):
|
||||
alpha.data.copy_(new_alphas[eid])
|
||||
|
||||
def reset_arch_parameters(self):
|
||||
self._arch_parameters = Variable(1e-3*torch.randn(self.num_edge, len(self.op_names)).cuda(), requires_grad=True)
|
||||
self.candidate_flags = torch.tensor(len(self._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda()
|
||||
self.proj_weights = torch.zeros_like(self._arch_parameters)
|
||||
|
||||
def genotype(self):
|
||||
proj_weights = self.get_projected_weights()
|
||||
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
with torch.no_grad():
|
||||
weights = proj_weights[ self.edge2index[node_str] ]
|
||||
op_name = self.op_names[ weights.argmax().item() ]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append( tuple(xlist) )
|
||||
return Structure( genotypes )
|
494
nasbench201/utils.py
Normal file
494
nasbench201/utils.py
Normal file
@ -0,0 +1,494 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import os.path
|
||||
import sys
|
||||
import shutil
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from PIL import Image
|
||||
from torch.autograd import Variable
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.datasets import utils
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
|
||||
|
||||
class AvgrageMeter(object):
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.cnt = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.sum += val * n
|
||||
self.cnt += n
|
||||
self.avg = self.sum / self.cnt
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].contiguous().view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
class Cutout(object):
|
||||
def __init__(self, length, prob=1.0):
|
||||
self.length = length
|
||||
self.prob = prob
|
||||
|
||||
def __call__(self, img):
|
||||
if np.random.binomial(1, self.prob):
|
||||
h, w = img.size(1), img.size(2)
|
||||
mask = np.ones((h, w), np.float32)
|
||||
y = np.random.randint(h)
|
||||
x = np.random.randint(w)
|
||||
|
||||
y1 = np.clip(y - self.length // 2, 0, h)
|
||||
y2 = np.clip(y + self.length // 2, 0, h)
|
||||
x1 = np.clip(x - self.length // 2, 0, w)
|
||||
x2 = np.clip(x + self.length // 2, 0, w)
|
||||
|
||||
mask[y1: y2, x1: x2] = 0.
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.expand_as(img)
|
||||
img *= mask
|
||||
return img
|
||||
|
||||
def _data_transforms_svhn(args):
|
||||
SVHN_MEAN = [0.4377, 0.4438, 0.4728]
|
||||
SVHN_STD = [0.1980, 0.2010, 0.1970]
|
||||
|
||||
train_transform = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(SVHN_MEAN, SVHN_STD),
|
||||
])
|
||||
if args.cutout:
|
||||
train_transform.transforms.append(Cutout(args.cutout_length,
|
||||
args.cutout_prob))
|
||||
|
||||
valid_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(SVHN_MEAN, SVHN_STD),
|
||||
])
|
||||
return train_transform, valid_transform
|
||||
|
||||
|
||||
def _data_transforms_cifar100(args):
|
||||
CIFAR_MEAN = [0.5071, 0.4865, 0.4409]
|
||||
CIFAR_STD = [0.2673, 0.2564, 0.2762]
|
||||
|
||||
train_transform = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
|
||||
])
|
||||
if args.cutout:
|
||||
train_transform.transforms.append(Cutout(args.cutout_length,
|
||||
args.cutout_prob))
|
||||
|
||||
valid_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
|
||||
])
|
||||
return train_transform, valid_transform
|
||||
|
||||
|
||||
def _data_transforms_cifar10(args):
|
||||
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
|
||||
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
|
||||
|
||||
train_transform = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
|
||||
])
|
||||
if args.cutout:
|
||||
train_transform.transforms.append(Cutout(args.cutout_length,
|
||||
args.cutout_prob))
|
||||
|
||||
valid_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
|
||||
])
|
||||
return train_transform, valid_transform
|
||||
|
||||
|
||||
def count_parameters_in_MB(model):
|
||||
return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6
|
||||
|
||||
|
||||
def count_parameters_in_Compact(model):
|
||||
from sota.cnn.model import Network as CompactModel
|
||||
genotype = model.genotype()
|
||||
compact_model = CompactModel(36, model._num_classes, 20, True, genotype)
|
||||
num_params = count_parameters_in_MB(compact_model)
|
||||
return num_params
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best, save, per_epoch=False, prefix=''):
|
||||
filename = prefix
|
||||
if per_epoch:
|
||||
epoch = state['epoch']
|
||||
filename += 'checkpoint_{}.pth.tar'.format(epoch)
|
||||
else:
|
||||
filename += 'checkpoint.pth.tar'
|
||||
filename = os.path.join(save, filename)
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
best_filename = os.path.join(save, 'model_best.pth.tar')
|
||||
shutil.copyfile(filename, best_filename)
|
||||
|
||||
|
||||
def load_checkpoint(model, optimizer, save, epoch=None):
|
||||
if epoch is None:
|
||||
filename = 'checkpoint.pth.tar'
|
||||
else:
|
||||
filename = 'checkpoint_{}.pth.tar'.format(epoch)
|
||||
filename = os.path.join(save, filename)
|
||||
start_epoch = 0
|
||||
if os.path.isfile(filename):
|
||||
print("=> loading checkpoint '{}'".format(filename))
|
||||
checkpoint = torch.load(filename)
|
||||
start_epoch = checkpoint['epoch']
|
||||
best_acc_top1 = checkpoint['best_acc_top1']
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
print("=> loaded checkpoint '{}' (epoch {})"
|
||||
.format(filename, checkpoint['epoch']))
|
||||
else:
|
||||
print("=> no checkpoint found at '{}'".format(filename))
|
||||
|
||||
return model, optimizer, start_epoch, best_acc_top1
|
||||
|
||||
|
||||
def save(model, model_path):
|
||||
torch.save(model.state_dict(), model_path)
|
||||
|
||||
|
||||
def load(model, model_path):
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
if drop_prob > 0.:
|
||||
keep_prob = 1. - drop_prob
|
||||
mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
|
||||
x.div_(keep_prob)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
def create_exp_dir(path, scripts_to_save=None):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
print('Experiment dir : {}'.format(path))
|
||||
|
||||
if scripts_to_save is not None:
|
||||
os.mkdir(os.path.join(path, 'scripts'))
|
||||
for script in scripts_to_save:
|
||||
dst_file = os.path.join(path, 'scripts', os.path.basename(script))
|
||||
shutil.copyfile(script, dst_file)
|
||||
|
||||
|
||||
class CIFAR10(VisionDataset):
|
||||
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory of dataset where directory
|
||||
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
|
||||
train (bool, optional): If True, creates dataset from training set, otherwise
|
||||
creates from test set.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
download (bool, optional): If true, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
|
||||
"""
|
||||
base_folder = 'cifar-10-batches-py'
|
||||
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
||||
filename = "cifar-10-python.tar.gz"
|
||||
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
|
||||
train_list = [
|
||||
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
|
||||
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
|
||||
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
|
||||
['data_batch_4', '634d18415352ddfa80567beed471001a'],
|
||||
#['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
|
||||
]
|
||||
|
||||
test_list = [
|
||||
['test_batch', '40351d587109b95175f43aff81a1287e'],
|
||||
]
|
||||
meta = {
|
||||
'filename': 'batches.meta',
|
||||
'key': 'label_names',
|
||||
'md5': '5ff9c542aee3614f3951f8cda6e48888',
|
||||
}
|
||||
|
||||
def __init__(self, root, train=True, transform=None, target_transform=None,
|
||||
download=False):
|
||||
|
||||
super(CIFAR10, self).__init__(root, transform=transform,
|
||||
target_transform=target_transform)
|
||||
|
||||
self.train = train # training set or test set
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError('Dataset not found or corrupted.' +
|
||||
' You can use download=True to download it')
|
||||
|
||||
if self.train:
|
||||
downloaded_list = self.train_list
|
||||
else:
|
||||
downloaded_list = self.test_list
|
||||
|
||||
self.data = []
|
||||
self.targets = []
|
||||
|
||||
# now load the picked numpy arrays
|
||||
for file_name, checksum in downloaded_list:
|
||||
file_path = os.path.join(self.root, self.base_folder, file_name)
|
||||
with open(file_path, 'rb') as f:
|
||||
if sys.version_info[0] == 2:
|
||||
entry = pickle.load(f)
|
||||
else:
|
||||
entry = pickle.load(f, encoding='latin1')
|
||||
self.data.append(entry['data'])
|
||||
if 'labels' in entry:
|
||||
self.targets.extend(entry['labels'])
|
||||
else:
|
||||
self.targets.extend(entry['fine_labels'])
|
||||
|
||||
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
|
||||
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
|
||||
|
||||
self._load_meta()
|
||||
|
||||
def _load_meta(self):
|
||||
path = os.path.join(self.root, self.base_folder, self.meta['filename'])
|
||||
if not utils.check_integrity(path, self.meta['md5']):
|
||||
raise RuntimeError('Dataset metadata file not found or corrupted.' +
|
||||
' You can use download=True to download it')
|
||||
with open(path, 'rb') as infile:
|
||||
if sys.version_info[0] == 2:
|
||||
data = pickle.load(infile)
|
||||
else:
|
||||
data = pickle.load(infile, encoding='latin1')
|
||||
self.classes = data[self.meta['key']]
|
||||
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
tuple: (image, target) where target is index of the target class.
|
||||
"""
|
||||
img, target = self.data[index], self.targets[index]
|
||||
|
||||
# doing this so that it is consistent with all other datasets
|
||||
# to return a PIL Image
|
||||
img = Image.fromarray(img)
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def _check_integrity(self):
|
||||
root = self.root
|
||||
for fentry in (self.train_list + self.test_list):
|
||||
filename, md5 = fentry[0], fentry[1]
|
||||
fpath = os.path.join(root, self.base_folder, filename)
|
||||
if not utils.check_integrity(fpath, md5):
|
||||
return False
|
||||
return True
|
||||
|
||||
def download(self):
|
||||
if self._check_integrity():
|
||||
print('Files already downloaded and verified')
|
||||
return
|
||||
utils.download_and_extract_archive(self.url, self.root,
|
||||
filename=self.filename,
|
||||
md5=self.tgz_md5)
|
||||
|
||||
def extra_repr(self):
|
||||
return "Split: {}".format("Train" if self.train is True else "Test")
|
||||
|
||||
|
||||
def pick_gpu_lowest_memory():
|
||||
import gpustat
|
||||
stats = gpustat.GPUStatCollection.new_query()
|
||||
ids = map(lambda gpu: int(gpu.entry['index']), stats)
|
||||
ratios = map(lambda gpu: float(gpu.memory_used)/float(gpu.memory_total), stats)
|
||||
bestGPU = min(zip(ids, ratios), key=lambda x: x[1])[0]
|
||||
return bestGPU
|
||||
|
||||
|
||||
#### early stopping (from RobustNAS)
|
||||
class EVLocalAvg(object):
|
||||
def __init__(self, window=5, ev_freq=2, total_epochs=50):
|
||||
""" Keep track of the eigenvalues local average.
|
||||
Args:
|
||||
window (int): number of elements used to compute local average.
|
||||
Default: 5
|
||||
ev_freq (int): frequency used to compute eigenvalues. Default:
|
||||
every 2 epochs
|
||||
total_epochs (int): total number of epochs that DARTS runs.
|
||||
Default: 50
|
||||
"""
|
||||
self.window = window
|
||||
self.ev_freq = ev_freq
|
||||
self.epochs = total_epochs
|
||||
|
||||
self.stop_search = False
|
||||
self.stop_epoch = total_epochs - 1
|
||||
self.stop_genotype = None
|
||||
self.stop_numparam = 0
|
||||
|
||||
self.ev = []
|
||||
self.ev_local_avg = []
|
||||
self.genotypes = {}
|
||||
self.numparams = {}
|
||||
self.la_epochs = {}
|
||||
|
||||
# start and end index of the local average window
|
||||
self.la_start_idx = 0
|
||||
self.la_end_idx = self.window
|
||||
|
||||
def reset(self):
|
||||
self.ev = []
|
||||
self.ev_local_avg = []
|
||||
self.genotypes = {}
|
||||
self.numparams = {}
|
||||
self.la_epochs = {}
|
||||
|
||||
def update(self, epoch, ev, genotype, numparam=0):
|
||||
""" Method to update the local average list.
|
||||
|
||||
Args:
|
||||
epoch (int): current epoch
|
||||
ev (float): current dominant eigenvalue
|
||||
genotype (namedtuple): current genotype
|
||||
|
||||
"""
|
||||
self.ev.append(ev)
|
||||
self.genotypes.update({epoch: genotype})
|
||||
self.numparams.update({epoch: numparam})
|
||||
# set the stop_genotype to the current genotype in case the early stop
|
||||
# procedure decides not to early stop
|
||||
self.stop_genotype = genotype
|
||||
|
||||
# since the local average computation starts after the dominant
|
||||
# eigenvalue in the first epoch is already computed we have to wait
|
||||
# at least until we have 3 eigenvalues in the list.
|
||||
if (len(self.ev) >= int(np.ceil(self.window/2))) and (epoch <
|
||||
self.epochs - 1):
|
||||
# start sliding the window as soon as the number of eigenvalues in
|
||||
# the list becomes equal to the window size
|
||||
if len(self.ev) < self.window:
|
||||
self.ev_local_avg.append(np.mean(self.ev))
|
||||
else:
|
||||
assert len(self.ev[self.la_start_idx: self.la_end_idx]) == self.window
|
||||
self.ev_local_avg.append(np.mean(self.ev[self.la_start_idx:
|
||||
self.la_end_idx]))
|
||||
self.la_start_idx += 1
|
||||
self.la_end_idx += 1
|
||||
|
||||
# keep track of the offset between the current epoch and the epoch
|
||||
# corresponding to the local average. NOTE: in the end the size of
|
||||
# self.ev and self.ev_local_avg should be equal
|
||||
self.la_epochs.update({epoch: int(epoch -
|
||||
int(self.ev_freq*np.floor(self.window/2)))})
|
||||
|
||||
elif len(self.ev) < int(np.ceil(self.window/2)):
|
||||
self.la_epochs.update({epoch: -1})
|
||||
|
||||
# since there is an offset between the current epoch and the local
|
||||
# average epoch, loop in the last epoch to compute the local average of
|
||||
# these number of elements: window, window - 1, window - 2, ..., ceil(window/2)
|
||||
elif epoch == self.epochs - 1:
|
||||
for i in range(int(np.ceil(self.window/2))):
|
||||
assert len(self.ev[self.la_start_idx: self.la_end_idx]) == self.window - i
|
||||
self.ev_local_avg.append(np.mean(self.ev[self.la_start_idx:
|
||||
self.la_end_idx + 1]))
|
||||
self.la_start_idx += 1
|
||||
|
||||
def early_stop(self, epoch, factor=1.3, es_start_epoch=10, delta=4, criteria='local_avg'):
|
||||
""" Early stopping criterion
|
||||
|
||||
Args:
|
||||
epoch (int): current epoch
|
||||
factor (float): threshold factor for the ration between the current
|
||||
and prefious eigenvalue. Default: 1.3
|
||||
es_start_epoch (int): until this epoch do not consider early
|
||||
stopping. Default: 20
|
||||
delta (int): factor influencing which previous local average we
|
||||
consider for early stopping. Default: 2
|
||||
"""
|
||||
if criteria == 'local_avg':
|
||||
if int(self.la_epochs[epoch] - self.ev_freq*delta) >= es_start_epoch:
|
||||
if criteria == 'local_avg':
|
||||
current_la = self.ev_local_avg[-1]
|
||||
previous_la = self.ev_local_avg[-1 - delta]
|
||||
self.stop_search = current_la / previous_la > factor
|
||||
if self.stop_search:
|
||||
self.stop_epoch = int(self.la_epochs[epoch] - self.ev_freq*delta)
|
||||
self.stop_genotype = self.genotypes[self.stop_epoch]
|
||||
self.stop_numparam = self.numparams[self.stop_epoch]
|
||||
elif criteria == 'exact':
|
||||
if epoch > es_start_epoch:
|
||||
current_la = self.ev[-1]
|
||||
previous_la = self.ev[-1 - delta]
|
||||
self.stop_search = current_la / previous_la > factor
|
||||
if self.stop_search:
|
||||
self.stop_epoch = epoch - delta
|
||||
self.stop_genotype = self.genotypes[self.stop_epoch]
|
||||
self.stop_numparam = self.numparams[self.stop_epoch]
|
||||
else:
|
||||
print('ERROR IN EARLY STOP: WRONG CRITERIA:', criteria); exit(0)
|
||||
|
||||
|
||||
def gen_comb(eids):
|
||||
comb = []
|
||||
for r in range(len(eids)):
|
||||
for c in range(r + 1, len(eids)):
|
||||
comb.append((eids[r], eids[c]))
|
||||
|
||||
return comb
|
Loading…
Reference in New Issue
Block a user