import sys import os import json import tqdm import torch import torch.utils import torchvision.datasets as dset import torch.backends.cudnn as cudnn import random import glob import logging import shutil import numpy as np sys.path.insert(0, '../') from nasbench201.cell_infers.tiny_network import TinyNetwork from nasbench201.genotypes import Structure from nas_201_api import NASBench201API as API from pycls.models.nas.nas import NetworkImageNet, NetworkCIFAR from pycls.models.nas.genotypes import Genotype import nasbench201.utils as ig_utils from foresight.pruners import * from Scorers.scorer import Jocab_Scorer import torchvision.transforms as transforms import argparse from mobilenet_search_space.retrain_architecture.model import Network from torch.utils.tensorboard import SummaryWriter from sota.cnn.hdf5 import H5Dataset parser = argparse.ArgumentParser("sota") parser.add_argument('--data', type=str, default='../data', help='location of the data corpus') parser.add_argument('--dataset', type=str, default='cifar10', help='choose dataset') parser.add_argument('--gpu', type=str, default='auto', help='gpu device id') parser.add_argument('--save', type=str, default='exp', help='experiment name') parser.add_argument('--save_path', type=str, default='../experiments/sota', help='experiment name') parser.add_argument('--seed', type=int, default=2, help='random seed') parser.add_argument('--ckpt_path', type=str, help='path that saved networks pool') parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data') parser.add_argument('--maxiter', default=1, type=int, help='score is the max of this many evaluations of the network') parser.add_argument('--batch_size', type=int, default=256, help='batch size for alpha') parser.add_argument('--cutout', action='store_true', default=False, help='use cutout') parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability') parser.add_argument('--init_channels', type=int, default=16, help='num of init channels') parser.add_argument('--layers', type=int, default=8, help='total number of layers') parser.add_argument('--validate_rounds', type=int, default=10, help='score round for networks') parser.add_argument('--proj_crit', type=str, default='jacob', choices=['loss', 'acc', 'var', 'cor', 'norm', 'jacob', 'snip', 'fisher', 'synflow', 'grad_norm', 'grasp', 'jacob_cov', 'comb', 'meco', 'zico'], help='criteria for projection') parser.add_argument('--edge_decision', type=str, default='random', choices=['random','reverse', 'order', 'global_op_greedy', 'global_op_once', 'global_edge_greedy', 'global_edge_once'], help='which edge to be projected next') args = parser.parse_args() torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) def load_network_pool(ckpt_path): with open(os.path.join(ckpt_path,'networks_pool.json'), 'r') as save_file: for line in save_file: networks_pool = json.loads(line) if 'pool_size' in networks_pool: return networks_pool['search_space'], networks_pool['dataset'], networks_pool['networks'], networks_pool['pool_size'] else: return networks_pool['search_space'], networks_pool['dataset'], networks_pool['networks'], len(networks_pool['networks']) #### args augment search_space, dataset, networks_pool, pool_size = load_network_pool(args.ckpt_path) # print(search_space, dataset, networks_pool, pool_size) search_space = search_space.strip() dataset = dataset.strip() expid = args.save args.save = '{}/{}-valid-{}-{}-{}-{}'.format(args.save_path, search_space, args.save, args.seed, pool_size, args.validate_rounds) if not dataset == 'cifar10': args.save += '-' + dataset if not args.edge_decision == 'random': args.save += '-' + args.edge_decision if not args.proj_crit == 'jacob': args.save += '-' + args.proj_crit scripts_to_save = glob.glob('*.py') + ['../exp_scripts/{}.sh'.format(expid)] if os.path.exists(args.save): if input("WARNING: {} exists, override?[y/n]".format(args.save)) == 'y': print('proceed to override saving directory') shutil.rmtree(args.save) else: exit(0) ig_utils.create_exp_dir(args.save, scripts_to_save=None) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') log_file = 'log' log_file += '.txt' log_path = os.path.join(args.save, log_file) logging.info('======> log filename: %s', log_file) logging.info('load pool from space:%s and dataset:%s', search_space, dataset) if os.path.exists(log_path): if input("WARNING: {} exists, override?[y/n]".format(log_file)) == 'y': print('proceed to override log file directory') else: exit(0) fh = logging.FileHandler(log_path, mode='w') fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) writer = SummaryWriter(args.save + '/runs') #### macros if dataset == 'cifar100': n_classes = 100 elif dataset == 'imagenet16-120': n_classes = 120 elif dataset == 'imagenet': n_classes = 1000 else: n_classes = 10 if search_space == 'nas-bench-201': api = API('../data/NAS-Bench-201-v1_0-e61699.pth') if search_space == 'nb_macro': import pickle as pkl f = open('../data/nbmacro-base-0.pickle','rb') head = pkl.load(f) value = pkl.load(f) api ={} for v in value: h, val_t1, test_t1, t_time = v api[h] = test_t1 def main(): #### data if dataset == 'imagenet': normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2), transforms.ToTensor(), normalize, ]) test_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ]) train_data = H5Dataset(os.path.join(args.data, 'imagenet-train-256.h5'), transform=train_transform) num_train = len(train_data) indices = list(range(num_train)) split = int(np.floor(args.validate_rounds * args.batch_size)) train_queue = torch.utils.data.DataLoader( train_data, batch_size=args.batch_size, num_workers=4, pin_memory=True, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split])) else: if dataset == 'cifar10': train_transform, valid_transform = ig_utils._data_transforms_cifar10(args) train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform) valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform) elif dataset == 'cifar100': train_transform, valid_transform = ig_utils._data_transforms_cifar100(args) train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform) valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform) elif dataset == 'svhn': train_transform, valid_transform = ig_utils._data_transforms_svhn(args) train_data = dset.SVHN(root=args.data, split='train', download=True, transform=train_transform) valid_data = dset.SVHN(root=args.data, split='test', download=True, transform=valid_transform) elif dataset == 'imagenet16-120': from nasbench201.DownsampledImageNet import ImageNet16 mean = [x / 255 for x in [122.68, 116.66, 104.01]] std = [x / 255 for x in [63.22, 61.26, 65.09]] lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)] train_transform = transforms.Compose(lists) train_data = ImageNet16(root=os.path.join(data,'imagenet16'), train=True, transform=train_transform, use_num_of_class_only=120) valid_data = ImageNet16(root=os.path.join(data,'imagenet16'), train=False, transform=train_transform, use_num_of_class_only=120) assert len(train_data) == 151700 num_train = len(train_data) indices = list(range(num_train)) split = int(np.floor(args.validate_rounds * args.batch_size)) train_queue = torch.utils.data.DataLoader( train_data, batch_size=args.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), pin_memory=True, num_workers=4) gpu = ig_utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu) torch.cuda.set_device(gpu) if args.proj_crit == 'jacob': validate_scorer = Jocab_Scorer(gpu) best_id = None best_score = 0 best_networks = None crit_list = [] print(len(train_queue)) net_history = [] for net_config in tqdm.tqdm(networks_pool, desc="networks", position=0): net_id = net_config['id'] # print(net_id) net_genotype = net_config['genotype'] # print(net_genotype) if net_genotype not in net_history: net_history.append(net_genotype) # print(net_genotype) network = get_networks_from_genotype(net_genotype, dataset, search_space) # print(network) if args.proj_crit == 'jacob': validate_scorer.setup_hooks(network, args.batch_size) for step, (input, target) in tqdm.tqdm(enumerate(train_queue), desc="validate_rounds", position=1, leave=False): input.cuda() target.cuda() if args.proj_crit == 'jacob': score = validate_scorer.score(network, input, target) else: #score = score_loop(network, None, train_queue, args.gpu, None, args.proj_crit) network.requires_feature = False else: measures = predictive.find_measures(network, train_queue, ('random', 1, n_classes), torch.device("cuda"), measure_names=[args.proj_crit]) # measures = predictive.find_measures(network, # train_queue, # ('random', 1, n_classes), #TODO don't hard-code num_classes to 10 # torch.device("cuda"), # measure_names=[args.proj_crit]) score = measures[args.proj_crit] if step == 0: crit_list.append(score) else: crit_list[-1] += score if args.proj_crit != 'jacob': break #best_networks = networks_pool[np.nanargmax(crit_list)]['genotype'] best_networks = net_history[np.nanargmax(crit_list)] if search_space == 'nas-bench-201': cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \ cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = query(api, best_networks, logging) networks_info={} networks_info['search_space'] = search_space networks_info['dataset'] = dataset networks_info['networks'] = best_networks with open(os.path.join(args.save,'best_networks.json'), 'w') as save_file: json.dump(networks_info, save_file) #### util functions def distill(result): result = result.split('\n') cifar10 = result[5].replace(' ', '').split(':') cifar100 = result[7].replace(' ', '').split(':') imagenet16 = result[9].replace(' ', '').split(':') cifar10_train = float(cifar10[1].strip(',test')[-7:-2].strip('=')) cifar10_test = float(cifar10[2][-7:-2].strip('=')) cifar100_train = float(cifar100[1].strip(',valid')[-7:-2].strip('=')) cifar100_valid = float(cifar100[2].strip(',test')[-7:-2].strip('=')) cifar100_test = float(cifar100[3][-7:-2].strip('=')) imagenet16_train = float(imagenet16[1].strip(',valid')[-7:-2].strip('=')) imagenet16_valid = float(imagenet16[2].strip(',test')[-7:-2].strip('=')) imagenet16_test = float(imagenet16[3][-7:-2].strip('=')) return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \ cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test def query(api, genotype, logging): result = api.query_by_arch(genotype, hp='200') logging.info('{:}'.format(result)) cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \ cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = distill(result) logging.info('cifar10 train %f test %f', cifar10_train, cifar10_test) logging.info('cifar100 train %f valid %f test %f', cifar100_train, cifar100_valid, cifar100_test) logging.info('imagenet16 train %f valid %f test %f', imagenet16_train, imagenet16_valid, imagenet16_test) return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \ cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test def get_networks_from_genotype(genotype_str, dataset, search_space): if search_space == 'nas-bench-201': net_index = api.query_index_by_arch(genotype_str) ##print(dataset) net_config = api.get_net_config(net_index, 'cifar10-valid') print(net_config) genotype = Structure.str2structure(net_config['arch_str']) network = TinyNetwork(net_config['C'], net_config['N'], genotype, n_classes) return network elif search_space == 'mobilenet': rngs = [int(id) for id in genotype_str.split(' ')] network = Network(rngs, n_class=n_classes) return network else: # print(genotype_str) genotype_config = json.loads(genotype_str) genotype = Genotype(normal=genotype_config['normal'], normal_concat=genotype_config['normal_concat'], reduce=genotype_config['reduce'], reduce_concat=genotype_config['reduce_concat']) if dataset == 'imagenet': network = NetworkImageNet(args.init_channels, n_classes, args.layers, False, genotype) else: network = NetworkCIFAR(args.init_channels, n_classes, args.layers, False, genotype) network.drop_path_prob = 0. return network if __name__ == '__main__': main()