From 8949d0b18e3e327faec8aba85ec87c8f39f949b7 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Fri, 20 Nov 2020 09:52:29 +0800 Subject: [PATCH] Update find_best API --- exps/show-dataset.py | 47 +++++++++++++++++++++++++++++ lib/datasets/DownsampledImageNet.py | 16 ++++++---- lib/nats_bench/api_utils.py | 3 ++ 3 files changed, 60 insertions(+), 6 deletions(-) create mode 100644 exps/show-dataset.py diff --git a/exps/show-dataset.py b/exps/show-dataset.py new file mode 100644 index 0000000..8f3e42c --- /dev/null +++ b/exps/show-dataset.py @@ -0,0 +1,47 @@ +############################################################################## +# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # +############################################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 # +############################################################################## +# python ./exps/NATS-Bench/main-tss.py --mode meta # +############################################################################## +import os, sys, time, torch, random, argparse +from typing import List, Text, Dict, Any +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True +from copy import deepcopy +from pathlib import Path + +lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() +if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +from config_utils import dict2config, load_config +from datasets import get_datasets +from nats_bench import create + + +def show_imagenet_16_120(dataset_dir=None): + if dataset_dir is None: + torch_home_dir = os.environ['TORCH_HOME'] if 'TORCH_HOME' in os.environ else os.path.join(os.environ['HOME'], '.torch') + dataset_dir = os.path.join(torch_home_dir, 'cifar.python', 'ImageNet16') + train_data, valid_data, xshape, class_num = get_datasets('ImageNet16-120', dataset_dir, -1) + split_info = load_config('configs/nas-benchmark/ImageNet16-120-split.txt', None, None) + print('=' * 10 + ' ImageNet-16-120 ' + '=' * 10) + print('Training Data: {:}'.format(train_data)) + print('Evaluation Data: {:}'.format(valid_data)) + print('Hold-out training: {:} images.'.format(len(split_info.train))) + print('Hold-out valid : {:} images.'.format(len(split_info.valid))) + + +if __name__ == '__main__': + # show_imagenet_16_120() + api_nats_tss = create(None, 'tss', fast_mode=True, verbose=True) + + valid_acc_12e = [] + test_acc_12e = [] + test_acc_200e = [] + for index in range(10000): + info = api_nats_tss.get_more_info(index, 'ImageNet16-120', hp='12') + valid_acc_12e.append(info['valid-accuracy']) + test_acc_12e.append(info['test-accuracy']) + info = api_nats_tss.get_more_info(index, 'ImageNet16-120', hp='200') + test_acc_200e.append(info['test-accuracy']) # which I reported. diff --git a/lib/datasets/DownsampledImageNet.py b/lib/datasets/DownsampledImageNet.py index 970336e..26eed32 100644 --- a/lib/datasets/DownsampledImageNet.py +++ b/lib/datasets/DownsampledImageNet.py @@ -92,6 +92,10 @@ class ImageNet16(data.Dataset): #std_data = np.mean(np.mean(std_data, axis=0), axis=0) #print ('Std : {:}'.format(std_data)) + def __repr__(self): + return ('{name}({num} images, {classes} classes)'.format(name=self.__class__.__name__, num=len(self.data), classes=len(set(self.targets)))) + + def __getitem__(self, index): img, target = self.data[index], self.targets[index] - 1 @@ -114,16 +118,16 @@ class ImageNet16(data.Dataset): return False return True -# +""" if __name__ == '__main__': - train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None) - valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None) + train = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None) + valid = ImageNet16('~/.torch/cifar.python/ImageNet16', False, None) print ( len(train) ) print ( len(valid) ) image, label = train[111] - trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200) - validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200) + trainX = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None, 200) + validX = ImageNet16('~/.torch/cifar.python/ImageNet16', False , None, 200) print ( len(trainX) ) print ( len(validX) ) - #import pdb; pdb.set_trace() +""" diff --git a/lib/nats_bench/api_utils.py b/lib/nats_bench/api_utils.py index 433a1aa..8e63446 100644 --- a/lib/nats_bench/api_utils.py +++ b/lib/nats_bench/api_utils.py @@ -482,6 +482,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): best_index, highest_accuracy = -1, None evaluated_indexes = sorted(list(self.evaluated_indexes)) for arch_index in evaluated_indexes: + self._prepare_info(arch_index) arch_info = self.arch2infos_dict[arch_index][hp] info = arch_info.get_compute_costs(dataset) # the information of costs flop, param, latency = info['flops'], info['params'], info['latency'] @@ -622,6 +623,8 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): print('<' * 40 + '------------' + '<' * 40) else: if 0 <= index < len(self.meta_archs): + if index not in self.evaluated_indexes: + self._prepare_info(index) if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated ' 'or not saved.'.format(index))