This commit is contained in:
HamsterMimi 2023-05-04 13:42:06 +08:00
parent 5a1dc89756
commit 2410fe9f5e
18 changed files with 3384 additions and 0 deletions

View File

@ -0,0 +1,133 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import pickle
import torch
import argparse
import json
import numpy as np
from thop import profile
from foresight.models import *
from foresight.pruners import *
from foresight.dataset import *
def get_num_classes(args):
return 100 if args.dataset == 'cifar100' else 10 if args.dataset == 'cifar10' else 120
def parse_arguments():
parser = argparse.ArgumentParser(description='Zero-cost Metrics for NAS-Bench-101')
parser.add_argument('--api_loc', default='../data/nasbench_only108.tfrecord',
type=str, help='path to API')
parser.add_argument('--json_loc', default='data/all_graphs.json',
type=str, help='path to JSON database')
parser.add_argument('--outdir', default='./',
type=str, help='output directory')
parser.add_argument('--outfname', default='test',
type=str, help='output filename')
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--dataset', type=str, default='cifar10',
help='dataset to use [cifar10, cifar100, ImageNet16-120]')
parser.add_argument('--gpu', type=int, default=0, help='GPU index to work on')
parser.add_argument('--num_data_workers', type=int, default=2, help='number of workers for dataloaders')
parser.add_argument('--dataload', type=str, default='random', help='random or grasp supported')
parser.add_argument('--dataload_info', type=int, default=1,
help='number of batches to use for random dataload or number of samples per class for grasp dataload')
parser.add_argument('--start', type=int, default=5, help='start index')
parser.add_argument('--end', type=int, default=10, help='end index')
parser.add_argument('--write_freq', type=int, default=100, help='frequency of write to file')
args = parser.parse_args()
args.device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
return args
def get_op_names(v):
o = []
for op in v:
if op == -1:
o.append('input')
elif op == -2:
o.append('output')
elif op == 0:
o.append('conv3x3-bn-relu')
elif op == 1:
o.append('conv1x1-bn-relu')
elif op == 2:
o.append('maxpool3x3')
return o
if __name__ == '__main__':
args = parse_arguments()
# nasbench = api.NASBench(args.api_loc)
models = json.load(open(args.json_loc))
print(f'Running models {args.start} to {args.end} out of {len(models.keys())}')
train_loader, val_loader = get_cifar_dataloaders(args.batch_size, args.batch_size, args.dataset,
args.num_data_workers)
all_points = []
pre = 'cf' if 'cifar' in args.dataset else 'im'
if args.outfname == 'test':
fn = f'nb1_{pre}{get_num_classes(args)}.p'
else:
fn = f'{args.outfname}.p'
op = os.path.join(args.outdir, fn)
print('outfile =', op)
first = True
# loop over nasbench1 archs (k=hash, v=[adj_matrix, ops])
idx = 0
cached_res = []
for k, v in models.items():
if idx < args.start:
idx += 1
continue
if idx >= args.end:
break
print(f'idx = {idx}')
idx += 1
res = {}
res['hash'] = k
# model
spec = nasbench1_spec._ToModelSpec(v[0], get_op_names(v[1]))
net = nasbench1.Network(spec, stem_out=128, num_stacks=3, num_mods=3, num_classes=get_num_classes(args))
net.to(args.device)
measures = predictive.find_measures(net,
train_loader,
(args.dataload, args.dataload_info, get_num_classes(args)),
args.device)
res['logmeasures'] = measures
print(res)
cached_res.append(res)
# write to file
if idx % args.write_freq == 0 or idx == args.end or idx == args.start + 10:
print(f'writing {len(cached_res)} results to {op}')
pf = open(op, 'ab')
for cr in cached_res:
pickle.dump(cr, pf)
pf.close()
cached_res = []

View File

@ -0,0 +1,128 @@
import argparse
import os
import time
from foresight.dataset import *
from foresight.models import nasbench2
from foresight.pruners import predictive
from foresight.weight_initializers import init_net
from models import get_cell_based_tiny_net
import pickle
def get_num_classes(args):
return 100 if args.dataset == 'cifar100' else 10 if args.dataset == 'cifar10' else 120
def parse_arguments():
parser = argparse.ArgumentParser(description='Zero-cost Metrics for NAS-Bench-201')
parser.add_argument('--api_loc', default='../data/NAS-Bench-201-v1_0-e61699.pth',
type=str, help='path to API')
parser.add_argument('--outdir', default='./',
type=str, help='output directory')
parser.add_argument('--init_w_type', type=str, default='none',
help='weight initialization (before pruning) type [none, xavier, kaiming, zero, one]')
parser.add_argument('--init_b_type', type=str, default='none',
help='bias initialization (before pruning) type [none, xavier, kaiming, zero, one]')
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--dataset', type=str, default='ImageNet16-120',
help='dataset to use [cifar10, cifar100, ImageNet16-120]')
parser.add_argument('--gpu', type=int, default=5, help='GPU index to work on')
parser.add_argument('--data_size', type=int, default=32, help='data_size')
parser.add_argument('--num_data_workers', type=int, default=2, help='number of workers for dataloaders')
parser.add_argument('--dataload', type=str, default='appoint', help='random, grasp, appoint supported')
parser.add_argument('--dataload_info', type=int, default=1,
help='number of batches to use for random dataload or number of samples per class for grasp dataload')
parser.add_argument('--seed', type=int, default=42, help='pytorch manual seed')
parser.add_argument('--write_freq', type=int, default=1, help='frequency of write to file')
parser.add_argument('--start', type=int, default=0, help='start index')
parser.add_argument('--end', type=int, default=0, help='end index')
parser.add_argument('--noacc', default=False, action='store_true',
help='avoid loading NASBench2 api an instead load a pickle file with tuple (index, arch_str)')
args = parser.parse_args()
args.device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
return args
if __name__ == '__main__':
args = parse_arguments()
print(args.device)
if args.noacc:
api = pickle.load(open(args.api_loc,'rb'))
else:
from nas_201_api import NASBench201API as API
api = API(args.api_loc)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
train_loader, val_loader = get_cifar_dataloaders(args.batch_size, args.batch_size, args.dataset, args.num_data_workers, resize=args.data_size)
x, y = next(iter(train_loader))
# random data
# x = torch.rand((args.batch_size, 3, args.data_size, args.data_size))
# y = 0
cached_res = []
pre = 'cf' if 'cifar' in args.dataset else 'im'
pfn = f'nb2_{args.search_space}_{pre}{get_num_classes(args)}_seed{args.seed}_dl{args.dataload}_dlinfo{args.dataload_info}_initw{args.init_w_type}_initb{args.init_b_type}_{args.batch_size}.p'
op = os.path.join(args.outdir, pfn)
end = len(api) if args.end == 0 else args.end
# loop over nasbench2 archs
for i, arch_str in enumerate(api):
if i < args.start:
continue
if i >= end:
break
res = {'i': i, 'arch': arch_str}
# print(arch_str)
if args.search_space == 'tss':
net = nasbench2.get_model_from_arch_str(arch_str, get_num_classes(args))
arch_str2 = nasbench2.get_arch_str_from_model(net)
if arch_str != arch_str2:
print(arch_str)
print(arch_str2)
raise ValueError
elif args.search_space == 'sss':
config = api.get_net_config(i, args.dataset)
# print(config)
net = get_cell_based_tiny_net(config)
net.to(args.device)
# print(net)
init_net(net, args.init_w_type, args.init_b_type)
# print(x.size(), y)
measures = get_score(net, x, i, args.device)
res['meco'] = measures
if not args.noacc:
info = api.get_more_info(i, 'cifar10-valid' if args.dataset == 'cifar10' else args.dataset, iepoch=None,
hp='200', is_random=False)
trainacc = info['train-accuracy']
valacc = info['valid-accuracy']
testacc = info['test-accuracy']
res['trainacc'] = trainacc
res['valacc'] = valacc
res['testacc'] = testacc
print(res)
cached_res.append(res)
# write to file
if i % args.write_freq == 0 or i == len(api) - 1 or i == 10:
print(f'writing {len(cached_res)} results to {op}')
pf = open(op, 'ab')
for cr in cached_res:
pickle.dump(cr, pf)
pf.close()
cached_res = []

View File

@ -0,0 +1,38 @@
#!/bin/bash
script_name=`basename "$0"`
id=${script_name%.*}
dataset=${dataset:-cifar10}
seed=${seed:-0}
gpu=${gpu:-"auto"}
pool_size=${pool_size:-10}
batch_size=${batch_size:-256}
edge_decision=${edge_decision:-'random'}
validate_rounds=${validate_rounds:-100}
metric=${metric:-'jacob'}
while [ $# -gt 0 ]; do
if [[ $1 == *"--"* ]]; then
param="${1/--/}"
declare $param="$2"
# echo $1 $2 // Optional to see the parameter:value result
fi
shift
done
echo 'id:' $id 'seed:' $seed 'dataset:' $dataset
echo 'gpu:' $gpu
cd ../nasbench201/
python3 networks_proposal.py \
--dataset $dataset \
--save $id --gpu $gpu --seed $seed \
--edge_decision $edge_decision --proj_crit $metric \
--batch_size $batch_size\
--pool_size $pool_size \
cd ../zerocostnas/
python3 post_validate.py\
--ckpt_path ../experiments/nas-bench-201/prop-$id-$seed-$pool_size-$metric\
--save $id --seed $seed --gpu $gpu\
--edge_decision $edge_decision --proj_crit $metric \
--batch_size $batch_size\
--validate_rounds $validate_rounds\

View File

@ -0,0 +1,110 @@
import os, sys, hashlib, torch
import numpy as np
from PIL import Image
import torch.utils.data as data
import pickle
def calculate_md5(fpath, chunk_size=1024 * 1024):
md5 = hashlib.md5()
with open(fpath, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''):
md5.update(chunk)
return md5.hexdigest()
def check_md5(fpath, md5, **kwargs):
return md5 == calculate_md5(fpath, **kwargs)
def check_integrity(fpath, md5=None):
print(fpath)
if not os.path.isfile(fpath): return False
if md5 is None: return True
else : return check_md5(fpath, md5)
class ImageNet16(data.Dataset):
# http://image-net.org/download-images
# A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets
# https://arxiv.org/pdf/1707.08819.pdf
train_list = [
['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'],
['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'],
['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'],
['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'],
['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'],
['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'],
['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'],
['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'],
['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'],
['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'],
]
valid_list = [
['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'],
]
def __init__(self, root, train, transform, use_num_of_class_only=None):
self.root = root
self.transform = transform
self.train = train # training set or valid set
if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.')
if self.train: downloaded_list = self.train_list
else : downloaded_list = self.valid_list
self.data = []
self.targets = []
# now load the picked numpy arrays
for i, (file_name, checksum) in enumerate(downloaded_list):
file_path = os.path.join(self.root, file_name)
#print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path))
with open(file_path, 'rb') as f:
if sys.version_info[0] == 2:
entry = pickle.load(f)
else:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
self.targets.extend(entry['labels'])
self.data = np.vstack(self.data).reshape(-1, 3, 16, 16)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
if use_num_of_class_only is not None:
assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only)
new_data, new_targets = [], []
for I, L in zip(self.data, self.targets):
if 1 <= L <= use_num_of_class_only:
new_data.append( I )
new_targets.append( L )
self.data = new_data
self.targets = new_targets
# self.mean.append(entry['mean'])
#self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16)
#self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1)
#print ('Mean : {:}'.format(self.mean))
#temp = self.data - np.reshape(self.mean, (1, 1, 1, 3))
#std_data = np.std(temp, axis=0)
#std_data = np.mean(np.mean(std_data, axis=0), axis=0)
#print ('Std : {:}'.format(std_data))
def __getitem__(self, index):
img, target = self.data[index], self.targets[index] - 1
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.data)
def _check_integrity(self):
root = self.root
for fentry in (self.train_list + self.valid_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, filename)
if not check_integrity(fpath, md5):
return False
return True

