update
This commit is contained in:
133
correlation/NAS-Bench-101.py
Normal file
133
correlation/NAS-Bench-101.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import pickle
|
||||
import torch
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
from thop import profile
|
||||
|
||||
from foresight.models import *
|
||||
from foresight.pruners import *
|
||||
from foresight.dataset import *
|
||||
|
||||
|
||||
def get_num_classes(args):
|
||||
return 100 if args.dataset == 'cifar100' else 10 if args.dataset == 'cifar10' else 120
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='Zero-cost Metrics for NAS-Bench-101')
|
||||
parser.add_argument('--api_loc', default='../data/nasbench_only108.tfrecord',
|
||||
type=str, help='path to API')
|
||||
parser.add_argument('--json_loc', default='data/all_graphs.json',
|
||||
type=str, help='path to JSON database')
|
||||
parser.add_argument('--outdir', default='./',
|
||||
type=str, help='output directory')
|
||||
parser.add_argument('--outfname', default='test',
|
||||
type=str, help='output filename')
|
||||
parser.add_argument('--batch_size', default=256, type=int)
|
||||
parser.add_argument('--dataset', type=str, default='cifar10',
|
||||
help='dataset to use [cifar10, cifar100, ImageNet16-120]')
|
||||
parser.add_argument('--gpu', type=int, default=0, help='GPU index to work on')
|
||||
parser.add_argument('--num_data_workers', type=int, default=2, help='number of workers for dataloaders')
|
||||
parser.add_argument('--dataload', type=str, default='random', help='random or grasp supported')
|
||||
parser.add_argument('--dataload_info', type=int, default=1,
|
||||
help='number of batches to use for random dataload or number of samples per class for grasp dataload')
|
||||
parser.add_argument('--start', type=int, default=5, help='start index')
|
||||
parser.add_argument('--end', type=int, default=10, help='end index')
|
||||
parser.add_argument('--write_freq', type=int, default=100, help='frequency of write to file')
|
||||
args = parser.parse_args()
|
||||
args.device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
|
||||
return args
|
||||
|
||||
|
||||
def get_op_names(v):
|
||||
o = []
|
||||
for op in v:
|
||||
if op == -1:
|
||||
o.append('input')
|
||||
elif op == -2:
|
||||
o.append('output')
|
||||
elif op == 0:
|
||||
o.append('conv3x3-bn-relu')
|
||||
elif op == 1:
|
||||
o.append('conv1x1-bn-relu')
|
||||
elif op == 2:
|
||||
o.append('maxpool3x3')
|
||||
return o
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
# nasbench = api.NASBench(args.api_loc)
|
||||
models = json.load(open(args.json_loc))
|
||||
|
||||
print(f'Running models {args.start} to {args.end} out of {len(models.keys())}')
|
||||
|
||||
train_loader, val_loader = get_cifar_dataloaders(args.batch_size, args.batch_size, args.dataset,
|
||||
args.num_data_workers)
|
||||
|
||||
all_points = []
|
||||
pre = 'cf' if 'cifar' in args.dataset else 'im'
|
||||
|
||||
if args.outfname == 'test':
|
||||
fn = f'nb1_{pre}{get_num_classes(args)}.p'
|
||||
else:
|
||||
fn = f'{args.outfname}.p'
|
||||
op = os.path.join(args.outdir, fn)
|
||||
|
||||
print('outfile =', op)
|
||||
first = True
|
||||
|
||||
# loop over nasbench1 archs (k=hash, v=[adj_matrix, ops])
|
||||
idx = 0
|
||||
cached_res = []
|
||||
for k, v in models.items():
|
||||
|
||||
if idx < args.start:
|
||||
idx += 1
|
||||
continue
|
||||
if idx >= args.end:
|
||||
break
|
||||
print(f'idx = {idx}')
|
||||
idx += 1
|
||||
|
||||
res = {}
|
||||
res['hash'] = k
|
||||
|
||||
# model
|
||||
spec = nasbench1_spec._ToModelSpec(v[0], get_op_names(v[1]))
|
||||
net = nasbench1.Network(spec, stem_out=128, num_stacks=3, num_mods=3, num_classes=get_num_classes(args))
|
||||
net.to(args.device)
|
||||
|
||||
measures = predictive.find_measures(net,
|
||||
train_loader,
|
||||
(args.dataload, args.dataload_info, get_num_classes(args)),
|
||||
args.device)
|
||||
res['logmeasures'] = measures
|
||||
|
||||
print(res)
|
||||
cached_res.append(res)
|
||||
|
||||
# write to file
|
||||
if idx % args.write_freq == 0 or idx == args.end or idx == args.start + 10:
|
||||
print(f'writing {len(cached_res)} results to {op}')
|
||||
pf = open(op, 'ab')
|
||||
for cr in cached_res:
|
||||
pickle.dump(cr, pf)
|
||||
pf.close()
|
||||
cached_res = []
|
128
correlation/NAS-Bench-201.py
Normal file
128
correlation/NAS-Bench-201.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import time
|
||||
|
||||
from foresight.dataset import *
|
||||
from foresight.models import nasbench2
|
||||
from foresight.pruners import predictive
|
||||
from foresight.weight_initializers import init_net
|
||||
from models import get_cell_based_tiny_net
|
||||
import pickle
|
||||
|
||||
|
||||
def get_num_classes(args):
|
||||
return 100 if args.dataset == 'cifar100' else 10 if args.dataset == 'cifar10' else 120
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='Zero-cost Metrics for NAS-Bench-201')
|
||||
parser.add_argument('--api_loc', default='../data/NAS-Bench-201-v1_0-e61699.pth',
|
||||
type=str, help='path to API')
|
||||
parser.add_argument('--outdir', default='./',
|
||||
type=str, help='output directory')
|
||||
parser.add_argument('--init_w_type', type=str, default='none',
|
||||
help='weight initialization (before pruning) type [none, xavier, kaiming, zero, one]')
|
||||
parser.add_argument('--init_b_type', type=str, default='none',
|
||||
help='bias initialization (before pruning) type [none, xavier, kaiming, zero, one]')
|
||||
parser.add_argument('--batch_size', default=64, type=int)
|
||||
parser.add_argument('--dataset', type=str, default='ImageNet16-120',
|
||||
help='dataset to use [cifar10, cifar100, ImageNet16-120]')
|
||||
parser.add_argument('--gpu', type=int, default=5, help='GPU index to work on')
|
||||
parser.add_argument('--data_size', type=int, default=32, help='data_size')
|
||||
parser.add_argument('--num_data_workers', type=int, default=2, help='number of workers for dataloaders')
|
||||
parser.add_argument('--dataload', type=str, default='appoint', help='random, grasp, appoint supported')
|
||||
parser.add_argument('--dataload_info', type=int, default=1,
|
||||
help='number of batches to use for random dataload or number of samples per class for grasp dataload')
|
||||
parser.add_argument('--seed', type=int, default=42, help='pytorch manual seed')
|
||||
parser.add_argument('--write_freq', type=int, default=1, help='frequency of write to file')
|
||||
parser.add_argument('--start', type=int, default=0, help='start index')
|
||||
parser.add_argument('--end', type=int, default=0, help='end index')
|
||||
parser.add_argument('--noacc', default=False, action='store_true',
|
||||
help='avoid loading NASBench2 api an instead load a pickle file with tuple (index, arch_str)')
|
||||
args = parser.parse_args()
|
||||
args.device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
print(args.device)
|
||||
|
||||
if args.noacc:
|
||||
api = pickle.load(open(args.api_loc,'rb'))
|
||||
else:
|
||||
from nas_201_api import NASBench201API as API
|
||||
api = API(args.api_loc)
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
train_loader, val_loader = get_cifar_dataloaders(args.batch_size, args.batch_size, args.dataset, args.num_data_workers, resize=args.data_size)
|
||||
x, y = next(iter(train_loader))
|
||||
# random data
|
||||
# x = torch.rand((args.batch_size, 3, args.data_size, args.data_size))
|
||||
# y = 0
|
||||
|
||||
cached_res = []
|
||||
pre = 'cf' if 'cifar' in args.dataset else 'im'
|
||||
pfn = f'nb2_{args.search_space}_{pre}{get_num_classes(args)}_seed{args.seed}_dl{args.dataload}_dlinfo{args.dataload_info}_initw{args.init_w_type}_initb{args.init_b_type}_{args.batch_size}.p'
|
||||
op = os.path.join(args.outdir, pfn)
|
||||
|
||||
end = len(api) if args.end == 0 else args.end
|
||||
|
||||
# loop over nasbench2 archs
|
||||
for i, arch_str in enumerate(api):
|
||||
|
||||
if i < args.start:
|
||||
continue
|
||||
if i >= end:
|
||||
break
|
||||
|
||||
res = {'i': i, 'arch': arch_str}
|
||||
# print(arch_str)
|
||||
if args.search_space == 'tss':
|
||||
net = nasbench2.get_model_from_arch_str(arch_str, get_num_classes(args))
|
||||
arch_str2 = nasbench2.get_arch_str_from_model(net)
|
||||
if arch_str != arch_str2:
|
||||
print(arch_str)
|
||||
print(arch_str2)
|
||||
raise ValueError
|
||||
elif args.search_space == 'sss':
|
||||
config = api.get_net_config(i, args.dataset)
|
||||
# print(config)
|
||||
net = get_cell_based_tiny_net(config)
|
||||
net.to(args.device)
|
||||
# print(net)
|
||||
|
||||
init_net(net, args.init_w_type, args.init_b_type)
|
||||
|
||||
# print(x.size(), y)
|
||||
measures = get_score(net, x, i, args.device)
|
||||
|
||||
res['meco'] = measures
|
||||
|
||||
if not args.noacc:
|
||||
info = api.get_more_info(i, 'cifar10-valid' if args.dataset == 'cifar10' else args.dataset, iepoch=None,
|
||||
hp='200', is_random=False)
|
||||
|
||||
trainacc = info['train-accuracy']
|
||||
valacc = info['valid-accuracy']
|
||||
testacc = info['test-accuracy']
|
||||
|
||||
res['trainacc'] = trainacc
|
||||
res['valacc'] = valacc
|
||||
res['testacc'] = testacc
|
||||
|
||||
print(res)
|
||||
cached_res.append(res)
|
||||
|
||||
# write to file
|
||||
if i % args.write_freq == 0 or i == len(api) - 1 or i == 10:
|
||||
print(f'writing {len(cached_res)} results to {op}')
|
||||
pf = open(op, 'ab')
|
||||
for cr in cached_res:
|
||||
pickle.dump(cr, pf)
|
||||
pf.close()
|
||||
cached_res = []
|
Reference in New Issue
Block a user