update
This commit is contained in:
494
nasbench201/utils.py
Normal file
494
nasbench201/utils.py
Normal file
@@ -0,0 +1,494 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import os.path
|
||||
import sys
|
||||
import shutil
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from PIL import Image
|
||||
from torch.autograd import Variable
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.datasets import utils
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
|
||||
|
||||
class AvgrageMeter(object):
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.cnt = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.sum += val * n
|
||||
self.cnt += n
|
||||
self.avg = self.sum / self.cnt
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].contiguous().view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
class Cutout(object):
|
||||
def __init__(self, length, prob=1.0):
|
||||
self.length = length
|
||||
self.prob = prob
|
||||
|
||||
def __call__(self, img):
|
||||
if np.random.binomial(1, self.prob):
|
||||
h, w = img.size(1), img.size(2)
|
||||
mask = np.ones((h, w), np.float32)
|
||||
y = np.random.randint(h)
|
||||
x = np.random.randint(w)
|
||||
|
||||
y1 = np.clip(y - self.length // 2, 0, h)
|
||||
y2 = np.clip(y + self.length // 2, 0, h)
|
||||
x1 = np.clip(x - self.length // 2, 0, w)
|
||||
x2 = np.clip(x + self.length // 2, 0, w)
|
||||
|
||||
mask[y1: y2, x1: x2] = 0.
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.expand_as(img)
|
||||
img *= mask
|
||||
return img
|
||||
|
||||
def _data_transforms_svhn(args):
|
||||
SVHN_MEAN = [0.4377, 0.4438, 0.4728]
|
||||
SVHN_STD = [0.1980, 0.2010, 0.1970]
|
||||
|
||||
train_transform = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(SVHN_MEAN, SVHN_STD),
|
||||
])
|
||||
if args.cutout:
|
||||
train_transform.transforms.append(Cutout(args.cutout_length,
|
||||
args.cutout_prob))
|
||||
|
||||
valid_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(SVHN_MEAN, SVHN_STD),
|
||||
])
|
||||
return train_transform, valid_transform
|
||||
|
||||
|
||||
def _data_transforms_cifar100(args):
|
||||
CIFAR_MEAN = [0.5071, 0.4865, 0.4409]
|
||||
CIFAR_STD = [0.2673, 0.2564, 0.2762]
|
||||
|
||||
train_transform = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
|
||||
])
|
||||
if args.cutout:
|
||||
train_transform.transforms.append(Cutout(args.cutout_length,
|
||||
args.cutout_prob))
|
||||
|
||||
valid_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
|
||||
])
|
||||
return train_transform, valid_transform
|
||||
|
||||
|
||||
def _data_transforms_cifar10(args):
|
||||
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
|
||||
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
|
||||
|
||||
train_transform = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
|
||||
])
|
||||
if args.cutout:
|
||||
train_transform.transforms.append(Cutout(args.cutout_length,
|
||||
args.cutout_prob))
|
||||
|
||||
valid_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
|
||||
])
|
||||
return train_transform, valid_transform
|
||||
|
||||
|
||||
def count_parameters_in_MB(model):
|
||||
return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6
|
||||
|
||||
|
||||
def count_parameters_in_Compact(model):
|
||||
from sota.cnn.model import Network as CompactModel
|
||||
genotype = model.genotype()
|
||||
compact_model = CompactModel(36, model._num_classes, 20, True, genotype)
|
||||
num_params = count_parameters_in_MB(compact_model)
|
||||
return num_params
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best, save, per_epoch=False, prefix=''):
|
||||
filename = prefix
|
||||
if per_epoch:
|
||||
epoch = state['epoch']
|
||||
filename += 'checkpoint_{}.pth.tar'.format(epoch)
|
||||
else:
|
||||
filename += 'checkpoint.pth.tar'
|
||||
filename = os.path.join(save, filename)
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
best_filename = os.path.join(save, 'model_best.pth.tar')
|
||||
shutil.copyfile(filename, best_filename)
|
||||
|
||||
|
||||
def load_checkpoint(model, optimizer, save, epoch=None):
|
||||
if epoch is None:
|
||||
filename = 'checkpoint.pth.tar'
|
||||
else:
|
||||
filename = 'checkpoint_{}.pth.tar'.format(epoch)
|
||||
filename = os.path.join(save, filename)
|
||||
start_epoch = 0
|
||||
if os.path.isfile(filename):
|
||||
print("=> loading checkpoint '{}'".format(filename))
|
||||
checkpoint = torch.load(filename)
|
||||
start_epoch = checkpoint['epoch']
|
||||
best_acc_top1 = checkpoint['best_acc_top1']
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
print("=> loaded checkpoint '{}' (epoch {})"
|
||||
.format(filename, checkpoint['epoch']))
|
||||
else:
|
||||
print("=> no checkpoint found at '{}'".format(filename))
|
||||
|
||||
return model, optimizer, start_epoch, best_acc_top1
|
||||
|
||||
|
||||
def save(model, model_path):
|
||||
torch.save(model.state_dict(), model_path)
|
||||
|
||||
|
||||
def load(model, model_path):
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
if drop_prob > 0.:
|
||||
keep_prob = 1. - drop_prob
|
||||
mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
|
||||
x.div_(keep_prob)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
def create_exp_dir(path, scripts_to_save=None):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
print('Experiment dir : {}'.format(path))
|
||||
|
||||
if scripts_to_save is not None:
|
||||
os.mkdir(os.path.join(path, 'scripts'))
|
||||
for script in scripts_to_save:
|
||||
dst_file = os.path.join(path, 'scripts', os.path.basename(script))
|
||||
shutil.copyfile(script, dst_file)
|
||||
|
||||
|
||||
class CIFAR10(VisionDataset):
|
||||
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory of dataset where directory
|
||||
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
|
||||
train (bool, optional): If True, creates dataset from training set, otherwise
|
||||
creates from test set.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
download (bool, optional): If true, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
|
||||
"""
|
||||
base_folder = 'cifar-10-batches-py'
|
||||
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
||||
filename = "cifar-10-python.tar.gz"
|
||||
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
|
||||
train_list = [
|
||||
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
|
||||
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
|
||||
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
|
||||
['data_batch_4', '634d18415352ddfa80567beed471001a'],
|
||||
#['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
|
||||
]
|
||||
|
||||
test_list = [
|
||||
['test_batch', '40351d587109b95175f43aff81a1287e'],
|
||||
]
|
||||
meta = {
|
||||
'filename': 'batches.meta',
|
||||
'key': 'label_names',
|
||||
'md5': '5ff9c542aee3614f3951f8cda6e48888',
|
||||
}
|
||||
|
||||
def __init__(self, root, train=True, transform=None, target_transform=None,
|
||||
download=False):
|
||||
|
||||
super(CIFAR10, self).__init__(root, transform=transform,
|
||||
target_transform=target_transform)
|
||||
|
||||
self.train = train # training set or test set
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError('Dataset not found or corrupted.' +
|
||||
' You can use download=True to download it')
|
||||
|
||||
if self.train:
|
||||
downloaded_list = self.train_list
|
||||
else:
|
||||
downloaded_list = self.test_list
|
||||
|
||||
self.data = []
|
||||
self.targets = []
|
||||
|
||||
# now load the picked numpy arrays
|
||||
for file_name, checksum in downloaded_list:
|
||||
file_path = os.path.join(self.root, self.base_folder, file_name)
|
||||
with open(file_path, 'rb') as f:
|
||||
if sys.version_info[0] == 2:
|
||||
entry = pickle.load(f)
|
||||
else:
|
||||
entry = pickle.load(f, encoding='latin1')
|
||||
self.data.append(entry['data'])
|
||||
if 'labels' in entry:
|
||||
self.targets.extend(entry['labels'])
|
||||
else:
|
||||
self.targets.extend(entry['fine_labels'])
|
||||
|
||||
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
|
||||
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
|
||||
|
||||
self._load_meta()
|
||||
|
||||
def _load_meta(self):
|
||||
path = os.path.join(self.root, self.base_folder, self.meta['filename'])
|
||||
if not utils.check_integrity(path, self.meta['md5']):
|
||||
raise RuntimeError('Dataset metadata file not found or corrupted.' +
|
||||
' You can use download=True to download it')
|
||||
with open(path, 'rb') as infile:
|
||||
if sys.version_info[0] == 2:
|
||||
data = pickle.load(infile)
|
||||
else:
|
||||
data = pickle.load(infile, encoding='latin1')
|
||||
self.classes = data[self.meta['key']]
|
||||
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
tuple: (image, target) where target is index of the target class.
|
||||
"""
|
||||
img, target = self.data[index], self.targets[index]
|
||||
|
||||
# doing this so that it is consistent with all other datasets
|
||||
# to return a PIL Image
|
||||
img = Image.fromarray(img)
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def _check_integrity(self):
|
||||
root = self.root
|
||||
for fentry in (self.train_list + self.test_list):
|
||||
filename, md5 = fentry[0], fentry[1]
|
||||
fpath = os.path.join(root, self.base_folder, filename)
|
||||
if not utils.check_integrity(fpath, md5):
|
||||
return False
|
||||
return True
|
||||
|
||||
def download(self):
|
||||
if self._check_integrity():
|
||||
print('Files already downloaded and verified')
|
||||
return
|
||||
utils.download_and_extract_archive(self.url, self.root,
|
||||
filename=self.filename,
|
||||
md5=self.tgz_md5)
|
||||
|
||||
def extra_repr(self):
|
||||
return "Split: {}".format("Train" if self.train is True else "Test")
|
||||
|
||||
|
||||
def pick_gpu_lowest_memory():
|
||||
import gpustat
|
||||
stats = gpustat.GPUStatCollection.new_query()
|
||||
ids = map(lambda gpu: int(gpu.entry['index']), stats)
|
||||
ratios = map(lambda gpu: float(gpu.memory_used)/float(gpu.memory_total), stats)
|
||||
bestGPU = min(zip(ids, ratios), key=lambda x: x[1])[0]
|
||||
return bestGPU
|
||||
|
||||
|
||||
#### early stopping (from RobustNAS)
|
||||
class EVLocalAvg(object):
|
||||
def __init__(self, window=5, ev_freq=2, total_epochs=50):
|
||||
""" Keep track of the eigenvalues local average.
|
||||
Args:
|
||||
window (int): number of elements used to compute local average.
|
||||
Default: 5
|
||||
ev_freq (int): frequency used to compute eigenvalues. Default:
|
||||
every 2 epochs
|
||||
total_epochs (int): total number of epochs that DARTS runs.
|
||||
Default: 50
|
||||
"""
|
||||
self.window = window
|
||||
self.ev_freq = ev_freq
|
||||
self.epochs = total_epochs
|
||||
|
||||
self.stop_search = False
|
||||
self.stop_epoch = total_epochs - 1
|
||||
self.stop_genotype = None
|
||||
self.stop_numparam = 0
|
||||
|
||||
self.ev = []
|
||||
self.ev_local_avg = []
|
||||
self.genotypes = {}
|
||||
self.numparams = {}
|
||||
self.la_epochs = {}
|
||||
|
||||
# start and end index of the local average window
|
||||
self.la_start_idx = 0
|
||||
self.la_end_idx = self.window
|
||||
|
||||
def reset(self):
|
||||
self.ev = []
|
||||
self.ev_local_avg = []
|
||||
self.genotypes = {}
|
||||
self.numparams = {}
|
||||
self.la_epochs = {}
|
||||
|
||||
def update(self, epoch, ev, genotype, numparam=0):
|
||||
""" Method to update the local average list.
|
||||
|
||||
Args:
|
||||
epoch (int): current epoch
|
||||
ev (float): current dominant eigenvalue
|
||||
genotype (namedtuple): current genotype
|
||||
|
||||
"""
|
||||
self.ev.append(ev)
|
||||
self.genotypes.update({epoch: genotype})
|
||||
self.numparams.update({epoch: numparam})
|
||||
# set the stop_genotype to the current genotype in case the early stop
|
||||
# procedure decides not to early stop
|
||||
self.stop_genotype = genotype
|
||||
|
||||
# since the local average computation starts after the dominant
|
||||
# eigenvalue in the first epoch is already computed we have to wait
|
||||
# at least until we have 3 eigenvalues in the list.
|
||||
if (len(self.ev) >= int(np.ceil(self.window/2))) and (epoch <
|
||||
self.epochs - 1):
|
||||
# start sliding the window as soon as the number of eigenvalues in
|
||||
# the list becomes equal to the window size
|
||||
if len(self.ev) < self.window:
|
||||
self.ev_local_avg.append(np.mean(self.ev))
|
||||
else:
|
||||
assert len(self.ev[self.la_start_idx: self.la_end_idx]) == self.window
|
||||
self.ev_local_avg.append(np.mean(self.ev[self.la_start_idx:
|
||||
self.la_end_idx]))
|
||||
self.la_start_idx += 1
|
||||
self.la_end_idx += 1
|
||||
|
||||
# keep track of the offset between the current epoch and the epoch
|
||||
# corresponding to the local average. NOTE: in the end the size of
|
||||
# self.ev and self.ev_local_avg should be equal
|
||||
self.la_epochs.update({epoch: int(epoch -
|
||||
int(self.ev_freq*np.floor(self.window/2)))})
|
||||
|
||||
elif len(self.ev) < int(np.ceil(self.window/2)):
|
||||
self.la_epochs.update({epoch: -1})
|
||||
|
||||
# since there is an offset between the current epoch and the local
|
||||
# average epoch, loop in the last epoch to compute the local average of
|
||||
# these number of elements: window, window - 1, window - 2, ..., ceil(window/2)
|
||||
elif epoch == self.epochs - 1:
|
||||
for i in range(int(np.ceil(self.window/2))):
|
||||
assert len(self.ev[self.la_start_idx: self.la_end_idx]) == self.window - i
|
||||
self.ev_local_avg.append(np.mean(self.ev[self.la_start_idx:
|
||||
self.la_end_idx + 1]))
|
||||
self.la_start_idx += 1
|
||||
|
||||
def early_stop(self, epoch, factor=1.3, es_start_epoch=10, delta=4, criteria='local_avg'):
|
||||
""" Early stopping criterion
|
||||
|
||||
Args:
|
||||
epoch (int): current epoch
|
||||
factor (float): threshold factor for the ration between the current
|
||||
and prefious eigenvalue. Default: 1.3
|
||||
es_start_epoch (int): until this epoch do not consider early
|
||||
stopping. Default: 20
|
||||
delta (int): factor influencing which previous local average we
|
||||
consider for early stopping. Default: 2
|
||||
"""
|
||||
if criteria == 'local_avg':
|
||||
if int(self.la_epochs[epoch] - self.ev_freq*delta) >= es_start_epoch:
|
||||
if criteria == 'local_avg':
|
||||
current_la = self.ev_local_avg[-1]
|
||||
previous_la = self.ev_local_avg[-1 - delta]
|
||||
self.stop_search = current_la / previous_la > factor
|
||||
if self.stop_search:
|
||||
self.stop_epoch = int(self.la_epochs[epoch] - self.ev_freq*delta)
|
||||
self.stop_genotype = self.genotypes[self.stop_epoch]
|
||||
self.stop_numparam = self.numparams[self.stop_epoch]
|
||||
elif criteria == 'exact':
|
||||
if epoch > es_start_epoch:
|
||||
current_la = self.ev[-1]
|
||||
previous_la = self.ev[-1 - delta]
|
||||
self.stop_search = current_la / previous_la > factor
|
||||
if self.stop_search:
|
||||
self.stop_epoch = epoch - delta
|
||||
self.stop_genotype = self.genotypes[self.stop_epoch]
|
||||
self.stop_numparam = self.numparams[self.stop_epoch]
|
||||
else:
|
||||
print('ERROR IN EARLY STOP: WRONG CRITERIA:', criteria); exit(0)
|
||||
|
||||
|
||||
def gen_comb(eids):
|
||||
comb = []
|
||||
for r in range(len(eids)):
|
||||
for c in range(r + 1, len(eids)):
|
||||
comb.append((eids[r], eids[c]))
|
||||
|
||||
return comb
|
Reference in New Issue
Block a user