From 2410fe9f5edd82d50b6e4df08c59ddd2a6948b93 Mon Sep 17 00:00:00 2001 From: HamsterMimi Date: Thu, 4 May 2023 13:42:06 +0800 Subject: [PATCH] update --- correlation/NAS-Bench-101.py | 133 +++++ correlation/NAS-Bench-201.py | 128 +++++ exp_scripts/zerocostpt_nb201_pipeline.sh | 38 ++ nasbench201/DownsampledImageNet.py | 110 ++++ nasbench201/architect_ig.py | 52 ++ nasbench201/cell_infers/cells.py | 120 +++++ nasbench201/cell_infers/tiny_network.py | 82 +++ nasbench201/cell_operations.py | 289 +++++++++++ nasbench201/genotypes.py | 194 +++++++ nasbench201/init_projection.py | 619 +++++++++++++++++++++++ nasbench201/linear_region.py | 270 ++++++++++ nasbench201/networks_proposal.py | 245 +++++++++ nasbench201/op_score.py | 113 +++++ nasbench201/search_cells.py | 182 +++++++ nasbench201/search_model.py | 202 ++++++++ nasbench201/search_model_darts.py | 33 ++ nasbench201/search_model_darts_proj.py | 80 +++ nasbench201/utils.py | 494 ++++++++++++++++++ 18 files changed, 3384 insertions(+) create mode 100644 correlation/NAS-Bench-101.py create mode 100644 correlation/NAS-Bench-201.py create mode 100644 exp_scripts/zerocostpt_nb201_pipeline.sh create mode 100644 nasbench201/DownsampledImageNet.py create mode 100644 nasbench201/architect_ig.py create mode 100644 nasbench201/cell_infers/cells.py create mode 100644 nasbench201/cell_infers/tiny_network.py create mode 100644 nasbench201/cell_operations.py create mode 100644 nasbench201/genotypes.py create mode 100644 nasbench201/init_projection.py create mode 100644 nasbench201/linear_region.py create mode 100644 nasbench201/networks_proposal.py create mode 100644 nasbench201/op_score.py create mode 100644 nasbench201/search_cells.py create mode 100644 nasbench201/search_model.py create mode 100644 nasbench201/search_model_darts.py create mode 100644 nasbench201/search_model_darts_proj.py create mode 100644 nasbench201/utils.py diff --git a/correlation/NAS-Bench-101.py b/correlation/NAS-Bench-101.py new file mode 100644 index 0000000..e9e7f11 --- /dev/null +++ b/correlation/NAS-Bench-101.py @@ -0,0 +1,133 @@ +# Copyright 2021 Samsung Electronics Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import pickle +import torch +import argparse +import json +import numpy as np +from thop import profile + +from foresight.models import * +from foresight.pruners import * +from foresight.dataset import * + + +def get_num_classes(args): + return 100 if args.dataset == 'cifar100' else 10 if args.dataset == 'cifar10' else 120 + + +def parse_arguments(): + parser = argparse.ArgumentParser(description='Zero-cost Metrics for NAS-Bench-101') + parser.add_argument('--api_loc', default='../data/nasbench_only108.tfrecord', + type=str, help='path to API') + parser.add_argument('--json_loc', default='data/all_graphs.json', + type=str, help='path to JSON database') + parser.add_argument('--outdir', default='./', + type=str, help='output directory') + parser.add_argument('--outfname', default='test', + type=str, help='output filename') + parser.add_argument('--batch_size', default=256, type=int) + parser.add_argument('--dataset', type=str, default='cifar10', + help='dataset to use [cifar10, cifar100, ImageNet16-120]') + parser.add_argument('--gpu', type=int, default=0, help='GPU index to work on') + parser.add_argument('--num_data_workers', type=int, default=2, help='number of workers for dataloaders') + parser.add_argument('--dataload', type=str, default='random', help='random or grasp supported') + parser.add_argument('--dataload_info', type=int, default=1, + help='number of batches to use for random dataload or number of samples per class for grasp dataload') + parser.add_argument('--start', type=int, default=5, help='start index') + parser.add_argument('--end', type=int, default=10, help='end index') + parser.add_argument('--write_freq', type=int, default=100, help='frequency of write to file') + args = parser.parse_args() + args.device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") + return args + + +def get_op_names(v): + o = [] + for op in v: + if op == -1: + o.append('input') + elif op == -2: + o.append('output') + elif op == 0: + o.append('conv3x3-bn-relu') + elif op == 1: + o.append('conv1x1-bn-relu') + elif op == 2: + o.append('maxpool3x3') + return o + + +if __name__ == '__main__': + args = parse_arguments() + # nasbench = api.NASBench(args.api_loc) + models = json.load(open(args.json_loc)) + + print(f'Running models {args.start} to {args.end} out of {len(models.keys())}') + + train_loader, val_loader = get_cifar_dataloaders(args.batch_size, args.batch_size, args.dataset, + args.num_data_workers) + + all_points = [] + pre = 'cf' if 'cifar' in args.dataset else 'im' + + if args.outfname == 'test': + fn = f'nb1_{pre}{get_num_classes(args)}.p' + else: + fn = f'{args.outfname}.p' + op = os.path.join(args.outdir, fn) + + print('outfile =', op) + first = True + + # loop over nasbench1 archs (k=hash, v=[adj_matrix, ops]) + idx = 0 + cached_res = [] + for k, v in models.items(): + + if idx < args.start: + idx += 1 + continue + if idx >= args.end: + break + print(f'idx = {idx}') + idx += 1 + + res = {} + res['hash'] = k + + # model + spec = nasbench1_spec._ToModelSpec(v[0], get_op_names(v[1])) + net = nasbench1.Network(spec, stem_out=128, num_stacks=3, num_mods=3, num_classes=get_num_classes(args)) + net.to(args.device) + + measures = predictive.find_measures(net, + train_loader, + (args.dataload, args.dataload_info, get_num_classes(args)), + args.device) + res['logmeasures'] = measures + + print(res) + cached_res.append(res) + + # write to file + if idx % args.write_freq == 0 or idx == args.end or idx == args.start + 10: + print(f'writing {len(cached_res)} results to {op}') + pf = open(op, 'ab') + for cr in cached_res: + pickle.dump(cr, pf) + pf.close() + cached_res = [] diff --git a/correlation/NAS-Bench-201.py b/correlation/NAS-Bench-201.py new file mode 100644 index 0000000..2e5f5ae --- /dev/null +++ b/correlation/NAS-Bench-201.py @@ -0,0 +1,128 @@ +import argparse +import os + +import time + +from foresight.dataset import * +from foresight.models import nasbench2 +from foresight.pruners import predictive +from foresight.weight_initializers import init_net +from models import get_cell_based_tiny_net +import pickle + + +def get_num_classes(args): + return 100 if args.dataset == 'cifar100' else 10 if args.dataset == 'cifar10' else 120 + + +def parse_arguments(): + parser = argparse.ArgumentParser(description='Zero-cost Metrics for NAS-Bench-201') + parser.add_argument('--api_loc', default='../data/NAS-Bench-201-v1_0-e61699.pth', + type=str, help='path to API') + parser.add_argument('--outdir', default='./', + type=str, help='output directory') + parser.add_argument('--init_w_type', type=str, default='none', + help='weight initialization (before pruning) type [none, xavier, kaiming, zero, one]') + parser.add_argument('--init_b_type', type=str, default='none', + help='bias initialization (before pruning) type [none, xavier, kaiming, zero, one]') + parser.add_argument('--batch_size', default=64, type=int) + parser.add_argument('--dataset', type=str, default='ImageNet16-120', + help='dataset to use [cifar10, cifar100, ImageNet16-120]') + parser.add_argument('--gpu', type=int, default=5, help='GPU index to work on') + parser.add_argument('--data_size', type=int, default=32, help='data_size') + parser.add_argument('--num_data_workers', type=int, default=2, help='number of workers for dataloaders') + parser.add_argument('--dataload', type=str, default='appoint', help='random, grasp, appoint supported') + parser.add_argument('--dataload_info', type=int, default=1, + help='number of batches to use for random dataload or number of samples per class for grasp dataload') + parser.add_argument('--seed', type=int, default=42, help='pytorch manual seed') + parser.add_argument('--write_freq', type=int, default=1, help='frequency of write to file') + parser.add_argument('--start', type=int, default=0, help='start index') + parser.add_argument('--end', type=int, default=0, help='end index') + parser.add_argument('--noacc', default=False, action='store_true', + help='avoid loading NASBench2 api an instead load a pickle file with tuple (index, arch_str)') + args = parser.parse_args() + args.device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") + return args + + +if __name__ == '__main__': + args = parse_arguments() + print(args.device) + + if args.noacc: + api = pickle.load(open(args.api_loc,'rb')) + else: + from nas_201_api import NASBench201API as API + api = API(args.api_loc) + + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + train_loader, val_loader = get_cifar_dataloaders(args.batch_size, args.batch_size, args.dataset, args.num_data_workers, resize=args.data_size) + x, y = next(iter(train_loader)) + # random data + # x = torch.rand((args.batch_size, 3, args.data_size, args.data_size)) + # y = 0 + + cached_res = [] + pre = 'cf' if 'cifar' in args.dataset else 'im' + pfn = f'nb2_{args.search_space}_{pre}{get_num_classes(args)}_seed{args.seed}_dl{args.dataload}_dlinfo{args.dataload_info}_initw{args.init_w_type}_initb{args.init_b_type}_{args.batch_size}.p' + op = os.path.join(args.outdir, pfn) + + end = len(api) if args.end == 0 else args.end + + # loop over nasbench2 archs + for i, arch_str in enumerate(api): + + if i < args.start: + continue + if i >= end: + break + + res = {'i': i, 'arch': arch_str} + # print(arch_str) + if args.search_space == 'tss': + net = nasbench2.get_model_from_arch_str(arch_str, get_num_classes(args)) + arch_str2 = nasbench2.get_arch_str_from_model(net) + if arch_str != arch_str2: + print(arch_str) + print(arch_str2) + raise ValueError + elif args.search_space == 'sss': + config = api.get_net_config(i, args.dataset) + # print(config) + net = get_cell_based_tiny_net(config) + net.to(args.device) + # print(net) + + init_net(net, args.init_w_type, args.init_b_type) + + # print(x.size(), y) + measures = get_score(net, x, i, args.device) + + res['meco'] = measures + + if not args.noacc: + info = api.get_more_info(i, 'cifar10-valid' if args.dataset == 'cifar10' else args.dataset, iepoch=None, + hp='200', is_random=False) + + trainacc = info['train-accuracy'] + valacc = info['valid-accuracy'] + testacc = info['test-accuracy'] + + res['trainacc'] = trainacc + res['valacc'] = valacc + res['testacc'] = testacc + + print(res) + cached_res.append(res) + + # write to file + if i % args.write_freq == 0 or i == len(api) - 1 or i == 10: + print(f'writing {len(cached_res)} results to {op}') + pf = open(op, 'ab') + for cr in cached_res: + pickle.dump(cr, pf) + pf.close() + cached_res = [] diff --git a/exp_scripts/zerocostpt_nb201_pipeline.sh b/exp_scripts/zerocostpt_nb201_pipeline.sh new file mode 100644 index 0000000..98bfc8b --- /dev/null +++ b/exp_scripts/zerocostpt_nb201_pipeline.sh @@ -0,0 +1,38 @@ +#!/bin/bash +script_name=`basename "$0"` +id=${script_name%.*} +dataset=${dataset:-cifar10} +seed=${seed:-0} +gpu=${gpu:-"auto"} +pool_size=${pool_size:-10} +batch_size=${batch_size:-256} +edge_decision=${edge_decision:-'random'} +validate_rounds=${validate_rounds:-100} +metric=${metric:-'jacob'} +while [ $# -gt 0 ]; do + if [[ $1 == *"--"* ]]; then + param="${1/--/}" + declare $param="$2" + # echo $1 $2 // Optional to see the parameter:value result + fi + shift +done + +echo 'id:' $id 'seed:' $seed 'dataset:' $dataset +echo 'gpu:' $gpu + +cd ../nasbench201/ +python3 networks_proposal.py \ + --dataset $dataset \ + --save $id --gpu $gpu --seed $seed \ + --edge_decision $edge_decision --proj_crit $metric \ + --batch_size $batch_size\ + --pool_size $pool_size \ + +cd ../zerocostnas/ +python3 post_validate.py\ + --ckpt_path ../experiments/nas-bench-201/prop-$id-$seed-$pool_size-$metric\ + --save $id --seed $seed --gpu $gpu\ + --edge_decision $edge_decision --proj_crit $metric \ + --batch_size $batch_size\ + --validate_rounds $validate_rounds\ diff --git a/nasbench201/DownsampledImageNet.py b/nasbench201/DownsampledImageNet.py new file mode 100644 index 0000000..fbcd502 --- /dev/null +++ b/nasbench201/DownsampledImageNet.py @@ -0,0 +1,110 @@ +import os, sys, hashlib, torch +import numpy as np +from PIL import Image +import torch.utils.data as data +import pickle + + +def calculate_md5(fpath, chunk_size=1024 * 1024): + md5 = hashlib.md5() + with open(fpath, 'rb') as f: + for chunk in iter(lambda: f.read(chunk_size), b''): + md5.update(chunk) + return md5.hexdigest() + + +def check_md5(fpath, md5, **kwargs): + return md5 == calculate_md5(fpath, **kwargs) + + +def check_integrity(fpath, md5=None): + print(fpath) + if not os.path.isfile(fpath): return False + if md5 is None: return True + else : return check_md5(fpath, md5) + + +class ImageNet16(data.Dataset): + # http://image-net.org/download-images + # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets + # https://arxiv.org/pdf/1707.08819.pdf + + train_list = [ + ['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'], + ['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'], + ['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'], + ['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'], + ['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'], + ['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'], + ['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'], + ['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'], + ['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'], + ['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'], + ] + valid_list = [ + ['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'], + ] + + def __init__(self, root, train, transform, use_num_of_class_only=None): + self.root = root + self.transform = transform + self.train = train # training set or valid set + if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.') + + if self.train: downloaded_list = self.train_list + else : downloaded_list = self.valid_list + self.data = [] + self.targets = [] + + # now load the picked numpy arrays + for i, (file_name, checksum) in enumerate(downloaded_list): + file_path = os.path.join(self.root, file_name) + #print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path)) + with open(file_path, 'rb') as f: + if sys.version_info[0] == 2: + entry = pickle.load(f) + else: + entry = pickle.load(f, encoding='latin1') + self.data.append(entry['data']) + self.targets.extend(entry['labels']) + self.data = np.vstack(self.data).reshape(-1, 3, 16, 16) + self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC + if use_num_of_class_only is not None: + assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only) + new_data, new_targets = [], [] + for I, L in zip(self.data, self.targets): + if 1 <= L <= use_num_of_class_only: + new_data.append( I ) + new_targets.append( L ) + self.data = new_data + self.targets = new_targets + # self.mean.append(entry['mean']) + #self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16) + #self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1) + #print ('Mean : {:}'.format(self.mean)) + #temp = self.data - np.reshape(self.mean, (1, 1, 1, 3)) + #std_data = np.std(temp, axis=0) + #std_data = np.mean(np.mean(std_data, axis=0), axis=0) + #print ('Std : {:}'.format(std_data)) + + def __getitem__(self, index): + img, target = self.data[index], self.targets[index] - 1 + + img = Image.fromarray(img) + + if self.transform is not None: + img = self.transform(img) + + return img, target + + def __len__(self): + return len(self.data) + + def _check_integrity(self): + root = self.root + for fentry in (self.train_list + self.valid_list): + filename, md5 = fentry[0], fentry[1] + fpath = os.path.join(root, filename) + if not check_integrity(fpath, md5): + return False + return True \ No newline at end of file diff --git a/nasbench201/architect_ig.py b/nasbench201/architect_ig.py new file mode 100644 index 0000000..d3c23f1 --- /dev/null +++ b/nasbench201/architect_ig.py @@ -0,0 +1,52 @@ +import torch + + +class Architect(object): + def __init__(self, model, args): + self.network_momentum = args.momentum + self.network_weight_decay = args.weight_decay + self.model = model + self.optimizer = torch.optim.Adam(self.model.arch_parameters(), + lr=args.arch_learning_rate, betas=(0.5, 0.999), + weight_decay=args.arch_weight_decay) + + self._init_arch_parameters = [] + for alpha in self.model.arch_parameters(): + alpha_init = torch.zeros_like(alpha) + alpha_init.data.copy_(alpha) + self._init_arch_parameters.append(alpha_init) + + #### mode + if args.method in ['darts', 'darts-proj', 'sdarts', 'sdarts-proj']: + self.method = 'fo' # first order update + elif 'so' in args.method: + print('ERROR: PLEASE USE architect.py for second order darts') + elif args.method in ['blank', 'blank-proj']: + self.method = 'blank' + else: + print('ERROR: WRONG ARCH UPDATE METHOD', args.method); exit(0) + + def reset_arch_parameters(self): + for alpha, alpha_init in zip(self.model.arch_parameters(), self._init_arch_parameters): + alpha.data.copy_(alpha_init.data) + + def step(self, input_train, target_train, input_valid, target_valid, *args, **kwargs): + if self.method == 'fo': + shared = self._step_fo(input_train, target_train, input_valid, target_valid) + elif self.method == 'so': + raise NotImplementedError + elif self.method == 'blank': ## do not update alpha + shared = None + + return shared + + #### first order + def _step_fo(self, input_train, target_train, input_valid, target_valid): + loss = self.model._loss(input_valid, target_valid) + loss.backward() + self.optimizer.step() + return None + + #### darts 2nd order + def _step_darts_so(self, input_train, target_train, input_valid, target_valid, eta, model_optimizer): + raise NotImplementedError \ No newline at end of file diff --git a/nasbench201/cell_infers/cells.py b/nasbench201/cell_infers/cells.py new file mode 100644 index 0000000..2dbb925 --- /dev/null +++ b/nasbench201/cell_infers/cells.py @@ -0,0 +1,120 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # +##################################################### + +import torch +import torch.nn as nn +from copy import deepcopy +from ..cell_operations import OPS + + +# Cell for NAS-Bench-201 +class InferCell(nn.Module): + + def __init__(self, genotype, C_in, C_out, stride): + super(InferCell, self).__init__() + + self.layers = nn.ModuleList() + self.node_IN = [] + self.node_IX = [] + self.genotype = deepcopy(genotype) + for i in range(1, len(genotype)): + node_info = genotype[i-1] + cur_index = [] + cur_innod = [] + for (op_name, op_in) in node_info: + if op_in == 0: + layer = OPS[op_name](C_in , C_out, stride, True, True) + else: + layer = OPS[op_name](C_out, C_out, 1, True, True) + cur_index.append( len(self.layers) ) + cur_innod.append( op_in ) + self.layers.append( layer ) + self.node_IX.append( cur_index ) + self.node_IN.append( cur_innod ) + self.nodes = len(genotype) + self.in_dim = C_in + self.out_dim = C_out + + def extra_repr(self): + string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__) + laystr = [] + for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)): + y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)] + x = '{:}<-({:})'.format(i+1, ','.join(y)) + laystr.append( x ) + return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr()) + + def forward(self, inputs): + nodes = [inputs] + for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)): + node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) ) + nodes.append( node_feature ) + return nodes[-1] + + + +# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 +class NASNetInferCell(nn.Module): + + def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats): + super(NASNetInferCell, self).__init__() + self.reduction = reduction + if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats) + else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats) + self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats) + + if not reduction: + nodes, concats = genotype['normal'], genotype['normal_concat'] + else: + nodes, concats = genotype['reduce'], genotype['reduce_concat'] + self._multiplier = len(concats) + self._concats = concats + self._steps = len(nodes) + self._nodes = nodes + self.edges = nn.ModuleDict() + for i, node in enumerate(nodes): + for in_node in node: + name, j = in_node[0], in_node[1] + stride = 2 if reduction and j < 2 else 1 + node_str = '{:}<-{:}'.format(i+2, j) + self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats) + + # [TODO] to support drop_prob in this function.. + def forward(self, s0, s1, unused_drop_prob): + s0 = self.preprocess0(s0) + s1 = self.preprocess1(s1) + + states = [s0, s1] + for i, node in enumerate(self._nodes): + clist = [] + for in_node in node: + name, j = in_node[0], in_node[1] + node_str = '{:}<-{:}'.format(i+2, j) + op = self.edges[ node_str ] + clist.append( op(states[j]) ) + states.append( sum(clist) ) + return torch.cat([states[x] for x in self._concats], dim=1) + + +class AuxiliaryHeadCIFAR(nn.Module): + + def __init__(self, C, num_classes): + """assuming input size 8x8""" + super(AuxiliaryHeadCIFAR, self).__init__() + self.features = nn.Sequential( + nn.ReLU(inplace=True), + nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 + nn.Conv2d(C, 128, 1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 768, 2, bias=False), + nn.BatchNorm2d(768), + nn.ReLU(inplace=True) + ) + self.classifier = nn.Linear(768, num_classes) + + def forward(self, x): + x = self.features(x) + x = self.classifier(x.view(x.size(0),-1)) + return x diff --git a/nasbench201/cell_infers/tiny_network.py b/nasbench201/cell_infers/tiny_network.py new file mode 100644 index 0000000..b50e3a8 --- /dev/null +++ b/nasbench201/cell_infers/tiny_network.py @@ -0,0 +1,82 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # +##################################################### +import torch.nn as nn +from ..cell_operations import ResNetBasicblock +from .cells import InferCell + + +# The macro structure for architectures in NAS-Bench-201 +class TinyNetwork(nn.Module): + + def __init__(self, C, N, genotype, num_classes): + super(TinyNetwork, self).__init__() + self._C = C + self._layerN = N + + self.stem = nn.Sequential( + nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(C)) + + layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + + C_prev = C + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2, True) + else: + cell = InferCell(genotype, C_prev, C_curr, 1) + self.cells.append( cell ) + C_prev = cell.out_dim + self._Layer= len(self.cells) + + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + + self.requires_feature = True + + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) + return string + + def extra_repr(self): + return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) + + def forward(self, inputs): + feature = self.stem(inputs) + for i, cell in enumerate(self.cells): + feature = cell(feature) + + out = self.lastact(feature) + out = self.global_pooling( out ) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + if self.requires_feature: + return logits, out + else: + return logits + + def _loss(self, input, target, return_logits=False): + logits, _ = self(input) + loss = self._criterion(logits, target) + + return (loss, logits) if return_logits else loss + + def step(self, input, target, args, shared=None, return_grad=False): + Lt, logit_t = self._loss(input, target, return_logits=True) + Lt.backward() + if args.grad_clip != 0: + nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip) + self.optimizer.step() + + if return_grad: + grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()]) + return logit_t, Lt, grad + else: + return logit_t, Lt \ No newline at end of file diff --git a/nasbench201/cell_operations.py b/nasbench201/cell_operations.py new file mode 100644 index 0000000..e83cb68 --- /dev/null +++ b/nasbench201/cell_operations.py @@ -0,0 +1,289 @@ +import sys +import torch +import torch.nn as nn +sys.path.insert(0, '../') +from Layers import layers +__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames'] + +OPS = { + 'noise' : lambda C_in, C_out, stride, affine, track_running_stats: NoiseOp(stride, 0., 1.), # C_in, C_out not needed + 'none' : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride), + 'avg_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'avg', affine, track_running_stats), + 'max_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'max', affine, track_running_stats), + 'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats), + 'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats), + 'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats), + 'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats), + 'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats), + 'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats), + 'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats), + 'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats), +} + +CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3'] +NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] +DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3'] +#### wrc modified +NAS_BENCH_201_SKIP = ['none', 'skip_connect', 'nor_conv_1x1_skip', 'nor_conv_3x3_skip', 'avg_pool_3x3'] +NAS_BENCH_201_SIMPLE = ['skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] +NAS_BENCH_201_S2 = ['skip_connect', 'nor_conv_3x3'] +NAS_BENCH_201_S4 = ['noise', 'nor_conv_3x3'] +NAS_BENCH_201_S10 = ['none', 'nor_conv_3x3'] + +SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK, + 'nas-bench-201': NAS_BENCH_201, + 'nas-bench-201-simple': NAS_BENCH_201_SIMPLE, + 'nas-bench-201-s2': NAS_BENCH_201_S2, + 'nas-bench-201-s4': NAS_BENCH_201_S4, + 'nas-bench-201-s10': NAS_BENCH_201_S10, + 'darts' : DARTS_SPACE} + +class NoiseOp(nn.Module): + def __init__(self, stride, mean, std): + super(NoiseOp, self).__init__() + self.stride = stride + self.mean = mean + self.std = std + + def forward(self, x, block_input=False): + if block_input: + x = x * 0 + if self.stride != 1: + x_new = x[:,:,::self.stride,::self.stride] + else: + x_new = x + noise = x_new.data.new(x_new.size()).normal_(self.mean, self.std) + return noise + +class ReLUConvBN(nn.Module): + + def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): + super(ReLUConvBN, self).__init__() + self.op = nn.Sequential( + nn.ReLU(inplace=False), + layers.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False), + nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats) + ) + + def forward(self, x, block_input=False): + if block_input: + x = x * 0 + return self.op(x) + + def score(self): + score = 0 + for l in self.op: + if hasattr(l, 'score'): + score += torch.sum(l.score).cpu().numpy() + return score + +#### wrc modified +class ReLUConvBNSkip(nn.Module): + + def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): + super(ReLUConvBNSkip, self).__init__() + self.op = nn.Sequential( + nn.ReLU(inplace=False), + layers.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False), + nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats) + ) + + def forward(self, x, block_input=False): + if block_input: + x = x * 0 + return self.op(x) + x + + def score(self): + score = 0 + for l in self.op: + if hasattr(l, 'score'): + score += torch.sum(l.score).cpu().numpy() + return score +#### + +class SepConv(nn.Module): + + def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): + super(SepConv, self).__init__() + self.op = nn.Sequential( + nn.ReLU(inplace=False), + layers.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), + layers.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), + nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats), + ) + + def forward(self, x, block_input=False): + if block_input: + x = x * 0 + return self.op(x) + + def score(self): + score = 0 + for l in self.op: + if hasattr(l, 'score'): + score += torch.sum(l.score).cpu().numpy() + return score + + +class DualSepConv(nn.Module): + + def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): + super(DualSepConv, self).__init__() + self.op_a = SepConv(C_in, C_in , kernel_size, stride, padding, dilation, affine, track_running_stats) + self.op_b = SepConv(C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats) + + def forward(self, x, block_input=False): + if block_input: + x = x * 0 + x = self.op_a(x) + x = self.op_b(x) + return x + + def score(self): + score = self.op_a.score() + self.op_b.score() + return score + + +class ResNetBasicblock(nn.Module): + + def __init__(self, inplanes, planes, stride, affine=True): + super(ResNetBasicblock, self).__init__() + assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) + self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine) + self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine) + if stride == 2: + self.downsample = nn.Sequential( + nn.AvgPool2d(kernel_size=2, stride=2, padding=0), + nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False)) + elif inplanes != planes: + self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine) + else: + self.downsample = None + self.in_dim = inplanes + self.out_dim = planes + self.stride = stride + self.num_conv = 2 + + def extra_repr(self): + string = '{name}(inC={in_dim}, outC={out_dim}, stride={stride})'.format(name=self.__class__.__name__, **self.__dict__) + return string + + def forward(self, inputs): + basicblock = self.conv_a(inputs) + basicblock = self.conv_b(basicblock) + + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + return residual + basicblock + + def score(self): + return self.conv_a.score() + self.conv_b.score() + + + + +class POOLING(nn.Module): + + def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True): + super(POOLING, self).__init__() + if C_in == C_out: + self.preprocess = None + else: + self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, affine, track_running_stats) + if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) + elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1) + else : raise ValueError('Invalid mode={:} in POOLING'.format(mode)) + + def forward(self, inputs, block_input=False): + if block_input: + inputs = inputs * 0 + if self.preprocess: x = self.preprocess(inputs) + else : x = inputs + return self.op(x) + + def score(self): + if self.preprocess : + return self.preprocess.score() + else: + return 0 + + +class Identity(nn.Module): + + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x, block_input=False): + if block_input: + x = x * 0 + return x + + +class Zero(nn.Module): + + def __init__(self, C_in, C_out, stride): + super(Zero, self).__init__() + self.C_in = C_in + self.C_out = C_out + self.stride = stride + self.is_zero = True + + def forward(self, x, block_input=False): + if block_input: + x = x*0 + if self.C_in == self.C_out: + if self.stride == 1: return x.mul(0.) + else : return x[:,:,::self.stride,::self.stride].mul(0.) + else: ## this is never called in nasbench201 + shape = list(x.shape) + shape[1] = self.C_out + zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device) + return zeros + + def extra_repr(self): + return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__) + + +class FactorizedReduce(nn.Module): + + def __init__(self, C_in, C_out, stride, affine, track_running_stats): + super(FactorizedReduce, self).__init__() + self.stride = stride + self.C_in = C_in + self.C_out = C_out + self.relu = nn.ReLU(inplace=False) + if stride == 2: + #assert C_out % 2 == 0, 'C_out : {:}'.format(C_out) + C_outs = [C_out // 2, C_out - C_out // 2] + self.convs = nn.ModuleList() + for i in range(2): + self.convs.append(layers.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) ) + self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) + elif stride == 1: + self.conv = layers.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False) + else: + raise ValueError('Invalid stride : {:}'.format(stride)) + self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats) + + def forward(self, x, block_input=False): + if block_input: + x = x * 0 + if self.stride == 2: + x = self.relu(x) + y = self.pad(x) + out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1) + else: + out = self.conv(x) + out = self.bn(out) + return out + + def extra_repr(self): + return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__) + + def score(self): + if self.stride == 1: + return self.conv.score() + else: + return self.convs[0].score()+self.convs[1].score() \ No newline at end of file diff --git a/nasbench201/genotypes.py b/nasbench201/genotypes.py new file mode 100644 index 0000000..04a8379 --- /dev/null +++ b/nasbench201/genotypes.py @@ -0,0 +1,194 @@ +from copy import deepcopy + + +def get_combination(space, num): + combs = [] + for i in range(num): + if i == 0: + for func in space: + combs.append( [(func, i)] ) + else: + new_combs = [] + for string in combs: + for func in space: + xstring = string + [(func, i)] + new_combs.append( xstring ) + combs = new_combs + return combs + + +class Structure: + + def __init__(self, genotype): + assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype)) + self.node_num = len(genotype) + 1 + self.nodes = [] + self.node_N = [] + for idx, node_info in enumerate(genotype): + assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info)) + assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info)) + for node_in in node_info: + assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in)) + assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in) + self.node_N.append( len(node_info) ) + self.nodes.append( tuple(deepcopy(node_info)) ) + + def tolist(self, remove_str): + # convert this class to the list, if remove_str is 'none', then remove the 'none' operation. + # note that we re-order the input node in this function + # return the-genotype-list and success [if unsuccess, it is not a connectivity] + genotypes = [] + for node_info in self.nodes: + node_info = list( node_info ) + node_info = sorted(node_info, key=lambda x: (x[1], x[0])) + node_info = tuple(filter(lambda x: x[0] != remove_str, node_info)) + if len(node_info) == 0: return None, False + genotypes.append( node_info ) + return genotypes, True + + def node(self, index): + assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self)) + return self.nodes[index] + + def tostr(self): + strings = [] + for node_info in self.nodes: + string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info]) + string = '|{:}|'.format(string) + strings.append( string ) + return '+'.join(strings) + + def check_valid(self): + nodes = {0: True} + for i, node_info in enumerate(self.nodes): + sums = [] + for op, xin in node_info: + if op == 'none' or nodes[xin] is False: x = False + else: x = True + sums.append( x ) + nodes[i+1] = sum(sums) > 0 + return nodes[len(self.nodes)] + + def to_unique_str(self, consider_zero=False): + # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation + # two operations are special, i.e., none and skip_connect + nodes = {0: '0'} + for i_node, node_info in enumerate(self.nodes): + cur_node = [] + for op, xin in node_info: + if consider_zero is None: + x = '('+nodes[xin]+')' + '@{:}'.format(op) + elif consider_zero: + if op == 'none' or nodes[xin] == '#': x = '#' # zero + elif op == 'skip_connect': x = nodes[xin] + else: x = '('+nodes[xin]+')' + '@{:}'.format(op) + else: + if op == 'skip_connect': x = nodes[xin] + else: x = '('+nodes[xin]+')' + '@{:}'.format(op) + cur_node.append(x) + nodes[i_node+1] = '+'.join( sorted(cur_node) ) + return nodes[ len(self.nodes) ] + + def check_valid_op(self, op_names): + for node_info in self.nodes: + for inode_edge in node_info: + #assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0]) + if inode_edge[0] not in op_names: return False + return True + + def __repr__(self): + return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__)) + + def __len__(self): + return len(self.nodes) + 1 + + def __getitem__(self, index): + return self.nodes[index] + + @staticmethod + def str2structure(xstr): + assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) + nodestrs = xstr.split('+') + genotypes = [] + for i, node_str in enumerate(nodestrs): + inputs = list(filter(lambda x: x != '', node_str.split('|'))) + for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) + inputs = ( xi.split('~') for xi in inputs ) + input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs) + genotypes.append( input_infos ) + return Structure( genotypes ) + + @staticmethod + def str2fullstructure(xstr, default_name='none'): + assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) + nodestrs = xstr.split('+') + genotypes = [] + for i, node_str in enumerate(nodestrs): + inputs = list(filter(lambda x: x != '', node_str.split('|'))) + for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) + inputs = ( xi.split('~') for xi in inputs ) + input_infos = list( (op, int(IDX)) for (op, IDX) in inputs) + all_in_nodes= list(x[1] for x in input_infos) + for j in range(i): + if j not in all_in_nodes: input_infos.append((default_name, j)) + node_info = sorted(input_infos, key=lambda x: (x[1], x[0])) + genotypes.append( tuple(node_info) ) + return Structure( genotypes ) + + @staticmethod + def gen_all(search_space, num, return_ori): + assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space)) + assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num) + all_archs = get_combination(search_space, 1) + for i, arch in enumerate(all_archs): + all_archs[i] = [ tuple(arch) ] + + for inode in range(2, num): + cur_nodes = get_combination(search_space, inode) + new_all_archs = [] + for previous_arch in all_archs: + for cur_node in cur_nodes: + new_all_archs.append( previous_arch + [tuple(cur_node)] ) + all_archs = new_all_archs + if return_ori: + return all_archs + else: + return [Structure(x) for x in all_archs] + + + +ResNet_CODE = Structure( + [(('nor_conv_3x3', 0), ), # node-1 + (('nor_conv_3x3', 1), ), # node-2 + (('skip_connect', 0), ('skip_connect', 2))] # node-3 + ) + +AllConv3x3_CODE = Structure( + [(('nor_conv_3x3', 0), ), # node-1 + (('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2 + (('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3 + ) + +AllFull_CODE = Structure( + [(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1 + (('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2 + (('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3 + ) + +AllConv1x1_CODE = Structure( + [(('nor_conv_1x1', 0), ), # node-1 + (('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2 + (('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3 + ) + +AllIdentity_CODE = Structure( + [(('skip_connect', 0), ), # node-1 + (('skip_connect', 0), ('skip_connect', 1)), # node-2 + (('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3 + ) + +architectures = {'resnet' : ResNet_CODE, + 'all_c3x3': AllConv3x3_CODE, + 'all_c1x1': AllConv1x1_CODE, + 'all_idnt': AllIdentity_CODE, + 'all_full': AllFull_CODE} \ No newline at end of file diff --git a/nasbench201/init_projection.py b/nasbench201/init_projection.py new file mode 100644 index 0000000..77a333d --- /dev/null +++ b/nasbench201/init_projection.py @@ -0,0 +1,619 @@ +import os +import sys +import numpy as np +import torch +import torch.nn.functional as f +sys.path.insert(0, '../') +import nasbench201.utils as ig_utils +import logging +import torch.utils +import copy +import scipy.stats as ss +from collections import OrderedDict +from foresight.pruners import * +from op_score import Jocab_Score, get_ntk_n +import gc +from nasbench201.linear_region import Linear_Region_Collector + +torch.set_printoptions(precision=4, sci_mode=False) +np.set_printoptions(precision=4, suppress=True) + +# global-edge-iter: similar toglobal-op-iterbut iteratively selects edge e from E based on the average score of all operations on each edge +def global_op_greedy_pt_project(proj_queue, model, args): + def project(model, args): + ## macros + num_edge, num_op = model.num_edge, model.num_op + + ##get remain eid numbers + remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0] + compare = lambda x, y : x < y + + crit_extrema = None + best_eid = None + input, target = next(iter(proj_queue)) + for eid in remain_eids: + for opid in range(num_op): + # projection + weights = model.get_projected_weights() + proj_mask = torch.ones_like(weights[eid]) + proj_mask[opid] = 0 + weights[eid] = weights[eid] * proj_mask + + ## proj evaluation + if args.proj_crit == 'jacob': + valid_stats = Jocab_Score(model, input, target, weights=weights) + crit = valid_stats + + if crit_extrema is None or compare(crit, crit_extrema): + crit_extrema = crit + best_opid = opid + best_eid = eid + + logging.info('best opid %d', best_opid) + return best_eid, best_opid + + tune_epochs = model.arch_parameters()[0].shape[0] + + for epoch in range(tune_epochs): + logging.info('epoch %d', epoch) + logging.info('project') + selected_eid, best_opid = project(model, args) + model.project_op(selected_eid, best_opid) + + return + +# global-edge-iter: similar toglobal-op-oncebut uses the average score of operations on edges to obtain the edge discretization order +def global_edge_greedy_pt_project(proj_queue, model, args): + def select_eid(model, args): + ## macros + num_edge, num_op = model.num_edge, model.num_op + + ##get remain eid numbers + remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0] + compare = lambda x, y : x < y + + crit_extrema = None + best_eid = None + input, target = next(iter(proj_queue)) + for eid in remain_eids: + eid_score = [] + for opid in range(num_op): + # projection + weights = model.get_projected_weights() + proj_mask = torch.ones_like(weights[eid]) + proj_mask[opid] = 0 + weights[eid] = weights[eid] * proj_mask + + ## proj evaluation + if args.proj_crit == 'jacob': + valid_stats = Jocab_Score(model, input, target, weights=weights) + crit = valid_stats + eid_score.append(crit) + eid_score = np.mean(eid_score) + + if crit_extrema is None or compare(eid_score, crit_extrema): + crit_extrema = eid_score + best_eid = eid + return best_eid + + def project(model, args, selected_eid): + ## macros + num_edge, num_op = model.num_edge, model.num_op + + ## select the best operation + if args.proj_crit == 'jacob': + crit_idx = 3 + compare = lambda x, y: x < y + else: + crit_idx = 4 + compare = lambda x, y: x < y + + best_opid = 0 + crit_list = [] + op_ids = [] + input, target = next(iter(proj_queue)) + for opid in range(num_op): + ## projection + weights = model.get_projected_weights() + proj_mask = torch.ones_like(weights[selected_eid]) + proj_mask[opid] = 0 + weights[selected_eid] = weights[selected_eid] * proj_mask + + ## proj evaluation + if args.proj_crit == 'jacob': + valid_stats = Jocab_Score(model, input, target, weights=weights) + crit = valid_stats + + crit_list.append(crit) + op_ids.append(opid) + + best_opid = op_ids[np.nanargmin(crit_list)] + + logging.info('best opid %d', best_opid) + logging.info(crit_list) + return selected_eid, best_opid + + num_edges = model.arch_parameters()[0].shape[0] + + for epoch in range(num_edges): + logging.info('epoch %d', epoch) + + logging.info('project') + selected_eid = select_eid(model, args) + selected_eid, best_opid = project(model, args, selected_eid) + model.project_op(selected_eid, best_opid) + return + +# global-op-once: only evaluates S(A−(e,o)) for all operations once to obtain a ranking order of the operations, and discretizes the edgesEaccording to this order +def global_op_once_pt_project(proj_queue, model, args): + def order(model, args): + ## macros + num_edge, num_op = model.num_edge, model.num_op + + ##get remain eid numbers + remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0] + compare = lambda x, y : x < y + + edge_score = OrderedDict() + input, target = next(iter(proj_queue)) + for eid in remain_eids: + crit_list = [] + for opid in range(num_op): + # projection + weights = model.get_projected_weights() + proj_mask = torch.ones_like(weights[eid]) + proj_mask[opid] = 0 + weights[eid] = weights[eid] * proj_mask + + ## proj evaluation + if args.proj_crit == 'jacob': + valid_stats = Jocab_Score(model, input, target, weights=weights) + crit = valid_stats + + crit_list.append(crit) + edge_score[eid] = np.nanargmin(crit_list) + return edge_score + + def project(model, args, selected_eid): + ## macros + num_edge, num_op = model.num_edge, model.num_op + ## select the best operation + if args.proj_crit == 'jacob': + crit_idx = 3 + compare = lambda x, y: x < y + else: + crit_idx = 4 + compare = lambda x, y: x < y + + best_opid = 0 + crit_list = [] + op_ids = [] + input, target = next(iter(proj_queue)) + for opid in range(num_op): + ## projection + weights = model.get_projected_weights() + proj_mask = torch.ones_like(weights[selected_eid]) + proj_mask[opid] = 0 + weights[selected_eid] = weights[selected_eid] * proj_mask + + ## proj evaluation + if args.proj_crit == 'jacob': + crit = Jocab_Score(model, input, target, weights=weights) + crit_list.append(crit) + op_ids.append(opid) + + best_opid = op_ids[np.nanargmin(crit_list)] + + logging.info('best opid %d', best_opid) + logging.info(crit_list) + return selected_eid, best_opid + + num_edges = model.arch_parameters()[0].shape[0] + + eid_order = order(model, args) + for epoch in range(num_edges): + logging.info('epoch %d', epoch) + logging.info('project') + selected_eid, _ = eid_order.popitem() + selected_eid, best_opid = project(model, args, selected_eid) + model.project_op(selected_eid, best_opid) + + return + +# global-edge-once: similar toglobal-op-oncebut uses the average score of operations on dges to obtain the edge discretization order +def global_edge_once_pt_project(proj_queue, model, args): + def order(model, args): + ## macros + num_edge, num_op = model.num_edge, model.num_op + + ##get remain eid numbers + remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0] + compare = lambda x, y : x < y + + edge_score = OrderedDict() + crit_extrema = None + best_eid = None + input, target = next(iter(proj_queue)) + for eid in remain_eids: + crit_list = [] + for opid in range(num_op): + # projection + weights = model.get_projected_weights() + proj_mask = torch.ones_like(weights[eid]) + proj_mask[opid] = 0 + weights[eid] = weights[eid] * proj_mask + + ## proj evaluation + if args.proj_crit == 'jacob': + crit = Jocab_Score(model, input, target, weights=weights) + + crit_list.append(crit) + edge_score[eid] = np.mean(crit_list) + return edge_score + + def project(model, args, selected_eid): + ## macros + num_edge, num_op = model.num_edge, model.num_op + ## select the best operation + if args.proj_crit == 'jacob': + crit_idx = 3 + compare = lambda x, y: x < y + else: + crit_idx = 4 + compare = lambda x, y: x < y + + best_opid = 0 + crit_extrema = None + crit_list = [] + op_ids = [] + input, target = next(iter(proj_queue)) + for opid in range(num_op): + ## projection + weights = model.get_projected_weights() + proj_mask = torch.ones_like(weights[selected_eid]) + proj_mask[opid] = 0 + weights[selected_eid] = weights[selected_eid] * proj_mask + + ## proj evaluation + if args.proj_crit == 'jacob': + crit = Jocab_Score(model, input, target, weights=weights) + crit_list.append(crit) + op_ids.append(opid) + + best_opid = op_ids[np.nanargmin(crit_list)] + + logging.info('best opid %d', best_opid) + logging.info(crit_list) + return selected_eid, best_opid + + num_edges = model.arch_parameters()[0].shape[0] + + eid_order = order(model, args) + for epoch in range(num_edges): + logging.info('epoch %d', epoch) + logging.info('project') + selected_eid, _ = eid_order.popitem() + selected_eid, best_opid = project(model, args, selected_eid) + model.project_op(selected_eid, best_opid) + + return + +# fixed [reverse, order]: discretizes the edges in a fixed order, where in our experiments we discretize from the222input towards the output of the cell struct +# random: discretizes the edges in a random order (DARTS-PT) +# NOTE: Only this methods allows use other zero-cost proxy metrics +def pt_project(proj_queue, model, args): + def project(model, args): + ## macros,一共6条边,每条边有5个操作 + num_edge, num_op = model.num_edge, model.num_op + + ## select an edge + remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0] + # print('candidate_flags:', model.candidate_flags) + # print(model.candidate_flags) + # 选边的方法 + if args.edge_decision == "random": + # 选出来了一个数组,取其中的一个元素 + selected_eid = np.random.choice(remain_eids, size=1)[0] + elif args.edge_decision == "reverse": + selected_eid = remain_eids[-1] + else: + selected_eid = remain_eids[0] + + ## select the best operation + if args.proj_crit == 'jacob': + crit_idx = 3 + compare = lambda x, y: x < y + else: + crit_idx = 4 + compare = lambda x, y: x < y + + if args.dataset == 'cifar100': + n_classes = 100 + elif args.dataset == 'imagenet16-120': + n_classes = 120 + else: + n_classes = 10 + + best_opid = 0 + crit_extrema = None + crit_list = [] + op_ids = [] + input, target = next(iter(proj_queue)) + for opid in range(num_op): + ## projection + weights = model.get_projected_weights() + proj_mask = torch.ones_like(weights[selected_eid]) + # print(selected_eid, weights[selected_eid]) + proj_mask[opid] = 0 + weights[selected_eid] = weights[selected_eid] * proj_mask + + + ## proj evaluation + if args.proj_crit == 'jacob': + crit = Jocab_Score(model, input, target, weights=weights) + else: + cache_weight = model.proj_weights[selected_eid] + cache_flag = model.candidate_flags[selected_eid] + + + for idx in range(num_op): + if idx == opid: + model.proj_weights[selected_eid][opid] = 0 + else: + model.proj_weights[selected_eid][idx] = 1.0/num_op + + + model.candidate_flags[selected_eid] = False + # print(model.get_projected_weights()) + + if args.proj_crit == 'comb': + synflow = predictive.find_measures(model, + proj_queue, + ('random', 1, n_classes), + torch.device("cuda"), + measure_names=['synflow']) + var = predictive.find_measures(model, + proj_queue, + ('random', 1, n_classes), + torch.device("cuda"), + measure_names=['var']) + # print(synflow, var) + comb = np.log(synflow['synflow'] + 1) / (var['var'] + 0.1) + measures = {'comb': comb} + else: + measures = predictive.find_measures(model, + proj_queue, + ('random', 1, n_classes), + torch.device("cuda"), + measure_names=[args.proj_crit]) + + # print(measures) + for idx in range(num_op): + model.proj_weights[selected_eid][idx] = 0 + model.candidate_flags[selected_eid] = cache_flag + crit = measures[args.proj_crit] + + crit_list.append(crit) + op_ids.append(opid) + + + best_opid = op_ids[np.nanargmin(crit_list)] + # best_opid = op_ids[np.nanargmax(crit_list)] + + logging.info('best opid %d', best_opid) + logging.info('current edge id %d', selected_eid) + logging.info(crit_list) + return selected_eid, best_opid + + num_edges = model.arch_parameters()[0].shape[0] + + for epoch in range(num_edges): + logging.info('epoch %d', epoch) + logging.info('project') + selected_eid, best_opid = project(model, args) + model.project_op(selected_eid, best_opid) + + return + +def tenas_project(proj_queue, model, model_thin, args): + def project(model, args): + ## macros + num_edge, num_op = model.num_edge, model.num_op + + ##get remain eid numbers + remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0] + compare = lambda x, y : x < y + + ntks = [] + lrs = [] + edge_op_id = [] + best_eid = None + + if args.proj_crit == 'tenas': + lrc_model = Linear_Region_Collector(input_size=(1000, 1, 3, 3), sample_batch=3, dataset=args.dataset, data_path=args.data, seed=args.seed) + for eid in remain_eids: + for opid in range(num_op): + # projection + weights = model.get_projected_weights() + proj_mask = torch.ones_like(weights[eid]) + proj_mask[opid] = 0 + weights[eid] = weights[eid] * proj_mask + + ## proj evaluation + if args.proj_crit == 'tenas': + lrc_model.reinit(ori_models=[model_thin], seed=args.seed, weights=weights) + lr = lrc_model.forward_batch_sample() + lrc_model.clear() + ntk = get_ntk_n(proj_queue, [model], recalbn=0, train_mode=True, num_batch=1, weights=weights) + ntks.append(ntk) + lrs.append(lr) + edge_op_id.append('{}:{}'.format(eid, opid)) + print('ntls', ntks) + print('lrs', lrs) + ntks_ranks = ss.rankdata(ntks) + lrs_ranks = ss.rankdata(lrs) + ntks_ranks = len(ntks_ranks) - ntks_ranks.astype(int) + op_ranks = [] + for i in range(len(edge_op_id)): + op_ranks.append(ntks_ranks[i]+lrs_ranks[i]) + + best_op_index = edge_op_id[np.nanargmin(op_ranks[0:num_op])] + best_eid, best_opid = [int(x) for x in best_op_index.split(':')] + + logging.info(op_ranks) + logging.info('best eid %d', best_eid) + logging.info('best opid %d', best_opid) + return best_eid, best_opid + num_edges = model.arch_parameters()[0].shape[0] + + for epoch in range(num_edges): + logging.info('epoch %d', epoch) + logging.info('project') + selected_eid, best_opid = project(model, args) + model.project_op(selected_eid, best_opid) + + return + +#new methods +#Randomly propose candidate of networks and transfer it to supernet, then perform global op selection in this subspace +def shrink_pt_project(proj_queue, model, args): + def project(model, args): + ## macros + num_edge, num_op = model.num_edge, model.num_op + + ## select an edge + remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0] + selected_eid = np.random.choice(remain_eids, size=1)[0] + + + ## select the best operation + if args.proj_crit == 'jacob': + crit_idx = 3 + compare = lambda x, y: x < y + else: + crit_idx = 4 + compare = lambda x, y: x < y + + if args.dataset == 'cifar100': + n_classes = 100 + elif args.dataset == 'imagenet16-120': + n_classes = 120 + else: + n_classes = 10 + + best_opid = 0 + crit_extrema = None + crit_list = [] + op_ids = [] + input, target = next(iter(proj_queue)) + for opid in range(num_op): + ## projection + weights = model.get_projected_weights() + proj_mask = torch.ones_like(weights[selected_eid]) + proj_mask[opid] = 0 + weights[selected_eid] = weights[selected_eid] * proj_mask + + ## proj evaluation + if args.proj_crit == 'jacob': + crit = Jocab_Score(model, input, target, weights=weights) + else: + cache_weight = model.proj_weights[selected_eid] + cache_flag = model.candidate_flags[selected_eid] + + for idx in range(num_op): + if idx == opid: + model.proj_weights[selected_eid][opid] = 0 + else: + model.proj_weights[selected_eid][idx] = 1.0/num_op + model.candidate_flags[selected_eid] = False + + measures = predictive.find_measures(model, + train_queue, + ('random', 1, n_classes), + torch.device("cuda"), + measure_names=[args.proj_crit]) + for idx in range(num_op): + model.proj_weights[selected_eid][idx] = 0 + model.candidate_flags[selected_eid] = cache_flag + crit = measures[args.proj_crit] + + crit_list.append(crit) + op_ids.append(opid) + + best_opid = op_ids[np.nanargmin(crit_list)] + + logging.info('best opid %d', best_opid) + logging.info('current edge id %d', selected_eid) + logging.info(crit_list) + return selected_eid, best_opid + + def global_project(model, args): + ## macros + num_edge, num_op = model.num_edge, model.num_op + + ##get remain eid numbers + remain_eids = torch.nonzero(model.subspace_candidate_flags).cpu().numpy().T[0] + compare = lambda x, y : x < y + + crit_extrema = None + best_eid = None + best_opid = None + input, target = next(iter(proj_queue)) + for eid in remain_eids: + remain_oids = torch.nonzero(model.proj_weights[eid]).cpu().numpy().T[0] + for opid in remain_oids: + # projection + weights = model.get_projected_weights() + proj_mask = torch.ones_like(weights[eid]) + proj_mask[opid] = 0 + weights[eid] = weights[eid] * proj_mask + ## proj evaluation + if args.proj_crit == 'jacob': + valid_stats = Jocab_Score(model, input, target, weights=weights) + crit = valid_stats + + if crit_extrema is None or compare(crit, crit_extrema): + crit_extrema = crit + best_opid = opid + best_eid = eid + + + logging.info('best eid %d', best_eid) + logging.info('best opid %d', best_opid) + model.subspace_candidate_flags[best_eid] = False + proj_mask = torch.zeros_like(model.proj_weights[best_eid]) + model.proj_weights[best_eid] = model.proj_weights[best_eid] * proj_mask + model.proj_weights[best_eid][best_opid] = 1 + return best_eid, best_opid + + num_edges = model.arch_parameters()[0].shape[0] + + #subspace + logging.info('Start subspace proposal') + subspace = copy.deepcopy(model.proj_weights) + for i in range(20): + model.reset_arch_parameters() + for epoch in range(num_edges): + logging.info('epoch %d', epoch) + logging.info('project') + selected_eid, best_opid = project(model, args) + model.project_op(selected_eid, best_opid) + subspace += model.proj_weights + + model.reset_arch_parameters() + subspace = torch.gt(subspace, 0).int().float() + subspace = f.normalize(subspace, p=1, dim=1) + model.proj_weights += subspace + for i in range(num_edges): + model.candidate_flags[i] = False + logging.info('Start final search in subspace') + logging.info(subspace) + + model.subspace_candidate_flags = torch.tensor(len(model._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda() + for epoch in range(num_edges): + logging.info('epoch %d', epoch) + logging.info('project') + selected_eid, best_opid = global_project(model, args) + model.printing(logging) + #model.project_op(selected_eid, best_opid) + return \ No newline at end of file diff --git a/nasbench201/linear_region.py b/nasbench201/linear_region.py new file mode 100644 index 0000000..f0e8067 --- /dev/null +++ b/nasbench201/linear_region.py @@ -0,0 +1,270 @@ +import os.path as osp +import numpy as np +import torch +import torch.nn as nn +import torch.utils.data +import torchvision.transforms as transforms +import torchvision.datasets as dset +from pdb import set_trace as bp +from operator import mul +from functools import reduce +import copy +Dataset2Class = {'cifar10': 10, + 'cifar100': 100, + 'imagenet-1k-s': 1000, + 'imagenet-1k': 1000, +} + + +class CUTOUT(object): + + def __init__(self, length): + self.length = length + + def __repr__(self): + return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__)) + + def __call__(self, img): + h, w = img.size(1), img.size(2) + mask = np.ones((h, w), np.float32) + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[y1: y2, x1: x2] = 0. + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img *= mask + return img + + +imagenet_pca = { + 'eigval': np.asarray([0.2175, 0.0188, 0.0045]), + 'eigvec': np.asarray([ + [-0.5675, 0.7192, 0.4009], + [-0.5808, -0.0045, -0.8140], + [-0.5836, -0.6948, 0.4203], + ]) +} + + +class RandChannel(object): + # randomly pick channels from input + def __init__(self, num_channel): + self.num_channel = num_channel + + def __repr__(self): + return ('{name}(num_channel={num_channel})'.format(name=self.__class__.__name__, **self.__dict__)) + + def __call__(self, img): + channel = img.size(0) + channel_choice = sorted(np.random.choice(list(range(channel)), size=self.num_channel, replace=False)) + return torch.index_select(img, 0, torch.Tensor(channel_choice).long()) + + +def get_datasets(name, root, input_size, cutout=-1): + assert len(input_size) in [3, 4] + if len(input_size) == 4: + input_size = input_size[1:] + assert input_size[1] == input_size[2] + + if name == 'cifar10': + mean = [x / 255 for x in [125.3, 123.0, 113.9]] + std = [x / 255 for x in [63.0, 62.1, 66.7]] + elif name == 'cifar100': + mean = [x / 255 for x in [129.3, 124.1, 112.4]] + std = [x / 255 for x in [68.2, 65.4, 70.4]] + elif name.startswith('imagenet-1k'): + mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] + elif name.startswith('ImageNet16'): + mean = [x / 255 for x in [122.68, 116.66, 104.01]] + std = [x / 255 for x in [63.22, 61.26 , 65.09]] + else: + raise TypeError("Unknow dataset : {:}".format(name)) + #ßprint(input_size) + # Data Argumentation + if name == 'cifar10' or name == 'cifar100': + lists = [transforms.RandomCrop(input_size[1], padding=4), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])] + if cutout > 0 : lists += [CUTOUT(cutout)] + train_transform = transforms.Compose(lists) + test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) + elif name.startswith('ImageNet16'): + lists = [transforms.RandomCrop(input_size[1], padding=4), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])] + if cutout > 0 : lists += [CUTOUT(cutout)] + train_transform = transforms.Compose(lists) + test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) + elif name.startswith('imagenet-1k'): + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + if name == 'imagenet-1k': + xlists = [] + xlists.append(transforms.Resize((32, 32), interpolation=2)) + xlists.append(transforms.RandomCrop(input_size[1], padding=0)) + elif name == 'imagenet-1k-s': + xlists = [transforms.RandomResizedCrop(32, scale=(0.2, 1.0))] + xlists = [] + else: raise ValueError('invalid name : {:}'.format(name)) + xlists.append(transforms.ToTensor()) + xlists.append(normalize) + xlists.append(RandChannel(input_size[0])) + train_transform = transforms.Compose(xlists) + test_transform = transforms.Compose([transforms.Resize(40), transforms.CenterCrop(32), transforms.ToTensor(), normalize]) + else: + raise TypeError("Unknow dataset : {:}".format(name)) + + if name == 'cifar10': + train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True) + test_data = dset.CIFAR10 (root, train=False, transform=test_transform , download=True) + assert len(train_data) == 50000 and len(test_data) == 10000 + elif name == 'cifar100': + train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True) + test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True) + assert len(train_data) == 50000 and len(test_data) == 10000 + elif name.startswith('imagenet-1k'): + train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform) + test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform) + else: raise TypeError("Unknow dataset : {:}".format(name)) + + class_num = Dataset2Class[name] + return train_data, test_data, class_num + + +class LinearRegionCount(object): + """Computes and stores the average and current value""" + def __init__(self, n_samples): + self.ActPattern = {} + self.n_LR = -1 + self.n_samples = n_samples + self.ptr = 0 + self.activations = None + + @torch.no_grad() + def update2D(self, activations): + n_batch = activations.size()[0] + n_neuron = activations.size()[1] + self.n_neuron = n_neuron + if self.activations is None: + self.activations = torch.zeros(self.n_samples, n_neuron).cuda() + self.activations[self.ptr:self.ptr+n_batch] = torch.sign(activations) # after ReLU + self.ptr += n_batch + + @torch.no_grad() + def calc_LR(self): + res = torch.matmul(self.activations.half(), (1-self.activations).T.half()) # each element in res: A * (1 - B) + res += res.T # make symmetric, each element in res: A * (1 - B) + (1 - A) * B, a non-zero element indicate a pair of two different linear regions + res = 1 - torch.sign(res) # a non-zero element now indicate two linear regions are identical + res = res.sum(1) # for each sample's linear region: how many identical regions from other samples + res = 1. / res.float() # contribution of each redudant (repeated) linear region + self.n_LR = res.sum().item() # sum of unique regions (by aggregating contribution of all regions) + del self.activations, res + self.activations = None + torch.cuda.empty_cache() + + @torch.no_grad() + def update1D(self, activationList): + code_string = '' + for key, value in activationList.items(): + n_neuron = value.size()[0] + for i in range(n_neuron): + if value[i] > 0: + code_string += '1' + else: + code_string += '0' + if code_string not in self.ActPattern: + self.ActPattern[code_string] = 1 + + def getLinearReginCount(self): + if self.n_LR == -1: + self.calc_LR() + return self.n_LR + + +class Linear_Region_Collector: + def __init__(self, models=[], input_size=(64, 3, 32, 32), sample_batch=100, dataset='cifar100', data_path=None, seed=0): + self.models = [] + self.input_size = input_size # BCHW + self.sample_batch = sample_batch + self.input_numel = reduce(mul, self.input_size, 1) + self.interFeature = [] + self.dataset = dataset + self.data_path = data_path + self.seed = seed + self.reinit(models, input_size, sample_batch, seed) + + def reinit(self, ori_models=None, input_size=None, sample_batch=None, seed=None, weights=None): + models = [] + for network in ori_models: + network = network.cuda() + net = copy.deepcopy(network) + net.proj_weights = weights + num_edge, num_op = net.num_edge, net.num_op + for i in range(num_edge): + net.candidate_flags[i] = False + net.eval() + models.append(net) + + if models is not None: + assert isinstance(models, list) + del self.models + self.models = models + for model in self.models: + self.register_hook(model) + device = torch.cuda.current_device() + model = model.cuda(device=device) + self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(models))] + if input_size is not None or sample_batch is not None: + if input_size is not None: + self.input_size = input_size # BCHW + self.input_numel = reduce(mul, self.input_size, 1) + if sample_batch is not None: + self.sample_batch = sample_batch + if self.data_path is not None: + self.train_data, _, class_num = get_datasets(self.dataset, self.data_path, self.input_size, -1) + self.train_loader = torch.utils.data.DataLoader(self.train_data, batch_size=self.input_size[0], num_workers=16, pin_memory=True, drop_last=True, shuffle=True) + self.loader = iter(self.train_loader) + if seed is not None and seed != self.seed: + self.seed = seed + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + del self.interFeature + self.interFeature = [] + torch.cuda.empty_cache() + + def clear(self): + self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(self.models))] + del self.interFeature + self.interFeature = [] + torch.cuda.empty_cache() + + def register_hook(self, model): + for m in model.modules(): + if isinstance(m, nn.ReLU): + m.register_forward_hook(hook=self.hook_in_forward) + + def hook_in_forward(self, module, input, output): + if isinstance(input, tuple) and len(input[0].size()) == 4: + self.interFeature.append(output.detach()) # for ReLU + + def forward_batch_sample(self): + for _ in range(self.sample_batch): + try: + inputs, targets = self.loader.next() + except Exception: + del self.loader + self.loader = iter(self.train_loader) + inputs, targets = self.loader.next() + for model, LRCount in zip(self.models, self.LRCounts): + self.forward(model, LRCount, inputs) + output = [LRCount.getLinearReginCount() for LRCount in self.LRCounts] + return output + + def forward(self, model, LRCount, input_data): + self.interFeature = [] + with torch.no_grad(): + model.forward(input_data.cuda()) + if len(self.interFeature) == 0: return + feature_data = torch.cat([f.view(input_data.size(0), -1) for f in self.interFeature], 1) + LRCount.update2D(feature_data) diff --git a/nasbench201/networks_proposal.py b/nasbench201/networks_proposal.py new file mode 100644 index 0000000..9cc42a4 --- /dev/null +++ b/nasbench201/networks_proposal.py @@ -0,0 +1,245 @@ +import os +import sys +sys.path.insert(0, '../') +import time +import glob +import json +import shutil +import logging +import argparse +import numpy as np + +import torch +import torch.nn as nn +import torch.utils +import torchvision.datasets as dset +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter +from torch.autograd import Variable + +import nasbench201.utils as ig_utils +from nasbench201.search_model_darts_proj import TinyNetworkDartsProj +from nasbench201.cell_operations import SearchSpaceNames +from nasbench201.init_projection import pt_project, global_op_greedy_pt_project, global_op_once_pt_project, global_edge_greedy_pt_project, global_edge_once_pt_project, shrink_pt_project, tenas_project +from nas_201_api import NASBench201API as API + +torch.set_printoptions(precision=4, sci_mode=False) +np.set_printoptions(precision=4, suppress=True) + + +parser = argparse.ArgumentParser("sota") +# data related +parser.add_argument('--data', type=str, default='../data', help='location of the data corpus') +parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100', 'imagenet16-120'], help='choose dataset') +parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data') +parser.add_argument('--batch_size', type=int, default=64, help='batch size for alpha') +parser.add_argument('--cutout', action='store_true', default=True, help='use cutout') +parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') +parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability') +parser.add_argument('--seed', type=int, default=2, help='random seed') + +#search space setting +parser.add_argument('--search_space', type=str, default='nas-bench-201') + +parser.add_argument('--pool_size', type=int, default=100, help='number of model to proposed') +parser.add_argument('--init_channels', type=int, default=16, help='num of init channels') +parser.add_argument('--layers', type=int, default=8, help='total number of layers') + +#system configurations +parser.add_argument('--gpu', type=str, default='auto', help='gpu device id') +parser.add_argument('--save', type=str, default='exp', help='experiment name') + +#default opt setting for model +parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate') +parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate') +parser.add_argument('--momentum', type=float, default=0.9, help='momentum') +parser.add_argument('--nesterov', action='store_true', default=True, help='using nestrov momentum for SGD') +parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay') +parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping') + +#### common +parser.add_argument('--fast', action='store_true', default=True, help='skip loading api which is slow') + +#### projection +parser.add_argument('--edge_decision', type=str, default='random', choices=['random','reverse', 'order', 'global_op_greedy', 'global_op_once', 'global_edge_greedy', 'global_edge_once', 'shrink_pt_project'], help='which edge to be projected next') +parser.add_argument('--proj_crit', type=str, default="comb", choices=['loss', 'acc', 'jacob', 'snip', 'fisher', 'synflow', 'grad_norm', 'grasp', 'jacob_cov','tenas', 'var', 'cor', 'norm', 'comb', 'meco'], help='criteria for projection') +args = parser.parse_args() + +#### args augment +expid = args.save +args.save = '../experiments/nas-bench-201/prop-{}-{}-{}'.format(args.save, args.seed, args.pool_size) +if not args.dataset == 'cifar10': + args.save += '-' + args.dataset +if not args.edge_decision == 'random': + args.save += '-' + args.edge_decision +if not args.proj_crit == 'jacob': + args.save += '-' + args.proj_crit + +#### logging +scripts_to_save = glob.glob('*.py') \ + # + ['../exp_scripts/{}.sh'.format(expid)] +if os.path.exists(args.save): + if input("WARNING: {} exists, override?[y/n]".format(args.save)) == 'y': + print('proceed to override saving directory') + shutil.rmtree(args.save) + else: + exit(0) +ig_utils.create_exp_dir(args.save, scripts_to_save=scripts_to_save) + +log_format = '%(asctime)s %(message)s' +logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=log_format, datefmt='%m/%d %I:%M:%S %p') + +log_file = 'log.txt' +log_path = os.path.join(args.save, log_file) +logging.info('======> log filename: %s', log_file) + +if os.path.exists(log_path): + if input("WARNING: {} exists, override?[y/n]".format(log_file)) == 'y': + print('proceed to override log file directory') + else: + exit(0) + +fh = logging.FileHandler(log_path, mode='w') +fh.setFormatter(logging.Formatter(log_format)) +logging.getLogger().addHandler(fh) +writer = SummaryWriter(args.save + '/runs') + +#### macros +if args.dataset == 'cifar100': + n_classes = 100 +elif args.dataset == 'imagenet16-120': + n_classes = 120 +else: + n_classes = 10 + +def main(): + torch.set_num_threads(3) + if not torch.cuda.is_available(): + logging.info('no gpu device available') + sys.exit(1) + + np.random.seed(args.seed) + gpu = ig_utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu) + torch.cuda.set_device(gpu) + cudnn.benchmark = True + torch.manual_seed(args.seed) + cudnn.enabled = True + torch.cuda.manual_seed(args.seed) + logging.info("args = %s", args) + logging.info('gpu device = %d' % gpu) + + #### model + criterion = nn.CrossEntropyLoss() + search_space = SearchSpaceNames[args.search_space] + + # 初始化超网络 + model = TinyNetworkDartsProj(C=args.init_channels, N=5, max_nodes=4, num_classes=n_classes, criterion=criterion, search_space=search_space, args=args) + model_thin = TinyNetworkDartsProj(C=args.init_channels, N=5, max_nodes=4, num_classes=n_classes, criterion=criterion, search_space=search_space, args=args, stem_channels=1) + model = model.cuda() + model_thin = model_thin.cuda() + logging.info("param size = %fMB", ig_utils.count_parameters_in_MB(model)) + + #### data + if args.dataset == 'cifar10': + train_transform, valid_transform = ig_utils._data_transforms_cifar10(args) + train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform) + valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform) + elif args.dataset == 'cifar100': + train_transform, valid_transform = ig_utils._data_transforms_cifar100(args) + train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform) + valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform) + elif args.dataset == 'imagenet16-120': + import torchvision.transforms as transforms + from nasbench201.DownsampledImageNet import ImageNet16 + mean = [x / 255 for x in [122.68, 116.66, 104.01]] + std = [x / 255 for x in [63.22, 61.26, 65.09]] + lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)] + train_transform = transforms.Compose(lists) + train_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=True, transform=train_transform, use_num_of_class_only=120) + valid_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=False, transform=train_transform, use_num_of_class_only=120) + assert len(train_data) == 151700 + + num_train = len(train_data) + indices = list(range(num_train)) + split = int(np.floor(args.train_portion * num_train)) + + train_queue = torch.utils.data.DataLoader( + train_data, batch_size=args.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), + pin_memory=True) + + + #format network pool diction + networks_pool={} + networks_pool['search_space'] = args.search_space + networks_pool['dataset'] = args.dataset + networks_pool['networks'] = [] + networks_pool['pool_size'] = args.pool_size + #### architecture selection / projection + for i in range(args.pool_size): + network_info={} + logging.info('{} MODEL HAS SEARCHED'.format(i+1)) + if args.edge_decision == 'global_op_greedy': + global_op_greedy_pt_project(train_queue, model, args) + elif args.edge_decision == 'global_op_once': + global_op_once_pt_project(train_queue, model, args) + elif args.edge_decision == 'global_edge_greedy': + global_edge_greedy_pt_project(train_queue, model, args) + elif args.edge_decision == 'global_edge_once': + global_edge_once_pt_project(train_queue, model, args) + elif args.edge_decision == 'shrink_pt_project': + shrink_pt_project(train_queue, model, args) + api = API('../data/NAS-Bench-201-v1_0-e61699.pth') + cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \ + cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = query(api, model.genotype().tostr(), logging) + else: + if args.proj_crit == 'jacob': + pt_project(train_queue, model, args) + else: + pt_project(train_queue, model, args) + # tenas_project(train_queue, model, model_thin, args) + + network_info['id'] = str(i) + network_info['genotype'] = model.genotype().tostr() + networks_pool['networks'].append(network_info) + model.reset_arch_parameters() + + with open(os.path.join(args.save,'networks_pool.json'), 'w') as save_file: + json.dump(networks_pool, save_file) + + +#### util functions +def distill(result): + result = result.split('\n') + cifar10 = result[5].replace(' ', '').split(':') + cifar100 = result[7].replace(' ', '').split(':') + imagenet16 = result[9].replace(' ', '').split(':') + + cifar10_train = float(cifar10[1].strip(',test')[-7:-2].strip('=')) + cifar10_test = float(cifar10[2][-7:-2].strip('=')) + cifar100_train = float(cifar100[1].strip(',valid')[-7:-2].strip('=')) + cifar100_valid = float(cifar100[2].strip(',test')[-7:-2].strip('=')) + cifar100_test = float(cifar100[3][-7:-2].strip('=')) + imagenet16_train = float(imagenet16[1].strip(',valid')[-7:-2].strip('=')) + imagenet16_valid = float(imagenet16[2].strip(',test')[-7:-2].strip('=')) + imagenet16_test = float(imagenet16[3][-7:-2].strip('=')) + + return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \ + cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test + + +def query(api, genotype, logging): + result = api.query_by_arch(genotype, hp='200') + logging.info('{:}'.format(result)) + cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \ + cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = distill(result) + logging.info('cifar10 train %f test %f', cifar10_train, cifar10_test) + logging.info('cifar100 train %f valid %f test %f', cifar100_train, cifar100_valid, cifar100_test) + logging.info('imagenet16 train %f valid %f test %f', imagenet16_train, imagenet16_valid, imagenet16_test) + return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \ + cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test + + +if __name__ == '__main__': + main() diff --git a/nasbench201/op_score.py b/nasbench201/op_score.py new file mode 100644 index 0000000..71ba44c --- /dev/null +++ b/nasbench201/op_score.py @@ -0,0 +1,113 @@ +import gc +import numpy as np +import os +import sys +import torch +import torch.nn.functional as f +from operator import mul +from functools import reduce +import copy +sys.path.insert(0, '../') + +def Jocab_Score(ori_model, input, target, weights=None): + model = copy.deepcopy(ori_model) + model.eval() + model.proj_weights = weights + num_edge, num_op = model.num_edge, model.num_op + for i in range(num_edge): + model.candidate_flags[i] = False + batch_size = input.shape[0] + model.K = torch.zeros(batch_size, batch_size).cuda() + + def counting_forward_hook(module, inp, out): + try: + if isinstance(inp, tuple): + inp = inp[0] + inp = inp.view(inp.size(0), -1) + x = (inp > 0).float() + K = x @ x.t() + K2 = (1.-x) @ (1.-x.t()) + model.K = model.K + K + K2 + except: + pass + + for name, module in model.named_modules(): + if 'ReLU' in str(type(module)): + module.register_forward_hook(counting_forward_hook) + + input = input.cuda() + model(input) + score = hooklogdet(model.K.cpu().numpy()) + del model + del input + return score + +def hooklogdet(K, labels=None): + s, ld = np.linalg.slogdet(K) + return ld + +# NTK +#------------------------------------------------------------ +#https://github.com/VITA-Group/TENAS/blob/main/lib/procedures/ntk.py +# +def recal_bn(network, xloader, recalbn, device): + for m in network.modules(): + if isinstance(m, torch.nn.BatchNorm2d): + m.running_mean.data.fill_(0) + m.running_var.data.fill_(0) + m.num_batches_tracked.data.zero_() + m.momentum = None + network.train() + with torch.no_grad(): + for i, (inputs, targets) in enumerate(xloader): + if i >= recalbn: break + inputs = inputs.cuda(device=device, non_blocking=True) + _, _ = network(inputs) + return network + +def get_ntk_n(xloader, networks, recalbn=0, train_mode=False, num_batch=-1, weights=None): + device = torch.cuda.current_device() + ntks = [] + copied_networks = [] + for network in networks: + network = network.cuda(device=device) + net = copy.deepcopy(network) + net.proj_weights = weights + num_edge, num_op = net.num_edge, net.num_op + for i in range(num_edge): + net.candidate_flags[i] = False + if train_mode: + net.train() + else: + net.eval() + copied_networks.append(net) + ###### + grads = [[] for _ in range(len(copied_networks))] + for i, (inputs, targets) in enumerate(xloader): + if num_batch > 0 and i >= num_batch: break + inputs = inputs.cuda(device=device, non_blocking=True) + for net_idx, network in enumerate(copied_networks): + network.zero_grad() + inputs_ = inputs.clone().cuda(device=device, non_blocking=True) + logit = network(inputs_) + if isinstance(logit, tuple): + logit = logit[1] # 201 networks: return features and logits + for _idx in range(len(inputs_)): + logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True) + grad = [] + for name, W in network.named_parameters(): + if 'weight' in name and W.grad is not None: + grad.append(W.grad.view(-1).detach()) + grads[net_idx].append(torch.cat(grad, -1)) + network.zero_grad() + torch.cuda.empty_cache() + ###### + grads = [torch.stack(_grads, 0) for _grads in grads] + ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads] + conds = [] + for ntk in ntks: + eigenvalues, _ = torch.symeig(ntk) # ascending + conds.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0]).item(), copy=True, nan=100000.0)) + + del copied_networks + return conds \ No newline at end of file diff --git a/nasbench201/search_cells.py b/nasbench201/search_cells.py new file mode 100644 index 0000000..84f6214 --- /dev/null +++ b/nasbench201/search_cells.py @@ -0,0 +1,182 @@ +import math, random, torch +import warnings +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +import sys +sys.path.insert(0, '../') +from nasbench201.cell_operations import OPS + + +# This module is used for NAS-Bench-201, represents a small search space with a complete DAG +class NAS201SearchCell(nn.Module): + + def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True): + super(NAS201SearchCell, self).__init__() + + self.op_names = deepcopy(op_names) + self.edges = nn.ModuleDict() + self.max_nodes = max_nodes + self.in_dim = C_in + self.out_dim = C_out + for i in range(1, max_nodes): + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + if j == 0: + xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names] + else: + xlists = [OPS[op_name](C_in , C_out, 1, affine, track_running_stats) for op_name in op_names] + self.edges[ node_str ] = nn.ModuleList( xlists ) + self.edge_keys = sorted(list(self.edges.keys())) + self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} + self.num_edges = len(self.edges) + + def extra_repr(self): + string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}'.format(**self.__dict__) + return string + + def forward(self, inputs, weightss): + return self._forward(inputs, weightss) + + def _forward(self, inputs, weightss): + with torch.autograd.set_detect_anomaly(True): + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + weights = weightss[ self.edge2index[node_str] ] + inter_nodes.append(sum(layer(nodes[j], block_input=True)*w if w==0 else layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights)) ) + nodes.append( sum(inter_nodes) ) + return nodes[-1] + + # GDAS + def forward_gdas(self, inputs, hardwts, index): + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + weights = hardwts[ self.edge2index[node_str] ] + argmaxs = index[ self.edge2index[node_str] ].item() + weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) ) + inter_nodes.append( weigsum ) + nodes.append( sum(inter_nodes) ) + return nodes[-1] + + # joint + def forward_joint(self, inputs, weightss): + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + weights = weightss[ self.edge2index[node_str] ] + #aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel() + aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) + inter_nodes.append( aggregation ) + nodes.append( sum(inter_nodes) ) + return nodes[-1] + + # uniform random sampling per iteration, SETN + def forward_urs(self, inputs): + nodes = [inputs] + for i in range(1, self.max_nodes): + while True: # to avoid select zero for all ops + sops, has_non_zero = [], False + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + candidates = self.edges[node_str] + select_op = random.choice(candidates) + sops.append( select_op ) + if not hasattr(select_op, 'is_zero') or select_op.is_zero is False: has_non_zero=True + if has_non_zero: break + inter_nodes = [] + for j, select_op in enumerate(sops): + inter_nodes.append( select_op(nodes[j]) ) + nodes.append( sum(inter_nodes) ) + return nodes[-1] + + # select the argmax + def forward_select(self, inputs, weightss): + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + weights = weightss[ self.edge2index[node_str] ] + inter_nodes.append( self.edges[node_str][ weights.argmax().item() ]( nodes[j] ) ) + #inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) ) + nodes.append( sum(inter_nodes) ) + return nodes[-1] + + # forward with a specific structure + def forward_dynamic(self, inputs, structure): + nodes = [inputs] + for i in range(1, self.max_nodes): + cur_op_node = structure.nodes[i-1] + inter_nodes = [] + for op_name, j in cur_op_node: + node_str = '{:}<-{:}'.format(i, j) + op_index = self.op_names.index( op_name ) + inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) ) + nodes.append( sum(inter_nodes) ) + return nodes[-1] + + +def channel_shuffle(x, groups): + batchsize, num_channels, height, width = x.data.size() + channels_per_group = num_channels // groups + # reshape + x = x.view(batchsize, groups, + channels_per_group, height, width) + x = torch.transpose(x, 1, 2).contiguous() + # flatten + x = x.view(batchsize, -1, height, width) + return x + + +class NAS201SearchCell_PartialChannel(NAS201SearchCell): + + def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True, k=4): + super(NAS201SearchCell, self).__init__() + + self.k = k + self.op_names = deepcopy(op_names) + self.edges = nn.ModuleDict() + self.max_nodes = max_nodes + self.in_dim = C_in + self.out_dim = C_out + for i in range(1, max_nodes): + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + if j == 0: + xlists = [OPS[op_name](C_in//self.k , C_out//self.k, stride, affine, track_running_stats) for op_name in op_names] + else: + xlists = [OPS[op_name](C_in//self.k , C_out//self.k, 1, affine, track_running_stats) for op_name in op_names] + self.edges[ node_str ] = nn.ModuleList( xlists ) + self.edge_keys = sorted(list(self.edges.keys())) + self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} + self.num_edges = len(self.edges) + + def MixedOp(self, x, ops, weights): + dim_2 = x.shape[1] + xtemp = x[ : , : dim_2//self.k, :, :] + xtemp2 = x[ : , dim_2//self.k:, :, :] + temp1 = sum(w * op(xtemp) for w, op in zip(weights, ops)) + ans = torch.cat([temp1,xtemp2],dim=1) + ans = channel_shuffle(ans,self.k) + return ans + + def forward(self, inputs, weightss): + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + weights = weightss[ self.edge2index[node_str] ] + # inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) ) + inter_nodes.append(self.MixedOp(x=nodes[j], ops=self.edges[node_str], weights=weights)) + nodes.append( sum(inter_nodes) ) + return nodes[-1] + diff --git a/nasbench201/search_model.py b/nasbench201/search_model.py new file mode 100644 index 0000000..68e7000 --- /dev/null +++ b/nasbench201/search_model.py @@ -0,0 +1,202 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +from .cell_operations import ResNetBasicblock +from .search_cells import NAS201SearchCell as SearchCell +from .genotypes import Structure +from torch.autograd import Variable + +class TinyNetwork(nn.Module): + + def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args, affine=False, track_running_stats=True, stem_channels=3): + super(TinyNetwork, self).__init__() + self._C = C + self._layerN = N + self.max_nodes = max_nodes + self._num_classes = num_classes + self._criterion = criterion + self._args = args + self._affine = affine + self._track_running_stats = track_running_stats + self.stem = nn.Sequential( + nn.Conv2d(stem_channels, C, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(C)) + + layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + + C_prev, num_edge, edge2index = C, None, None + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2) + else: + cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) + if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index + else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) + self.cells.append( cell ) + C_prev = cell.out_dim + self.num_edge = num_edge + self.num_op = len(search_space) + self.op_names = deepcopy( search_space ) + self._Layer = len(self.cells) + self.edge2index = edge2index + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + # self._arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) + self._arch_parameters = Variable(1e-3*torch.randn(num_edge, len(search_space)).cuda(), requires_grad=True) + + ## optimizer + ## 记录的是m在内存中的地址,以示区分 + arch_params = set(id(m) for m in self.arch_parameters()) + self._model_params = [m for m in self.parameters() if id(m) not in arch_params] + + # 模型参数优化器 + self.optimizer = torch.optim.SGD( + self._model_params, + args.learning_rate, + momentum=args.momentum, + weight_decay=args.weight_decay, + nesterov= args.nesterov) + + + def entropy_y_x(self, p_logit): + p = F.softmax(p_logit, dim=1) + return - torch.sum(p * F.log_softmax(p_logit, dim=1)) / p_logit.shape[0] + + def _loss(self, input, target, return_logits=False): + logits = self(input) + loss = self._criterion(logits, target) + + return (loss, logits) if return_logits else loss + + def get_weights(self): + xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) + xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) + xlist+= list( self.classifier.parameters() ) + return xlist + + def arch_parameters(self): + return [self._arch_parameters] + + def get_theta(self): + return nn.functional.softmax(self._arch_parameters, dim=-1).cpu() + + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) + return string + + def extra_repr(self): + return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) + + def genotype(self): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + with torch.no_grad(): + weights = self._arch_parameters[ self.edge2index[node_str] ] + op_name = self.op_names[ weights.argmax().item() ] + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + return Structure( genotypes ) + + def forward(self, inputs, weights=None): + sim_nn = [] + + weights = nn.functional.softmax(self._arch_parameters, dim=-1) if weights is None else weights + + if self.slim: + weights[1].data.fill_(0) + weights[3].data.fill_(0) + weights[4].data.fill_(0) + + feature = self.stem(inputs) + for i, cell in enumerate(self.cells): + if isinstance(cell, SearchCell): + feature = cell(feature, weights) + else: + feature = cell(feature) + + out = self.lastact(feature) + out = self.global_pooling( out ) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return logits + + def _save_arch_parameters(self): + self._saved_arch_parameters = [p.clone() for p in self._arch_parameters] + + def project_arch(self): + self._save_arch_parameters() + for p in self.arch_parameters(): + m, n = p.size() + maxIndexs = p.data.cpu().numpy().argmax(axis=1) + p.data = self.proximal_step(p, maxIndexs) + + def proximal_step(self, var, maxIndexs=None): + values = var.data.cpu().numpy() + m, n = values.shape + alphas = [] + for i in range(m): + for j in range(n): + if j == maxIndexs[i]: + alphas.append(values[i][j].copy()) + values[i][j] = 1 + else: + values[i][j] = 0 + return torch.Tensor(values).cuda() + + def restore_arch_parameters(self): + for i, p in enumerate(self._arch_parameters): + p.data.copy_(self._saved_arch_parameters[i]) + del self._saved_arch_parameters + + def new(self): + model_new = TinyNetwork(self._C, self._layerN, self.max_nodes, self._num_classes, self._criterion, + self.op_names, self._args, self._affine, self._track_running_stats).cuda() + for x, y in zip(model_new.arch_parameters(), self.arch_parameters()): + x.data.copy_(y.data) + + return model_new + + def step(self, input, target, args, shared=None, return_grad=False): + Lt, logit_t = self._loss(input, target, return_logits=True) + Lt.backward() + if args.grad_clip != 0: + nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip) + self.optimizer.step() + + if return_grad: + grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()]) + return logit_t, Lt, grad + else: + return logit_t, Lt + + def printing(self, logging): + logging.info(self.get_theta()) + + def set_arch_parameters(self, new_alphas): + for alpha, new_alpha in zip(self.arch_parameters(), new_alphas): + alpha.data.copy_(new_alpha.data) + + def save_arch_parameters(self): + self._saved_arch_parameters = self._arch_parameters.clone() + + def restore_arch_parameters(self): + self.set_arch_parameters(self._saved_arch_parameters) + + def reset_optimizer(self, lr, momentum, weight_decay): + del self.optimizer + self.optimizer = torch.optim.SGD( + self.get_weights(), + lr, + momentum=momentum, + weight_decay=weight_decay, + nesterov= args.nesterov) \ No newline at end of file diff --git a/nasbench201/search_model_darts.py b/nasbench201/search_model_darts.py new file mode 100644 index 0000000..60bca66 --- /dev/null +++ b/nasbench201/search_model_darts.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn +from .search_cells import NAS201SearchCell as SearchCell +from .search_model import TinyNetwork as TinyNetwork + + +class TinyNetworkDarts(TinyNetwork): + def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args, + affine=False, track_running_stats=True, stem_channels=3): + super(TinyNetworkDarts, self).__init__(C, N, max_nodes, num_classes, criterion, search_space, args, + affine=affine, track_running_stats=track_running_stats, stem_channels=stem_channels) + + self.theta_map = lambda x: torch.softmax(x, dim=-1) + + def get_theta(self): + return self.theta_map(self._arch_parameters).cpu() + + def forward(self, inputs): + weights = self.theta_map(self._arch_parameters) + feature = self.stem(inputs) + + for i, cell in enumerate(self.cells): + if isinstance(cell, SearchCell): + feature = cell(feature, weights) + else: + feature = cell(feature) + + out = self.lastact(feature) + out = self.global_pooling( out ) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return logits diff --git a/nasbench201/search_model_darts_proj.py b/nasbench201/search_model_darts_proj.py new file mode 100644 index 0000000..5c08276 --- /dev/null +++ b/nasbench201/search_model_darts_proj.py @@ -0,0 +1,80 @@ +import torch +from .search_cells import NAS201SearchCell as SearchCell +from .search_model import TinyNetwork as TinyNetwork +from .genotypes import Structure +from torch.autograd import Variable + +class TinyNetworkDartsProj(TinyNetwork): + def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args, + affine=False, track_running_stats=True, stem_channels=3): + super(TinyNetworkDartsProj, self).__init__(C, N, max_nodes, num_classes, criterion, search_space, args, + affine=affine, track_running_stats=track_running_stats, stem_channels=stem_channels) + self.theta_map = lambda x: torch.softmax(x, dim=-1) + + #### for edgewise projection + self.candidate_flags = torch.tensor(len(self._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda() + self.proj_weights = torch.zeros_like(self._arch_parameters) + + def project_op(self, eid, opid): + self.proj_weights[eid][opid] = 1 ## hard by default + self.candidate_flags[eid] = False + + def get_projected_weights(self): + weights = self.theta_map(self._arch_parameters) + + ## proj op + for eid in range(len(self._arch_parameters)): + if not self.candidate_flags[eid]: + weights[eid].data.copy_(self.proj_weights[eid]) + + return weights + + def forward(self, inputs, weights=None): + with torch.autograd.set_detect_anomaly(True): + if weights is None: + weights = self.get_projected_weights() + + feature = self.stem(inputs) + for i, cell in enumerate(self.cells): + if isinstance(cell, SearchCell): + feature = cell(feature, weights) + else: + feature = cell(feature) + + out = self.lastact(feature) + out = self.global_pooling( out ) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return logits + + #### utils + def get_theta(self): + return self.get_projected_weights() + + def arch_parameters(self): + return [self._arch_parameters] + + def set_arch_parameters(self, new_alphas): + for eid, alpha in enumerate(self.arch_parameters()): + alpha.data.copy_(new_alphas[eid]) + + def reset_arch_parameters(self): + self._arch_parameters = Variable(1e-3*torch.randn(self.num_edge, len(self.op_names)).cuda(), requires_grad=True) + self.candidate_flags = torch.tensor(len(self._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda() + self.proj_weights = torch.zeros_like(self._arch_parameters) + + def genotype(self): + proj_weights = self.get_projected_weights() + + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + with torch.no_grad(): + weights = proj_weights[ self.edge2index[node_str] ] + op_name = self.op_names[ weights.argmax().item() ] + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + return Structure( genotypes ) \ No newline at end of file diff --git a/nasbench201/utils.py b/nasbench201/utils.py new file mode 100644 index 0000000..c1d43d7 --- /dev/null +++ b/nasbench201/utils.py @@ -0,0 +1,494 @@ +from __future__ import print_function + +import numpy as np +import os +import os.path +import sys +import shutil +import torch +import torchvision.transforms as transforms + +from PIL import Image +from torch.autograd import Variable +from torchvision.datasets import VisionDataset +from torchvision.datasets import utils + +if sys.version_info[0] == 2: + import cPickle as pickle +else: + import pickle + + +class AvgrageMeter(object): + + def __init__(self): + self.reset() + + def reset(self): + self.avg = 0 + self.sum = 0 + self.cnt = 0 + + def update(self, val, n=1): + self.sum += val * n + self.cnt += n + self.avg = self.sum / self.cnt + + +def accuracy(output, target, topk=(1,)): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].contiguous().view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +class Cutout(object): + def __init__(self, length, prob=1.0): + self.length = length + self.prob = prob + + def __call__(self, img): + if np.random.binomial(1, self.prob): + h, w = img.size(1), img.size(2) + mask = np.ones((h, w), np.float32) + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[y1: y2, x1: x2] = 0. + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img *= mask + return img + +def _data_transforms_svhn(args): + SVHN_MEAN = [0.4377, 0.4438, 0.4728] + SVHN_STD = [0.1980, 0.2010, 0.1970] + + train_transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(SVHN_MEAN, SVHN_STD), + ]) + if args.cutout: + train_transform.transforms.append(Cutout(args.cutout_length, + args.cutout_prob)) + + valid_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(SVHN_MEAN, SVHN_STD), + ]) + return train_transform, valid_transform + + +def _data_transforms_cifar100(args): + CIFAR_MEAN = [0.5071, 0.4865, 0.4409] + CIFAR_STD = [0.2673, 0.2564, 0.2762] + + train_transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(CIFAR_MEAN, CIFAR_STD), + ]) + if args.cutout: + train_transform.transforms.append(Cutout(args.cutout_length, + args.cutout_prob)) + + valid_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(CIFAR_MEAN, CIFAR_STD), + ]) + return train_transform, valid_transform + + +def _data_transforms_cifar10(args): + CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] + CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] + + train_transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(CIFAR_MEAN, CIFAR_STD), + ]) + if args.cutout: + train_transform.transforms.append(Cutout(args.cutout_length, + args.cutout_prob)) + + valid_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(CIFAR_MEAN, CIFAR_STD), + ]) + return train_transform, valid_transform + + +def count_parameters_in_MB(model): + return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6 + + +def count_parameters_in_Compact(model): + from sota.cnn.model import Network as CompactModel + genotype = model.genotype() + compact_model = CompactModel(36, model._num_classes, 20, True, genotype) + num_params = count_parameters_in_MB(compact_model) + return num_params + + +def save_checkpoint(state, is_best, save, per_epoch=False, prefix=''): + filename = prefix + if per_epoch: + epoch = state['epoch'] + filename += 'checkpoint_{}.pth.tar'.format(epoch) + else: + filename += 'checkpoint.pth.tar' + filename = os.path.join(save, filename) + torch.save(state, filename) + if is_best: + best_filename = os.path.join(save, 'model_best.pth.tar') + shutil.copyfile(filename, best_filename) + + +def load_checkpoint(model, optimizer, save, epoch=None): + if epoch is None: + filename = 'checkpoint.pth.tar' + else: + filename = 'checkpoint_{}.pth.tar'.format(epoch) + filename = os.path.join(save, filename) + start_epoch = 0 + if os.path.isfile(filename): + print("=> loading checkpoint '{}'".format(filename)) + checkpoint = torch.load(filename) + start_epoch = checkpoint['epoch'] + best_acc_top1 = checkpoint['best_acc_top1'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(filename, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(filename)) + + return model, optimizer, start_epoch, best_acc_top1 + + +def save(model, model_path): + torch.save(model.state_dict(), model_path) + + +def load(model, model_path): + model.load_state_dict(torch.load(model_path)) + + +def drop_path(x, drop_prob): + if drop_prob > 0.: + keep_prob = 1. - drop_prob + mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) + x.div_(keep_prob) + x.mul_(mask) + return x + + +def create_exp_dir(path, scripts_to_save=None): + if not os.path.exists(path): + os.makedirs(path) + print('Experiment dir : {}'.format(path)) + + if scripts_to_save is not None: + os.mkdir(os.path.join(path, 'scripts')) + for script in scripts_to_save: + dst_file = os.path.join(path, 'scripts', os.path.basename(script)) + shutil.copyfile(script, dst_file) + + +class CIFAR10(VisionDataset): + """`CIFAR10 `_ Dataset. + + Args: + root (string): Root directory of dataset where directory + ``cifar-10-batches-py`` exists or will be saved to if download is set to True. + train (bool, optional): If True, creates dataset from training set, otherwise + creates from test set. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + + """ + base_folder = 'cifar-10-batches-py' + url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" + filename = "cifar-10-python.tar.gz" + tgz_md5 = 'c58f30108f718f92721af3b95e74349a' + train_list = [ + ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], + ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], + ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], + ['data_batch_4', '634d18415352ddfa80567beed471001a'], + #['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], + ] + + test_list = [ + ['test_batch', '40351d587109b95175f43aff81a1287e'], + ] + meta = { + 'filename': 'batches.meta', + 'key': 'label_names', + 'md5': '5ff9c542aee3614f3951f8cda6e48888', + } + + def __init__(self, root, train=True, transform=None, target_transform=None, + download=False): + + super(CIFAR10, self).__init__(root, transform=transform, + target_transform=target_transform) + + self.train = train # training set or test set + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + if self.train: + downloaded_list = self.train_list + else: + downloaded_list = self.test_list + + self.data = [] + self.targets = [] + + # now load the picked numpy arrays + for file_name, checksum in downloaded_list: + file_path = os.path.join(self.root, self.base_folder, file_name) + with open(file_path, 'rb') as f: + if sys.version_info[0] == 2: + entry = pickle.load(f) + else: + entry = pickle.load(f, encoding='latin1') + self.data.append(entry['data']) + if 'labels' in entry: + self.targets.extend(entry['labels']) + else: + self.targets.extend(entry['fine_labels']) + + self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) + self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC + + self._load_meta() + + def _load_meta(self): + path = os.path.join(self.root, self.base_folder, self.meta['filename']) + if not utils.check_integrity(path, self.meta['md5']): + raise RuntimeError('Dataset metadata file not found or corrupted.' + + ' You can use download=True to download it') + with open(path, 'rb') as infile: + if sys.version_info[0] == 2: + data = pickle.load(infile) + else: + data = pickle.load(infile, encoding='latin1') + self.classes = data[self.meta['key']] + self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + img, target = self.data[index], self.targets[index] + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return len(self.data) + + def _check_integrity(self): + root = self.root + for fentry in (self.train_list + self.test_list): + filename, md5 = fentry[0], fentry[1] + fpath = os.path.join(root, self.base_folder, filename) + if not utils.check_integrity(fpath, md5): + return False + return True + + def download(self): + if self._check_integrity(): + print('Files already downloaded and verified') + return + utils.download_and_extract_archive(self.url, self.root, + filename=self.filename, + md5=self.tgz_md5) + + def extra_repr(self): + return "Split: {}".format("Train" if self.train is True else "Test") + + +def pick_gpu_lowest_memory(): + import gpustat + stats = gpustat.GPUStatCollection.new_query() + ids = map(lambda gpu: int(gpu.entry['index']), stats) + ratios = map(lambda gpu: float(gpu.memory_used)/float(gpu.memory_total), stats) + bestGPU = min(zip(ids, ratios), key=lambda x: x[1])[0] + return bestGPU + + +#### early stopping (from RobustNAS) +class EVLocalAvg(object): + def __init__(self, window=5, ev_freq=2, total_epochs=50): + """ Keep track of the eigenvalues local average. + Args: + window (int): number of elements used to compute local average. + Default: 5 + ev_freq (int): frequency used to compute eigenvalues. Default: + every 2 epochs + total_epochs (int): total number of epochs that DARTS runs. + Default: 50 + """ + self.window = window + self.ev_freq = ev_freq + self.epochs = total_epochs + + self.stop_search = False + self.stop_epoch = total_epochs - 1 + self.stop_genotype = None + self.stop_numparam = 0 + + self.ev = [] + self.ev_local_avg = [] + self.genotypes = {} + self.numparams = {} + self.la_epochs = {} + + # start and end index of the local average window + self.la_start_idx = 0 + self.la_end_idx = self.window + + def reset(self): + self.ev = [] + self.ev_local_avg = [] + self.genotypes = {} + self.numparams = {} + self.la_epochs = {} + + def update(self, epoch, ev, genotype, numparam=0): + """ Method to update the local average list. + + Args: + epoch (int): current epoch + ev (float): current dominant eigenvalue + genotype (namedtuple): current genotype + + """ + self.ev.append(ev) + self.genotypes.update({epoch: genotype}) + self.numparams.update({epoch: numparam}) + # set the stop_genotype to the current genotype in case the early stop + # procedure decides not to early stop + self.stop_genotype = genotype + + # since the local average computation starts after the dominant + # eigenvalue in the first epoch is already computed we have to wait + # at least until we have 3 eigenvalues in the list. + if (len(self.ev) >= int(np.ceil(self.window/2))) and (epoch < + self.epochs - 1): + # start sliding the window as soon as the number of eigenvalues in + # the list becomes equal to the window size + if len(self.ev) < self.window: + self.ev_local_avg.append(np.mean(self.ev)) + else: + assert len(self.ev[self.la_start_idx: self.la_end_idx]) == self.window + self.ev_local_avg.append(np.mean(self.ev[self.la_start_idx: + self.la_end_idx])) + self.la_start_idx += 1 + self.la_end_idx += 1 + + # keep track of the offset between the current epoch and the epoch + # corresponding to the local average. NOTE: in the end the size of + # self.ev and self.ev_local_avg should be equal + self.la_epochs.update({epoch: int(epoch - + int(self.ev_freq*np.floor(self.window/2)))}) + + elif len(self.ev) < int(np.ceil(self.window/2)): + self.la_epochs.update({epoch: -1}) + + # since there is an offset between the current epoch and the local + # average epoch, loop in the last epoch to compute the local average of + # these number of elements: window, window - 1, window - 2, ..., ceil(window/2) + elif epoch == self.epochs - 1: + for i in range(int(np.ceil(self.window/2))): + assert len(self.ev[self.la_start_idx: self.la_end_idx]) == self.window - i + self.ev_local_avg.append(np.mean(self.ev[self.la_start_idx: + self.la_end_idx + 1])) + self.la_start_idx += 1 + + def early_stop(self, epoch, factor=1.3, es_start_epoch=10, delta=4, criteria='local_avg'): + """ Early stopping criterion + + Args: + epoch (int): current epoch + factor (float): threshold factor for the ration between the current + and prefious eigenvalue. Default: 1.3 + es_start_epoch (int): until this epoch do not consider early + stopping. Default: 20 + delta (int): factor influencing which previous local average we + consider for early stopping. Default: 2 + """ + if criteria == 'local_avg': + if int(self.la_epochs[epoch] - self.ev_freq*delta) >= es_start_epoch: + if criteria == 'local_avg': + current_la = self.ev_local_avg[-1] + previous_la = self.ev_local_avg[-1 - delta] + self.stop_search = current_la / previous_la > factor + if self.stop_search: + self.stop_epoch = int(self.la_epochs[epoch] - self.ev_freq*delta) + self.stop_genotype = self.genotypes[self.stop_epoch] + self.stop_numparam = self.numparams[self.stop_epoch] + elif criteria == 'exact': + if epoch > es_start_epoch: + current_la = self.ev[-1] + previous_la = self.ev[-1 - delta] + self.stop_search = current_la / previous_la > factor + if self.stop_search: + self.stop_epoch = epoch - delta + self.stop_genotype = self.genotypes[self.stop_epoch] + self.stop_numparam = self.numparams[self.stop_epoch] + else: + print('ERROR IN EARLY STOP: WRONG CRITERIA:', criteria); exit(0) + + +def gen_comb(eids): + comb = [] + for r in range(len(eids)): + for c in range(r + 1, len(eids)): + comb.append((eids[r], eids[c])) + + return comb