View File

@ -0,0 +1,52 @@
import torch
class Architect(object):
def __init__(self, model, args):
self.network_momentum = args.momentum
self.network_weight_decay = args.weight_decay
self.model = model
self.optimizer = torch.optim.Adam(self.model.arch_parameters(),
lr=args.arch_learning_rate, betas=(0.5, 0.999),
weight_decay=args.arch_weight_decay)
self._init_arch_parameters = []
for alpha in self.model.arch_parameters():
alpha_init = torch.zeros_like(alpha)
alpha_init.data.copy_(alpha)
self._init_arch_parameters.append(alpha_init)
#### mode
if args.method in ['darts', 'darts-proj', 'sdarts', 'sdarts-proj']:
self.method = 'fo' # first order update
elif 'so' in args.method:
print('ERROR: PLEASE USE architect.py for second order darts')
elif args.method in ['blank', 'blank-proj']:
self.method = 'blank'
else:
print('ERROR: WRONG ARCH UPDATE METHOD', args.method); exit(0)
def reset_arch_parameters(self):
for alpha, alpha_init in zip(self.model.arch_parameters(), self._init_arch_parameters):
alpha.data.copy_(alpha_init.data)
def step(self, input_train, target_train, input_valid, target_valid, *args, **kwargs):
if self.method == 'fo':
shared = self._step_fo(input_train, target_train, input_valid, target_valid)
elif self.method == 'so':
raise NotImplementedError
elif self.method == 'blank': ## do not update alpha
shared = None
return shared
#### first order
def _step_fo(self, input_train, target_train, input_valid, target_valid):
loss = self.model._loss(input_valid, target_valid)
loss.backward()
self.optimizer.step()
return None
#### darts 2nd order
def _step_darts_so(self, input_train, target_train, input_valid, target_valid, eta, model_optimizer):
raise NotImplementedError

View File

@ -0,0 +1,120 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import OPS
# Cell for NAS-Bench-201
class InferCell(nn.Module):
def __init__(self, genotype, C_in, C_out, stride):
super(InferCell, self).__init__()
self.layers = nn.ModuleList()
self.node_IN = []
self.node_IX = []
self.genotype = deepcopy(genotype)
for i in range(1, len(genotype)):
node_info = genotype[i-1]
cur_index = []
cur_innod = []
for (op_name, op_in) in node_info:
if op_in == 0:
layer = OPS[op_name](C_in , C_out, stride, True, True)
else:
layer = OPS[op_name](C_out, C_out, 1, True, True)
cur_index.append( len(self.layers) )
cur_innod.append( op_in )
self.layers.append( layer )
self.node_IX.append( cur_index )
self.node_IN.append( cur_innod )
self.nodes = len(genotype)
self.in_dim = C_in
self.out_dim = C_out
def extra_repr(self):
string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
laystr = []
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)]
x = '{:}<-({:})'.format(i+1, ','.join(y))
laystr.append( x )
return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr())
def forward(self, inputs):
nodes = [inputs]
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) )
nodes.append( node_feature )
return nodes[-1]
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class NASNetInferCell(nn.Module):
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats):
super(NASNetInferCell, self).__init__()
self.reduction = reduction
if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats)
else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats)
self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats)
if not reduction:
nodes, concats = genotype['normal'], genotype['normal_concat']
else:
nodes, concats = genotype['reduce'], genotype['reduce_concat']
self._multiplier = len(concats)
self._concats = concats
self._steps = len(nodes)
self._nodes = nodes
self.edges = nn.ModuleDict()
for i, node in enumerate(nodes):
for in_node in node:
name, j = in_node[0], in_node[1]
stride = 2 if reduction and j < 2 else 1
node_str = '{:}<-{:}'.format(i+2, j)
self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats)
# [TODO] to support drop_prob in this function..
def forward(self, s0, s1, unused_drop_prob):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i, node in enumerate(self._nodes):
clist = []
for in_node in node:
name, j = in_node[0], in_node[1]
node_str = '{:}<-{:}'.format(i+2, j)
op = self.edges[ node_str ]
clist.append( op(states[j]) )
states.append( sum(clist) )
return torch.cat([states[x] for x in self._concats], dim=1)
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x

View File

@ -0,0 +1,82 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
from ..cell_operations import ResNetBasicblock
from .cells import InferCell
# The macro structure for architectures in NAS-Bench-201
class TinyNetwork(nn.Module):
def __init__(self, C, N, genotype, num_classes):
super(TinyNetwork, self).__init__()
self._C = C
self._layerN = N
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev = C
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2, True)
else:
cell = InferCell(genotype, C_prev, C_curr, 1)
self.cells.append( cell )
C_prev = cell.out_dim
self._Layer= len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.requires_feature = True
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
if self.requires_feature:
return logits, out
else:
return logits
def _loss(self, input, target, return_logits=False):
logits, _ = self(input)
loss = self._criterion(logits, target)
return (loss, logits) if return_logits else loss
def step(self, input, target, args, shared=None, return_grad=False):
Lt, logit_t = self._loss(input, target, return_logits=True)
Lt.backward()
if args.grad_clip != 0:
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
self.optimizer.step()
if return_grad:
grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()])
return logit_t, Lt, grad
else:
return logit_t, Lt

View File

@ -0,0 +1,289 @@
import sys
import torch
import torch.nn as nn
sys.path.insert(0, '../')
from Layers import layers
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
OPS = {
'noise' : lambda C_in, C_out, stride, affine, track_running_stats: NoiseOp(stride, 0., 1.), # C_in, C_out not needed
'none' : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride),
'avg_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'avg', affine, track_running_stats),
'max_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'max', affine, track_running_stats),
'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats),
'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats),
'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats),
'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats),
'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats),
'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
}
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3']
#### wrc modified
NAS_BENCH_201_SKIP = ['none', 'skip_connect', 'nor_conv_1x1_skip', 'nor_conv_3x3_skip', 'avg_pool_3x3']
NAS_BENCH_201_SIMPLE = ['skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
NAS_BENCH_201_S2 = ['skip_connect', 'nor_conv_3x3']
NAS_BENCH_201_S4 = ['noise', 'nor_conv_3x3']
NAS_BENCH_201_S10 = ['none', 'nor_conv_3x3']
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
'nas-bench-201': NAS_BENCH_201,
'nas-bench-201-simple': NAS_BENCH_201_SIMPLE,
'nas-bench-201-s2': NAS_BENCH_201_S2,
'nas-bench-201-s4': NAS_BENCH_201_S4,
'nas-bench-201-s10': NAS_BENCH_201_S10,
'darts' : DARTS_SPACE}
class NoiseOp(nn.Module):
def __init__(self, stride, mean, std):
super(NoiseOp, self).__init__()
self.stride = stride
self.mean = mean
self.std = std
def forward(self, x, block_input=False):
if block_input:
x = x * 0
if self.stride != 1:
x_new = x[:,:,::self.stride,::self.stride]
else:
x_new = x
noise = x_new.data.new(x_new.size()).normal_(self.mean, self.std)
return noise
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
layers.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
)
def forward(self, x, block_input=False):
if block_input:
x = x * 0
return self.op(x)
def score(self):
score = 0
for l in self.op:
if hasattr(l, 'score'):
score += torch.sum(l.score).cpu().numpy()
return score
#### wrc modified
class ReLUConvBNSkip(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(ReLUConvBNSkip, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
layers.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
)
def forward(self, x, block_input=False):
if block_input:
x = x * 0
return self.op(x) + x
def score(self):
score = 0
for l in self.op:
if hasattr(l, 'score'):
score += torch.sum(l.score).cpu().numpy()
return score
####
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
layers.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
layers.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats),
)
def forward(self, x, block_input=False):
if block_input:
x = x * 0
return self.op(x)
def score(self):
score = 0
for l in self.op:
if hasattr(l, 'score'):
score += torch.sum(l.score).cpu().numpy()
return score
class DualSepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(DualSepConv, self).__init__()
self.op_a = SepConv(C_in, C_in , kernel_size, stride, padding, dilation, affine, track_running_stats)
self.op_b = SepConv(C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats)
def forward(self, x, block_input=False):
if block_input:
x = x * 0
x = self.op_a(x)
x = self.op_b(x)
return x
def score(self):
score = self.op_a.score() + self.op_b.score()
return score
class ResNetBasicblock(nn.Module):
def __init__(self, inplanes, planes, stride, affine=True):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine)
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine)
if stride == 2:
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
elif inplanes != planes:
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine)
else:
self.downsample = None
self.in_dim = inplanes
self.out_dim = planes
self.stride = stride
self.num_conv = 2
def extra_repr(self):
string = '{name}(inC={in_dim}, outC={out_dim}, stride={stride})'.format(name=self.__class__.__name__, **self.__dict__)
return string
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
return residual + basicblock
def score(self):
return self.conv_a.score() + self.conv_b.score()
class POOLING(nn.Module):
def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True):
super(POOLING, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, affine, track_running_stats)
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
def forward(self, inputs, block_input=False):
if block_input:
inputs = inputs * 0
if self.preprocess: x = self.preprocess(inputs)
else : x = inputs
return self.op(x)
def score(self):
if self.preprocess :
return self.preprocess.score()
else:
return 0
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x, block_input=False):
if block_input:
x = x * 0
return x
class Zero(nn.Module):
def __init__(self, C_in, C_out, stride):
super(Zero, self).__init__()
self.C_in = C_in
self.C_out = C_out
self.stride = stride
self.is_zero = True
def forward(self, x, block_input=False):
if block_input:
x = x*0
if self.C_in == self.C_out:
if self.stride == 1: return x.mul(0.)
else : return x[:,:,::self.stride,::self.stride].mul(0.)
else: ## this is never called in nasbench201
shape = list(x.shape)
shape[1] = self.C_out
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
return zeros
def extra_repr(self):
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, stride, affine, track_running_stats):
super(FactorizedReduce, self).__init__()
self.stride = stride
self.C_in = C_in
self.C_out = C_out
self.relu = nn.ReLU(inplace=False)
if stride == 2:
#assert C_out % 2 == 0, 'C_out : {:}'.format(C_out)
C_outs = [C_out // 2, C_out - C_out // 2]
self.convs = nn.ModuleList()
for i in range(2):
self.convs.append(layers.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) )
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
elif stride == 1:
self.conv = layers.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False)
else:
raise ValueError('Invalid stride : {:}'.format(stride))
self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
def forward(self, x, block_input=False):
if block_input:
x = x * 0
if self.stride == 2:
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1)
else:
out = self.conv(x)
out = self.bn(out)
return out
def extra_repr(self):
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
def score(self):
if self.stride == 1:
return self.conv.score()
else:
return self.convs[0].score()+self.convs[1].score()

