Update find_best API
This commit is contained in:
parent
a9eec30b05
commit
8949d0b18e
47
exps/show-dataset.py
Normal file
47
exps/show-dataset.py
Normal file
@ -0,0 +1,47 @@
|
||||
##############################################################################
|
||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
|
||||
##############################################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
|
||||
##############################################################################
|
||||
# python ./exps/NATS-Bench/main-tss.py --mode meta #
|
||||
##############################################################################
|
||||
import os, sys, time, torch, random, argparse
|
||||
from typing import List, Text, Dict, Any
|
||||
from PIL import ImageFile
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import dict2config, load_config
|
||||
from datasets import get_datasets
|
||||
from nats_bench import create
|
||||
|
||||
|
||||
def show_imagenet_16_120(dataset_dir=None):
|
||||
if dataset_dir is None:
|
||||
torch_home_dir = os.environ['TORCH_HOME'] if 'TORCH_HOME' in os.environ else os.path.join(os.environ['HOME'], '.torch')
|
||||
dataset_dir = os.path.join(torch_home_dir, 'cifar.python', 'ImageNet16')
|
||||
train_data, valid_data, xshape, class_num = get_datasets('ImageNet16-120', dataset_dir, -1)
|
||||
split_info = load_config('configs/nas-benchmark/ImageNet16-120-split.txt', None, None)
|
||||
print('=' * 10 + ' ImageNet-16-120 ' + '=' * 10)
|
||||
print('Training Data: {:}'.format(train_data))
|
||||
print('Evaluation Data: {:}'.format(valid_data))
|
||||
print('Hold-out training: {:} images.'.format(len(split_info.train)))
|
||||
print('Hold-out valid : {:} images.'.format(len(split_info.valid)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# show_imagenet_16_120()
|
||||
api_nats_tss = create(None, 'tss', fast_mode=True, verbose=True)
|
||||
|
||||
valid_acc_12e = []
|
||||
test_acc_12e = []
|
||||
test_acc_200e = []
|
||||
for index in range(10000):
|
||||
info = api_nats_tss.get_more_info(index, 'ImageNet16-120', hp='12')
|
||||
valid_acc_12e.append(info['valid-accuracy'])
|
||||
test_acc_12e.append(info['test-accuracy'])
|
||||
info = api_nats_tss.get_more_info(index, 'ImageNet16-120', hp='200')
|
||||
test_acc_200e.append(info['test-accuracy']) # which I reported.
|
@ -92,6 +92,10 @@ class ImageNet16(data.Dataset):
|
||||
#std_data = np.mean(np.mean(std_data, axis=0), axis=0)
|
||||
#print ('Std : {:}'.format(std_data))
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}({num} images, {classes} classes)'.format(name=self.__class__.__name__, num=len(self.data), classes=len(set(self.targets))))
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, target = self.data[index], self.targets[index] - 1
|
||||
|
||||
@ -114,16 +118,16 @@ class ImageNet16(data.Dataset):
|
||||
return False
|
||||
return True
|
||||
|
||||
#
|
||||
"""
|
||||
if __name__ == '__main__':
|
||||
train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None)
|
||||
valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None)
|
||||
train = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None)
|
||||
valid = ImageNet16('~/.torch/cifar.python/ImageNet16', False, None)
|
||||
|
||||
print ( len(train) )
|
||||
print ( len(valid) )
|
||||
image, label = train[111]
|
||||
trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200)
|
||||
validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200)
|
||||
trainX = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None, 200)
|
||||
validX = ImageNet16('~/.torch/cifar.python/ImageNet16', False , None, 200)
|
||||
print ( len(trainX) )
|
||||
print ( len(validX) )
|
||||
#import pdb; pdb.set_trace()
|
||||
"""
|
||||
|
@ -482,6 +482,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
best_index, highest_accuracy = -1, None
|
||||
evaluated_indexes = sorted(list(self.evaluated_indexes))
|
||||
for arch_index in evaluated_indexes:
|
||||
self._prepare_info(arch_index)
|
||||
arch_info = self.arch2infos_dict[arch_index][hp]
|
||||
info = arch_info.get_compute_costs(dataset) # the information of costs
|
||||
flop, param, latency = info['flops'], info['params'], info['latency']
|
||||
@ -622,6 +623,8 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
print('<' * 40 + '------------' + '<' * 40)
|
||||
else:
|
||||
if 0 <= index < len(self.meta_archs):
|
||||
if index not in self.evaluated_indexes:
|
||||
self._prepare_info(index)
|
||||
if index not in self.evaluated_indexes:
|
||||
print('The {:}-th architecture has not been evaluated '
|
||||
'or not saved.'.format(index))
|
||||
|
Loading…
Reference in New Issue
Block a user