update
This commit is contained in:
parent
5a1dc89756
commit
2410fe9f5e
133
correlation/NAS-Bench-101.py
Normal file
133
correlation/NAS-Bench-101.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
from thop import profile
|
||||||
|
|
||||||
|
from foresight.models import *
|
||||||
|
from foresight.pruners import *
|
||||||
|
from foresight.dataset import *
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_classes(args):
|
||||||
|
return 100 if args.dataset == 'cifar100' else 10 if args.dataset == 'cifar10' else 120
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
parser = argparse.ArgumentParser(description='Zero-cost Metrics for NAS-Bench-101')
|
||||||
|
parser.add_argument('--api_loc', default='../data/nasbench_only108.tfrecord',
|
||||||
|
type=str, help='path to API')
|
||||||
|
parser.add_argument('--json_loc', default='data/all_graphs.json',
|
||||||
|
type=str, help='path to JSON database')
|
||||||
|
parser.add_argument('--outdir', default='./',
|
||||||
|
type=str, help='output directory')
|
||||||
|
parser.add_argument('--outfname', default='test',
|
||||||
|
type=str, help='output filename')
|
||||||
|
parser.add_argument('--batch_size', default=256, type=int)
|
||||||
|
parser.add_argument('--dataset', type=str, default='cifar10',
|
||||||
|
help='dataset to use [cifar10, cifar100, ImageNet16-120]')
|
||||||
|
parser.add_argument('--gpu', type=int, default=0, help='GPU index to work on')
|
||||||
|
parser.add_argument('--num_data_workers', type=int, default=2, help='number of workers for dataloaders')
|
||||||
|
parser.add_argument('--dataload', type=str, default='random', help='random or grasp supported')
|
||||||
|
parser.add_argument('--dataload_info', type=int, default=1,
|
||||||
|
help='number of batches to use for random dataload or number of samples per class for grasp dataload')
|
||||||
|
parser.add_argument('--start', type=int, default=5, help='start index')
|
||||||
|
parser.add_argument('--end', type=int, default=10, help='end index')
|
||||||
|
parser.add_argument('--write_freq', type=int, default=100, help='frequency of write to file')
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def get_op_names(v):
|
||||||
|
o = []
|
||||||
|
for op in v:
|
||||||
|
if op == -1:
|
||||||
|
o.append('input')
|
||||||
|
elif op == -2:
|
||||||
|
o.append('output')
|
||||||
|
elif op == 0:
|
||||||
|
o.append('conv3x3-bn-relu')
|
||||||
|
elif op == 1:
|
||||||
|
o.append('conv1x1-bn-relu')
|
||||||
|
elif op == 2:
|
||||||
|
o.append('maxpool3x3')
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_arguments()
|
||||||
|
# nasbench = api.NASBench(args.api_loc)
|
||||||
|
models = json.load(open(args.json_loc))
|
||||||
|
|
||||||
|
print(f'Running models {args.start} to {args.end} out of {len(models.keys())}')
|
||||||
|
|
||||||
|
train_loader, val_loader = get_cifar_dataloaders(args.batch_size, args.batch_size, args.dataset,
|
||||||
|
args.num_data_workers)
|
||||||
|
|
||||||
|
all_points = []
|
||||||
|
pre = 'cf' if 'cifar' in args.dataset else 'im'
|
||||||
|
|
||||||
|
if args.outfname == 'test':
|
||||||
|
fn = f'nb1_{pre}{get_num_classes(args)}.p'
|
||||||
|
else:
|
||||||
|
fn = f'{args.outfname}.p'
|
||||||
|
op = os.path.join(args.outdir, fn)
|
||||||
|
|
||||||
|
print('outfile =', op)
|
||||||
|
first = True
|
||||||
|
|
||||||
|
# loop over nasbench1 archs (k=hash, v=[adj_matrix, ops])
|
||||||
|
idx = 0
|
||||||
|
cached_res = []
|
||||||
|
for k, v in models.items():
|
||||||
|
|
||||||
|
if idx < args.start:
|
||||||
|
idx += 1
|
||||||
|
continue
|
||||||
|
if idx >= args.end:
|
||||||
|
break
|
||||||
|
print(f'idx = {idx}')
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
res = {}
|
||||||
|
res['hash'] = k
|
||||||
|
|
||||||
|
# model
|
||||||
|
spec = nasbench1_spec._ToModelSpec(v[0], get_op_names(v[1]))
|
||||||
|
net = nasbench1.Network(spec, stem_out=128, num_stacks=3, num_mods=3, num_classes=get_num_classes(args))
|
||||||
|
net.to(args.device)
|
||||||
|
|
||||||
|
measures = predictive.find_measures(net,
|
||||||
|
train_loader,
|
||||||
|
(args.dataload, args.dataload_info, get_num_classes(args)),
|
||||||
|
args.device)
|
||||||
|
res['logmeasures'] = measures
|
||||||
|
|
||||||
|
print(res)
|
||||||
|
cached_res.append(res)
|
||||||
|
|
||||||
|
# write to file
|
||||||
|
if idx % args.write_freq == 0 or idx == args.end or idx == args.start + 10:
|
||||||
|
print(f'writing {len(cached_res)} results to {op}')
|
||||||
|
pf = open(op, 'ab')
|
||||||
|
for cr in cached_res:
|
||||||
|
pickle.dump(cr, pf)
|
||||||
|
pf.close()
|
||||||
|
cached_res = []
|
128
correlation/NAS-Bench-201.py
Normal file
128
correlation/NAS-Bench-201.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from foresight.dataset import *
|
||||||
|
from foresight.models import nasbench2
|
||||||
|
from foresight.pruners import predictive
|
||||||
|
from foresight.weight_initializers import init_net
|
||||||
|
from models import get_cell_based_tiny_net
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_classes(args):
|
||||||
|
return 100 if args.dataset == 'cifar100' else 10 if args.dataset == 'cifar10' else 120
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
parser = argparse.ArgumentParser(description='Zero-cost Metrics for NAS-Bench-201')
|
||||||
|
parser.add_argument('--api_loc', default='../data/NAS-Bench-201-v1_0-e61699.pth',
|
||||||
|
type=str, help='path to API')
|
||||||
|
parser.add_argument('--outdir', default='./',
|
||||||
|
type=str, help='output directory')
|
||||||
|
parser.add_argument('--init_w_type', type=str, default='none',
|
||||||
|
help='weight initialization (before pruning) type [none, xavier, kaiming, zero, one]')
|
||||||
|
parser.add_argument('--init_b_type', type=str, default='none',
|
||||||
|
help='bias initialization (before pruning) type [none, xavier, kaiming, zero, one]')
|
||||||
|
parser.add_argument('--batch_size', default=64, type=int)
|
||||||
|
parser.add_argument('--dataset', type=str, default='ImageNet16-120',
|
||||||
|
help='dataset to use [cifar10, cifar100, ImageNet16-120]')
|
||||||
|
parser.add_argument('--gpu', type=int, default=5, help='GPU index to work on')
|
||||||
|
parser.add_argument('--data_size', type=int, default=32, help='data_size')
|
||||||
|
parser.add_argument('--num_data_workers', type=int, default=2, help='number of workers for dataloaders')
|
||||||
|
parser.add_argument('--dataload', type=str, default='appoint', help='random, grasp, appoint supported')
|
||||||
|
parser.add_argument('--dataload_info', type=int, default=1,
|
||||||
|
help='number of batches to use for random dataload or number of samples per class for grasp dataload')
|
||||||
|
parser.add_argument('--seed', type=int, default=42, help='pytorch manual seed')
|
||||||
|
parser.add_argument('--write_freq', type=int, default=1, help='frequency of write to file')
|
||||||
|
parser.add_argument('--start', type=int, default=0, help='start index')
|
||||||
|
parser.add_argument('--end', type=int, default=0, help='end index')
|
||||||
|
parser.add_argument('--noacc', default=False, action='store_true',
|
||||||
|
help='avoid loading NASBench2 api an instead load a pickle file with tuple (index, arch_str)')
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_arguments()
|
||||||
|
print(args.device)
|
||||||
|
|
||||||
|
if args.noacc:
|
||||||
|
api = pickle.load(open(args.api_loc,'rb'))
|
||||||
|
else:
|
||||||
|
from nas_201_api import NASBench201API as API
|
||||||
|
api = API(args.api_loc)
|
||||||
|
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
train_loader, val_loader = get_cifar_dataloaders(args.batch_size, args.batch_size, args.dataset, args.num_data_workers, resize=args.data_size)
|
||||||
|
x, y = next(iter(train_loader))
|
||||||
|
# random data
|
||||||
|
# x = torch.rand((args.batch_size, 3, args.data_size, args.data_size))
|
||||||
|
# y = 0
|
||||||
|
|
||||||
|
cached_res = []
|
||||||
|
pre = 'cf' if 'cifar' in args.dataset else 'im'
|
||||||
|
pfn = f'nb2_{args.search_space}_{pre}{get_num_classes(args)}_seed{args.seed}_dl{args.dataload}_dlinfo{args.dataload_info}_initw{args.init_w_type}_initb{args.init_b_type}_{args.batch_size}.p'
|
||||||
|
op = os.path.join(args.outdir, pfn)
|
||||||
|
|
||||||
|
end = len(api) if args.end == 0 else args.end
|
||||||
|
|
||||||
|
# loop over nasbench2 archs
|
||||||
|
for i, arch_str in enumerate(api):
|
||||||
|
|
||||||
|
if i < args.start:
|
||||||
|
continue
|
||||||
|
if i >= end:
|
||||||
|
break
|
||||||
|
|
||||||
|
res = {'i': i, 'arch': arch_str}
|
||||||
|
# print(arch_str)
|
||||||
|
if args.search_space == 'tss':
|
||||||
|
net = nasbench2.get_model_from_arch_str(arch_str, get_num_classes(args))
|
||||||
|
arch_str2 = nasbench2.get_arch_str_from_model(net)
|
||||||
|
if arch_str != arch_str2:
|
||||||
|
print(arch_str)
|
||||||
|
print(arch_str2)
|
||||||
|
raise ValueError
|
||||||
|
elif args.search_space == 'sss':
|
||||||
|
config = api.get_net_config(i, args.dataset)
|
||||||
|
# print(config)
|
||||||
|
net = get_cell_based_tiny_net(config)
|
||||||
|
net.to(args.device)
|
||||||
|
# print(net)
|
||||||
|
|
||||||
|
init_net(net, args.init_w_type, args.init_b_type)
|
||||||
|
|
||||||
|
# print(x.size(), y)
|
||||||
|
measures = get_score(net, x, i, args.device)
|
||||||
|
|
||||||
|
res['meco'] = measures
|
||||||
|
|
||||||
|
if not args.noacc:
|
||||||
|
info = api.get_more_info(i, 'cifar10-valid' if args.dataset == 'cifar10' else args.dataset, iepoch=None,
|
||||||
|
hp='200', is_random=False)
|
||||||
|
|
||||||
|
trainacc = info['train-accuracy']
|
||||||
|
valacc = info['valid-accuracy']
|
||||||
|
testacc = info['test-accuracy']
|
||||||
|
|
||||||
|
res['trainacc'] = trainacc
|
||||||
|
res['valacc'] = valacc
|
||||||
|
res['testacc'] = testacc
|
||||||
|
|
||||||
|
print(res)
|
||||||
|
cached_res.append(res)
|
||||||
|
|
||||||
|
# write to file
|
||||||
|
if i % args.write_freq == 0 or i == len(api) - 1 or i == 10:
|
||||||
|
print(f'writing {len(cached_res)} results to {op}')
|
||||||
|
pf = open(op, 'ab')
|
||||||
|
for cr in cached_res:
|
||||||
|
pickle.dump(cr, pf)
|
||||||
|
pf.close()
|
||||||
|
cached_res = []
|
38
exp_scripts/zerocostpt_nb201_pipeline.sh
Normal file
38
exp_scripts/zerocostpt_nb201_pipeline.sh
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
script_name=`basename "$0"`
|
||||||
|
id=${script_name%.*}
|
||||||
|
dataset=${dataset:-cifar10}
|
||||||
|
seed=${seed:-0}
|
||||||
|
gpu=${gpu:-"auto"}
|
||||||
|
pool_size=${pool_size:-10}
|
||||||
|
batch_size=${batch_size:-256}
|
||||||
|
edge_decision=${edge_decision:-'random'}
|
||||||
|
validate_rounds=${validate_rounds:-100}
|
||||||
|
metric=${metric:-'jacob'}
|
||||||
|
while [ $# -gt 0 ]; do
|
||||||
|
if [[ $1 == *"--"* ]]; then
|
||||||
|
param="${1/--/}"
|
||||||
|
declare $param="$2"
|
||||||
|
# echo $1 $2 // Optional to see the parameter:value result
|
||||||
|
fi
|
||||||
|
shift
|
||||||
|
done
|
||||||
|
|
||||||
|
echo 'id:' $id 'seed:' $seed 'dataset:' $dataset
|
||||||
|
echo 'gpu:' $gpu
|
||||||
|
|
||||||
|
cd ../nasbench201/
|
||||||
|
python3 networks_proposal.py \
|
||||||
|
--dataset $dataset \
|
||||||
|
--save $id --gpu $gpu --seed $seed \
|
||||||
|
--edge_decision $edge_decision --proj_crit $metric \
|
||||||
|
--batch_size $batch_size\
|
||||||
|
--pool_size $pool_size \
|
||||||
|
|
||||||
|
cd ../zerocostnas/
|
||||||
|
python3 post_validate.py\
|
||||||
|
--ckpt_path ../experiments/nas-bench-201/prop-$id-$seed-$pool_size-$metric\
|
||||||
|
--save $id --seed $seed --gpu $gpu\
|
||||||
|
--edge_decision $edge_decision --proj_crit $metric \
|
||||||
|
--batch_size $batch_size\
|
||||||
|
--validate_rounds $validate_rounds\
|
110
nasbench201/DownsampledImageNet.py
Normal file
110
nasbench201/DownsampledImageNet.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import os, sys, hashlib, torch
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import torch.utils.data as data
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_md5(fpath, chunk_size=1024 * 1024):
|
||||||
|
md5 = hashlib.md5()
|
||||||
|
with open(fpath, 'rb') as f:
|
||||||
|
for chunk in iter(lambda: f.read(chunk_size), b''):
|
||||||
|
md5.update(chunk)
|
||||||
|
return md5.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def check_md5(fpath, md5, **kwargs):
|
||||||
|
return md5 == calculate_md5(fpath, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def check_integrity(fpath, md5=None):
|
||||||
|
print(fpath)
|
||||||
|
if not os.path.isfile(fpath): return False
|
||||||
|
if md5 is None: return True
|
||||||
|
else : return check_md5(fpath, md5)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNet16(data.Dataset):
|
||||||
|
# http://image-net.org/download-images
|
||||||
|
# A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets
|
||||||
|
# https://arxiv.org/pdf/1707.08819.pdf
|
||||||
|
|
||||||
|
train_list = [
|
||||||
|
['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'],
|
||||||
|
['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'],
|
||||||
|
['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'],
|
||||||
|
['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'],
|
||||||
|
['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'],
|
||||||
|
['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'],
|
||||||
|
['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'],
|
||||||
|
['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'],
|
||||||
|
['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'],
|
||||||
|
['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'],
|
||||||
|
]
|
||||||
|
valid_list = [
|
||||||
|
['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'],
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, root, train, transform, use_num_of_class_only=None):
|
||||||
|
self.root = root
|
||||||
|
self.transform = transform
|
||||||
|
self.train = train # training set or valid set
|
||||||
|
if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.')
|
||||||
|
|
||||||
|
if self.train: downloaded_list = self.train_list
|
||||||
|
else : downloaded_list = self.valid_list
|
||||||
|
self.data = []
|
||||||
|
self.targets = []
|
||||||
|
|
||||||
|
# now load the picked numpy arrays
|
||||||
|
for i, (file_name, checksum) in enumerate(downloaded_list):
|
||||||
|
file_path = os.path.join(self.root, file_name)
|
||||||
|
#print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path))
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
if sys.version_info[0] == 2:
|
||||||
|
entry = pickle.load(f)
|
||||||
|
else:
|
||||||
|
entry = pickle.load(f, encoding='latin1')
|
||||||
|
self.data.append(entry['data'])
|
||||||
|
self.targets.extend(entry['labels'])
|
||||||
|
self.data = np.vstack(self.data).reshape(-1, 3, 16, 16)
|
||||||
|
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
|
||||||
|
if use_num_of_class_only is not None:
|
||||||
|
assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only)
|
||||||
|
new_data, new_targets = [], []
|
||||||
|
for I, L in zip(self.data, self.targets):
|
||||||
|
if 1 <= L <= use_num_of_class_only:
|
||||||
|
new_data.append( I )
|
||||||
|
new_targets.append( L )
|
||||||
|
self.data = new_data
|
||||||
|
self.targets = new_targets
|
||||||
|
# self.mean.append(entry['mean'])
|
||||||
|
#self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16)
|
||||||
|
#self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1)
|
||||||
|
#print ('Mean : {:}'.format(self.mean))
|
||||||
|
#temp = self.data - np.reshape(self.mean, (1, 1, 1, 3))
|
||||||
|
#std_data = np.std(temp, axis=0)
|
||||||
|
#std_data = np.mean(np.mean(std_data, axis=0), axis=0)
|
||||||
|
#print ('Std : {:}'.format(std_data))
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
img, target = self.data[index], self.targets[index] - 1
|
||||||
|
|
||||||
|
img = Image.fromarray(img)
|
||||||
|
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(img)
|
||||||
|
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def _check_integrity(self):
|
||||||
|
root = self.root
|
||||||
|
for fentry in (self.train_list + self.valid_list):
|
||||||
|
filename, md5 = fentry[0], fentry[1]
|
||||||
|
fpath = os.path.join(root, filename)
|
||||||
|
if not check_integrity(fpath, md5):
|
||||||
|
return False
|
||||||
|
return True
|
52
nasbench201/architect_ig.py
Normal file
52
nasbench201/architect_ig.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Architect(object):
|
||||||
|
def __init__(self, model, args):
|
||||||
|
self.network_momentum = args.momentum
|
||||||
|
self.network_weight_decay = args.weight_decay
|
||||||
|
self.model = model
|
||||||
|
self.optimizer = torch.optim.Adam(self.model.arch_parameters(),
|
||||||
|
lr=args.arch_learning_rate, betas=(0.5, 0.999),
|
||||||
|
weight_decay=args.arch_weight_decay)
|
||||||
|
|
||||||
|
self._init_arch_parameters = []
|
||||||
|
for alpha in self.model.arch_parameters():
|
||||||
|
alpha_init = torch.zeros_like(alpha)
|
||||||
|
alpha_init.data.copy_(alpha)
|
||||||
|
self._init_arch_parameters.append(alpha_init)
|
||||||
|
|
||||||
|
#### mode
|
||||||
|
if args.method in ['darts', 'darts-proj', 'sdarts', 'sdarts-proj']:
|
||||||
|
self.method = 'fo' # first order update
|
||||||
|
elif 'so' in args.method:
|
||||||
|
print('ERROR: PLEASE USE architect.py for second order darts')
|
||||||
|
elif args.method in ['blank', 'blank-proj']:
|
||||||
|
self.method = 'blank'
|
||||||
|
else:
|
||||||
|
print('ERROR: WRONG ARCH UPDATE METHOD', args.method); exit(0)
|
||||||
|
|
||||||
|
def reset_arch_parameters(self):
|
||||||
|
for alpha, alpha_init in zip(self.model.arch_parameters(), self._init_arch_parameters):
|
||||||
|
alpha.data.copy_(alpha_init.data)
|
||||||
|
|
||||||
|
def step(self, input_train, target_train, input_valid, target_valid, *args, **kwargs):
|
||||||
|
if self.method == 'fo':
|
||||||
|
shared = self._step_fo(input_train, target_train, input_valid, target_valid)
|
||||||
|
elif self.method == 'so':
|
||||||
|
raise NotImplementedError
|
||||||
|
elif self.method == 'blank': ## do not update alpha
|
||||||
|
shared = None
|
||||||
|
|
||||||
|
return shared
|
||||||
|
|
||||||
|
#### first order
|
||||||
|
def _step_fo(self, input_train, target_train, input_valid, target_valid):
|
||||||
|
loss = self.model._loss(input_valid, target_valid)
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
return None
|
||||||
|
|
||||||
|
#### darts 2nd order
|
||||||
|
def _step_darts_so(self, input_train, target_train, input_valid, target_valid, eta, model_optimizer):
|
||||||
|
raise NotImplementedError
|
120
nasbench201/cell_infers/cells.py
Normal file
120
nasbench201/cell_infers/cells.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||||
|
#####################################################
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from copy import deepcopy
|
||||||
|
from ..cell_operations import OPS
|
||||||
|
|
||||||
|
|
||||||
|
# Cell for NAS-Bench-201
|
||||||
|
class InferCell(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, genotype, C_in, C_out, stride):
|
||||||
|
super(InferCell, self).__init__()
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
self.node_IN = []
|
||||||
|
self.node_IX = []
|
||||||
|
self.genotype = deepcopy(genotype)
|
||||||
|
for i in range(1, len(genotype)):
|
||||||
|
node_info = genotype[i-1]
|
||||||
|
cur_index = []
|
||||||
|
cur_innod = []
|
||||||
|
for (op_name, op_in) in node_info:
|
||||||
|
if op_in == 0:
|
||||||
|
layer = OPS[op_name](C_in , C_out, stride, True, True)
|
||||||
|
else:
|
||||||
|
layer = OPS[op_name](C_out, C_out, 1, True, True)
|
||||||
|
cur_index.append( len(self.layers) )
|
||||||
|
cur_innod.append( op_in )
|
||||||
|
self.layers.append( layer )
|
||||||
|
self.node_IX.append( cur_index )
|
||||||
|
self.node_IN.append( cur_innod )
|
||||||
|
self.nodes = len(genotype)
|
||||||
|
self.in_dim = C_in
|
||||||
|
self.out_dim = C_out
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
|
||||||
|
laystr = []
|
||||||
|
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
|
||||||
|
y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)]
|
||||||
|
x = '{:}<-({:})'.format(i+1, ','.join(y))
|
||||||
|
laystr.append( x )
|
||||||
|
return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr())
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
nodes = [inputs]
|
||||||
|
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
|
||||||
|
node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) )
|
||||||
|
nodes.append( node_feature )
|
||||||
|
return nodes[-1]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||||
|
class NASNetInferCell(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats):
|
||||||
|
super(NASNetInferCell, self).__init__()
|
||||||
|
self.reduction = reduction
|
||||||
|
if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats)
|
||||||
|
else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats)
|
||||||
|
self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats)
|
||||||
|
|
||||||
|
if not reduction:
|
||||||
|
nodes, concats = genotype['normal'], genotype['normal_concat']
|
||||||
|
else:
|
||||||
|
nodes, concats = genotype['reduce'], genotype['reduce_concat']
|
||||||
|
self._multiplier = len(concats)
|
||||||
|
self._concats = concats
|
||||||
|
self._steps = len(nodes)
|
||||||
|
self._nodes = nodes
|
||||||
|
self.edges = nn.ModuleDict()
|
||||||
|
for i, node in enumerate(nodes):
|
||||||
|
for in_node in node:
|
||||||
|
name, j = in_node[0], in_node[1]
|
||||||
|
stride = 2 if reduction and j < 2 else 1
|
||||||
|
node_str = '{:}<-{:}'.format(i+2, j)
|
||||||
|
self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats)
|
||||||
|
|
||||||
|
# [TODO] to support drop_prob in this function..
|
||||||
|
def forward(self, s0, s1, unused_drop_prob):
|
||||||
|
s0 = self.preprocess0(s0)
|
||||||
|
s1 = self.preprocess1(s1)
|
||||||
|
|
||||||
|
states = [s0, s1]
|
||||||
|
for i, node in enumerate(self._nodes):
|
||||||
|
clist = []
|
||||||
|
for in_node in node:
|
||||||
|
name, j = in_node[0], in_node[1]
|
||||||
|
node_str = '{:}<-{:}'.format(i+2, j)
|
||||||
|
op = self.edges[ node_str ]
|
||||||
|
clist.append( op(states[j]) )
|
||||||
|
states.append( sum(clist) )
|
||||||
|
return torch.cat([states[x] for x in self._concats], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class AuxiliaryHeadCIFAR(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, C, num_classes):
|
||||||
|
"""assuming input size 8x8"""
|
||||||
|
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||||
|
self.features = nn.Sequential(
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
|
||||||
|
nn.Conv2d(C, 128, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(128),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(128, 768, 2, bias=False),
|
||||||
|
nn.BatchNorm2d(768),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
)
|
||||||
|
self.classifier = nn.Linear(768, num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.features(x)
|
||||||
|
x = self.classifier(x.view(x.size(0),-1))
|
||||||
|
return x
|
82
nasbench201/cell_infers/tiny_network.py
Normal file
82
nasbench201/cell_infers/tiny_network.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||||
|
#####################################################
|
||||||
|
import torch.nn as nn
|
||||||
|
from ..cell_operations import ResNetBasicblock
|
||||||
|
from .cells import InferCell
|
||||||
|
|
||||||
|
|
||||||
|
# The macro structure for architectures in NAS-Bench-201
|
||||||
|
class TinyNetwork(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, C, N, genotype, num_classes):
|
||||||
|
super(TinyNetwork, self).__init__()
|
||||||
|
self._C = C
|
||||||
|
self._layerN = N
|
||||||
|
|
||||||
|
self.stem = nn.Sequential(
|
||||||
|
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(C))
|
||||||
|
|
||||||
|
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||||
|
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||||
|
|
||||||
|
C_prev = C
|
||||||
|
self.cells = nn.ModuleList()
|
||||||
|
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||||
|
if reduction:
|
||||||
|
cell = ResNetBasicblock(C_prev, C_curr, 2, True)
|
||||||
|
else:
|
||||||
|
cell = InferCell(genotype, C_prev, C_curr, 1)
|
||||||
|
self.cells.append( cell )
|
||||||
|
C_prev = cell.out_dim
|
||||||
|
self._Layer= len(self.cells)
|
||||||
|
|
||||||
|
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||||
|
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.classifier = nn.Linear(C_prev, num_classes)
|
||||||
|
|
||||||
|
self.requires_feature = True
|
||||||
|
|
||||||
|
def get_message(self):
|
||||||
|
string = self.extra_repr()
|
||||||
|
for i, cell in enumerate(self.cells):
|
||||||
|
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
||||||
|
return string
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
feature = self.stem(inputs)
|
||||||
|
for i, cell in enumerate(self.cells):
|
||||||
|
feature = cell(feature)
|
||||||
|
|
||||||
|
out = self.lastact(feature)
|
||||||
|
out = self.global_pooling( out )
|
||||||
|
out = out.view(out.size(0), -1)
|
||||||
|
logits = self.classifier(out)
|
||||||
|
|
||||||
|
if self.requires_feature:
|
||||||
|
return logits, out
|
||||||
|
else:
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def _loss(self, input, target, return_logits=False):
|
||||||
|
logits, _ = self(input)
|
||||||
|
loss = self._criterion(logits, target)
|
||||||
|
|
||||||
|
return (loss, logits) if return_logits else loss
|
||||||
|
|
||||||
|
def step(self, input, target, args, shared=None, return_grad=False):
|
||||||
|
Lt, logit_t = self._loss(input, target, return_logits=True)
|
||||||
|
Lt.backward()
|
||||||
|
if args.grad_clip != 0:
|
||||||
|
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
if return_grad:
|
||||||
|
grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()])
|
||||||
|
return logit_t, Lt, grad
|
||||||
|
else:
|
||||||
|
return logit_t, Lt
|
289
nasbench201/cell_operations.py
Normal file
289
nasbench201/cell_operations.py
Normal file
@ -0,0 +1,289 @@
|
|||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
sys.path.insert(0, '../')
|
||||||
|
from Layers import layers
|
||||||
|
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
|
||||||
|
|
||||||
|
OPS = {
|
||||||
|
'noise' : lambda C_in, C_out, stride, affine, track_running_stats: NoiseOp(stride, 0., 1.), # C_in, C_out not needed
|
||||||
|
'none' : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride),
|
||||||
|
'avg_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'avg', affine, track_running_stats),
|
||||||
|
'max_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'max', affine, track_running_stats),
|
||||||
|
'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats),
|
||||||
|
'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
|
||||||
|
'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats),
|
||||||
|
'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
|
||||||
|
'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats),
|
||||||
|
'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats),
|
||||||
|
'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats),
|
||||||
|
'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
|
||||||
|
}
|
||||||
|
|
||||||
|
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
|
||||||
|
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||||
|
DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3']
|
||||||
|
#### wrc modified
|
||||||
|
NAS_BENCH_201_SKIP = ['none', 'skip_connect', 'nor_conv_1x1_skip', 'nor_conv_3x3_skip', 'avg_pool_3x3']
|
||||||
|
NAS_BENCH_201_SIMPLE = ['skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||||
|
NAS_BENCH_201_S2 = ['skip_connect', 'nor_conv_3x3']
|
||||||
|
NAS_BENCH_201_S4 = ['noise', 'nor_conv_3x3']
|
||||||
|
NAS_BENCH_201_S10 = ['none', 'nor_conv_3x3']
|
||||||
|
|
||||||
|
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
|
||||||
|
'nas-bench-201': NAS_BENCH_201,
|
||||||
|
'nas-bench-201-simple': NAS_BENCH_201_SIMPLE,
|
||||||
|
'nas-bench-201-s2': NAS_BENCH_201_S2,
|
||||||
|
'nas-bench-201-s4': NAS_BENCH_201_S4,
|
||||||
|
'nas-bench-201-s10': NAS_BENCH_201_S10,
|
||||||
|
'darts' : DARTS_SPACE}
|
||||||
|
|
||||||
|
class NoiseOp(nn.Module):
|
||||||
|
def __init__(self, stride, mean, std):
|
||||||
|
super(NoiseOp, self).__init__()
|
||||||
|
self.stride = stride
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
|
||||||
|
def forward(self, x, block_input=False):
|
||||||
|
if block_input:
|
||||||
|
x = x * 0
|
||||||
|
if self.stride != 1:
|
||||||
|
x_new = x[:,:,::self.stride,::self.stride]
|
||||||
|
else:
|
||||||
|
x_new = x
|
||||||
|
noise = x_new.data.new(x_new.size()).normal_(self.mean, self.std)
|
||||||
|
return noise
|
||||||
|
|
||||||
|
class ReLUConvBN(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||||
|
super(ReLUConvBN, self).__init__()
|
||||||
|
self.op = nn.Sequential(
|
||||||
|
nn.ReLU(inplace=False),
|
||||||
|
layers.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False),
|
||||||
|
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, block_input=False):
|
||||||
|
if block_input:
|
||||||
|
x = x * 0
|
||||||
|
return self.op(x)
|
||||||
|
|
||||||
|
def score(self):
|
||||||
|
score = 0
|
||||||
|
for l in self.op:
|
||||||
|
if hasattr(l, 'score'):
|
||||||
|
score += torch.sum(l.score).cpu().numpy()
|
||||||
|
return score
|
||||||
|
|
||||||
|
#### wrc modified
|
||||||
|
class ReLUConvBNSkip(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||||
|
super(ReLUConvBNSkip, self).__init__()
|
||||||
|
self.op = nn.Sequential(
|
||||||
|
nn.ReLU(inplace=False),
|
||||||
|
layers.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False),
|
||||||
|
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, block_input=False):
|
||||||
|
if block_input:
|
||||||
|
x = x * 0
|
||||||
|
return self.op(x) + x
|
||||||
|
|
||||||
|
def score(self):
|
||||||
|
score = 0
|
||||||
|
for l in self.op:
|
||||||
|
if hasattr(l, 'score'):
|
||||||
|
score += torch.sum(l.score).cpu().numpy()
|
||||||
|
return score
|
||||||
|
####
|
||||||
|
|
||||||
|
class SepConv(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||||
|
super(SepConv, self).__init__()
|
||||||
|
self.op = nn.Sequential(
|
||||||
|
nn.ReLU(inplace=False),
|
||||||
|
layers.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
|
||||||
|
layers.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||||
|
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, block_input=False):
|
||||||
|
if block_input:
|
||||||
|
x = x * 0
|
||||||
|
return self.op(x)
|
||||||
|
|
||||||
|
def score(self):
|
||||||
|
score = 0
|
||||||
|
for l in self.op:
|
||||||
|
if hasattr(l, 'score'):
|
||||||
|
score += torch.sum(l.score).cpu().numpy()
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
class DualSepConv(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||||
|
super(DualSepConv, self).__init__()
|
||||||
|
self.op_a = SepConv(C_in, C_in , kernel_size, stride, padding, dilation, affine, track_running_stats)
|
||||||
|
self.op_b = SepConv(C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats)
|
||||||
|
|
||||||
|
def forward(self, x, block_input=False):
|
||||||
|
if block_input:
|
||||||
|
x = x * 0
|
||||||
|
x = self.op_a(x)
|
||||||
|
x = self.op_b(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def score(self):
|
||||||
|
score = self.op_a.score() + self.op_b.score()
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetBasicblock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride, affine=True):
|
||||||
|
super(ResNetBasicblock, self).__init__()
|
||||||
|
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||||
|
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine)
|
||||||
|
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine)
|
||||||
|
if stride == 2:
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
|
||||||
|
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
|
||||||
|
elif inplanes != planes:
|
||||||
|
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine)
|
||||||
|
else:
|
||||||
|
self.downsample = None
|
||||||
|
self.in_dim = inplanes
|
||||||
|
self.out_dim = planes
|
||||||
|
self.stride = stride
|
||||||
|
self.num_conv = 2
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
string = '{name}(inC={in_dim}, outC={out_dim}, stride={stride})'.format(name=self.__class__.__name__, **self.__dict__)
|
||||||
|
return string
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
basicblock = self.conv_a(inputs)
|
||||||
|
basicblock = self.conv_b(basicblock)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(inputs)
|
||||||
|
else:
|
||||||
|
residual = inputs
|
||||||
|
return residual + basicblock
|
||||||
|
|
||||||
|
def score(self):
|
||||||
|
return self.conv_a.score() + self.conv_b.score()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class POOLING(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True):
|
||||||
|
super(POOLING, self).__init__()
|
||||||
|
if C_in == C_out:
|
||||||
|
self.preprocess = None
|
||||||
|
else:
|
||||||
|
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, affine, track_running_stats)
|
||||||
|
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
|
||||||
|
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
|
||||||
|
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
|
||||||
|
|
||||||
|
def forward(self, inputs, block_input=False):
|
||||||
|
if block_input:
|
||||||
|
inputs = inputs * 0
|
||||||
|
if self.preprocess: x = self.preprocess(inputs)
|
||||||
|
else : x = inputs
|
||||||
|
return self.op(x)
|
||||||
|
|
||||||
|
def score(self):
|
||||||
|
if self.preprocess :
|
||||||
|
return self.preprocess.score()
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(Identity, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x, block_input=False):
|
||||||
|
if block_input:
|
||||||
|
x = x * 0
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Zero(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, C_in, C_out, stride):
|
||||||
|
super(Zero, self).__init__()
|
||||||
|
self.C_in = C_in
|
||||||
|
self.C_out = C_out
|
||||||
|
self.stride = stride
|
||||||
|
self.is_zero = True
|
||||||
|
|
||||||
|
def forward(self, x, block_input=False):
|
||||||
|
if block_input:
|
||||||
|
x = x*0
|
||||||
|
if self.C_in == self.C_out:
|
||||||
|
if self.stride == 1: return x.mul(0.)
|
||||||
|
else : return x[:,:,::self.stride,::self.stride].mul(0.)
|
||||||
|
else: ## this is never called in nasbench201
|
||||||
|
shape = list(x.shape)
|
||||||
|
shape[1] = self.C_out
|
||||||
|
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
|
||||||
|
return zeros
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
|
||||||
|
|
||||||
|
|
||||||
|
class FactorizedReduce(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, C_in, C_out, stride, affine, track_running_stats):
|
||||||
|
super(FactorizedReduce, self).__init__()
|
||||||
|
self.stride = stride
|
||||||
|
self.C_in = C_in
|
||||||
|
self.C_out = C_out
|
||||||
|
self.relu = nn.ReLU(inplace=False)
|
||||||
|
if stride == 2:
|
||||||
|
#assert C_out % 2 == 0, 'C_out : {:}'.format(C_out)
|
||||||
|
C_outs = [C_out // 2, C_out - C_out // 2]
|
||||||
|
self.convs = nn.ModuleList()
|
||||||
|
for i in range(2):
|
||||||
|
self.convs.append(layers.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) )
|
||||||
|
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||||
|
elif stride == 1:
|
||||||
|
self.conv = layers.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False)
|
||||||
|
else:
|
||||||
|
raise ValueError('Invalid stride : {:}'.format(stride))
|
||||||
|
self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
|
||||||
|
|
||||||
|
def forward(self, x, block_input=False):
|
||||||
|
if block_input:
|
||||||
|
x = x * 0
|
||||||
|
if self.stride == 2:
|
||||||
|
x = self.relu(x)
|
||||||
|
y = self.pad(x)
|
||||||
|
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1)
|
||||||
|
else:
|
||||||
|
out = self.conv(x)
|
||||||
|
out = self.bn(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
|
||||||
|
|
||||||
|
def score(self):
|
||||||
|
if self.stride == 1:
|
||||||
|
return self.conv.score()
|
||||||
|
else:
|
||||||
|
return self.convs[0].score()+self.convs[1].score()
|
194
nasbench201/genotypes.py
Normal file
194
nasbench201/genotypes.py
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
|
def get_combination(space, num):
|
||||||
|
combs = []
|
||||||
|
for i in range(num):
|
||||||
|
if i == 0:
|
||||||
|
for func in space:
|
||||||
|
combs.append( [(func, i)] )
|
||||||
|
else:
|
||||||
|
new_combs = []
|
||||||
|
for string in combs:
|
||||||
|
for func in space:
|
||||||
|
xstring = string + [(func, i)]
|
||||||
|
new_combs.append( xstring )
|
||||||
|
combs = new_combs
|
||||||
|
return combs
|
||||||
|
|
||||||
|
|
||||||
|
class Structure:
|
||||||
|
|
||||||
|
def __init__(self, genotype):
|
||||||
|
assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype))
|
||||||
|
self.node_num = len(genotype) + 1
|
||||||
|
self.nodes = []
|
||||||
|
self.node_N = []
|
||||||
|
for idx, node_info in enumerate(genotype):
|
||||||
|
assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info))
|
||||||
|
assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info))
|
||||||
|
for node_in in node_info:
|
||||||
|
assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in))
|
||||||
|
assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in)
|
||||||
|
self.node_N.append( len(node_info) )
|
||||||
|
self.nodes.append( tuple(deepcopy(node_info)) )
|
||||||
|
|
||||||
|
def tolist(self, remove_str):
|
||||||
|
# convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
|
||||||
|
# note that we re-order the input node in this function
|
||||||
|
# return the-genotype-list and success [if unsuccess, it is not a connectivity]
|
||||||
|
genotypes = []
|
||||||
|
for node_info in self.nodes:
|
||||||
|
node_info = list( node_info )
|
||||||
|
node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
|
||||||
|
node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
|
||||||
|
if len(node_info) == 0: return None, False
|
||||||
|
genotypes.append( node_info )
|
||||||
|
return genotypes, True
|
||||||
|
|
||||||
|
def node(self, index):
|
||||||
|
assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self))
|
||||||
|
return self.nodes[index]
|
||||||
|
|
||||||
|
def tostr(self):
|
||||||
|
strings = []
|
||||||
|
for node_info in self.nodes:
|
||||||
|
string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info])
|
||||||
|
string = '|{:}|'.format(string)
|
||||||
|
strings.append( string )
|
||||||
|
return '+'.join(strings)
|
||||||
|
|
||||||
|
def check_valid(self):
|
||||||
|
nodes = {0: True}
|
||||||
|
for i, node_info in enumerate(self.nodes):
|
||||||
|
sums = []
|
||||||
|
for op, xin in node_info:
|
||||||
|
if op == 'none' or nodes[xin] is False: x = False
|
||||||
|
else: x = True
|
||||||
|
sums.append( x )
|
||||||
|
nodes[i+1] = sum(sums) > 0
|
||||||
|
return nodes[len(self.nodes)]
|
||||||
|
|
||||||
|
def to_unique_str(self, consider_zero=False):
|
||||||
|
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
|
||||||
|
# two operations are special, i.e., none and skip_connect
|
||||||
|
nodes = {0: '0'}
|
||||||
|
for i_node, node_info in enumerate(self.nodes):
|
||||||
|
cur_node = []
|
||||||
|
for op, xin in node_info:
|
||||||
|
if consider_zero is None:
|
||||||
|
x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||||
|
elif consider_zero:
|
||||||
|
if op == 'none' or nodes[xin] == '#': x = '#' # zero
|
||||||
|
elif op == 'skip_connect': x = nodes[xin]
|
||||||
|
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||||
|
else:
|
||||||
|
if op == 'skip_connect': x = nodes[xin]
|
||||||
|
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||||
|
cur_node.append(x)
|
||||||
|
nodes[i_node+1] = '+'.join( sorted(cur_node) )
|
||||||
|
return nodes[ len(self.nodes) ]
|
||||||
|
|
||||||
|
def check_valid_op(self, op_names):
|
||||||
|
for node_info in self.nodes:
|
||||||
|
for inode_edge in node_info:
|
||||||
|
#assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
|
||||||
|
if inode_edge[0] not in op_names: return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.nodes) + 1
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.nodes[index]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def str2structure(xstr):
|
||||||
|
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
|
||||||
|
nodestrs = xstr.split('+')
|
||||||
|
genotypes = []
|
||||||
|
for i, node_str in enumerate(nodestrs):
|
||||||
|
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||||
|
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||||
|
inputs = ( xi.split('~') for xi in inputs )
|
||||||
|
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
|
||||||
|
genotypes.append( input_infos )
|
||||||
|
return Structure( genotypes )
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def str2fullstructure(xstr, default_name='none'):
|
||||||
|
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
|
||||||
|
nodestrs = xstr.split('+')
|
||||||
|
genotypes = []
|
||||||
|
for i, node_str in enumerate(nodestrs):
|
||||||
|
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||||
|
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||||
|
inputs = ( xi.split('~') for xi in inputs )
|
||||||
|
input_infos = list( (op, int(IDX)) for (op, IDX) in inputs)
|
||||||
|
all_in_nodes= list(x[1] for x in input_infos)
|
||||||
|
for j in range(i):
|
||||||
|
if j not in all_in_nodes: input_infos.append((default_name, j))
|
||||||
|
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
|
||||||
|
genotypes.append( tuple(node_info) )
|
||||||
|
return Structure( genotypes )
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gen_all(search_space, num, return_ori):
|
||||||
|
assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space))
|
||||||
|
assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num)
|
||||||
|
all_archs = get_combination(search_space, 1)
|
||||||
|
for i, arch in enumerate(all_archs):
|
||||||
|
all_archs[i] = [ tuple(arch) ]
|
||||||
|
|
||||||
|
for inode in range(2, num):
|
||||||
|
cur_nodes = get_combination(search_space, inode)
|
||||||
|
new_all_archs = []
|
||||||
|
for previous_arch in all_archs:
|
||||||
|
for cur_node in cur_nodes:
|
||||||
|
new_all_archs.append( previous_arch + [tuple(cur_node)] )
|
||||||
|
all_archs = new_all_archs
|
||||||
|
if return_ori:
|
||||||
|
return all_archs
|
||||||
|
else:
|
||||||
|
return [Structure(x) for x in all_archs]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ResNet_CODE = Structure(
|
||||||
|
[(('nor_conv_3x3', 0), ), # node-1
|
||||||
|
(('nor_conv_3x3', 1), ), # node-2
|
||||||
|
(('skip_connect', 0), ('skip_connect', 2))] # node-3
|
||||||
|
)
|
||||||
|
|
||||||
|
AllConv3x3_CODE = Structure(
|
||||||
|
[(('nor_conv_3x3', 0), ), # node-1
|
||||||
|
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2
|
||||||
|
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3
|
||||||
|
)
|
||||||
|
|
||||||
|
AllFull_CODE = Structure(
|
||||||
|
[(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1
|
||||||
|
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2
|
||||||
|
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3
|
||||||
|
)
|
||||||
|
|
||||||
|
AllConv1x1_CODE = Structure(
|
||||||
|
[(('nor_conv_1x1', 0), ), # node-1
|
||||||
|
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2
|
||||||
|
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3
|
||||||
|
)
|
||||||
|
|
||||||
|
AllIdentity_CODE = Structure(
|
||||||
|
[(('skip_connect', 0), ), # node-1
|
||||||
|
(('skip_connect', 0), ('skip_connect', 1)), # node-2
|
||||||
|
(('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3
|
||||||
|
)
|
||||||
|
|
||||||
|
architectures = {'resnet' : ResNet_CODE,
|
||||||
|
'all_c3x3': AllConv3x3_CODE,
|
||||||
|
'all_c1x1': AllConv1x1_CODE,
|
||||||
|
'all_idnt': AllIdentity_CODE,
|
||||||
|
'all_full': AllFull_CODE}
|
619
nasbench201/init_projection.py
Normal file
619
nasbench201/init_projection.py
Normal file
@ -0,0 +1,619 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as f
|
||||||
|
sys.path.insert(0, '../')
|
||||||
|
import nasbench201.utils as ig_utils
|
||||||
|
import logging
|
||||||
|
import torch.utils
|
||||||
|
import copy
|
||||||
|
import scipy.stats as ss
|
||||||
|
from collections import OrderedDict
|
||||||
|
from foresight.pruners import *
|
||||||
|
from op_score import Jocab_Score, get_ntk_n
|
||||||
|
import gc
|
||||||
|
from nasbench201.linear_region import Linear_Region_Collector
|
||||||
|
|
||||||
|
torch.set_printoptions(precision=4, sci_mode=False)
|
||||||
|
np.set_printoptions(precision=4, suppress=True)
|
||||||
|
|
||||||
|
# global-edge-iter: similar toglobal-op-iterbut iteratively selects edge e from E based on the average score of all operations on each edge
|
||||||
|
def global_op_greedy_pt_project(proj_queue, model, args):
|
||||||
|
def project(model, args):
|
||||||
|
## macros
|
||||||
|
num_edge, num_op = model.num_edge, model.num_op
|
||||||
|
|
||||||
|
##get remain eid numbers
|
||||||
|
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
|
||||||
|
compare = lambda x, y : x < y
|
||||||
|
|
||||||
|
crit_extrema = None
|
||||||
|
best_eid = None
|
||||||
|
input, target = next(iter(proj_queue))
|
||||||
|
for eid in remain_eids:
|
||||||
|
for opid in range(num_op):
|
||||||
|
# projection
|
||||||
|
weights = model.get_projected_weights()
|
||||||
|
proj_mask = torch.ones_like(weights[eid])
|
||||||
|
proj_mask[opid] = 0
|
||||||
|
weights[eid] = weights[eid] * proj_mask
|
||||||
|
|
||||||
|
## proj evaluation
|
||||||
|
if args.proj_crit == 'jacob':
|
||||||
|
valid_stats = Jocab_Score(model, input, target, weights=weights)
|
||||||
|
crit = valid_stats
|
||||||
|
|
||||||
|
if crit_extrema is None or compare(crit, crit_extrema):
|
||||||
|
crit_extrema = crit
|
||||||
|
best_opid = opid
|
||||||
|
best_eid = eid
|
||||||
|
|
||||||
|
logging.info('best opid %d', best_opid)
|
||||||
|
return best_eid, best_opid
|
||||||
|
|
||||||
|
tune_epochs = model.arch_parameters()[0].shape[0]
|
||||||
|
|
||||||
|
for epoch in range(tune_epochs):
|
||||||
|
logging.info('epoch %d', epoch)
|
||||||
|
logging.info('project')
|
||||||
|
selected_eid, best_opid = project(model, args)
|
||||||
|
model.project_op(selected_eid, best_opid)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
# global-edge-iter: similar toglobal-op-oncebut uses the average score of operations on edges to obtain the edge discretization order
|
||||||
|
def global_edge_greedy_pt_project(proj_queue, model, args):
|
||||||
|
def select_eid(model, args):
|
||||||
|
## macros
|
||||||
|
num_edge, num_op = model.num_edge, model.num_op
|
||||||
|
|
||||||
|
##get remain eid numbers
|
||||||
|
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
|
||||||
|
compare = lambda x, y : x < y
|
||||||
|
|
||||||
|
crit_extrema = None
|
||||||
|
best_eid = None
|
||||||
|
input, target = next(iter(proj_queue))
|
||||||
|
for eid in remain_eids:
|
||||||
|
eid_score = []
|
||||||
|
for opid in range(num_op):
|
||||||
|
# projection
|
||||||
|
weights = model.get_projected_weights()
|
||||||
|
proj_mask = torch.ones_like(weights[eid])
|
||||||
|
proj_mask[opid] = 0
|
||||||
|
weights[eid] = weights[eid] * proj_mask
|
||||||
|
|
||||||
|
## proj evaluation
|
||||||
|
if args.proj_crit == 'jacob':
|
||||||
|
valid_stats = Jocab_Score(model, input, target, weights=weights)
|
||||||
|
crit = valid_stats
|
||||||
|
eid_score.append(crit)
|
||||||
|
eid_score = np.mean(eid_score)
|
||||||
|
|
||||||
|
if crit_extrema is None or compare(eid_score, crit_extrema):
|
||||||
|
crit_extrema = eid_score
|
||||||
|
best_eid = eid
|
||||||
|
return best_eid
|
||||||
|
|
||||||
|
def project(model, args, selected_eid):
|
||||||
|
## macros
|
||||||
|
num_edge, num_op = model.num_edge, model.num_op
|
||||||
|
|
||||||
|
## select the best operation
|
||||||
|
if args.proj_crit == 'jacob':
|
||||||
|
crit_idx = 3
|
||||||
|
compare = lambda x, y: x < y
|
||||||
|
else:
|
||||||
|
crit_idx = 4
|
||||||
|
compare = lambda x, y: x < y
|
||||||
|
|
||||||
|
best_opid = 0
|
||||||
|
crit_list = []
|
||||||
|
op_ids = []
|
||||||
|
input, target = next(iter(proj_queue))
|
||||||
|
for opid in range(num_op):
|
||||||
|
## projection
|
||||||
|
weights = model.get_projected_weights()
|
||||||
|
proj_mask = torch.ones_like(weights[selected_eid])
|
||||||
|
proj_mask[opid] = 0
|
||||||
|
weights[selected_eid] = weights[selected_eid] * proj_mask
|
||||||
|
|
||||||
|
## proj evaluation
|
||||||
|
if args.proj_crit == 'jacob':
|
||||||
|
valid_stats = Jocab_Score(model, input, target, weights=weights)
|
||||||
|
crit = valid_stats
|
||||||
|
|
||||||
|
crit_list.append(crit)
|
||||||
|
op_ids.append(opid)
|
||||||
|
|
||||||
|
best_opid = op_ids[np.nanargmin(crit_list)]
|
||||||
|
|
||||||
|
logging.info('best opid %d', best_opid)
|
||||||
|
logging.info(crit_list)
|
||||||
|
return selected_eid, best_opid
|
||||||
|
|
||||||
|
num_edges = model.arch_parameters()[0].shape[0]
|
||||||
|
|
||||||
|
for epoch in range(num_edges):
|
||||||
|
logging.info('epoch %d', epoch)
|
||||||
|
|
||||||
|
logging.info('project')
|
||||||
|
selected_eid = select_eid(model, args)
|
||||||
|
selected_eid, best_opid = project(model, args, selected_eid)
|
||||||
|
model.project_op(selected_eid, best_opid)
|
||||||
|
return
|
||||||
|
|
||||||
|
# global-op-once: only evaluates S(A−(e,o)) for all operations once to obtain a ranking order of the operations, and discretizes the edgesEaccording to this order
|
||||||
|
def global_op_once_pt_project(proj_queue, model, args):
|
||||||
|
def order(model, args):
|
||||||
|
## macros
|
||||||
|
num_edge, num_op = model.num_edge, model.num_op
|
||||||
|
|
||||||
|
##get remain eid numbers
|
||||||
|
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
|
||||||
|
compare = lambda x, y : x < y
|
||||||
|
|
||||||
|
edge_score = OrderedDict()
|
||||||
|
input, target = next(iter(proj_queue))
|
||||||
|
for eid in remain_eids:
|
||||||
|
crit_list = []
|
||||||
|
for opid in range(num_op):
|
||||||
|
# projection
|
||||||
|
weights = model.get_projected_weights()
|
||||||
|
proj_mask = torch.ones_like(weights[eid])
|
||||||
|
proj_mask[opid] = 0
|
||||||
|
weights[eid] = weights[eid] * proj_mask
|
||||||
|
|
||||||
|
## proj evaluation
|
||||||
|
if args.proj_crit == 'jacob':
|
||||||
|
valid_stats = Jocab_Score(model, input, target, weights=weights)
|
||||||
|
crit = valid_stats
|
||||||
|
|
||||||
|
crit_list.append(crit)
|
||||||
|
edge_score[eid] = np.nanargmin(crit_list)
|
||||||
|
return edge_score
|
||||||
|
|
||||||
|
def project(model, args, selected_eid):
|
||||||
|
## macros
|
||||||
|
num_edge, num_op = model.num_edge, model.num_op
|
||||||
|
## select the best operation
|
||||||
|
if args.proj_crit == 'jacob':
|
||||||
|
crit_idx = 3
|
||||||
|
compare = lambda x, y: x < y
|
||||||
|
else:
|
||||||
|
crit_idx = 4
|
||||||
|
compare = lambda x, y: x < y
|
||||||
|
|
||||||
|
best_opid = 0
|
||||||
|
crit_list = []
|
||||||
|
op_ids = []
|
||||||
|
input, target = next(iter(proj_queue))
|
||||||
|
for opid in range(num_op):
|
||||||
|
## projection
|
||||||
|
weights = model.get_projected_weights()
|
||||||
|
proj_mask = torch.ones_like(weights[selected_eid])
|
||||||
|
proj_mask[opid] = 0
|
||||||
|
weights[selected_eid] = weights[selected_eid] * proj_mask
|
||||||
|
|
||||||
|
## proj evaluation
|
||||||
|
if args.proj_crit == 'jacob':
|
||||||
|
crit = Jocab_Score(model, input, target, weights=weights)
|
||||||
|
crit_list.append(crit)
|
||||||
|
op_ids.append(opid)
|
||||||
|
|
||||||
|
best_opid = op_ids[np.nanargmin(crit_list)]
|
||||||
|
|
||||||
|
logging.info('best opid %d', best_opid)
|
||||||
|
logging.info(crit_list)
|
||||||
|
return selected_eid, best_opid
|
||||||
|
|
||||||
|
num_edges = model.arch_parameters()[0].shape[0]
|
||||||
|
|
||||||
|
eid_order = order(model, args)
|
||||||
|
for epoch in range(num_edges):
|
||||||
|
logging.info('epoch %d', epoch)
|
||||||
|
logging.info('project')
|
||||||
|
selected_eid, _ = eid_order.popitem()
|
||||||
|
selected_eid, best_opid = project(model, args, selected_eid)
|
||||||
|
model.project_op(selected_eid, best_opid)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
# global-edge-once: similar toglobal-op-oncebut uses the average score of operations on dges to obtain the edge discretization order
|
||||||
|
def global_edge_once_pt_project(proj_queue, model, args):
|
||||||
|
def order(model, args):
|
||||||
|
## macros
|
||||||
|
num_edge, num_op = model.num_edge, model.num_op
|
||||||
|
|
||||||
|
##get remain eid numbers
|
||||||
|
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
|
||||||
|
compare = lambda x, y : x < y
|
||||||
|
|
||||||
|
edge_score = OrderedDict()
|
||||||
|
crit_extrema = None
|
||||||
|
best_eid = None
|
||||||
|
input, target = next(iter(proj_queue))
|
||||||
|
for eid in remain_eids:
|
||||||
|
crit_list = []
|
||||||
|
for opid in range(num_op):
|
||||||
|
# projection
|
||||||
|
weights = model.get_projected_weights()
|
||||||
|
proj_mask = torch.ones_like(weights[eid])
|
||||||
|
proj_mask[opid] = 0
|
||||||
|
weights[eid] = weights[eid] * proj_mask
|
||||||
|
|
||||||
|
## proj evaluation
|
||||||
|
if args.proj_crit == 'jacob':
|
||||||
|
crit = Jocab_Score(model, input, target, weights=weights)
|
||||||
|
|
||||||
|
crit_list.append(crit)
|
||||||
|
edge_score[eid] = np.mean(crit_list)
|
||||||
|
return edge_score
|
||||||
|
|
||||||
|
def project(model, args, selected_eid):
|
||||||
|
## macros
|
||||||
|
num_edge, num_op = model.num_edge, model.num_op
|
||||||
|
## select the best operation
|
||||||
|
if args.proj_crit == 'jacob':
|
||||||
|
crit_idx = 3
|
||||||
|
compare = lambda x, y: x < y
|
||||||
|
else:
|
||||||
|
crit_idx = 4
|
||||||
|
compare = lambda x, y: x < y
|
||||||
|
|
||||||
|
best_opid = 0
|
||||||
|
crit_extrema = None
|
||||||
|
crit_list = []
|
||||||
|
op_ids = []
|
||||||
|
input, target = next(iter(proj_queue))
|
||||||
|
for opid in range(num_op):
|
||||||
|
## projection
|
||||||
|
weights = model.get_projected_weights()
|
||||||
|
proj_mask = torch.ones_like(weights[selected_eid])
|
||||||
|
proj_mask[opid] = 0
|
||||||
|
weights[selected_eid] = weights[selected_eid] * proj_mask
|
||||||
|
|
||||||
|
## proj evaluation
|
||||||
|
if args.proj_crit == 'jacob':
|
||||||
|
crit = Jocab_Score(model, input, target, weights=weights)
|
||||||
|
crit_list.append(crit)
|
||||||
|
op_ids.append(opid)
|
||||||
|
|
||||||
|
best_opid = op_ids[np.nanargmin(crit_list)]
|
||||||
|
|
||||||
|
logging.info('best opid %d', best_opid)
|
||||||
|
logging.info(crit_list)
|
||||||
|
return selected_eid, best_opid
|
||||||
|
|
||||||
|
num_edges = model.arch_parameters()[0].shape[0]
|
||||||
|
|
||||||
|
eid_order = order(model, args)
|
||||||
|
for epoch in range(num_edges):
|
||||||
|
logging.info('epoch %d', epoch)
|
||||||
|
logging.info('project')
|
||||||
|
selected_eid, _ = eid_order.popitem()
|
||||||
|
selected_eid, best_opid = project(model, args, selected_eid)
|
||||||
|
model.project_op(selected_eid, best_opid)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
# fixed [reverse, order]: discretizes the edges in a fixed order, where in our experiments we discretize from the222input towards the output of the cell struct
|
||||||
|
# random: discretizes the edges in a random order (DARTS-PT)
|
||||||
|
# NOTE: Only this methods allows use other zero-cost proxy metrics
|
||||||
|
def pt_project(proj_queue, model, args):
|
||||||
|
def project(model, args):
|
||||||
|
## macros,一共6条边,每条边有5个操作
|
||||||
|
num_edge, num_op = model.num_edge, model.num_op
|
||||||
|
|
||||||
|
## select an edge
|
||||||
|
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
|
||||||
|
# print('candidate_flags:', model.candidate_flags)
|
||||||
|
# print(model.candidate_flags)
|
||||||
|
# 选边的方法
|
||||||
|
if args.edge_decision == "random":
|
||||||
|
# 选出来了一个数组,取其中的一个元素
|
||||||
|
selected_eid = np.random.choice(remain_eids, size=1)[0]
|
||||||
|
elif args.edge_decision == "reverse":
|
||||||
|
selected_eid = remain_eids[-1]
|
||||||
|
else:
|
||||||
|
selected_eid = remain_eids[0]
|
||||||
|
|
||||||
|
## select the best operation
|
||||||
|
if args.proj_crit == 'jacob':
|
||||||
|
crit_idx = 3
|
||||||
|
compare = lambda x, y: x < y
|
||||||
|
else:
|
||||||
|
crit_idx = 4
|
||||||
|
compare = lambda x, y: x < y
|
||||||
|
|
||||||
|
if args.dataset == 'cifar100':
|
||||||
|
n_classes = 100
|
||||||
|
elif args.dataset == 'imagenet16-120':
|
||||||
|
n_classes = 120
|
||||||
|
else:
|
||||||
|
n_classes = 10
|
||||||
|
|
||||||
|
best_opid = 0
|
||||||
|
crit_extrema = None
|
||||||
|
crit_list = []
|
||||||
|
op_ids = []
|
||||||
|
input, target = next(iter(proj_queue))
|
||||||
|
for opid in range(num_op):
|
||||||
|
## projection
|
||||||
|
weights = model.get_projected_weights()
|
||||||
|
proj_mask = torch.ones_like(weights[selected_eid])
|
||||||
|
# print(selected_eid, weights[selected_eid])
|
||||||
|
proj_mask[opid] = 0
|
||||||
|
weights[selected_eid] = weights[selected_eid] * proj_mask
|
||||||
|
|
||||||
|
|
||||||
|
## proj evaluation
|
||||||
|
if args.proj_crit == 'jacob':
|
||||||
|
crit = Jocab_Score(model, input, target, weights=weights)
|
||||||
|
else:
|
||||||
|
cache_weight = model.proj_weights[selected_eid]
|
||||||
|
cache_flag = model.candidate_flags[selected_eid]
|
||||||
|
|
||||||
|
|
||||||
|
for idx in range(num_op):
|
||||||
|
if idx == opid:
|
||||||
|
model.proj_weights[selected_eid][opid] = 0
|
||||||
|
else:
|
||||||
|
model.proj_weights[selected_eid][idx] = 1.0/num_op
|
||||||
|
|
||||||
|
|
||||||
|
model.candidate_flags[selected_eid] = False
|
||||||
|
# print(model.get_projected_weights())
|
||||||
|
|
||||||
|
if args.proj_crit == 'comb':
|
||||||
|
synflow = predictive.find_measures(model,
|
||||||
|
proj_queue,
|
||||||
|
('random', 1, n_classes),
|
||||||
|
torch.device("cuda"),
|
||||||
|
measure_names=['synflow'])
|
||||||
|
var = predictive.find_measures(model,
|
||||||
|
proj_queue,
|
||||||
|
('random', 1, n_classes),
|
||||||
|
torch.device("cuda"),
|
||||||
|
measure_names=['var'])
|
||||||
|
# print(synflow, var)
|
||||||
|
comb = np.log(synflow['synflow'] + 1) / (var['var'] + 0.1)
|
||||||
|
measures = {'comb': comb}
|
||||||
|
else:
|
||||||
|
measures = predictive.find_measures(model,
|
||||||
|
proj_queue,
|
||||||
|
('random', 1, n_classes),
|
||||||
|
torch.device("cuda"),
|
||||||
|
measure_names=[args.proj_crit])
|
||||||
|
|
||||||
|
# print(measures)
|
||||||
|
for idx in range(num_op):
|
||||||
|
model.proj_weights[selected_eid][idx] = 0
|
||||||
|
model.candidate_flags[selected_eid] = cache_flag
|
||||||
|
crit = measures[args.proj_crit]
|
||||||
|
|
||||||
|
crit_list.append(crit)
|
||||||
|
op_ids.append(opid)
|
||||||
|
|
||||||
|
|
||||||
|
best_opid = op_ids[np.nanargmin(crit_list)]
|
||||||
|
# best_opid = op_ids[np.nanargmax(crit_list)]
|
||||||
|
|
||||||
|
logging.info('best opid %d', best_opid)
|
||||||
|
logging.info('current edge id %d', selected_eid)
|
||||||
|
logging.info(crit_list)
|
||||||
|
return selected_eid, best_opid
|
||||||
|
|
||||||
|
num_edges = model.arch_parameters()[0].shape[0]
|
||||||
|
|
||||||
|
for epoch in range(num_edges):
|
||||||
|
logging.info('epoch %d', epoch)
|
||||||
|
logging.info('project')
|
||||||
|
selected_eid, best_opid = project(model, args)
|
||||||
|
model.project_op(selected_eid, best_opid)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def tenas_project(proj_queue, model, model_thin, args):
|
||||||
|
def project(model, args):
|
||||||
|
## macros
|
||||||
|
num_edge, num_op = model.num_edge, model.num_op
|
||||||
|
|
||||||
|
##get remain eid numbers
|
||||||
|
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
|
||||||
|
compare = lambda x, y : x < y
|
||||||
|
|
||||||
|
ntks = []
|
||||||
|
lrs = []
|
||||||
|
edge_op_id = []
|
||||||
|
best_eid = None
|
||||||
|
|
||||||
|
if args.proj_crit == 'tenas':
|
||||||
|
lrc_model = Linear_Region_Collector(input_size=(1000, 1, 3, 3), sample_batch=3, dataset=args.dataset, data_path=args.data, seed=args.seed)
|
||||||
|
for eid in remain_eids:
|
||||||
|
for opid in range(num_op):
|
||||||
|
# projection
|
||||||
|
weights = model.get_projected_weights()
|
||||||
|
proj_mask = torch.ones_like(weights[eid])
|
||||||
|
proj_mask[opid] = 0
|
||||||
|
weights[eid] = weights[eid] * proj_mask
|
||||||
|
|
||||||
|
## proj evaluation
|
||||||
|
if args.proj_crit == 'tenas':
|
||||||
|
lrc_model.reinit(ori_models=[model_thin], seed=args.seed, weights=weights)
|
||||||
|
lr = lrc_model.forward_batch_sample()
|
||||||
|
lrc_model.clear()
|
||||||
|
ntk = get_ntk_n(proj_queue, [model], recalbn=0, train_mode=True, num_batch=1, weights=weights)
|
||||||
|
ntks.append(ntk)
|
||||||
|
lrs.append(lr)
|
||||||
|
edge_op_id.append('{}:{}'.format(eid, opid))
|
||||||
|
print('ntls', ntks)
|
||||||
|
print('lrs', lrs)
|
||||||
|
ntks_ranks = ss.rankdata(ntks)
|
||||||
|
lrs_ranks = ss.rankdata(lrs)
|
||||||
|
ntks_ranks = len(ntks_ranks) - ntks_ranks.astype(int)
|
||||||
|
op_ranks = []
|
||||||
|
for i in range(len(edge_op_id)):
|
||||||
|
op_ranks.append(ntks_ranks[i]+lrs_ranks[i])
|
||||||
|
|
||||||
|
best_op_index = edge_op_id[np.nanargmin(op_ranks[0:num_op])]
|
||||||
|
best_eid, best_opid = [int(x) for x in best_op_index.split(':')]
|
||||||
|
|
||||||
|
logging.info(op_ranks)
|
||||||
|
logging.info('best eid %d', best_eid)
|
||||||
|
logging.info('best opid %d', best_opid)
|
||||||
|
return best_eid, best_opid
|
||||||
|
num_edges = model.arch_parameters()[0].shape[0]
|
||||||
|
|
||||||
|
for epoch in range(num_edges):
|
||||||
|
logging.info('epoch %d', epoch)
|
||||||
|
logging.info('project')
|
||||||
|
selected_eid, best_opid = project(model, args)
|
||||||
|
model.project_op(selected_eid, best_opid)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
#new methods
|
||||||
|
#Randomly propose candidate of networks and transfer it to supernet, then perform global op selection in this subspace
|
||||||
|
def shrink_pt_project(proj_queue, model, args):
|
||||||
|
def project(model, args):
|
||||||
|
## macros
|
||||||
|
num_edge, num_op = model.num_edge, model.num_op
|
||||||
|
|
||||||
|
## select an edge
|
||||||
|
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
|
||||||
|
selected_eid = np.random.choice(remain_eids, size=1)[0]
|
||||||
|
|
||||||
|
|
||||||
|
## select the best operation
|
||||||
|
if args.proj_crit == 'jacob':
|
||||||
|
crit_idx = 3
|
||||||
|
compare = lambda x, y: x < y
|
||||||
|
else:
|
||||||
|
crit_idx = 4
|
||||||
|
compare = lambda x, y: x < y
|
||||||
|
|
||||||
|
if args.dataset == 'cifar100':
|
||||||
|
n_classes = 100
|
||||||
|
elif args.dataset == 'imagenet16-120':
|
||||||
|
n_classes = 120
|
||||||
|
else:
|
||||||
|
n_classes = 10
|
||||||
|
|
||||||
|
best_opid = 0
|
||||||
|
crit_extrema = None
|
||||||
|
crit_list = []
|
||||||
|
op_ids = []
|
||||||
|
input, target = next(iter(proj_queue))
|
||||||
|
for opid in range(num_op):
|
||||||
|
## projection
|
||||||
|
weights = model.get_projected_weights()
|
||||||
|
proj_mask = torch.ones_like(weights[selected_eid])
|
||||||
|
proj_mask[opid] = 0
|
||||||
|
weights[selected_eid] = weights[selected_eid] * proj_mask
|
||||||
|
|
||||||
|
## proj evaluation
|
||||||
|
if args.proj_crit == 'jacob':
|
||||||
|
crit = Jocab_Score(model, input, target, weights=weights)
|
||||||
|
else:
|
||||||
|
cache_weight = model.proj_weights[selected_eid]
|
||||||
|
cache_flag = model.candidate_flags[selected_eid]
|
||||||
|
|
||||||
|
for idx in range(num_op):
|
||||||
|
if idx == opid:
|
||||||
|
model.proj_weights[selected_eid][opid] = 0
|
||||||
|
else:
|
||||||
|
model.proj_weights[selected_eid][idx] = 1.0/num_op
|
||||||
|
model.candidate_flags[selected_eid] = False
|
||||||
|
|
||||||
|
measures = predictive.find_measures(model,
|
||||||
|
train_queue,
|
||||||
|
('random', 1, n_classes),
|
||||||
|
torch.device("cuda"),
|
||||||
|
measure_names=[args.proj_crit])
|
||||||
|
for idx in range(num_op):
|
||||||
|
model.proj_weights[selected_eid][idx] = 0
|
||||||
|
model.candidate_flags[selected_eid] = cache_flag
|
||||||
|
crit = measures[args.proj_crit]
|
||||||
|
|
||||||
|
crit_list.append(crit)
|
||||||
|
op_ids.append(opid)
|
||||||
|
|
||||||
|
best_opid = op_ids[np.nanargmin(crit_list)]
|
||||||
|
|
||||||
|
logging.info('best opid %d', best_opid)
|
||||||
|
logging.info('current edge id %d', selected_eid)
|
||||||
|
logging.info(crit_list)
|
||||||
|
return selected_eid, best_opid
|
||||||
|
|
||||||
|
def global_project(model, args):
|
||||||
|
## macros
|
||||||
|
num_edge, num_op = model.num_edge, model.num_op
|
||||||
|
|
||||||
|
##get remain eid numbers
|
||||||
|
remain_eids = torch.nonzero(model.subspace_candidate_flags).cpu().numpy().T[0]
|
||||||
|
compare = lambda x, y : x < y
|
||||||
|
|
||||||
|
crit_extrema = None
|
||||||
|
best_eid = None
|
||||||
|
best_opid = None
|
||||||
|
input, target = next(iter(proj_queue))
|
||||||
|
for eid in remain_eids:
|
||||||
|
remain_oids = torch.nonzero(model.proj_weights[eid]).cpu().numpy().T[0]
|
||||||
|
for opid in remain_oids:
|
||||||
|
# projection
|
||||||
|
weights = model.get_projected_weights()
|
||||||
|
proj_mask = torch.ones_like(weights[eid])
|
||||||
|
proj_mask[opid] = 0
|
||||||
|
weights[eid] = weights[eid] * proj_mask
|
||||||
|
## proj evaluation
|
||||||
|
if args.proj_crit == 'jacob':
|
||||||
|
valid_stats = Jocab_Score(model, input, target, weights=weights)
|
||||||
|
crit = valid_stats
|
||||||
|
|
||||||
|
if crit_extrema is None or compare(crit, crit_extrema):
|
||||||
|
crit_extrema = crit
|
||||||
|
best_opid = opid
|
||||||
|
best_eid = eid
|
||||||
|
|
||||||
|
|
||||||
|
logging.info('best eid %d', best_eid)
|
||||||
|
logging.info('best opid %d', best_opid)
|
||||||
|
model.subspace_candidate_flags[best_eid] = False
|
||||||
|
proj_mask = torch.zeros_like(model.proj_weights[best_eid])
|
||||||
|
model.proj_weights[best_eid] = model.proj_weights[best_eid] * proj_mask
|
||||||
|
model.proj_weights[best_eid][best_opid] = 1
|
||||||
|
return best_eid, best_opid
|
||||||
|
|
||||||
|
num_edges = model.arch_parameters()[0].shape[0]
|
||||||
|
|
||||||
|
#subspace
|
||||||
|
logging.info('Start subspace proposal')
|
||||||
|
subspace = copy.deepcopy(model.proj_weights)
|
||||||
|
for i in range(20):
|
||||||
|
model.reset_arch_parameters()
|
||||||
|
for epoch in range(num_edges):
|
||||||
|
logging.info('epoch %d', epoch)
|
||||||
|
logging.info('project')
|
||||||
|
selected_eid, best_opid = project(model, args)
|
||||||
|
model.project_op(selected_eid, best_opid)
|
||||||
|
subspace += model.proj_weights
|
||||||
|
|
||||||
|
model.reset_arch_parameters()
|
||||||
|
subspace = torch.gt(subspace, 0).int().float()
|
||||||
|
subspace = f.normalize(subspace, p=1, dim=1)
|
||||||
|
model.proj_weights += subspace
|
||||||
|
for i in range(num_edges):
|
||||||
|
model.candidate_flags[i] = False
|
||||||
|
logging.info('Start final search in subspace')
|
||||||
|
logging.info(subspace)
|
||||||
|
|
||||||
|
model.subspace_candidate_flags = torch.tensor(len(model._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda()
|
||||||
|
for epoch in range(num_edges):
|
||||||
|
logging.info('epoch %d', epoch)
|
||||||
|
logging.info('project')
|
||||||
|
selected_eid, best_opid = global_project(model, args)
|
||||||
|
model.printing(logging)
|
||||||
|
#model.project_op(selected_eid, best_opid)
|
||||||
|
return
|
270
nasbench201/linear_region.py
Normal file
270
nasbench201/linear_region.py
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
import os.path as osp
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.data
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
import torchvision.datasets as dset
|
||||||
|
from pdb import set_trace as bp
|
||||||
|
from operator import mul
|
||||||
|
from functools import reduce
|
||||||
|
import copy
|
||||||
|
Dataset2Class = {'cifar10': 10,
|
||||||
|
'cifar100': 100,
|
||||||
|
'imagenet-1k-s': 1000,
|
||||||
|
'imagenet-1k': 1000,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CUTOUT(object):
|
||||||
|
|
||||||
|
def __init__(self, length):
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
h, w = img.size(1), img.size(2)
|
||||||
|
mask = np.ones((h, w), np.float32)
|
||||||
|
y = np.random.randint(h)
|
||||||
|
x = np.random.randint(w)
|
||||||
|
|
||||||
|
y1 = np.clip(y - self.length // 2, 0, h)
|
||||||
|
y2 = np.clip(y + self.length // 2, 0, h)
|
||||||
|
x1 = np.clip(x - self.length // 2, 0, w)
|
||||||
|
x2 = np.clip(x + self.length // 2, 0, w)
|
||||||
|
|
||||||
|
mask[y1: y2, x1: x2] = 0.
|
||||||
|
mask = torch.from_numpy(mask)
|
||||||
|
mask = mask.expand_as(img)
|
||||||
|
img *= mask
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
imagenet_pca = {
|
||||||
|
'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
|
||||||
|
'eigvec': np.asarray([
|
||||||
|
[-0.5675, 0.7192, 0.4009],
|
||||||
|
[-0.5808, -0.0045, -0.8140],
|
||||||
|
[-0.5836, -0.6948, 0.4203],
|
||||||
|
])
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RandChannel(object):
|
||||||
|
# randomly pick channels from input
|
||||||
|
def __init__(self, num_channel):
|
||||||
|
self.num_channel = num_channel
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ('{name}(num_channel={num_channel})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
channel = img.size(0)
|
||||||
|
channel_choice = sorted(np.random.choice(list(range(channel)), size=self.num_channel, replace=False))
|
||||||
|
return torch.index_select(img, 0, torch.Tensor(channel_choice).long())
|
||||||
|
|
||||||
|
|
||||||
|
def get_datasets(name, root, input_size, cutout=-1):
|
||||||
|
assert len(input_size) in [3, 4]
|
||||||
|
if len(input_size) == 4:
|
||||||
|
input_size = input_size[1:]
|
||||||
|
assert input_size[1] == input_size[2]
|
||||||
|
|
||||||
|
if name == 'cifar10':
|
||||||
|
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
|
||||||
|
std = [x / 255 for x in [63.0, 62.1, 66.7]]
|
||||||
|
elif name == 'cifar100':
|
||||||
|
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
||||||
|
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
||||||
|
elif name.startswith('imagenet-1k'):
|
||||||
|
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
||||||
|
elif name.startswith('ImageNet16'):
|
||||||
|
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
|
||||||
|
std = [x / 255 for x in [63.22, 61.26 , 65.09]]
|
||||||
|
else:
|
||||||
|
raise TypeError("Unknow dataset : {:}".format(name))
|
||||||
|
#ßprint(input_size)
|
||||||
|
# Data Argumentation
|
||||||
|
if name == 'cifar10' or name == 'cifar100':
|
||||||
|
lists = [transforms.RandomCrop(input_size[1], padding=4), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])]
|
||||||
|
if cutout > 0 : lists += [CUTOUT(cutout)]
|
||||||
|
train_transform = transforms.Compose(lists)
|
||||||
|
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||||
|
elif name.startswith('ImageNet16'):
|
||||||
|
lists = [transforms.RandomCrop(input_size[1], padding=4), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])]
|
||||||
|
if cutout > 0 : lists += [CUTOUT(cutout)]
|
||||||
|
train_transform = transforms.Compose(lists)
|
||||||
|
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||||
|
elif name.startswith('imagenet-1k'):
|
||||||
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
if name == 'imagenet-1k':
|
||||||
|
xlists = []
|
||||||
|
xlists.append(transforms.Resize((32, 32), interpolation=2))
|
||||||
|
xlists.append(transforms.RandomCrop(input_size[1], padding=0))
|
||||||
|
elif name == 'imagenet-1k-s':
|
||||||
|
xlists = [transforms.RandomResizedCrop(32, scale=(0.2, 1.0))]
|
||||||
|
xlists = []
|
||||||
|
else: raise ValueError('invalid name : {:}'.format(name))
|
||||||
|
xlists.append(transforms.ToTensor())
|
||||||
|
xlists.append(normalize)
|
||||||
|
xlists.append(RandChannel(input_size[0]))
|
||||||
|
train_transform = transforms.Compose(xlists)
|
||||||
|
test_transform = transforms.Compose([transforms.Resize(40), transforms.CenterCrop(32), transforms.ToTensor(), normalize])
|
||||||
|
else:
|
||||||
|
raise TypeError("Unknow dataset : {:}".format(name))
|
||||||
|
|
||||||
|
if name == 'cifar10':
|
||||||
|
train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True)
|
||||||
|
test_data = dset.CIFAR10 (root, train=False, transform=test_transform , download=True)
|
||||||
|
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||||
|
elif name == 'cifar100':
|
||||||
|
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
|
||||||
|
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
|
||||||
|
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||||
|
elif name.startswith('imagenet-1k'):
|
||||||
|
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
|
||||||
|
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
|
||||||
|
else: raise TypeError("Unknow dataset : {:}".format(name))
|
||||||
|
|
||||||
|
class_num = Dataset2Class[name]
|
||||||
|
return train_data, test_data, class_num
|
||||||
|
|
||||||
|
|
||||||
|
class LinearRegionCount(object):
|
||||||
|
"""Computes and stores the average and current value"""
|
||||||
|
def __init__(self, n_samples):
|
||||||
|
self.ActPattern = {}
|
||||||
|
self.n_LR = -1
|
||||||
|
self.n_samples = n_samples
|
||||||
|
self.ptr = 0
|
||||||
|
self.activations = None
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def update2D(self, activations):
|
||||||
|
n_batch = activations.size()[0]
|
||||||
|
n_neuron = activations.size()[1]
|
||||||
|
self.n_neuron = n_neuron
|
||||||
|
if self.activations is None:
|
||||||
|
self.activations = torch.zeros(self.n_samples, n_neuron).cuda()
|
||||||
|
self.activations[self.ptr:self.ptr+n_batch] = torch.sign(activations) # after ReLU
|
||||||
|
self.ptr += n_batch
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def calc_LR(self):
|
||||||
|
res = torch.matmul(self.activations.half(), (1-self.activations).T.half()) # each element in res: A * (1 - B)
|
||||||
|
res += res.T # make symmetric, each element in res: A * (1 - B) + (1 - A) * B, a non-zero element indicate a pair of two different linear regions
|
||||||
|
res = 1 - torch.sign(res) # a non-zero element now indicate two linear regions are identical
|
||||||
|
res = res.sum(1) # for each sample's linear region: how many identical regions from other samples
|
||||||
|
res = 1. / res.float() # contribution of each redudant (repeated) linear region
|
||||||
|
self.n_LR = res.sum().item() # sum of unique regions (by aggregating contribution of all regions)
|
||||||
|
del self.activations, res
|
||||||
|
self.activations = None
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def update1D(self, activationList):
|
||||||
|
code_string = ''
|
||||||
|
for key, value in activationList.items():
|
||||||
|
n_neuron = value.size()[0]
|
||||||
|
for i in range(n_neuron):
|
||||||
|
if value[i] > 0:
|
||||||
|
code_string += '1'
|
||||||
|
else:
|
||||||
|
code_string += '0'
|
||||||
|
if code_string not in self.ActPattern:
|
||||||
|
self.ActPattern[code_string] = 1
|
||||||
|
|
||||||
|
def getLinearReginCount(self):
|
||||||
|
if self.n_LR == -1:
|
||||||
|
self.calc_LR()
|
||||||
|
return self.n_LR
|
||||||
|
|
||||||
|
|
||||||
|
class Linear_Region_Collector:
|
||||||
|
def __init__(self, models=[], input_size=(64, 3, 32, 32), sample_batch=100, dataset='cifar100', data_path=None, seed=0):
|
||||||
|
self.models = []
|
||||||
|
self.input_size = input_size # BCHW
|
||||||
|
self.sample_batch = sample_batch
|
||||||
|
self.input_numel = reduce(mul, self.input_size, 1)
|
||||||
|
self.interFeature = []
|
||||||
|
self.dataset = dataset
|
||||||
|
self.data_path = data_path
|
||||||
|
self.seed = seed
|
||||||
|
self.reinit(models, input_size, sample_batch, seed)
|
||||||
|
|
||||||
|
def reinit(self, ori_models=None, input_size=None, sample_batch=None, seed=None, weights=None):
|
||||||
|
models = []
|
||||||
|
for network in ori_models:
|
||||||
|
network = network.cuda()
|
||||||
|
net = copy.deepcopy(network)
|
||||||
|
net.proj_weights = weights
|
||||||
|
num_edge, num_op = net.num_edge, net.num_op
|
||||||
|
for i in range(num_edge):
|
||||||
|
net.candidate_flags[i] = False
|
||||||
|
net.eval()
|
||||||
|
models.append(net)
|
||||||
|
|
||||||
|
if models is not None:
|
||||||
|
assert isinstance(models, list)
|
||||||
|
del self.models
|
||||||
|
self.models = models
|
||||||
|
for model in self.models:
|
||||||
|
self.register_hook(model)
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
model = model.cuda(device=device)
|
||||||
|
self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(models))]
|
||||||
|
if input_size is not None or sample_batch is not None:
|
||||||
|
if input_size is not None:
|
||||||
|
self.input_size = input_size # BCHW
|
||||||
|
self.input_numel = reduce(mul, self.input_size, 1)
|
||||||
|
if sample_batch is not None:
|
||||||
|
self.sample_batch = sample_batch
|
||||||
|
if self.data_path is not None:
|
||||||
|
self.train_data, _, class_num = get_datasets(self.dataset, self.data_path, self.input_size, -1)
|
||||||
|
self.train_loader = torch.utils.data.DataLoader(self.train_data, batch_size=self.input_size[0], num_workers=16, pin_memory=True, drop_last=True, shuffle=True)
|
||||||
|
self.loader = iter(self.train_loader)
|
||||||
|
if seed is not None and seed != self.seed:
|
||||||
|
self.seed = seed
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
del self.interFeature
|
||||||
|
self.interFeature = []
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(self.models))]
|
||||||
|
del self.interFeature
|
||||||
|
self.interFeature = []
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def register_hook(self, model):
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, nn.ReLU):
|
||||||
|
m.register_forward_hook(hook=self.hook_in_forward)
|
||||||
|
|
||||||
|
def hook_in_forward(self, module, input, output):
|
||||||
|
if isinstance(input, tuple) and len(input[0].size()) == 4:
|
||||||
|
self.interFeature.append(output.detach()) # for ReLU
|
||||||
|
|
||||||
|
def forward_batch_sample(self):
|
||||||
|
for _ in range(self.sample_batch):
|
||||||
|
try:
|
||||||
|
inputs, targets = self.loader.next()
|
||||||
|
except Exception:
|
||||||
|
del self.loader
|
||||||
|
self.loader = iter(self.train_loader)
|
||||||
|
inputs, targets = self.loader.next()
|
||||||
|
for model, LRCount in zip(self.models, self.LRCounts):
|
||||||
|
self.forward(model, LRCount, inputs)
|
||||||
|
output = [LRCount.getLinearReginCount() for LRCount in self.LRCounts]
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(self, model, LRCount, input_data):
|
||||||
|
self.interFeature = []
|
||||||
|
with torch.no_grad():
|
||||||
|
model.forward(input_data.cuda())
|
||||||
|
if len(self.interFeature) == 0: return
|
||||||
|
feature_data = torch.cat([f.view(input_data.size(0), -1) for f in self.interFeature], 1)
|
||||||
|
LRCount.update2D(feature_data)
|
245
nasbench201/networks_proposal.py
Normal file
245
nasbench201/networks_proposal.py
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, '../')
|
||||||
|
import time
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils
|
||||||
|
import torchvision.datasets as dset
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
import nasbench201.utils as ig_utils
|
||||||
|
from nasbench201.search_model_darts_proj import TinyNetworkDartsProj
|
||||||
|
from nasbench201.cell_operations import SearchSpaceNames
|
||||||
|
from nasbench201.init_projection import pt_project, global_op_greedy_pt_project, global_op_once_pt_project, global_edge_greedy_pt_project, global_edge_once_pt_project, shrink_pt_project, tenas_project
|
||||||
|
from nas_201_api import NASBench201API as API
|
||||||
|
|
||||||
|
torch.set_printoptions(precision=4, sci_mode=False)
|
||||||
|
np.set_printoptions(precision=4, suppress=True)
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser("sota")
|
||||||
|
# data related
|
||||||
|
parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
|
||||||
|
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100', 'imagenet16-120'], help='choose dataset')
|
||||||
|
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=64, help='batch size for alpha')
|
||||||
|
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
|
||||||
|
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
|
||||||
|
parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability')
|
||||||
|
parser.add_argument('--seed', type=int, default=2, help='random seed')
|
||||||
|
|
||||||
|
#search space setting
|
||||||
|
parser.add_argument('--search_space', type=str, default='nas-bench-201')
|
||||||
|
|
||||||
|
parser.add_argument('--pool_size', type=int, default=100, help='number of model to proposed')
|
||||||
|
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
|
||||||
|
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
|
||||||
|
|
||||||
|
#system configurations
|
||||||
|
parser.add_argument('--gpu', type=str, default='auto', help='gpu device id')
|
||||||
|
parser.add_argument('--save', type=str, default='exp', help='experiment name')
|
||||||
|
|
||||||
|
#default opt setting for model
|
||||||
|
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
|
||||||
|
parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate')
|
||||||
|
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
|
||||||
|
parser.add_argument('--nesterov', action='store_true', default=True, help='using nestrov momentum for SGD')
|
||||||
|
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
|
||||||
|
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
|
||||||
|
|
||||||
|
#### common
|
||||||
|
parser.add_argument('--fast', action='store_true', default=True, help='skip loading api which is slow')
|
||||||
|
|
||||||
|
#### projection
|
||||||
|
parser.add_argument('--edge_decision', type=str, default='random', choices=['random','reverse', 'order', 'global_op_greedy', 'global_op_once', 'global_edge_greedy', 'global_edge_once', 'shrink_pt_project'], help='which edge to be projected next')
|
||||||
|
parser.add_argument('--proj_crit', type=str, default="comb", choices=['loss', 'acc', 'jacob', 'snip', 'fisher', 'synflow', 'grad_norm', 'grasp', 'jacob_cov','tenas', 'var', 'cor', 'norm', 'comb', 'meco'], help='criteria for projection')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
#### args augment
|
||||||
|
expid = args.save
|
||||||
|
args.save = '../experiments/nas-bench-201/prop-{}-{}-{}'.format(args.save, args.seed, args.pool_size)
|
||||||
|
if not args.dataset == 'cifar10':
|
||||||
|
args.save += '-' + args.dataset
|
||||||
|
if not args.edge_decision == 'random':
|
||||||
|
args.save += '-' + args.edge_decision
|
||||||
|
if not args.proj_crit == 'jacob':
|
||||||
|
args.save += '-' + args.proj_crit
|
||||||
|
|
||||||
|
#### logging
|
||||||
|
scripts_to_save = glob.glob('*.py') \
|
||||||
|
# + ['../exp_scripts/{}.sh'.format(expid)]
|
||||||
|
if os.path.exists(args.save):
|
||||||
|
if input("WARNING: {} exists, override?[y/n]".format(args.save)) == 'y':
|
||||||
|
print('proceed to override saving directory')
|
||||||
|
shutil.rmtree(args.save)
|
||||||
|
else:
|
||||||
|
exit(0)
|
||||||
|
ig_utils.create_exp_dir(args.save, scripts_to_save=scripts_to_save)
|
||||||
|
|
||||||
|
log_format = '%(asctime)s %(message)s'
|
||||||
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
|
||||||
|
format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
||||||
|
|
||||||
|
log_file = 'log.txt'
|
||||||
|
log_path = os.path.join(args.save, log_file)
|
||||||
|
logging.info('======> log filename: %s', log_file)
|
||||||
|
|
||||||
|
if os.path.exists(log_path):
|
||||||
|
if input("WARNING: {} exists, override?[y/n]".format(log_file)) == 'y':
|
||||||
|
print('proceed to override log file directory')
|
||||||
|
else:
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
fh = logging.FileHandler(log_path, mode='w')
|
||||||
|
fh.setFormatter(logging.Formatter(log_format))
|
||||||
|
logging.getLogger().addHandler(fh)
|
||||||
|
writer = SummaryWriter(args.save + '/runs')
|
||||||
|
|
||||||
|
#### macros
|
||||||
|
if args.dataset == 'cifar100':
|
||||||
|
n_classes = 100
|
||||||
|
elif args.dataset == 'imagenet16-120':
|
||||||
|
n_classes = 120
|
||||||
|
else:
|
||||||
|
n_classes = 10
|
||||||
|
|
||||||
|
def main():
|
||||||
|
torch.set_num_threads(3)
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
logging.info('no gpu device available')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
gpu = ig_utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu)
|
||||||
|
torch.cuda.set_device(gpu)
|
||||||
|
cudnn.benchmark = True
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
cudnn.enabled = True
|
||||||
|
torch.cuda.manual_seed(args.seed)
|
||||||
|
logging.info("args = %s", args)
|
||||||
|
logging.info('gpu device = %d' % gpu)
|
||||||
|
|
||||||
|
#### model
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
search_space = SearchSpaceNames[args.search_space]
|
||||||
|
|
||||||
|
# 初始化超网络
|
||||||
|
model = TinyNetworkDartsProj(C=args.init_channels, N=5, max_nodes=4, num_classes=n_classes, criterion=criterion, search_space=search_space, args=args)
|
||||||
|
model_thin = TinyNetworkDartsProj(C=args.init_channels, N=5, max_nodes=4, num_classes=n_classes, criterion=criterion, search_space=search_space, args=args, stem_channels=1)
|
||||||
|
model = model.cuda()
|
||||||
|
model_thin = model_thin.cuda()
|
||||||
|
logging.info("param size = %fMB", ig_utils.count_parameters_in_MB(model))
|
||||||
|
|
||||||
|
#### data
|
||||||
|
if args.dataset == 'cifar10':
|
||||||
|
train_transform, valid_transform = ig_utils._data_transforms_cifar10(args)
|
||||||
|
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
|
||||||
|
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
|
||||||
|
elif args.dataset == 'cifar100':
|
||||||
|
train_transform, valid_transform = ig_utils._data_transforms_cifar100(args)
|
||||||
|
train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
|
||||||
|
valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)
|
||||||
|
elif args.dataset == 'imagenet16-120':
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from nasbench201.DownsampledImageNet import ImageNet16
|
||||||
|
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
|
||||||
|
std = [x / 255 for x in [63.22, 61.26, 65.09]]
|
||||||
|
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||||
|
train_transform = transforms.Compose(lists)
|
||||||
|
train_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=True, transform=train_transform, use_num_of_class_only=120)
|
||||||
|
valid_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=False, transform=train_transform, use_num_of_class_only=120)
|
||||||
|
assert len(train_data) == 151700
|
||||||
|
|
||||||
|
num_train = len(train_data)
|
||||||
|
indices = list(range(num_train))
|
||||||
|
split = int(np.floor(args.train_portion * num_train))
|
||||||
|
|
||||||
|
train_queue = torch.utils.data.DataLoader(
|
||||||
|
train_data, batch_size=args.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
|
||||||
|
pin_memory=True)
|
||||||
|
|
||||||
|
|
||||||
|
#format network pool diction
|
||||||
|
networks_pool={}
|
||||||
|
networks_pool['search_space'] = args.search_space
|
||||||
|
networks_pool['dataset'] = args.dataset
|
||||||
|
networks_pool['networks'] = []
|
||||||
|
networks_pool['pool_size'] = args.pool_size
|
||||||
|
#### architecture selection / projection
|
||||||
|
for i in range(args.pool_size):
|
||||||
|
network_info={}
|
||||||
|
logging.info('{} MODEL HAS SEARCHED'.format(i+1))
|
||||||
|
if args.edge_decision == 'global_op_greedy':
|
||||||
|
global_op_greedy_pt_project(train_queue, model, args)
|
||||||
|
elif args.edge_decision == 'global_op_once':
|
||||||
|
global_op_once_pt_project(train_queue, model, args)
|
||||||
|
elif args.edge_decision == 'global_edge_greedy':
|
||||||
|
global_edge_greedy_pt_project(train_queue, model, args)
|
||||||
|
elif args.edge_decision == 'global_edge_once':
|
||||||
|
global_edge_once_pt_project(train_queue, model, args)
|
||||||
|
elif args.edge_decision == 'shrink_pt_project':
|
||||||
|
shrink_pt_project(train_queue, model, args)
|
||||||
|
api = API('../data/NAS-Bench-201-v1_0-e61699.pth')
|
||||||
|
cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
|
||||||
|
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = query(api, model.genotype().tostr(), logging)
|
||||||
|
else:
|
||||||
|
if args.proj_crit == 'jacob':
|
||||||
|
pt_project(train_queue, model, args)
|
||||||
|
else:
|
||||||
|
pt_project(train_queue, model, args)
|
||||||
|
# tenas_project(train_queue, model, model_thin, args)
|
||||||
|
|
||||||
|
network_info['id'] = str(i)
|
||||||
|
network_info['genotype'] = model.genotype().tostr()
|
||||||
|
networks_pool['networks'].append(network_info)
|
||||||
|
model.reset_arch_parameters()
|
||||||
|
|
||||||
|
with open(os.path.join(args.save,'networks_pool.json'), 'w') as save_file:
|
||||||
|
json.dump(networks_pool, save_file)
|
||||||
|
|
||||||
|
|
||||||
|
#### util functions
|
||||||
|
def distill(result):
|
||||||
|
result = result.split('\n')
|
||||||
|
cifar10 = result[5].replace(' ', '').split(':')
|
||||||
|
cifar100 = result[7].replace(' ', '').split(':')
|
||||||
|
imagenet16 = result[9].replace(' ', '').split(':')
|
||||||
|
|
||||||
|
cifar10_train = float(cifar10[1].strip(',test')[-7:-2].strip('='))
|
||||||
|
cifar10_test = float(cifar10[2][-7:-2].strip('='))
|
||||||
|
cifar100_train = float(cifar100[1].strip(',valid')[-7:-2].strip('='))
|
||||||
|
cifar100_valid = float(cifar100[2].strip(',test')[-7:-2].strip('='))
|
||||||
|
cifar100_test = float(cifar100[3][-7:-2].strip('='))
|
||||||
|
imagenet16_train = float(imagenet16[1].strip(',valid')[-7:-2].strip('='))
|
||||||
|
imagenet16_valid = float(imagenet16[2].strip(',test')[-7:-2].strip('='))
|
||||||
|
imagenet16_test = float(imagenet16[3][-7:-2].strip('='))
|
||||||
|
|
||||||
|
return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
|
||||||
|
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test
|
||||||
|
|
||||||
|
|
||||||
|
def query(api, genotype, logging):
|
||||||
|
result = api.query_by_arch(genotype, hp='200')
|
||||||
|
logging.info('{:}'.format(result))
|
||||||
|
cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
|
||||||
|
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = distill(result)
|
||||||
|
logging.info('cifar10 train %f test %f', cifar10_train, cifar10_test)
|
||||||
|
logging.info('cifar100 train %f valid %f test %f', cifar100_train, cifar100_valid, cifar100_test)
|
||||||
|
logging.info('imagenet16 train %f valid %f test %f', imagenet16_train, imagenet16_valid, imagenet16_test)
|
||||||
|
return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
|
||||||
|
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
113
nasbench201/op_score.py
Normal file
113
nasbench201/op_score.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
import gc
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as f
|
||||||
|
from operator import mul
|
||||||
|
from functools import reduce
|
||||||
|
import copy
|
||||||
|
sys.path.insert(0, '../')
|
||||||
|
|
||||||
|
def Jocab_Score(ori_model, input, target, weights=None):
|
||||||
|
model = copy.deepcopy(ori_model)
|
||||||
|
model.eval()
|
||||||
|
model.proj_weights = weights
|
||||||
|
num_edge, num_op = model.num_edge, model.num_op
|
||||||
|
for i in range(num_edge):
|
||||||
|
model.candidate_flags[i] = False
|
||||||
|
batch_size = input.shape[0]
|
||||||
|
model.K = torch.zeros(batch_size, batch_size).cuda()
|
||||||
|
|
||||||
|
def counting_forward_hook(module, inp, out):
|
||||||
|
try:
|
||||||
|
if isinstance(inp, tuple):
|
||||||
|
inp = inp[0]
|
||||||
|
inp = inp.view(inp.size(0), -1)
|
||||||
|
x = (inp > 0).float()
|
||||||
|
K = x @ x.t()
|
||||||
|
K2 = (1.-x) @ (1.-x.t())
|
||||||
|
model.K = model.K + K + K2
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if 'ReLU' in str(type(module)):
|
||||||
|
module.register_forward_hook(counting_forward_hook)
|
||||||
|
|
||||||
|
input = input.cuda()
|
||||||
|
model(input)
|
||||||
|
score = hooklogdet(model.K.cpu().numpy())
|
||||||
|
del model
|
||||||
|
del input
|
||||||
|
return score
|
||||||
|
|
||||||
|
def hooklogdet(K, labels=None):
|
||||||
|
s, ld = np.linalg.slogdet(K)
|
||||||
|
return ld
|
||||||
|
|
||||||
|
# NTK
|
||||||
|
#------------------------------------------------------------
|
||||||
|
#https://github.com/VITA-Group/TENAS/blob/main/lib/procedures/ntk.py
|
||||||
|
#
|
||||||
|
def recal_bn(network, xloader, recalbn, device):
|
||||||
|
for m in network.modules():
|
||||||
|
if isinstance(m, torch.nn.BatchNorm2d):
|
||||||
|
m.running_mean.data.fill_(0)
|
||||||
|
m.running_var.data.fill_(0)
|
||||||
|
m.num_batches_tracked.data.zero_()
|
||||||
|
m.momentum = None
|
||||||
|
network.train()
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, (inputs, targets) in enumerate(xloader):
|
||||||
|
if i >= recalbn: break
|
||||||
|
inputs = inputs.cuda(device=device, non_blocking=True)
|
||||||
|
_, _ = network(inputs)
|
||||||
|
return network
|
||||||
|
|
||||||
|
def get_ntk_n(xloader, networks, recalbn=0, train_mode=False, num_batch=-1, weights=None):
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
ntks = []
|
||||||
|
copied_networks = []
|
||||||
|
for network in networks:
|
||||||
|
network = network.cuda(device=device)
|
||||||
|
net = copy.deepcopy(network)
|
||||||
|
net.proj_weights = weights
|
||||||
|
num_edge, num_op = net.num_edge, net.num_op
|
||||||
|
for i in range(num_edge):
|
||||||
|
net.candidate_flags[i] = False
|
||||||
|
if train_mode:
|
||||||
|
net.train()
|
||||||
|
else:
|
||||||
|
net.eval()
|
||||||
|
copied_networks.append(net)
|
||||||
|
######
|
||||||
|
grads = [[] for _ in range(len(copied_networks))]
|
||||||
|
for i, (inputs, targets) in enumerate(xloader):
|
||||||
|
if num_batch > 0 and i >= num_batch: break
|
||||||
|
inputs = inputs.cuda(device=device, non_blocking=True)
|
||||||
|
for net_idx, network in enumerate(copied_networks):
|
||||||
|
network.zero_grad()
|
||||||
|
inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
|
||||||
|
logit = network(inputs_)
|
||||||
|
if isinstance(logit, tuple):
|
||||||
|
logit = logit[1] # 201 networks: return features and logits
|
||||||
|
for _idx in range(len(inputs_)):
|
||||||
|
logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
|
||||||
|
grad = []
|
||||||
|
for name, W in network.named_parameters():
|
||||||
|
if 'weight' in name and W.grad is not None:
|
||||||
|
grad.append(W.grad.view(-1).detach())
|
||||||
|
grads[net_idx].append(torch.cat(grad, -1))
|
||||||
|
network.zero_grad()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
######
|
||||||
|
grads = [torch.stack(_grads, 0) for _grads in grads]
|
||||||
|
ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads]
|
||||||
|
conds = []
|
||||||
|
for ntk in ntks:
|
||||||
|
eigenvalues, _ = torch.symeig(ntk) # ascending
|
||||||
|
conds.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0]).item(), copy=True, nan=100000.0))
|
||||||
|
|
||||||
|
del copied_networks
|
||||||
|
return conds
|
182
nasbench201/search_cells.py
Normal file
182
nasbench201/search_cells.py
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
import math, random, torch
|
||||||
|
import warnings
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from copy import deepcopy
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, '../')
|
||||||
|
from nasbench201.cell_operations import OPS
|
||||||
|
|
||||||
|
|
||||||
|
# This module is used for NAS-Bench-201, represents a small search space with a complete DAG
|
||||||
|
class NAS201SearchCell(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True):
|
||||||
|
super(NAS201SearchCell, self).__init__()
|
||||||
|
|
||||||
|
self.op_names = deepcopy(op_names)
|
||||||
|
self.edges = nn.ModuleDict()
|
||||||
|
self.max_nodes = max_nodes
|
||||||
|
self.in_dim = C_in
|
||||||
|
self.out_dim = C_out
|
||||||
|
for i in range(1, max_nodes):
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
if j == 0:
|
||||||
|
xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names]
|
||||||
|
else:
|
||||||
|
xlists = [OPS[op_name](C_in , C_out, 1, affine, track_running_stats) for op_name in op_names]
|
||||||
|
self.edges[ node_str ] = nn.ModuleList( xlists )
|
||||||
|
self.edge_keys = sorted(list(self.edges.keys()))
|
||||||
|
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
|
||||||
|
self.num_edges = len(self.edges)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
|
||||||
|
return string
|
||||||
|
|
||||||
|
def forward(self, inputs, weightss):
|
||||||
|
return self._forward(inputs, weightss)
|
||||||
|
|
||||||
|
def _forward(self, inputs, weightss):
|
||||||
|
with torch.autograd.set_detect_anomaly(True):
|
||||||
|
nodes = [inputs]
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
inter_nodes = []
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
weights = weightss[ self.edge2index[node_str] ]
|
||||||
|
inter_nodes.append(sum(layer(nodes[j], block_input=True)*w if w==0 else layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights)) )
|
||||||
|
nodes.append( sum(inter_nodes) )
|
||||||
|
return nodes[-1]
|
||||||
|
|
||||||
|
# GDAS
|
||||||
|
def forward_gdas(self, inputs, hardwts, index):
|
||||||
|
nodes = [inputs]
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
inter_nodes = []
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
weights = hardwts[ self.edge2index[node_str] ]
|
||||||
|
argmaxs = index[ self.edge2index[node_str] ].item()
|
||||||
|
weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) )
|
||||||
|
inter_nodes.append( weigsum )
|
||||||
|
nodes.append( sum(inter_nodes) )
|
||||||
|
return nodes[-1]
|
||||||
|
|
||||||
|
# joint
|
||||||
|
def forward_joint(self, inputs, weightss):
|
||||||
|
nodes = [inputs]
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
inter_nodes = []
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
weights = weightss[ self.edge2index[node_str] ]
|
||||||
|
#aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel()
|
||||||
|
aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) )
|
||||||
|
inter_nodes.append( aggregation )
|
||||||
|
nodes.append( sum(inter_nodes) )
|
||||||
|
return nodes[-1]
|
||||||
|
|
||||||
|
# uniform random sampling per iteration, SETN
|
||||||
|
def forward_urs(self, inputs):
|
||||||
|
nodes = [inputs]
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
while True: # to avoid select zero for all ops
|
||||||
|
sops, has_non_zero = [], False
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
candidates = self.edges[node_str]
|
||||||
|
select_op = random.choice(candidates)
|
||||||
|
sops.append( select_op )
|
||||||
|
if not hasattr(select_op, 'is_zero') or select_op.is_zero is False: has_non_zero=True
|
||||||
|
if has_non_zero: break
|
||||||
|
inter_nodes = []
|
||||||
|
for j, select_op in enumerate(sops):
|
||||||
|
inter_nodes.append( select_op(nodes[j]) )
|
||||||
|
nodes.append( sum(inter_nodes) )
|
||||||
|
return nodes[-1]
|
||||||
|
|
||||||
|
# select the argmax
|
||||||
|
def forward_select(self, inputs, weightss):
|
||||||
|
nodes = [inputs]
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
inter_nodes = []
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
weights = weightss[ self.edge2index[node_str] ]
|
||||||
|
inter_nodes.append( self.edges[node_str][ weights.argmax().item() ]( nodes[j] ) )
|
||||||
|
#inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
|
||||||
|
nodes.append( sum(inter_nodes) )
|
||||||
|
return nodes[-1]
|
||||||
|
|
||||||
|
# forward with a specific structure
|
||||||
|
def forward_dynamic(self, inputs, structure):
|
||||||
|
nodes = [inputs]
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
cur_op_node = structure.nodes[i-1]
|
||||||
|
inter_nodes = []
|
||||||
|
for op_name, j in cur_op_node:
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
op_index = self.op_names.index( op_name )
|
||||||
|
inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) )
|
||||||
|
nodes.append( sum(inter_nodes) )
|
||||||
|
return nodes[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def channel_shuffle(x, groups):
|
||||||
|
batchsize, num_channels, height, width = x.data.size()
|
||||||
|
channels_per_group = num_channels // groups
|
||||||
|
# reshape
|
||||||
|
x = x.view(batchsize, groups,
|
||||||
|
channels_per_group, height, width)
|
||||||
|
x = torch.transpose(x, 1, 2).contiguous()
|
||||||
|
# flatten
|
||||||
|
x = x.view(batchsize, -1, height, width)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class NAS201SearchCell_PartialChannel(NAS201SearchCell):
|
||||||
|
|
||||||
|
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True, k=4):
|
||||||
|
super(NAS201SearchCell, self).__init__()
|
||||||
|
|
||||||
|
self.k = k
|
||||||
|
self.op_names = deepcopy(op_names)
|
||||||
|
self.edges = nn.ModuleDict()
|
||||||
|
self.max_nodes = max_nodes
|
||||||
|
self.in_dim = C_in
|
||||||
|
self.out_dim = C_out
|
||||||
|
for i in range(1, max_nodes):
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
if j == 0:
|
||||||
|
xlists = [OPS[op_name](C_in//self.k , C_out//self.k, stride, affine, track_running_stats) for op_name in op_names]
|
||||||
|
else:
|
||||||
|
xlists = [OPS[op_name](C_in//self.k , C_out//self.k, 1, affine, track_running_stats) for op_name in op_names]
|
||||||
|
self.edges[ node_str ] = nn.ModuleList( xlists )
|
||||||
|
self.edge_keys = sorted(list(self.edges.keys()))
|
||||||
|
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
|
||||||
|
self.num_edges = len(self.edges)
|
||||||
|
|
||||||
|
def MixedOp(self, x, ops, weights):
|
||||||
|
dim_2 = x.shape[1]
|
||||||
|
xtemp = x[ : , : dim_2//self.k, :, :]
|
||||||
|
xtemp2 = x[ : , dim_2//self.k:, :, :]
|
||||||
|
temp1 = sum(w * op(xtemp) for w, op in zip(weights, ops))
|
||||||
|
ans = torch.cat([temp1,xtemp2],dim=1)
|
||||||
|
ans = channel_shuffle(ans,self.k)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def forward(self, inputs, weightss):
|
||||||
|
nodes = [inputs]
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
inter_nodes = []
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
weights = weightss[ self.edge2index[node_str] ]
|
||||||
|
# inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
|
||||||
|
inter_nodes.append(self.MixedOp(x=nodes[j], ops=self.edges[node_str], weights=weights))
|
||||||
|
nodes.append( sum(inter_nodes) )
|
||||||
|
return nodes[-1]
|
||||||
|
|
202
nasbench201/search_model.py
Normal file
202
nasbench201/search_model.py
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from copy import deepcopy
|
||||||
|
from .cell_operations import ResNetBasicblock
|
||||||
|
from .search_cells import NAS201SearchCell as SearchCell
|
||||||
|
from .genotypes import Structure
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
class TinyNetwork(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args, affine=False, track_running_stats=True, stem_channels=3):
|
||||||
|
super(TinyNetwork, self).__init__()
|
||||||
|
self._C = C
|
||||||
|
self._layerN = N
|
||||||
|
self.max_nodes = max_nodes
|
||||||
|
self._num_classes = num_classes
|
||||||
|
self._criterion = criterion
|
||||||
|
self._args = args
|
||||||
|
self._affine = affine
|
||||||
|
self._track_running_stats = track_running_stats
|
||||||
|
self.stem = nn.Sequential(
|
||||||
|
nn.Conv2d(stem_channels, C, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(C))
|
||||||
|
|
||||||
|
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||||
|
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||||
|
|
||||||
|
C_prev, num_edge, edge2index = C, None, None
|
||||||
|
self.cells = nn.ModuleList()
|
||||||
|
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||||
|
if reduction:
|
||||||
|
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||||
|
else:
|
||||||
|
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
|
||||||
|
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||||
|
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
||||||
|
self.cells.append( cell )
|
||||||
|
C_prev = cell.out_dim
|
||||||
|
self.num_edge = num_edge
|
||||||
|
self.num_op = len(search_space)
|
||||||
|
self.op_names = deepcopy( search_space )
|
||||||
|
self._Layer = len(self.cells)
|
||||||
|
self.edge2index = edge2index
|
||||||
|
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||||
|
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.classifier = nn.Linear(C_prev, num_classes)
|
||||||
|
# self._arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
||||||
|
self._arch_parameters = Variable(1e-3*torch.randn(num_edge, len(search_space)).cuda(), requires_grad=True)
|
||||||
|
|
||||||
|
## optimizer
|
||||||
|
## 记录的是m在内存中的地址,以示区分
|
||||||
|
arch_params = set(id(m) for m in self.arch_parameters())
|
||||||
|
self._model_params = [m for m in self.parameters() if id(m) not in arch_params]
|
||||||
|
|
||||||
|
# 模型参数优化器
|
||||||
|
self.optimizer = torch.optim.SGD(
|
||||||
|
self._model_params,
|
||||||
|
args.learning_rate,
|
||||||
|
momentum=args.momentum,
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
nesterov= args.nesterov)
|
||||||
|
|
||||||
|
|
||||||
|
def entropy_y_x(self, p_logit):
|
||||||
|
p = F.softmax(p_logit, dim=1)
|
||||||
|
return - torch.sum(p * F.log_softmax(p_logit, dim=1)) / p_logit.shape[0]
|
||||||
|
|
||||||
|
def _loss(self, input, target, return_logits=False):
|
||||||
|
logits = self(input)
|
||||||
|
loss = self._criterion(logits, target)
|
||||||
|
|
||||||
|
return (loss, logits) if return_logits else loss
|
||||||
|
|
||||||
|
def get_weights(self):
|
||||||
|
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
|
||||||
|
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
|
||||||
|
xlist+= list( self.classifier.parameters() )
|
||||||
|
return xlist
|
||||||
|
|
||||||
|
def arch_parameters(self):
|
||||||
|
return [self._arch_parameters]
|
||||||
|
|
||||||
|
def get_theta(self):
|
||||||
|
return nn.functional.softmax(self._arch_parameters, dim=-1).cpu()
|
||||||
|
|
||||||
|
def get_message(self):
|
||||||
|
string = self.extra_repr()
|
||||||
|
for i, cell in enumerate(self.cells):
|
||||||
|
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
||||||
|
return string
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||||
|
|
||||||
|
def genotype(self):
|
||||||
|
genotypes = []
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
xlist = []
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
with torch.no_grad():
|
||||||
|
weights = self._arch_parameters[ self.edge2index[node_str] ]
|
||||||
|
op_name = self.op_names[ weights.argmax().item() ]
|
||||||
|
xlist.append((op_name, j))
|
||||||
|
genotypes.append( tuple(xlist) )
|
||||||
|
return Structure( genotypes )
|
||||||
|
|
||||||
|
def forward(self, inputs, weights=None):
|
||||||
|
sim_nn = []
|
||||||
|
|
||||||
|
weights = nn.functional.softmax(self._arch_parameters, dim=-1) if weights is None else weights
|
||||||
|
|
||||||
|
if self.slim:
|
||||||
|
weights[1].data.fill_(0)
|
||||||
|
weights[3].data.fill_(0)
|
||||||
|
weights[4].data.fill_(0)
|
||||||
|
|
||||||
|
feature = self.stem(inputs)
|
||||||
|
for i, cell in enumerate(self.cells):
|
||||||
|
if isinstance(cell, SearchCell):
|
||||||
|
feature = cell(feature, weights)
|
||||||
|
else:
|
||||||
|
feature = cell(feature)
|
||||||
|
|
||||||
|
out = self.lastact(feature)
|
||||||
|
out = self.global_pooling( out )
|
||||||
|
out = out.view(out.size(0), -1)
|
||||||
|
logits = self.classifier(out)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def _save_arch_parameters(self):
|
||||||
|
self._saved_arch_parameters = [p.clone() for p in self._arch_parameters]
|
||||||
|
|
||||||
|
def project_arch(self):
|
||||||
|
self._save_arch_parameters()
|
||||||
|
for p in self.arch_parameters():
|
||||||
|
m, n = p.size()
|
||||||
|
maxIndexs = p.data.cpu().numpy().argmax(axis=1)
|
||||||
|
p.data = self.proximal_step(p, maxIndexs)
|
||||||
|
|
||||||
|
def proximal_step(self, var, maxIndexs=None):
|
||||||
|
values = var.data.cpu().numpy()
|
||||||
|
m, n = values.shape
|
||||||
|
alphas = []
|
||||||
|
for i in range(m):
|
||||||
|
for j in range(n):
|
||||||
|
if j == maxIndexs[i]:
|
||||||
|
alphas.append(values[i][j].copy())
|
||||||
|
values[i][j] = 1
|
||||||
|
else:
|
||||||
|
values[i][j] = 0
|
||||||
|
return torch.Tensor(values).cuda()
|
||||||
|
|
||||||
|
def restore_arch_parameters(self):
|
||||||
|
for i, p in enumerate(self._arch_parameters):
|
||||||
|
p.data.copy_(self._saved_arch_parameters[i])
|
||||||
|
del self._saved_arch_parameters
|
||||||
|
|
||||||
|
def new(self):
|
||||||
|
model_new = TinyNetwork(self._C, self._layerN, self.max_nodes, self._num_classes, self._criterion,
|
||||||
|
self.op_names, self._args, self._affine, self._track_running_stats).cuda()
|
||||||
|
for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
|
||||||
|
x.data.copy_(y.data)
|
||||||
|
|
||||||
|
return model_new
|
||||||
|
|
||||||
|
def step(self, input, target, args, shared=None, return_grad=False):
|
||||||
|
Lt, logit_t = self._loss(input, target, return_logits=True)
|
||||||
|
Lt.backward()
|
||||||
|
if args.grad_clip != 0:
|
||||||
|
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
if return_grad:
|
||||||
|
grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()])
|
||||||
|
return logit_t, Lt, grad
|
||||||
|
else:
|
||||||
|
return logit_t, Lt
|
||||||
|
|
||||||
|
def printing(self, logging):
|
||||||
|
logging.info(self.get_theta())
|
||||||
|
|
||||||
|
def set_arch_parameters(self, new_alphas):
|
||||||
|
for alpha, new_alpha in zip(self.arch_parameters(), new_alphas):
|
||||||
|
alpha.data.copy_(new_alpha.data)
|
||||||
|
|
||||||
|
def save_arch_parameters(self):
|
||||||
|
self._saved_arch_parameters = self._arch_parameters.clone()
|
||||||
|
|
||||||
|
def restore_arch_parameters(self):
|
||||||
|
self.set_arch_parameters(self._saved_arch_parameters)
|
||||||
|
|
||||||
|
def reset_optimizer(self, lr, momentum, weight_decay):
|
||||||
|
del self.optimizer
|
||||||
|
self.optimizer = torch.optim.SGD(
|
||||||
|
self.get_weights(),
|
||||||
|
lr,
|
||||||
|
momentum=momentum,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
nesterov= args.nesterov)
|
33
nasbench201/search_model_darts.py
Normal file
33
nasbench201/search_model_darts.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from .search_cells import NAS201SearchCell as SearchCell
|
||||||
|
from .search_model import TinyNetwork as TinyNetwork
|
||||||
|
|
||||||
|
|
||||||
|
class TinyNetworkDarts(TinyNetwork):
|
||||||
|
def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args,
|
||||||
|
affine=False, track_running_stats=True, stem_channels=3):
|
||||||
|
super(TinyNetworkDarts, self).__init__(C, N, max_nodes, num_classes, criterion, search_space, args,
|
||||||
|
affine=affine, track_running_stats=track_running_stats, stem_channels=stem_channels)
|
||||||
|
|
||||||
|
self.theta_map = lambda x: torch.softmax(x, dim=-1)
|
||||||
|
|
||||||
|
def get_theta(self):
|
||||||
|
return self.theta_map(self._arch_parameters).cpu()
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
weights = self.theta_map(self._arch_parameters)
|
||||||
|
feature = self.stem(inputs)
|
||||||
|
|
||||||
|
for i, cell in enumerate(self.cells):
|
||||||
|
if isinstance(cell, SearchCell):
|
||||||
|
feature = cell(feature, weights)
|
||||||
|
else:
|
||||||
|
feature = cell(feature)
|
||||||
|
|
||||||
|
out = self.lastact(feature)
|
||||||
|
out = self.global_pooling( out )
|
||||||
|
out = out.view(out.size(0), -1)
|
||||||
|
logits = self.classifier(out)
|
||||||
|
|
||||||
|
return logits
|
80
nasbench201/search_model_darts_proj.py
Normal file
80
nasbench201/search_model_darts_proj.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import torch
|
||||||
|
from .search_cells import NAS201SearchCell as SearchCell
|
||||||
|
from .search_model import TinyNetwork as TinyNetwork
|
||||||
|
from .genotypes import Structure
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
class TinyNetworkDartsProj(TinyNetwork):
|
||||||
|
def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args,
|
||||||
|
affine=False, track_running_stats=True, stem_channels=3):
|
||||||
|
super(TinyNetworkDartsProj, self).__init__(C, N, max_nodes, num_classes, criterion, search_space, args,
|
||||||
|
affine=affine, track_running_stats=track_running_stats, stem_channels=stem_channels)
|
||||||
|
self.theta_map = lambda x: torch.softmax(x, dim=-1)
|
||||||
|
|
||||||
|
#### for edgewise projection
|
||||||
|
self.candidate_flags = torch.tensor(len(self._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda()
|
||||||
|
self.proj_weights = torch.zeros_like(self._arch_parameters)
|
||||||
|
|
||||||
|
def project_op(self, eid, opid):
|
||||||
|
self.proj_weights[eid][opid] = 1 ## hard by default
|
||||||
|
self.candidate_flags[eid] = False
|
||||||
|
|
||||||
|
def get_projected_weights(self):
|
||||||
|
weights = self.theta_map(self._arch_parameters)
|
||||||
|
|
||||||
|
## proj op
|
||||||
|
for eid in range(len(self._arch_parameters)):
|
||||||
|
if not self.candidate_flags[eid]:
|
||||||
|
weights[eid].data.copy_(self.proj_weights[eid])
|
||||||
|
|
||||||
|
return weights
|
||||||
|
|
||||||
|
def forward(self, inputs, weights=None):
|
||||||
|
with torch.autograd.set_detect_anomaly(True):
|
||||||
|
if weights is None:
|
||||||
|
weights = self.get_projected_weights()
|
||||||
|
|
||||||
|
feature = self.stem(inputs)
|
||||||
|
for i, cell in enumerate(self.cells):
|
||||||
|
if isinstance(cell, SearchCell):
|
||||||
|
feature = cell(feature, weights)
|
||||||
|
else:
|
||||||
|
feature = cell(feature)
|
||||||
|
|
||||||
|
out = self.lastact(feature)
|
||||||
|
out = self.global_pooling( out )
|
||||||
|
out = out.view(out.size(0), -1)
|
||||||
|
logits = self.classifier(out)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
#### utils
|
||||||
|
def get_theta(self):
|
||||||
|
return self.get_projected_weights()
|
||||||
|
|
||||||
|
def arch_parameters(self):
|
||||||
|
return [self._arch_parameters]
|
||||||
|
|
||||||
|
def set_arch_parameters(self, new_alphas):
|
||||||
|
for eid, alpha in enumerate(self.arch_parameters()):
|
||||||
|
alpha.data.copy_(new_alphas[eid])
|
||||||
|
|
||||||
|
def reset_arch_parameters(self):
|
||||||
|
self._arch_parameters = Variable(1e-3*torch.randn(self.num_edge, len(self.op_names)).cuda(), requires_grad=True)
|
||||||
|
self.candidate_flags = torch.tensor(len(self._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda()
|
||||||
|
self.proj_weights = torch.zeros_like(self._arch_parameters)
|
||||||
|
|
||||||
|
def genotype(self):
|
||||||
|
proj_weights = self.get_projected_weights()
|
||||||
|
|
||||||
|
genotypes = []
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
xlist = []
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
with torch.no_grad():
|
||||||
|
weights = proj_weights[ self.edge2index[node_str] ]
|
||||||
|
op_name = self.op_names[ weights.argmax().item() ]
|
||||||
|
xlist.append((op_name, j))
|
||||||
|
genotypes.append( tuple(xlist) )
|
||||||
|
return Structure( genotypes )
|
494
nasbench201/utils.py
Normal file
494
nasbench201/utils.py
Normal file
@ -0,0 +1,494 @@
|
|||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import os.path
|
||||||
|
import sys
|
||||||
|
import shutil
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from torchvision.datasets import VisionDataset
|
||||||
|
from torchvision.datasets import utils
|
||||||
|
|
||||||
|
if sys.version_info[0] == 2:
|
||||||
|
import cPickle as pickle
|
||||||
|
else:
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
|
class AvgrageMeter(object):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.avg = 0
|
||||||
|
self.sum = 0
|
||||||
|
self.cnt = 0
|
||||||
|
|
||||||
|
def update(self, val, n=1):
|
||||||
|
self.sum += val * n
|
||||||
|
self.cnt += n
|
||||||
|
self.avg = self.sum / self.cnt
|
||||||
|
|
||||||
|
|
||||||
|
def accuracy(output, target, topk=(1,)):
|
||||||
|
maxk = max(topk)
|
||||||
|
batch_size = target.size(0)
|
||||||
|
|
||||||
|
_, pred = output.topk(maxk, 1, True, True)
|
||||||
|
pred = pred.t()
|
||||||
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for k in topk:
|
||||||
|
correct_k = correct[:k].contiguous().view(-1).float().sum(0)
|
||||||
|
res.append(correct_k.mul_(100.0 / batch_size))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class Cutout(object):
|
||||||
|
def __init__(self, length, prob=1.0):
|
||||||
|
self.length = length
|
||||||
|
self.prob = prob
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
if np.random.binomial(1, self.prob):
|
||||||
|
h, w = img.size(1), img.size(2)
|
||||||
|
mask = np.ones((h, w), np.float32)
|
||||||
|
y = np.random.randint(h)
|
||||||
|
x = np.random.randint(w)
|
||||||
|
|
||||||
|
y1 = np.clip(y - self.length // 2, 0, h)
|
||||||
|
y2 = np.clip(y + self.length // 2, 0, h)
|
||||||
|
x1 = np.clip(x - self.length // 2, 0, w)
|
||||||
|
x2 = np.clip(x + self.length // 2, 0, w)
|
||||||
|
|
||||||
|
mask[y1: y2, x1: x2] = 0.
|
||||||
|
mask = torch.from_numpy(mask)
|
||||||
|
mask = mask.expand_as(img)
|
||||||
|
img *= mask
|
||||||
|
return img
|
||||||
|
|
||||||
|
def _data_transforms_svhn(args):
|
||||||
|
SVHN_MEAN = [0.4377, 0.4438, 0.4728]
|
||||||
|
SVHN_STD = [0.1980, 0.2010, 0.1970]
|
||||||
|
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.RandomCrop(32, padding=4),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(SVHN_MEAN, SVHN_STD),
|
||||||
|
])
|
||||||
|
if args.cutout:
|
||||||
|
train_transform.transforms.append(Cutout(args.cutout_length,
|
||||||
|
args.cutout_prob))
|
||||||
|
|
||||||
|
valid_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(SVHN_MEAN, SVHN_STD),
|
||||||
|
])
|
||||||
|
return train_transform, valid_transform
|
||||||
|
|
||||||
|
|
||||||
|
def _data_transforms_cifar100(args):
|
||||||
|
CIFAR_MEAN = [0.5071, 0.4865, 0.4409]
|
||||||
|
CIFAR_STD = [0.2673, 0.2564, 0.2762]
|
||||||
|
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.RandomCrop(32, padding=4),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
|
||||||
|
])
|
||||||
|
if args.cutout:
|
||||||
|
train_transform.transforms.append(Cutout(args.cutout_length,
|
||||||
|
args.cutout_prob))
|
||||||
|
|
||||||
|
valid_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
|
||||||
|
])
|
||||||
|
return train_transform, valid_transform
|
||||||
|
|
||||||
|
|
||||||
|
def _data_transforms_cifar10(args):
|
||||||
|
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
|
||||||
|
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
|
||||||
|
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.RandomCrop(32, padding=4),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
|
||||||
|
])
|
||||||
|
if args.cutout:
|
||||||
|
train_transform.transforms.append(Cutout(args.cutout_length,
|
||||||
|
args.cutout_prob))
|
||||||
|
|
||||||
|
valid_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
|
||||||
|
])
|
||||||
|
return train_transform, valid_transform
|
||||||
|
|
||||||
|
|
||||||
|
def count_parameters_in_MB(model):
|
||||||
|
return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6
|
||||||
|
|
||||||
|
|
||||||
|
def count_parameters_in_Compact(model):
|
||||||
|
from sota.cnn.model import Network as CompactModel
|
||||||
|
genotype = model.genotype()
|
||||||
|
compact_model = CompactModel(36, model._num_classes, 20, True, genotype)
|
||||||
|
num_params = count_parameters_in_MB(compact_model)
|
||||||
|
return num_params
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(state, is_best, save, per_epoch=False, prefix=''):
|
||||||
|
filename = prefix
|
||||||
|
if per_epoch:
|
||||||
|
epoch = state['epoch']
|
||||||
|
filename += 'checkpoint_{}.pth.tar'.format(epoch)
|
||||||
|
else:
|
||||||
|
filename += 'checkpoint.pth.tar'
|
||||||
|
filename = os.path.join(save, filename)
|
||||||
|
torch.save(state, filename)
|
||||||
|
if is_best:
|
||||||
|
best_filename = os.path.join(save, 'model_best.pth.tar')
|
||||||
|
shutil.copyfile(filename, best_filename)
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(model, optimizer, save, epoch=None):
|
||||||
|
if epoch is None:
|
||||||
|
filename = 'checkpoint.pth.tar'
|
||||||
|
else:
|
||||||
|
filename = 'checkpoint_{}.pth.tar'.format(epoch)
|
||||||
|
filename = os.path.join(save, filename)
|
||||||
|
start_epoch = 0
|
||||||
|
if os.path.isfile(filename):
|
||||||
|
print("=> loading checkpoint '{}'".format(filename))
|
||||||
|
checkpoint = torch.load(filename)
|
||||||
|
start_epoch = checkpoint['epoch']
|
||||||
|
best_acc_top1 = checkpoint['best_acc_top1']
|
||||||
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
|
print("=> loaded checkpoint '{}' (epoch {})"
|
||||||
|
.format(filename, checkpoint['epoch']))
|
||||||
|
else:
|
||||||
|
print("=> no checkpoint found at '{}'".format(filename))
|
||||||
|
|
||||||
|
return model, optimizer, start_epoch, best_acc_top1
|
||||||
|
|
||||||
|
|
||||||
|
def save(model, model_path):
|
||||||
|
torch.save(model.state_dict(), model_path)
|
||||||
|
|
||||||
|
|
||||||
|
def load(model, model_path):
|
||||||
|
model.load_state_dict(torch.load(model_path))
|
||||||
|
|
||||||
|
|
||||||
|
def drop_path(x, drop_prob):
|
||||||
|
if drop_prob > 0.:
|
||||||
|
keep_prob = 1. - drop_prob
|
||||||
|
mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
|
||||||
|
x.div_(keep_prob)
|
||||||
|
x.mul_(mask)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def create_exp_dir(path, scripts_to_save=None):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.makedirs(path)
|
||||||
|
print('Experiment dir : {}'.format(path))
|
||||||
|
|
||||||
|
if scripts_to_save is not None:
|
||||||
|
os.mkdir(os.path.join(path, 'scripts'))
|
||||||
|
for script in scripts_to_save:
|
||||||
|
dst_file = os.path.join(path, 'scripts', os.path.basename(script))
|
||||||
|
shutil.copyfile(script, dst_file)
|
||||||
|
|
||||||
|
|
||||||
|
class CIFAR10(VisionDataset):
|
||||||
|
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root (string): Root directory of dataset where directory
|
||||||
|
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
|
||||||
|
train (bool, optional): If True, creates dataset from training set, otherwise
|
||||||
|
creates from test set.
|
||||||
|
transform (callable, optional): A function/transform that takes in an PIL image
|
||||||
|
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||||
|
target_transform (callable, optional): A function/transform that takes in the
|
||||||
|
target and transforms it.
|
||||||
|
download (bool, optional): If true, downloads the dataset from the internet and
|
||||||
|
puts it in root directory. If dataset is already downloaded, it is not
|
||||||
|
downloaded again.
|
||||||
|
|
||||||
|
"""
|
||||||
|
base_folder = 'cifar-10-batches-py'
|
||||||
|
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
||||||
|
filename = "cifar-10-python.tar.gz"
|
||||||
|
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
|
||||||
|
train_list = [
|
||||||
|
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
|
||||||
|
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
|
||||||
|
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
|
||||||
|
['data_batch_4', '634d18415352ddfa80567beed471001a'],
|
||||||
|
#['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
|
||||||
|
]
|
||||||
|
|
||||||
|
test_list = [
|
||||||
|
['test_batch', '40351d587109b95175f43aff81a1287e'],
|
||||||
|
]
|
||||||
|
meta = {
|
||||||
|
'filename': 'batches.meta',
|
||||||
|
'key': 'label_names',
|
||||||
|
'md5': '5ff9c542aee3614f3951f8cda6e48888',
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, root, train=True, transform=None, target_transform=None,
|
||||||
|
download=False):
|
||||||
|
|
||||||
|
super(CIFAR10, self).__init__(root, transform=transform,
|
||||||
|
target_transform=target_transform)
|
||||||
|
|
||||||
|
self.train = train # training set or test set
|
||||||
|
|
||||||
|
if download:
|
||||||
|
self.download()
|
||||||
|
|
||||||
|
if not self._check_integrity():
|
||||||
|
raise RuntimeError('Dataset not found or corrupted.' +
|
||||||
|
' You can use download=True to download it')
|
||||||
|
|
||||||
|
if self.train:
|
||||||
|
downloaded_list = self.train_list
|
||||||
|
else:
|
||||||
|
downloaded_list = self.test_list
|
||||||
|
|
||||||
|
self.data = []
|
||||||
|
self.targets = []
|
||||||
|
|
||||||
|
# now load the picked numpy arrays
|
||||||
|
for file_name, checksum in downloaded_list:
|
||||||
|
file_path = os.path.join(self.root, self.base_folder, file_name)
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
if sys.version_info[0] == 2:
|
||||||
|
entry = pickle.load(f)
|
||||||
|
else:
|
||||||
|
entry = pickle.load(f, encoding='latin1')
|
||||||
|
self.data.append(entry['data'])
|
||||||
|
if 'labels' in entry:
|
||||||
|
self.targets.extend(entry['labels'])
|
||||||
|
else:
|
||||||
|
self.targets.extend(entry['fine_labels'])
|
||||||
|
|
||||||
|
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
|
||||||
|
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
|
||||||
|
|
||||||
|
self._load_meta()
|
||||||
|
|
||||||
|
def _load_meta(self):
|
||||||
|
path = os.path.join(self.root, self.base_folder, self.meta['filename'])
|
||||||
|
if not utils.check_integrity(path, self.meta['md5']):
|
||||||
|
raise RuntimeError('Dataset metadata file not found or corrupted.' +
|
||||||
|
' You can use download=True to download it')
|
||||||
|
with open(path, 'rb') as infile:
|
||||||
|
if sys.version_info[0] == 2:
|
||||||
|
data = pickle.load(infile)
|
||||||
|
else:
|
||||||
|
data = pickle.load(infile, encoding='latin1')
|
||||||
|
self.classes = data[self.meta['key']]
|
||||||
|
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
index (int): Index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (image, target) where target is index of the target class.
|
||||||
|
"""
|
||||||
|
img, target = self.data[index], self.targets[index]
|
||||||
|
|
||||||
|
# doing this so that it is consistent with all other datasets
|
||||||
|
# to return a PIL Image
|
||||||
|
img = Image.fromarray(img)
|
||||||
|
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(img)
|
||||||
|
|
||||||
|
if self.target_transform is not None:
|
||||||
|
target = self.target_transform(target)
|
||||||
|
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def _check_integrity(self):
|
||||||
|
root = self.root
|
||||||
|
for fentry in (self.train_list + self.test_list):
|
||||||
|
filename, md5 = fentry[0], fentry[1]
|
||||||
|
fpath = os.path.join(root, self.base_folder, filename)
|
||||||
|
if not utils.check_integrity(fpath, md5):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def download(self):
|
||||||
|
if self._check_integrity():
|
||||||
|
print('Files already downloaded and verified')
|
||||||
|
return
|
||||||
|
utils.download_and_extract_archive(self.url, self.root,
|
||||||
|
filename=self.filename,
|
||||||
|
md5=self.tgz_md5)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return "Split: {}".format("Train" if self.train is True else "Test")
|
||||||
|
|
||||||
|
|
||||||
|
def pick_gpu_lowest_memory():
|
||||||
|
import gpustat
|
||||||
|
stats = gpustat.GPUStatCollection.new_query()
|
||||||
|
ids = map(lambda gpu: int(gpu.entry['index']), stats)
|
||||||
|
ratios = map(lambda gpu: float(gpu.memory_used)/float(gpu.memory_total), stats)
|
||||||
|
bestGPU = min(zip(ids, ratios), key=lambda x: x[1])[0]
|
||||||
|
return bestGPU
|
||||||
|
|
||||||
|
|
||||||
|
#### early stopping (from RobustNAS)
|
||||||
|
class EVLocalAvg(object):
|
||||||
|
def __init__(self, window=5, ev_freq=2, total_epochs=50):
|
||||||
|
""" Keep track of the eigenvalues local average.
|
||||||
|
Args:
|
||||||
|
window (int): number of elements used to compute local average.
|
||||||
|
Default: 5
|
||||||
|
ev_freq (int): frequency used to compute eigenvalues. Default:
|
||||||
|
every 2 epochs
|
||||||
|
total_epochs (int): total number of epochs that DARTS runs.
|
||||||
|
Default: 50
|
||||||
|
"""
|
||||||
|
self.window = window
|
||||||
|
self.ev_freq = ev_freq
|
||||||
|
self.epochs = total_epochs
|
||||||
|
|
||||||
|
self.stop_search = False
|
||||||
|
self.stop_epoch = total_epochs - 1
|
||||||
|
self.stop_genotype = None
|
||||||
|
self.stop_numparam = 0
|
||||||
|
|
||||||
|
self.ev = []
|
||||||
|
self.ev_local_avg = []
|
||||||
|
self.genotypes = {}
|
||||||
|
self.numparams = {}
|
||||||
|
self.la_epochs = {}
|
||||||
|
|
||||||
|
# start and end index of the local average window
|
||||||
|
self.la_start_idx = 0
|
||||||
|
self.la_end_idx = self.window
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.ev = []
|
||||||
|
self.ev_local_avg = []
|
||||||
|
self.genotypes = {}
|
||||||
|
self.numparams = {}
|
||||||
|
self.la_epochs = {}
|
||||||
|
|
||||||
|
def update(self, epoch, ev, genotype, numparam=0):
|
||||||
|
""" Method to update the local average list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epoch (int): current epoch
|
||||||
|
ev (float): current dominant eigenvalue
|
||||||
|
genotype (namedtuple): current genotype
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.ev.append(ev)
|
||||||
|
self.genotypes.update({epoch: genotype})
|
||||||
|
self.numparams.update({epoch: numparam})
|
||||||
|
# set the stop_genotype to the current genotype in case the early stop
|
||||||
|
# procedure decides not to early stop
|
||||||
|
self.stop_genotype = genotype
|
||||||
|
|
||||||
|
# since the local average computation starts after the dominant
|
||||||
|
# eigenvalue in the first epoch is already computed we have to wait
|
||||||
|
# at least until we have 3 eigenvalues in the list.
|
||||||
|
if (len(self.ev) >= int(np.ceil(self.window/2))) and (epoch <
|
||||||
|
self.epochs - 1):
|
||||||
|
# start sliding the window as soon as the number of eigenvalues in
|
||||||
|
# the list becomes equal to the window size
|
||||||
|
if len(self.ev) < self.window:
|
||||||
|
self.ev_local_avg.append(np.mean(self.ev))
|
||||||
|
else:
|
||||||
|
assert len(self.ev[self.la_start_idx: self.la_end_idx]) == self.window
|
||||||
|
self.ev_local_avg.append(np.mean(self.ev[self.la_start_idx:
|
||||||
|
self.la_end_idx]))
|
||||||
|
self.la_start_idx += 1
|
||||||
|
self.la_end_idx += 1
|
||||||
|
|
||||||
|
# keep track of the offset between the current epoch and the epoch
|
||||||
|
# corresponding to the local average. NOTE: in the end the size of
|
||||||
|
# self.ev and self.ev_local_avg should be equal
|
||||||
|
self.la_epochs.update({epoch: int(epoch -
|
||||||
|
int(self.ev_freq*np.floor(self.window/2)))})
|
||||||
|
|
||||||
|
elif len(self.ev) < int(np.ceil(self.window/2)):
|
||||||
|
self.la_epochs.update({epoch: -1})
|
||||||
|
|
||||||
|
# since there is an offset between the current epoch and the local
|
||||||
|
# average epoch, loop in the last epoch to compute the local average of
|
||||||
|
# these number of elements: window, window - 1, window - 2, ..., ceil(window/2)
|
||||||
|
elif epoch == self.epochs - 1:
|
||||||
|
for i in range(int(np.ceil(self.window/2))):
|
||||||
|
assert len(self.ev[self.la_start_idx: self.la_end_idx]) == self.window - i
|
||||||
|
self.ev_local_avg.append(np.mean(self.ev[self.la_start_idx:
|
||||||
|
self.la_end_idx + 1]))
|
||||||
|
self.la_start_idx += 1
|
||||||
|
|
||||||
|
def early_stop(self, epoch, factor=1.3, es_start_epoch=10, delta=4, criteria='local_avg'):
|
||||||
|
""" Early stopping criterion
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epoch (int): current epoch
|
||||||
|
factor (float): threshold factor for the ration between the current
|
||||||
|
and prefious eigenvalue. Default: 1.3
|
||||||
|
es_start_epoch (int): until this epoch do not consider early
|
||||||
|
stopping. Default: 20
|
||||||
|
delta (int): factor influencing which previous local average we
|
||||||
|
consider for early stopping. Default: 2
|
||||||
|
"""
|
||||||
|
if criteria == 'local_avg':
|
||||||
|
if int(self.la_epochs[epoch] - self.ev_freq*delta) >= es_start_epoch:
|
||||||
|
if criteria == 'local_avg':
|
||||||
|
current_la = self.ev_local_avg[-1]
|
||||||
|
previous_la = self.ev_local_avg[-1 - delta]
|
||||||
|
self.stop_search = current_la / previous_la > factor
|
||||||
|
if self.stop_search:
|
||||||
|
self.stop_epoch = int(self.la_epochs[epoch] - self.ev_freq*delta)
|
||||||
|
self.stop_genotype = self.genotypes[self.stop_epoch]
|
||||||
|
self.stop_numparam = self.numparams[self.stop_epoch]
|
||||||
|
elif criteria == 'exact':
|
||||||
|
if epoch > es_start_epoch:
|
||||||
|
current_la = self.ev[-1]
|
||||||
|
previous_la = self.ev[-1 - delta]
|
||||||
|
self.stop_search = current_la / previous_la > factor
|
||||||
|
if self.stop_search:
|
||||||
|
self.stop_epoch = epoch - delta
|
||||||
|
self.stop_genotype = self.genotypes[self.stop_epoch]
|
||||||
|
self.stop_numparam = self.numparams[self.stop_epoch]
|
||||||
|
else:
|
||||||
|
print('ERROR IN EARLY STOP: WRONG CRITERIA:', criteria); exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_comb(eids):
|
||||||
|
comb = []
|
||||||
|
for r in range(len(eids)):
|
||||||
|
for c in range(r + 1, len(eids)):
|
||||||
|
comb.append((eids[r], eids[c]))
|
||||||
|
|
||||||
|
return comb
|
Loading…
Reference in New Issue
Block a user