194
nasbench201/genotypes.py Normal file
View File

@ -0,0 +1,194 @@
from copy import deepcopy
def get_combination(space, num):
combs = []
for i in range(num):
if i == 0:
for func in space:
combs.append( [(func, i)] )
else:
new_combs = []
for string in combs:
for func in space:
xstring = string + [(func, i)]
new_combs.append( xstring )
combs = new_combs
return combs
class Structure:
def __init__(self, genotype):
assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype))
self.node_num = len(genotype) + 1
self.nodes = []
self.node_N = []
for idx, node_info in enumerate(genotype):
assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info))
assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info))
for node_in in node_info:
assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in))
assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in)
self.node_N.append( len(node_info) )
self.nodes.append( tuple(deepcopy(node_info)) )
def tolist(self, remove_str):
# convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
# note that we re-order the input node in this function
# return the-genotype-list and success [if unsuccess, it is not a connectivity]
genotypes = []
for node_info in self.nodes:
node_info = list( node_info )
node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
if len(node_info) == 0: return None, False
genotypes.append( node_info )
return genotypes, True
def node(self, index):
assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self))
return self.nodes[index]
def tostr(self):
strings = []
for node_info in self.nodes:
string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info])
string = '|{:}|'.format(string)
strings.append( string )
return '+'.join(strings)
def check_valid(self):
nodes = {0: True}
for i, node_info in enumerate(self.nodes):
sums = []
for op, xin in node_info:
if op == 'none' or nodes[xin] is False: x = False
else: x = True
sums.append( x )
nodes[i+1] = sum(sums) > 0
return nodes[len(self.nodes)]
def to_unique_str(self, consider_zero=False):
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
# two operations are special, i.e., none and skip_connect
nodes = {0: '0'}
for i_node, node_info in enumerate(self.nodes):
cur_node = []
for op, xin in node_info:
if consider_zero is None:
x = '('+nodes[xin]+')' + '@{:}'.format(op)
elif consider_zero:
if op == 'none' or nodes[xin] == '#': x = '#' # zero
elif op == 'skip_connect': x = nodes[xin]
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
else:
if op == 'skip_connect': x = nodes[xin]
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
cur_node.append(x)
nodes[i_node+1] = '+'.join( sorted(cur_node) )
return nodes[ len(self.nodes) ]
def check_valid_op(self, op_names):
for node_info in self.nodes:
for inode_edge in node_info:
#assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
if inode_edge[0] not in op_names: return False
return True
def __repr__(self):
return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__))
def __len__(self):
return len(self.nodes) + 1
def __getitem__(self, index):
return self.nodes[index]
@staticmethod
def str2structure(xstr):
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
nodestrs = xstr.split('+')
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs )
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
genotypes.append( input_infos )
return Structure( genotypes )
@staticmethod
def str2fullstructure(xstr, default_name='none'):
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
nodestrs = xstr.split('+')
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs )
input_infos = list( (op, int(IDX)) for (op, IDX) in inputs)
all_in_nodes= list(x[1] for x in input_infos)
for j in range(i):
if j not in all_in_nodes: input_infos.append((default_name, j))
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
genotypes.append( tuple(node_info) )
return Structure( genotypes )
@staticmethod
def gen_all(search_space, num, return_ori):
assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space))
assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num)
all_archs = get_combination(search_space, 1)
for i, arch in enumerate(all_archs):
all_archs[i] = [ tuple(arch) ]
for inode in range(2, num):
cur_nodes = get_combination(search_space, inode)
new_all_archs = []
for previous_arch in all_archs:
for cur_node in cur_nodes:
new_all_archs.append( previous_arch + [tuple(cur_node)] )
all_archs = new_all_archs
if return_ori:
return all_archs
else:
return [Structure(x) for x in all_archs]
ResNet_CODE = Structure(
[(('nor_conv_3x3', 0), ), # node-1
(('nor_conv_3x3', 1), ), # node-2
(('skip_connect', 0), ('skip_connect', 2))] # node-3
)
AllConv3x3_CODE = Structure(
[(('nor_conv_3x3', 0), ), # node-1
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3
)
AllFull_CODE = Structure(
[(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3
)
AllConv1x1_CODE = Structure(
[(('nor_conv_1x1', 0), ), # node-1
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3
)
AllIdentity_CODE = Structure(
[(('skip_connect', 0), ), # node-1
(('skip_connect', 0), ('skip_connect', 1)), # node-2
(('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3
)
architectures = {'resnet' : ResNet_CODE,
'all_c3x3': AllConv3x3_CODE,
'all_c1x1': AllConv1x1_CODE,
'all_idnt': AllIdentity_CODE,
'all_full': AllFull_CODE}

View File

@ -0,0 +1,619 @@
import os
import sys
import numpy as np
import torch
import torch.nn.functional as f
sys.path.insert(0, '../')
import nasbench201.utils as ig_utils
import logging
import torch.utils
import copy
import scipy.stats as ss
from collections import OrderedDict
from foresight.pruners import *
from op_score import Jocab_Score, get_ntk_n
import gc
from nasbench201.linear_region import Linear_Region_Collector
torch.set_printoptions(precision=4, sci_mode=False)
np.set_printoptions(precision=4, suppress=True)
# global-edge-iter: similar toglobal-op-iterbut iteratively selects edge e from E based on the average score of all operations on each edge
def global_op_greedy_pt_project(proj_queue, model, args):
def project(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
##get remain eid numbers
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
compare = lambda x, y : x < y
crit_extrema = None
best_eid = None
input, target = next(iter(proj_queue))
for eid in remain_eids:
for opid in range(num_op):
# projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
valid_stats = Jocab_Score(model, input, target, weights=weights)
crit = valid_stats
if crit_extrema is None or compare(crit, crit_extrema):
crit_extrema = crit
best_opid = opid
best_eid = eid
logging.info('best opid %d', best_opid)
return best_eid, best_opid
tune_epochs = model.arch_parameters()[0].shape[0]
for epoch in range(tune_epochs):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, best_opid = project(model, args)
model.project_op(selected_eid, best_opid)
return
# global-edge-iter: similar toglobal-op-oncebut uses the average score of operations on edges to obtain the edge discretization order
def global_edge_greedy_pt_project(proj_queue, model, args):
def select_eid(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
##get remain eid numbers
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
compare = lambda x, y : x < y
crit_extrema = None
best_eid = None
input, target = next(iter(proj_queue))
for eid in remain_eids:
eid_score = []
for opid in range(num_op):
# projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
valid_stats = Jocab_Score(model, input, target, weights=weights)
crit = valid_stats
eid_score.append(crit)
eid_score = np.mean(eid_score)
if crit_extrema is None or compare(eid_score, crit_extrema):
crit_extrema = eid_score
best_eid = eid
return best_eid
def project(model, args, selected_eid):
## macros
num_edge, num_op = model.num_edge, model.num_op
## select the best operation
if args.proj_crit == 'jacob':
crit_idx = 3
compare = lambda x, y: x < y
else:
crit_idx = 4
compare = lambda x, y: x < y
best_opid = 0
crit_list = []
op_ids = []
input, target = next(iter(proj_queue))
for opid in range(num_op):
## projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[selected_eid])
proj_mask[opid] = 0
weights[selected_eid] = weights[selected_eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
valid_stats = Jocab_Score(model, input, target, weights=weights)
crit = valid_stats
crit_list.append(crit)
op_ids.append(opid)
best_opid = op_ids[np.nanargmin(crit_list)]
logging.info('best opid %d', best_opid)
logging.info(crit_list)
return selected_eid, best_opid
num_edges = model.arch_parameters()[0].shape[0]
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid = select_eid(model, args)
selected_eid, best_opid = project(model, args, selected_eid)
model.project_op(selected_eid, best_opid)
return
# global-op-once: only evaluates S(A(e,o)) for all operations once to obtain a ranking order of the operations, and discretizes the edgesEaccording to this order
def global_op_once_pt_project(proj_queue, model, args):
def order(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
##get remain eid numbers
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
compare = lambda x, y : x < y
edge_score = OrderedDict()
input, target = next(iter(proj_queue))
for eid in remain_eids:
crit_list = []
for opid in range(num_op):
# projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
valid_stats = Jocab_Score(model, input, target, weights=weights)
crit = valid_stats
crit_list.append(crit)
edge_score[eid] = np.nanargmin(crit_list)
return edge_score
def project(model, args, selected_eid):
## macros
num_edge, num_op = model.num_edge, model.num_op
## select the best operation
if args.proj_crit == 'jacob':
crit_idx = 3
compare = lambda x, y: x < y
else:
crit_idx = 4
compare = lambda x, y: x < y
best_opid = 0
crit_list = []
op_ids = []
input, target = next(iter(proj_queue))
for opid in range(num_op):
## projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[selected_eid])
proj_mask[opid] = 0
weights[selected_eid] = weights[selected_eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
crit = Jocab_Score(model, input, target, weights=weights)
crit_list.append(crit)
op_ids.append(opid)
best_opid = op_ids[np.nanargmin(crit_list)]
logging.info('best opid %d', best_opid)
logging.info(crit_list)
return selected_eid, best_opid
num_edges = model.arch_parameters()[0].shape[0]
eid_order = order(model, args)
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, _ = eid_order.popitem()
selected_eid, best_opid = project(model, args, selected_eid)
model.project_op(selected_eid, best_opid)
return
# global-edge-once: similar toglobal-op-oncebut uses the average score of operations on dges to obtain the edge discretization order
def global_edge_once_pt_project(proj_queue, model, args):
def order(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
##get remain eid numbers
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
compare = lambda x, y : x < y
edge_score = OrderedDict()
crit_extrema = None
best_eid = None
input, target = next(iter(proj_queue))
for eid in remain_eids:
crit_list = []
for opid in range(num_op):
# projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
crit = Jocab_Score(model, input, target, weights=weights)
crit_list.append(crit)
edge_score[eid] = np.mean(crit_list)
return edge_score
def project(model, args, selected_eid):
## macros
num_edge, num_op = model.num_edge, model.num_op
## select the best operation
if args.proj_crit == 'jacob':
crit_idx = 3
compare = lambda x, y: x < y
else:
crit_idx = 4
compare = lambda x, y: x < y
best_opid = 0
crit_extrema = None
crit_list = []
op_ids = []
input, target = next(iter(proj_queue))
for opid in range(num_op):
## projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[selected_eid])
proj_mask[opid] = 0
weights[selected_eid] = weights[selected_eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
crit = Jocab_Score(model, input, target, weights=weights)
crit_list.append(crit)
op_ids.append(opid)
best_opid = op_ids[np.nanargmin(crit_list)]
logging.info('best opid %d', best_opid)
logging.info(crit_list)
return selected_eid, best_opid
num_edges = model.arch_parameters()[0].shape[0]
eid_order = order(model, args)
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, _ = eid_order.popitem()
selected_eid, best_opid = project(model, args, selected_eid)
model.project_op(selected_eid, best_opid)
return
# fixed [reverse, order]: discretizes the edges in a fixed order, where in our experiments we discretize from the222input towards the output of the cell struct
# random: discretizes the edges in a random order (DARTS-PT)
# NOTE: Only this methods allows use other zero-cost proxy metrics
def pt_project(proj_queue, model, args):
def project(model, args):
## macros,一共6条边每条边有5个操作
num_edge, num_op = model.num_edge, model.num_op
## select an edge
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
# print('candidate_flags:', model.candidate_flags)
# print(model.candidate_flags)
# 选边的方法
if args.edge_decision == "random":
# 选出来了一个数组,取其中的一个元素
selected_eid = np.random.choice(remain_eids, size=1)[0]
elif args.edge_decision == "reverse":
selected_eid = remain_eids[-1]
else:
selected_eid = remain_eids[0]
## select the best operation
if args.proj_crit == 'jacob':
crit_idx = 3
compare = lambda x, y: x < y
else:
crit_idx = 4
compare = lambda x, y: x < y
if args.dataset == 'cifar100':
n_classes = 100
elif args.dataset == 'imagenet16-120':
n_classes = 120
else:
n_classes = 10
best_opid = 0
crit_extrema = None
crit_list = []
op_ids = []
input, target = next(iter(proj_queue))
for opid in range(num_op):
## projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[selected_eid])
# print(selected_eid, weights[selected_eid])
proj_mask[opid] = 0
weights[selected_eid] = weights[selected_eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
crit = Jocab_Score(model, input, target, weights=weights)
else:
cache_weight = model.proj_weights[selected_eid]
cache_flag = model.candidate_flags[selected_eid]
for idx in range(num_op):
if idx == opid:
model.proj_weights[selected_eid][opid] = 0
else:
model.proj_weights[selected_eid][idx] = 1.0/num_op
model.candidate_flags[selected_eid] = False
# print(model.get_projected_weights())
if args.proj_crit == 'comb':
synflow = predictive.find_measures(model,
proj_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=['synflow'])
var = predictive.find_measures(model,
proj_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=['var'])
# print(synflow, var)
comb = np.log(synflow['synflow'] + 1) / (var['var'] + 0.1)
measures = {'comb': comb}
else:
measures = predictive.find_measures(model,
proj_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=[args.proj_crit])
# print(measures)
for idx in range(num_op):
model.proj_weights[selected_eid][idx] = 0
model.candidate_flags[selected_eid] = cache_flag
crit = measures[args.proj_crit]
crit_list.append(crit)
op_ids.append(opid)
best_opid = op_ids[np.nanargmin(crit_list)]
# best_opid = op_ids[np.nanargmax(crit_list)]
logging.info('best opid %d', best_opid)
logging.info('current edge id %d', selected_eid)
logging.info(crit_list)
return selected_eid, best_opid
num_edges = model.arch_parameters()[0].shape[0]
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, best_opid = project(model, args)
model.project_op(selected_eid, best_opid)
return
def tenas_project(proj_queue, model, model_thin, args):
def project(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
##get remain eid numbers
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
compare = lambda x, y : x < y
ntks = []
lrs = []
edge_op_id = []
best_eid = None
if args.proj_crit == 'tenas':
lrc_model = Linear_Region_Collector(input_size=(1000, 1, 3, 3), sample_batch=3, dataset=args.dataset, data_path=args.data, seed=args.seed)
for eid in remain_eids:
for opid in range(num_op):
# projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
if args.proj_crit == 'tenas':
lrc_model.reinit(ori_models=[model_thin], seed=args.seed, weights=weights)
lr = lrc_model.forward_batch_sample()
lrc_model.clear()
ntk = get_ntk_n(proj_queue, [model], recalbn=0, train_mode=True, num_batch=1, weights=weights)
ntks.append(ntk)
lrs.append(lr)
edge_op_id.append('{}:{}'.format(eid, opid))
print('ntls', ntks)
print('lrs', lrs)
ntks_ranks = ss.rankdata(ntks)
lrs_ranks = ss.rankdata(lrs)
ntks_ranks = len(ntks_ranks) - ntks_ranks.astype(int)
op_ranks = []
for i in range(len(edge_op_id)):
op_ranks.append(ntks_ranks[i]+lrs_ranks[i])
best_op_index = edge_op_id[np.nanargmin(op_ranks[0:num_op])]
best_eid, best_opid = [int(x) for x in best_op_index.split(':')]
logging.info(op_ranks)
logging.info('best eid %d', best_eid)
logging.info('best opid %d', best_opid)
return best_eid, best_opid
num_edges = model.arch_parameters()[0].shape[0]
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, best_opid = project(model, args)
model.project_op(selected_eid, best_opid)
return
#new methods
#Randomly propose candidate of networks and transfer it to supernet, then perform global op selection in this subspace
def shrink_pt_project(proj_queue, model, args):
def project(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
## select an edge
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
selected_eid = np.random.choice(remain_eids, size=1)[0]
## select the best operation
if args.proj_crit == 'jacob':
crit_idx = 3
compare = lambda x, y: x < y
else:
crit_idx = 4
compare = lambda x, y: x < y
if args.dataset == 'cifar100':
n_classes = 100
elif args.dataset == 'imagenet16-120':
n_classes = 120
else:
n_classes = 10
best_opid = 0
crit_extrema = None
crit_list = []
op_ids = []
input, target = next(iter(proj_queue))
for opid in range(num_op):
## projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[selected_eid])
proj_mask[opid] = 0
weights[selected_eid] = weights[selected_eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
crit = Jocab_Score(model, input, target, weights=weights)
else:
cache_weight = model.proj_weights[selected_eid]
cache_flag = model.candidate_flags[selected_eid]
for idx in range(num_op):
if idx == opid:
model.proj_weights[selected_eid][opid] = 0
else:
model.proj_weights[selected_eid][idx] = 1.0/num_op
model.candidate_flags[selected_eid] = False
measures = predictive.find_measures(model,
train_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=[args.proj_crit])
for idx in range(num_op):
model.proj_weights[selected_eid][idx] = 0
model.candidate_flags[selected_eid] = cache_flag
crit = measures[args.proj_crit]
crit_list.append(crit)
op_ids.append(opid)
best_opid = op_ids[np.nanargmin(crit_list)]
logging.info('best opid %d', best_opid)
logging.info('current edge id %d', selected_eid)
logging.info(crit_list)
return selected_eid, best_opid
def global_project(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
##get remain eid numbers
remain_eids = torch.nonzero(model.subspace_candidate_flags).cpu().numpy().T[0]
compare = lambda x, y : x < y
crit_extrema = None
best_eid = None
best_opid = None
input, target = next(iter(proj_queue))
for eid in remain_eids:
remain_oids = torch.nonzero(model.proj_weights[eid]).cpu().numpy().T[0]
for opid in remain_oids:
# projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
valid_stats = Jocab_Score(model, input, target, weights=weights)
crit = valid_stats
if crit_extrema is None or compare(crit, crit_extrema):
crit_extrema = crit
best_opid = opid
best_eid = eid
logging.info('best eid %d', best_eid)
logging.info('best opid %d', best_opid)
model.subspace_candidate_flags[best_eid] = False
proj_mask = torch.zeros_like(model.proj_weights[best_eid])
model.proj_weights[best_eid] = model.proj_weights[best_eid] * proj_mask
model.proj_weights[best_eid][best_opid] = 1
return best_eid, best_opid
num_edges = model.arch_parameters()[0].shape[0]
#subspace
logging.info('Start subspace proposal')
subspace = copy.deepcopy(model.proj_weights)
for i in range(20):
model.reset_arch_parameters()
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, best_opid = project(model, args)
model.project_op(selected_eid, best_opid)
subspace += model.proj_weights
model.reset_arch_parameters()
subspace = torch.gt(subspace, 0).int().float()
subspace = f.normalize(subspace, p=1, dim=1)
model.proj_weights += subspace
for i in range(num_edges):
model.candidate_flags[i] = False
logging.info('Start final search in subspace')
logging.info(subspace)
model.subspace_candidate_flags = torch.tensor(len(model._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda()
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, best_opid = global_project(model, args)
model.printing(logging)
#model.project_op(selected_eid, best_opid)
return

View File

@ -0,0 +1,270 @@
import os.path as osp
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as dset
from pdb import set_trace as bp
from operator import mul
from functools import reduce
import copy
Dataset2Class = {'cifar10': 10,
'cifar100': 100,
'imagenet-1k-s': 1000,
'imagenet-1k': 1000,
}
class CUTOUT(object):
def __init__(self, length):
self.length = length
def __repr__(self):
return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
imagenet_pca = {
'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
'eigvec': np.asarray([
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
])
}
class RandChannel(object):
# randomly pick channels from input
def __init__(self, num_channel):
self.num_channel = num_channel
def __repr__(self):
return ('{name}(num_channel={num_channel})'.format(name=self.__class__.__name__, **self.__dict__))
def __call__(self, img):
channel = img.size(0)
channel_choice = sorted(np.random.choice(list(range(channel)), size=self.num_channel, replace=False))
return torch.index_select(img, 0, torch.Tensor(channel_choice).long())
def get_datasets(name, root, input_size, cutout=-1):
assert len(input_size) in [3, 4]
if len(input_size) == 4:
input_size = input_size[1:]
assert input_size[1] == input_size[2]
if name == 'cifar10':
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
elif name == 'cifar100':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif name.startswith('imagenet-1k'):
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
elif name.startswith('ImageNet16'):
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
std = [x / 255 for x in [63.22, 61.26 , 65.09]]
else:
raise TypeError("Unknow dataset : {:}".format(name))
#ßprint(input_size)
# Data Argumentation
if name == 'cifar10' or name == 'cifar100':
lists = [transforms.RandomCrop(input_size[1], padding=4), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])]
if cutout > 0 : lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
elif name.startswith('ImageNet16'):
lists = [transforms.RandomCrop(input_size[1], padding=4), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])]
if cutout > 0 : lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
elif name.startswith('imagenet-1k'):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
if name == 'imagenet-1k':
xlists = []
xlists.append(transforms.Resize((32, 32), interpolation=2))
xlists.append(transforms.RandomCrop(input_size[1], padding=0))
elif name == 'imagenet-1k-s':
xlists = [transforms.RandomResizedCrop(32, scale=(0.2, 1.0))]
xlists = []
else: raise ValueError('invalid name : {:}'.format(name))
xlists.append(transforms.ToTensor())
xlists.append(normalize)
xlists.append(RandChannel(input_size[0]))
train_transform = transforms.Compose(xlists)
test_transform = transforms.Compose([transforms.Resize(40), transforms.CenterCrop(32), transforms.ToTensor(), normalize])
else:
raise TypeError("Unknow dataset : {:}".format(name))
if name == 'cifar10':
train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR10 (root, train=False, transform=test_transform , download=True)
assert len(train_data) == 50000 and len(test_data) == 10000
elif name == 'cifar100':
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
assert len(train_data) == 50000 and len(test_data) == 10000
elif name.startswith('imagenet-1k'):
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
else: raise TypeError("Unknow dataset : {:}".format(name))
class_num = Dataset2Class[name]
return train_data, test_data, class_num
class LinearRegionCount(object):
"""Computes and stores the average and current value"""
def __init__(self, n_samples):
self.ActPattern = {}
self.n_LR = -1
self.n_samples = n_samples
self.ptr = 0
self.activations = None
@torch.no_grad()
def update2D(self, activations):
n_batch = activations.size()[0]
n_neuron = activations.size()[1]
self.n_neuron = n_neuron
if self.activations is None:
self.activations = torch.zeros(self.n_samples, n_neuron).cuda()
self.activations[self.ptr:self.ptr+n_batch] = torch.sign(activations) # after ReLU
self.ptr += n_batch
@torch.no_grad()
def calc_LR(self):
res = torch.matmul(self.activations.half(), (1-self.activations).T.half()) # each element in res: A * (1 - B)
res += res.T # make symmetric, each element in res: A * (1 - B) + (1 - A) * B, a non-zero element indicate a pair of two different linear regions
res = 1 - torch.sign(res) # a non-zero element now indicate two linear regions are identical
res = res.sum(1) # for each sample's linear region: how many identical regions from other samples
res = 1. / res.float() # contribution of each redudant (repeated) linear region
self.n_LR = res.sum().item() # sum of unique regions (by aggregating contribution of all regions)
del self.activations, res
self.activations = None
torch.cuda.empty_cache()
@torch.no_grad()
def update1D(self, activationList):
code_string = ''
for key, value in activationList.items():
n_neuron = value.size()[0]
for i in range(n_neuron):
if value[i] > 0:
code_string += '1'
else:
code_string += '0'
if code_string not in self.ActPattern:
self.ActPattern[code_string] = 1
def getLinearReginCount(self):
if self.n_LR == -1:
self.calc_LR()
return self.n_LR
class Linear_Region_Collector:
def __init__(self, models=[], input_size=(64, 3, 32, 32), sample_batch=100, dataset='cifar100', data_path=None, seed=0):
self.models = []
self.input_size = input_size # BCHW
self.sample_batch = sample_batch
self.input_numel = reduce(mul, self.input_size, 1)
self.interFeature = []
self.dataset = dataset
self.data_path = data_path
self.seed = seed
self.reinit(models, input_size, sample_batch, seed)
def reinit(self, ori_models=None, input_size=None, sample_batch=None, seed=None, weights=None):
models = []
for network in ori_models:
network = network.cuda()
net = copy.deepcopy(network)
net.proj_weights = weights
num_edge, num_op = net.num_edge, net.num_op
for i in range(num_edge):
net.candidate_flags[i] = False
net.eval()
models.append(net)
if models is not None:
assert isinstance(models, list)
del self.models
self.models = models
for model in self.models:
self.register_hook(model)
device = torch.cuda.current_device()
model = model.cuda(device=device)
self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(models))]
if input_size is not None or sample_batch is not None:
if input_size is not None:
self.input_size = input_size # BCHW
self.input_numel = reduce(mul, self.input_size, 1)
if sample_batch is not None:
self.sample_batch = sample_batch
if self.data_path is not None:
self.train_data, _, class_num = get_datasets(self.dataset, self.data_path, self.input_size, -1)
self.train_loader = torch.utils.data.DataLoader(self.train_data, batch_size=self.input_size[0], num_workers=16, pin_memory=True, drop_last=True, shuffle=True)
self.loader = iter(self.train_loader)
if seed is not None and seed != self.seed:
self.seed = seed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
del self.interFeature
self.interFeature = []
torch.cuda.empty_cache()
def clear(self):
self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(self.models))]
del self.interFeature
self.interFeature = []
torch.cuda.empty_cache()
def register_hook(self, model):
for m in model.modules():
if isinstance(m, nn.ReLU):
m.register_forward_hook(hook=self.hook_in_forward)
def hook_in_forward(self, module, input, output):
if isinstance(input, tuple) and len(input[0].size()) == 4:
self.interFeature.append(output.detach()) # for ReLU
def forward_batch_sample(self):
for _ in range(self.sample_batch):
try:
inputs, targets = self.loader.next()
except Exception:
del self.loader
self.loader = iter(self.train_loader)
inputs, targets = self.loader.next()
for model, LRCount in zip(self.models, self.LRCounts):
self.forward(model, LRCount, inputs)
output = [LRCount.getLinearReginCount() for LRCount in self.LRCounts]
return output
def forward(self, model, LRCount, input_data):
self.interFeature = []
with torch.no_grad():
model.forward(input_data.cuda())
if len(self.interFeature) == 0: return
feature_data = torch.cat([f.view(input_data.size(0), -1) for f in self.interFeature], 1)
LRCount.update2D(feature_data)

View File

@ -0,0 +1,245 @@
import os
import sys
sys.path.insert(0, '../')
import time
import glob
import json
import shutil
import logging
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.utils
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
import nasbench201.utils as ig_utils
from nasbench201.search_model_darts_proj import TinyNetworkDartsProj
from nasbench201.cell_operations import SearchSpaceNames
from nasbench201.init_projection import pt_project, global_op_greedy_pt_project, global_op_once_pt_project, global_edge_greedy_pt_project, global_edge_once_pt_project, shrink_pt_project, tenas_project
from nas_201_api import NASBench201API as API
torch.set_printoptions(precision=4, sci_mode=False)
np.set_printoptions(precision=4, suppress=True)
parser = argparse.ArgumentParser("sota")
# data related
parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100', 'imagenet16-120'], help='choose dataset')
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
parser.add_argument('--batch_size', type=int, default=64, help='batch size for alpha')
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability')
parser.add_argument('--seed', type=int, default=2, help='random seed')
#search space setting
parser.add_argument('--search_space', type=str, default='nas-bench-201')
parser.add_argument('--pool_size', type=int, default=100, help='number of model to proposed')
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
#system configurations
parser.add_argument('--gpu', type=str, default='auto', help='gpu device id')
parser.add_argument('--save', type=str, default='exp', help='experiment name')
#default opt setting for model
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--nesterov', action='store_true', default=True, help='using nestrov momentum for SGD')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
#### common
parser.add_argument('--fast', action='store_true', default=True, help='skip loading api which is slow')
#### projection
parser.add_argument('--edge_decision', type=str, default='random', choices=['random','reverse', 'order', 'global_op_greedy', 'global_op_once', 'global_edge_greedy', 'global_edge_once', 'shrink_pt_project'], help='which edge to be projected next')
parser.add_argument('--proj_crit', type=str, default="comb", choices=['loss', 'acc', 'jacob', 'snip', 'fisher', 'synflow', 'grad_norm', 'grasp', 'jacob_cov','tenas', 'var', 'cor', 'norm', 'comb', 'meco'], help='criteria for projection')
args = parser.parse_args()
#### args augment
expid = args.save
args.save = '../experiments/nas-bench-201/prop-{}-{}-{}'.format(args.save, args.seed, args.pool_size)
if not args.dataset == 'cifar10':
args.save += '-' + args.dataset
if not args.edge_decision == 'random':
args.save += '-' + args.edge_decision
if not args.proj_crit == 'jacob':
args.save += '-' + args.proj_crit
#### logging
scripts_to_save = glob.glob('*.py') \
# + ['../exp_scripts/{}.sh'.format(expid)]
if os.path.exists(args.save):
if input("WARNING: {} exists, override?[y/n]".format(args.save)) == 'y':
print('proceed to override saving directory')
shutil.rmtree(args.save)
else:
exit(0)
ig_utils.create_exp_dir(args.save, scripts_to_save=scripts_to_save)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
log_file = 'log.txt'
log_path = os.path.join(args.save, log_file)
logging.info('======> log filename: %s', log_file)
if os.path.exists(log_path):
if input("WARNING: {} exists, override?[y/n]".format(log_file)) == 'y':
print('proceed to override log file directory')
else:
exit(0)
fh = logging.FileHandler(log_path, mode='w')
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
writer = SummaryWriter(args.save + '/runs')
#### macros
if args.dataset == 'cifar100':
n_classes = 100
elif args.dataset == 'imagenet16-120':
n_classes = 120
else:
n_classes = 10
def main():
torch.set_num_threads(3)
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)
np.random.seed(args.seed)
gpu = ig_utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu)
torch.cuda.set_device(gpu)
cudnn.benchmark = True
torch.manual_seed(args.seed)
cudnn.enabled = True
torch.cuda.manual_seed(args.seed)
logging.info("args = %s", args)
logging.info('gpu device = %d' % gpu)
#### model
criterion = nn.CrossEntropyLoss()
search_space = SearchSpaceNames[args.search_space]
# 初始化超网络
model = TinyNetworkDartsProj(C=args.init_channels, N=5, max_nodes=4, num_classes=n_classes, criterion=criterion, search_space=search_space, args=args)
model_thin = TinyNetworkDartsProj(C=args.init_channels, N=5, max_nodes=4, num_classes=n_classes, criterion=criterion, search_space=search_space, args=args, stem_channels=1)
model = model.cuda()
model_thin = model_thin.cuda()
logging.info("param size = %fMB", ig_utils.count_parameters_in_MB(model))
#### data
if args.dataset == 'cifar10':
train_transform, valid_transform = ig_utils._data_transforms_cifar10(args)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
elif args.dataset == 'cifar100':
train_transform, valid_transform = ig_utils._data_transforms_cifar100(args)
train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)
elif args.dataset == 'imagenet16-120':
import torchvision.transforms as transforms
from nasbench201.DownsampledImageNet import ImageNet16
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
std = [x / 255 for x in [63.22, 61.26, 65.09]]
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
train_transform = transforms.Compose(lists)
train_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=True, transform=train_transform, use_num_of_class_only=120)
valid_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=False, transform=train_transform, use_num_of_class_only=120)
assert len(train_data) == 151700
num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(args.train_portion * num_train))
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
pin_memory=True)
#format network pool diction
networks_pool={}
networks_pool['search_space'] = args.search_space
networks_pool['dataset'] = args.dataset
networks_pool['networks'] = []
networks_pool['pool_size'] = args.pool_size
#### architecture selection / projection
for i in range(args.pool_size):
network_info={}
logging.info('{} MODEL HAS SEARCHED'.format(i+1))
if args.edge_decision == 'global_op_greedy':
global_op_greedy_pt_project(train_queue, model, args)
elif args.edge_decision == 'global_op_once':
global_op_once_pt_project(train_queue, model, args)
elif args.edge_decision == 'global_edge_greedy':
global_edge_greedy_pt_project(train_queue, model, args)
elif args.edge_decision == 'global_edge_once':
global_edge_once_pt_project(train_queue, model, args)
elif args.edge_decision == 'shrink_pt_project':
shrink_pt_project(train_queue, model, args)
api = API('../data/NAS-Bench-201-v1_0-e61699.pth')
cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = query(api, model.genotype().tostr(), logging)
else:
if args.proj_crit == 'jacob':
pt_project(train_queue, model, args)
else:
pt_project(train_queue, model, args)
# tenas_project(train_queue, model, model_thin, args)
network_info['id'] = str(i)
network_info['genotype'] = model.genotype().tostr()
networks_pool['networks'].append(network_info)
model.reset_arch_parameters()
with open(os.path.join(args.save,'networks_pool.json'), 'w') as save_file:
json.dump(networks_pool, save_file)
#### util functions
def distill(result):
result = result.split('\n')
cifar10 = result[5].replace(' ', '').split(':')
cifar100 = result[7].replace(' ', '').split(':')
imagenet16 = result[9].replace(' ', '').split(':')
cifar10_train = float(cifar10[1].strip(',test')[-7:-2].strip('='))
cifar10_test = float(cifar10[2][-7:-2].strip('='))
cifar100_train = float(cifar100[1].strip(',valid')[-7:-2].strip('='))
cifar100_valid = float(cifar100[2].strip(',test')[-7:-2].strip('='))
cifar100_test = float(cifar100[3][-7:-2].strip('='))
imagenet16_train = float(imagenet16[1].strip(',valid')[-7:-2].strip('='))
imagenet16_valid = float(imagenet16[2].strip(',test')[-7:-2].strip('='))
imagenet16_test = float(imagenet16[3][-7:-2].strip('='))
return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test
def query(api, genotype, logging):
result = api.query_by_arch(genotype, hp='200')
logging.info('{:}'.format(result))
cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = distill(result)
logging.info('cifar10 train %f test %f', cifar10_train, cifar10_test)
logging.info('cifar100 train %f valid %f test %f', cifar100_train, cifar100_valid, cifar100_test)
logging.info('imagenet16 train %f valid %f test %f', imagenet16_train, imagenet16_valid, imagenet16_test)
return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test
if __name__ == '__main__':
main()

113
nasbench201/op_score.py Normal file
View File

@ -0,0 +1,113 @@
import gc
import numpy as np
import os
import sys
import torch
import torch.nn.functional as f
from operator import mul
from functools import reduce
import copy
sys.path.insert(0, '../')
def Jocab_Score(ori_model, input, target, weights=None):
model = copy.deepcopy(ori_model)
model.eval()
model.proj_weights = weights
num_edge, num_op = model.num_edge, model.num_op
for i in range(num_edge):
model.candidate_flags[i] = False
batch_size = input.shape[0]
model.K = torch.zeros(batch_size, batch_size).cuda()
def counting_forward_hook(module, inp, out):
try:
if isinstance(inp, tuple):
inp = inp[0]
inp = inp.view(inp.size(0), -1)
x = (inp > 0).float()
K = x @ x.t()
K2 = (1.-x) @ (1.-x.t())
model.K = model.K + K + K2
except:
pass
for name, module in model.named_modules():
if 'ReLU' in str(type(module)):
module.register_forward_hook(counting_forward_hook)
input = input.cuda()
model(input)
score = hooklogdet(model.K.cpu().numpy())
del model
del input
return score
def hooklogdet(K, labels=None):
s, ld = np.linalg.slogdet(K)
return ld
# NTK
#------------------------------------------------------------
#https://github.com/VITA-Group/TENAS/blob/main/lib/procedures/ntk.py
#
def recal_bn(network, xloader, recalbn, device):
for m in network.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.running_mean.data.fill_(0)
m.running_var.data.fill_(0)
m.num_batches_tracked.data.zero_()
m.momentum = None
network.train()
with torch.no_grad():
for i, (inputs, targets) in enumerate(xloader):
if i >= recalbn: break
inputs = inputs.cuda(device=device, non_blocking=True)
_, _ = network(inputs)
return network
def get_ntk_n(xloader, networks, recalbn=0, train_mode=False, num_batch=-1, weights=None):
device = torch.cuda.current_device()
ntks = []
copied_networks = []
for network in networks:
network = network.cuda(device=device)
net = copy.deepcopy(network)
net.proj_weights = weights
num_edge, num_op = net.num_edge, net.num_op
for i in range(num_edge):
net.candidate_flags[i] = False
if train_mode:
net.train()
else:
net.eval()
copied_networks.append(net)
######
grads = [[] for _ in range(len(copied_networks))]
for i, (inputs, targets) in enumerate(xloader):
if num_batch > 0 and i >= num_batch: break
inputs = inputs.cuda(device=device, non_blocking=True)
for net_idx, network in enumerate(copied_networks):
network.zero_grad()
inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
logit = network(inputs_)
if isinstance(logit, tuple):
logit = logit[1] # 201 networks: return features and logits
for _idx in range(len(inputs_)):
logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
grad = []
for name, W in network.named_parameters():
if 'weight' in name and W.grad is not None:
grad.append(W.grad.view(-1).detach())
grads[net_idx].append(torch.cat(grad, -1))
network.zero_grad()
torch.cuda.empty_cache()
######
grads = [torch.stack(_grads, 0) for _grads in grads]
ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads]
conds = []
for ntk in ntks:
eigenvalues, _ = torch.symeig(ntk) # ascending
conds.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0]).item(), copy=True, nan=100000.0))
del copied_networks
return conds

182
nasbench201/search_cells.py Normal file
View File

@ -0,0 +1,182 @@
import math, random, torch
import warnings
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
import sys
sys.path.insert(0, '../')
from nasbench201.cell_operations import OPS
# This module is used for NAS-Bench-201, represents a small search space with a complete DAG
class NAS201SearchCell(nn.Module):
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True):
super(NAS201SearchCell, self).__init__()
self.op_names = deepcopy(op_names)
self.edges = nn.ModuleDict()
self.max_nodes = max_nodes
self.in_dim = C_in
self.out_dim = C_out
for i in range(1, max_nodes):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
if j == 0:
xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names]
else:
xlists = [OPS[op_name](C_in , C_out, 1, affine, track_running_stats) for op_name in op_names]
self.edges[ node_str ] = nn.ModuleList( xlists )
self.edge_keys = sorted(list(self.edges.keys()))
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edges)
def extra_repr(self):
string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
return string
def forward(self, inputs, weightss):
return self._forward(inputs, weightss)
def _forward(self, inputs, weightss):
with torch.autograd.set_detect_anomaly(True):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = weightss[ self.edge2index[node_str] ]
inter_nodes.append(sum(layer(nodes[j], block_input=True)*w if w==0 else layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights)) )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# GDAS
def forward_gdas(self, inputs, hardwts, index):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = hardwts[ self.edge2index[node_str] ]
argmaxs = index[ self.edge2index[node_str] ].item()
weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) )
inter_nodes.append( weigsum )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# joint
def forward_joint(self, inputs, weightss):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = weightss[ self.edge2index[node_str] ]
#aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel()
aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) )
inter_nodes.append( aggregation )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# uniform random sampling per iteration, SETN
def forward_urs(self, inputs):
nodes = [inputs]
for i in range(1, self.max_nodes):
while True: # to avoid select zero for all ops
sops, has_non_zero = [], False
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
candidates = self.edges[node_str]
select_op = random.choice(candidates)
sops.append( select_op )
if not hasattr(select_op, 'is_zero') or select_op.is_zero is False: has_non_zero=True
if has_non_zero: break
inter_nodes = []
for j, select_op in enumerate(sops):
inter_nodes.append( select_op(nodes[j]) )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# select the argmax
def forward_select(self, inputs, weightss):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = weightss[ self.edge2index[node_str] ]
inter_nodes.append( self.edges[node_str][ weights.argmax().item() ]( nodes[j] ) )
#inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# forward with a specific structure
def forward_dynamic(self, inputs, structure):
nodes = [inputs]
for i in range(1, self.max_nodes):
cur_op_node = structure.nodes[i-1]
inter_nodes = []
for op_name, j in cur_op_node:
node_str = '{:}<-{:}'.format(i, j)
op_index = self.op_names.index( op_name )
inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) )
nodes.append( sum(inter_nodes) )
return nodes[-1]
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
class NAS201SearchCell_PartialChannel(NAS201SearchCell):
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True, k=4):
super(NAS201SearchCell, self).__init__()
self.k = k
self.op_names = deepcopy(op_names)
self.edges = nn.ModuleDict()
self.max_nodes = max_nodes
self.in_dim = C_in
self.out_dim = C_out
for i in range(1, max_nodes):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
if j == 0:
xlists = [OPS[op_name](C_in//self.k , C_out//self.k, stride, affine, track_running_stats) for op_name in op_names]
else:
xlists = [OPS[op_name](C_in//self.k , C_out//self.k, 1, affine, track_running_stats) for op_name in op_names]
self.edges[ node_str ] = nn.ModuleList( xlists )
self.edge_keys = sorted(list(self.edges.keys()))
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edges)
def MixedOp(self, x, ops, weights):
dim_2 = x.shape[1]
xtemp = x[ : , : dim_2//self.k, :, :]
xtemp2 = x[ : , dim_2//self.k:, :, :]
temp1 = sum(w * op(xtemp) for w, op in zip(weights, ops))
ans = torch.cat([temp1,xtemp2],dim=1)
ans = channel_shuffle(ans,self.k)
return ans
def forward(self, inputs, weightss):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = weightss[ self.edge2index[node_str] ]
# inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
inter_nodes.append(self.MixedOp(x=nodes[j], ops=self.edges[node_str], weights=weights))
nodes.append( sum(inter_nodes) )
return nodes[-1]

202
nasbench201/search_model.py Normal file
View File

@ -0,0 +1,202 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from .cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
from torch.autograd import Variable
class TinyNetwork(nn.Module):
def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args, affine=False, track_running_stats=True, stem_channels=3):
super(TinyNetwork, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self._num_classes = num_classes
self._criterion = criterion
self._args = args
self._affine = affine
self._track_running_stats = track_running_stats
self.stem = nn.Sequential(
nn.Conv2d(stem_channels, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev = cell.out_dim
self.num_edge = num_edge
self.num_op = len(search_space)
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
# self._arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self._arch_parameters = Variable(1e-3*torch.randn(num_edge, len(search_space)).cuda(), requires_grad=True)
## optimizer
## 记录的是m在内存中的地址以示区分
arch_params = set(id(m) for m in self.arch_parameters())
self._model_params = [m for m in self.parameters() if id(m) not in arch_params]
# 模型参数优化器
self.optimizer = torch.optim.SGD(
self._model_params,
args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov= args.nesterov)
def entropy_y_x(self, p_logit):
p = F.softmax(p_logit, dim=1)
return - torch.sum(p * F.log_softmax(p_logit, dim=1)) / p_logit.shape[0]
def _loss(self, input, target, return_logits=False):
logits = self(input)
loss = self._criterion(logits, target)
return (loss, logits) if return_logits else loss
def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def arch_parameters(self):
return [self._arch_parameters]
def get_theta(self):
return nn.functional.softmax(self._arch_parameters, dim=-1).cpu()
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad():
weights = self._arch_parameters[ self.edge2index[node_str] ]
op_name = self.op_names[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def forward(self, inputs, weights=None):
sim_nn = []
weights = nn.functional.softmax(self._arch_parameters, dim=-1) if weights is None else weights
if self.slim:
weights[1].data.fill_(0)
weights[3].data.fill_(0)
weights[4].data.fill_(0)
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell(feature, weights)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return logits
def _save_arch_parameters(self):
self._saved_arch_parameters = [p.clone() for p in self._arch_parameters]
def project_arch(self):
self._save_arch_parameters()
for p in self.arch_parameters():
m, n = p.size()
maxIndexs = p.data.cpu().numpy().argmax(axis=1)
p.data = self.proximal_step(p, maxIndexs)
def proximal_step(self, var, maxIndexs=None):
values = var.data.cpu().numpy()
m, n = values.shape
alphas = []
for i in range(m):
for j in range(n):
if j == maxIndexs[i]:
alphas.append(values[i][j].copy())
values[i][j] = 1
else:
values[i][j] = 0
return torch.Tensor(values).cuda()
def restore_arch_parameters(self):
for i, p in enumerate(self._arch_parameters):
p.data.copy_(self._saved_arch_parameters[i])
del self._saved_arch_parameters
def new(self):
model_new = TinyNetwork(self._C, self._layerN, self.max_nodes, self._num_classes, self._criterion,
self.op_names, self._args, self._affine, self._track_running_stats).cuda()
for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
x.data.copy_(y.data)
return model_new
def step(self, input, target, args, shared=None, return_grad=False):
Lt, logit_t = self._loss(input, target, return_logits=True)
Lt.backward()
if args.grad_clip != 0:
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
self.optimizer.step()
if return_grad:
grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()])
return logit_t, Lt, grad
else:
return logit_t, Lt
def printing(self, logging):
logging.info(self.get_theta())
def set_arch_parameters(self, new_alphas):
for alpha, new_alpha in zip(self.arch_parameters(), new_alphas):
alpha.data.copy_(new_alpha.data)
def save_arch_parameters(self):
self._saved_arch_parameters = self._arch_parameters.clone()
def restore_arch_parameters(self):
self.set_arch_parameters(self._saved_arch_parameters)
def reset_optimizer(self, lr, momentum, weight_decay):
del self.optimizer
self.optimizer = torch.optim.SGD(
self.get_weights(),
lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov= args.nesterov)

View File

@ -0,0 +1,33 @@
import torch
import torch.nn as nn
from .search_cells import NAS201SearchCell as SearchCell
from .search_model import TinyNetwork as TinyNetwork
class TinyNetworkDarts(TinyNetwork):
def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args,
affine=False, track_running_stats=True, stem_channels=3):
super(TinyNetworkDarts, self).__init__(C, N, max_nodes, num_classes, criterion, search_space, args,
affine=affine, track_running_stats=track_running_stats, stem_channels=stem_channels)
self.theta_map = lambda x: torch.softmax(x, dim=-1)
def get_theta(self):
return self.theta_map(self._arch_parameters).cpu()
def forward(self, inputs):
weights = self.theta_map(self._arch_parameters)
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell(feature, weights)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return logits

View File

@ -0,0 +1,80 @@
import torch
from .search_cells import NAS201SearchCell as SearchCell
from .search_model import TinyNetwork as TinyNetwork
from .genotypes import Structure
from torch.autograd import Variable
class TinyNetworkDartsProj(TinyNetwork):
def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args,
affine=False, track_running_stats=True, stem_channels=3):
super(TinyNetworkDartsProj, self).__init__(C, N, max_nodes, num_classes, criterion, search_space, args,
affine=affine, track_running_stats=track_running_stats, stem_channels=stem_channels)
self.theta_map = lambda x: torch.softmax(x, dim=-1)
#### for edgewise projection
self.candidate_flags = torch.tensor(len(self._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda()
self.proj_weights = torch.zeros_like(self._arch_parameters)
def project_op(self, eid, opid):
self.proj_weights[eid][opid] = 1 ## hard by default
self.candidate_flags[eid] = False
def get_projected_weights(self):
weights = self.theta_map(self._arch_parameters)
## proj op
for eid in range(len(self._arch_parameters)):
if not self.candidate_flags[eid]:
weights[eid].data.copy_(self.proj_weights[eid])
return weights
def forward(self, inputs, weights=None):
with torch.autograd.set_detect_anomaly(True):
if weights is None:
weights = self.get_projected_weights()
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell(feature, weights)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return logits
#### utils
def get_theta(self):
return self.get_projected_weights()
def arch_parameters(self):
return [self._arch_parameters]
def set_arch_parameters(self, new_alphas):
for eid, alpha in enumerate(self.arch_parameters()):
alpha.data.copy_(new_alphas[eid])
def reset_arch_parameters(self):
self._arch_parameters = Variable(1e-3*torch.randn(self.num_edge, len(self.op_names)).cuda(), requires_grad=True)
self.candidate_flags = torch.tensor(len(self._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda()
self.proj_weights = torch.zeros_like(self._arch_parameters)
def genotype(self):
proj_weights = self.get_projected_weights()
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad():
weights = proj_weights[ self.edge2index[node_str] ]
op_name = self.op_names[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )

494
nasbench201/utils.py Normal file
View File

@ -0,0 +1,494 @@
from __future__ import print_function
import numpy as np
import os
import os.path
import sys
import shutil
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
from torchvision.datasets import VisionDataset
from torchvision.datasets import utils
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
class AvgrageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.avg = 0
self.sum = 0
self.cnt = 0
def update(self, val, n=1):
self.sum += val * n
self.cnt += n
self.avg = self.sum / self.cnt
def accuracy(output, target, topk=(1,)):
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].contiguous().view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class Cutout(object):
def __init__(self, length, prob=1.0):
self.length = length
self.prob = prob
def __call__(self, img):
if np.random.binomial(1, self.prob):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
def _data_transforms_svhn(args):
SVHN_MEAN = [0.4377, 0.4438, 0.4728]
SVHN_STD = [0.1980, 0.2010, 0.1970]
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(SVHN_MEAN, SVHN_STD),
])
if args.cutout:
train_transform.transforms.append(Cutout(args.cutout_length,
args.cutout_prob))
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(SVHN_MEAN, SVHN_STD),
])
return train_transform, valid_transform
def _data_transforms_cifar100(args):
CIFAR_MEAN = [0.5071, 0.4865, 0.4409]
CIFAR_STD = [0.2673, 0.2564, 0.2762]
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
if args.cutout:
train_transform.transforms.append(Cutout(args.cutout_length,
args.cutout_prob))
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
return train_transform, valid_transform
def _data_transforms_cifar10(args):
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
if args.cutout:
train_transform.transforms.append(Cutout(args.cutout_length,
args.cutout_prob))
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
return train_transform, valid_transform
def count_parameters_in_MB(model):
return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6
def count_parameters_in_Compact(model):
from sota.cnn.model import Network as CompactModel
genotype = model.genotype()
compact_model = CompactModel(36, model._num_classes, 20, True, genotype)
num_params = count_parameters_in_MB(compact_model)
return num_params
def save_checkpoint(state, is_best, save, per_epoch=False, prefix=''):
filename = prefix
if per_epoch:
epoch = state['epoch']
filename += 'checkpoint_{}.pth.tar'.format(epoch)
else:
filename += 'checkpoint.pth.tar'
filename = os.path.join(save, filename)
torch.save(state, filename)
if is_best:
best_filename = os.path.join(save, 'model_best.pth.tar')
shutil.copyfile(filename, best_filename)
def load_checkpoint(model, optimizer, save, epoch=None):
if epoch is None:
filename = 'checkpoint.pth.tar'
else:
filename = 'checkpoint_{}.pth.tar'.format(epoch)
filename = os.path.join(save, filename)
start_epoch = 0
if os.path.isfile(filename):
print("=> loading checkpoint '{}'".format(filename))
checkpoint = torch.load(filename)
start_epoch = checkpoint['epoch']
best_acc_top1 = checkpoint['best_acc_top1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(filename, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(filename))
return model, optimizer, start_epoch, best_acc_top1
def save(model, model_path):
torch.save(model.state_dict(), model_path)
def load(model, model_path):
model.load_state_dict(torch.load(model_path))
def drop_path(x, drop_prob):
if drop_prob > 0.:
keep_prob = 1. - drop_prob
mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
x.div_(keep_prob)
x.mul_(mask)
return x
def create_exp_dir(path, scripts_to_save=None):
if not os.path.exists(path):
os.makedirs(path)
print('Experiment dir : {}'.format(path))
if scripts_to_save is not None:
os.mkdir(os.path.join(path, 'scripts'))
for script in scripts_to_save:
dst_file = os.path.join(path, 'scripts', os.path.basename(script))
shutil.copyfile(script, dst_file)
class CIFAR10(VisionDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
base_folder = 'cifar-10-batches-py'
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
#['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
]
test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'],
]
meta = {
'filename': 'batches.meta',
'key': 'label_names',
'md5': '5ff9c542aee3614f3951f8cda6e48888',
}
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False):
super(CIFAR10, self).__init__(root, transform=transform,
target_transform=target_transform)
self.train = train # training set or test set
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
if self.train:
downloaded_list = self.train_list
else:
downloaded_list = self.test_list
self.data = []
self.targets = []
# now load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
if sys.version_info[0] == 2:
entry = pickle.load(f)
else:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
if 'labels' in entry:
self.targets.extend(entry['labels'])
else:
self.targets.extend(entry['fine_labels'])
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
self._load_meta()
def _load_meta(self):
path = os.path.join(self.root, self.base_folder, self.meta['filename'])
if not utils.check_integrity(path, self.meta['md5']):
raise RuntimeError('Dataset metadata file not found or corrupted.' +
' You can use download=True to download it')
with open(path, 'rb') as infile:
if sys.version_info[0] == 2:
data = pickle.load(infile)
else:
data = pickle.load(infile, encoding='latin1')
self.classes = data[self.meta['key']]
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
def _check_integrity(self):
root = self.root
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename)
if not utils.check_integrity(fpath, md5):
return False
return True
def download(self):
if self._check_integrity():
print('Files already downloaded and verified')
return
utils.download_and_extract_archive(self.url, self.root,
filename=self.filename,
md5=self.tgz_md5)
def extra_repr(self):
return "Split: {}".format("Train" if self.train is True else "Test")
def pick_gpu_lowest_memory():
import gpustat
stats = gpustat.GPUStatCollection.new_query()
ids = map(lambda gpu: int(gpu.entry['index']), stats)
ratios = map(lambda gpu: float(gpu.memory_used)/float(gpu.memory_total), stats)
bestGPU = min(zip(ids, ratios), key=lambda x: x[1])[0]
return bestGPU
#### early stopping (from RobustNAS)
class EVLocalAvg(object):
def __init__(self, window=5, ev_freq=2, total_epochs=50):
""" Keep track of the eigenvalues local average.
Args:
window (int): number of elements used to compute local average.
Default: 5
ev_freq (int): frequency used to compute eigenvalues. Default:
every 2 epochs
total_epochs (int): total number of epochs that DARTS runs.
Default: 50
"""
self.window = window
self.ev_freq = ev_freq
self.epochs = total_epochs
self.stop_search = False
self.stop_epoch = total_epochs - 1
self.stop_genotype = None
self.stop_numparam = 0
self.ev = []
self.ev_local_avg = []
self.genotypes = {}
self.numparams = {}
self.la_epochs = {}
# start and end index of the local average window
self.la_start_idx = 0
self.la_end_idx = self.window
def reset(self):
self.ev = []
self.ev_local_avg = []
self.genotypes = {}
self.numparams = {}
self.la_epochs = {}
def update(self, epoch, ev, genotype, numparam=0):
""" Method to update the local average list.
Args:
epoch (int): current epoch
ev (float): current dominant eigenvalue
genotype (namedtuple): current genotype
"""
self.ev.append(ev)
self.genotypes.update({epoch: genotype})
self.numparams.update({epoch: numparam})
# set the stop_genotype to the current genotype in case the early stop
# procedure decides not to early stop
self.stop_genotype = genotype
# since the local average computation starts after the dominant
# eigenvalue in the first epoch is already computed we have to wait
# at least until we have 3 eigenvalues in the list.
if (len(self.ev) >= int(np.ceil(self.window/2))) and (epoch <
self.epochs - 1):
# start sliding the window as soon as the number of eigenvalues in
# the list becomes equal to the window size
if len(self.ev) < self.window:
self.ev_local_avg.append(np.mean(self.ev))
else:
assert len(self.ev[self.la_start_idx: self.la_end_idx]) == self.window
self.ev_local_avg.append(np.mean(self.ev[self.la_start_idx:
self.la_end_idx]))
self.la_start_idx += 1
self.la_end_idx += 1
# keep track of the offset between the current epoch and the epoch
# corresponding to the local average. NOTE: in the end the size of
# self.ev and self.ev_local_avg should be equal
self.la_epochs.update({epoch: int(epoch -
int(self.ev_freq*np.floor(self.window/2)))})
elif len(self.ev) < int(np.ceil(self.window/2)):
self.la_epochs.update({epoch: -1})
# since there is an offset between the current epoch and the local
# average epoch, loop in the last epoch to compute the local average of
# these number of elements: window, window - 1, window - 2, ..., ceil(window/2)
elif epoch == self.epochs - 1:
for i in range(int(np.ceil(self.window/2))):
assert len(self.ev[self.la_start_idx: self.la_end_idx]) == self.window - i
self.ev_local_avg.append(np.mean(self.ev[self.la_start_idx:
self.la_end_idx + 1]))
self.la_start_idx += 1
def early_stop(self, epoch, factor=1.3, es_start_epoch=10, delta=4, criteria='local_avg'):
""" Early stopping criterion
Args:
epoch (int): current epoch
factor (float): threshold factor for the ration between the current
and prefious eigenvalue. Default: 1.3
es_start_epoch (int): until this epoch do not consider early
stopping. Default: 20
delta (int): factor influencing which previous local average we
consider for early stopping. Default: 2
"""
if criteria == 'local_avg':
if int(self.la_epochs[epoch] - self.ev_freq*delta) >= es_start_epoch:
if criteria == 'local_avg':
current_la = self.ev_local_avg[-1]
previous_la = self.ev_local_avg[-1 - delta]
self.stop_search = current_la / previous_la > factor
if self.stop_search:
self.stop_epoch = int(self.la_epochs[epoch] - self.ev_freq*delta)
self.stop_genotype = self.genotypes[self.stop_epoch]
self.stop_numparam = self.numparams[self.stop_epoch]
elif criteria == 'exact':
if epoch > es_start_epoch:
current_la = self.ev[-1]
previous_la = self.ev[-1 - delta]
self.stop_search = current_la / previous_la > factor
if self.stop_search:
self.stop_epoch = epoch - delta
self.stop_genotype = self.genotypes[self.stop_epoch]
self.stop_numparam = self.numparams[self.stop_epoch]
else:
print('ERROR IN EARLY STOP: WRONG CRITERIA:', criteria); exit(0)
def gen_comb(eids):
comb = []
for r in range(len(eids)):
for c in range(r + 1, len(eids)):
comb.append((eids[r], eids[c]))
return comb