Update WW
This commit is contained in:
parent
1fcde3e8ac
commit
ac08b7be1a
@ -3,111 +3,19 @@
|
|||||||
########################################################
|
########################################################
|
||||||
# python exps/NAS-Bench-201/test-correlation.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
|
# python exps/NAS-Bench-201/test-correlation.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
|
||||||
########################################################
|
########################################################
|
||||||
import os, sys, time, glob, random, argparse
|
import sys, argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||||
from config_utils import load_config, dict2config, configure2str
|
from log_utils import time_string
|
||||||
from datasets import get_datasets, SearchDataset
|
from models import CellStructure
|
||||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
|
||||||
from utils import get_model_infos, obtain_accuracy
|
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
|
||||||
from models import get_cell_based_tiny_net, get_search_spaces, CellStructure
|
|
||||||
from nas_201_api import NASBench201API as API
|
from nas_201_api import NASBench201API as API
|
||||||
|
|
||||||
|
|
||||||
def valid_func(xloader, network, criterion):
|
|
||||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
|
||||||
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
|
||||||
network.eval()
|
|
||||||
end = time.time()
|
|
||||||
with torch.no_grad():
|
|
||||||
for step, (arch_inputs, arch_targets) in enumerate(xloader):
|
|
||||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
|
||||||
# measure data loading time
|
|
||||||
data_time.update(time.time() - end)
|
|
||||||
# prediction
|
|
||||||
_, logits = network(arch_inputs)
|
|
||||||
arch_loss = criterion(logits, arch_targets)
|
|
||||||
# record
|
|
||||||
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
|
||||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
|
||||||
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
|
|
||||||
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
|
|
||||||
# measure elapsed time
|
|
||||||
batch_time.update(time.time() - end)
|
|
||||||
end = time.time()
|
|
||||||
return arch_losses.avg, arch_top1.avg, arch_top5.avg
|
|
||||||
|
|
||||||
|
|
||||||
def main(xargs):
|
|
||||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
|
||||||
torch.backends.cudnn.enabled = True
|
|
||||||
torch.backends.cudnn.benchmark = False
|
|
||||||
torch.backends.cudnn.deterministic = True
|
|
||||||
torch.set_num_threads( xargs.workers )
|
|
||||||
prepare_seed(xargs.rand_seed)
|
|
||||||
logger = prepare_logger(args)
|
|
||||||
|
|
||||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
|
||||||
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
|
||||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
|
||||||
cifar_split = load_config(split_Fpath, None, None)
|
|
||||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
|
||||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
|
||||||
elif xargs.dataset.startswith('ImageNet16'):
|
|
||||||
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
|
||||||
imagenet16_split = load_config(split_Fpath, None, None)
|
|
||||||
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
|
||||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
|
||||||
else:
|
|
||||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
|
||||||
config_path = 'configs/nas-benchmark/algos/DARTS.config'
|
|
||||||
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
|
||||||
# To split data
|
|
||||||
train_data_v2 = deepcopy(train_data)
|
|
||||||
train_data_v2.transform = valid_data.transform
|
|
||||||
valid_data = train_data_v2
|
|
||||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
|
||||||
# data loader
|
|
||||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
|
||||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
|
||||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
|
||||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
|
||||||
|
|
||||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
|
||||||
model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells,
|
|
||||||
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
|
||||||
'space' : search_space}, None)
|
|
||||||
search_model = get_cell_based_tiny_net(model_config)
|
|
||||||
logger.log('search-model :\n{:}'.format(search_model))
|
|
||||||
|
|
||||||
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
|
|
||||||
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
|
|
||||||
logger.log('w-optimizer : {:}'.format(w_optimizer))
|
|
||||||
logger.log('a-optimizer : {:}'.format(a_optimizer))
|
|
||||||
logger.log('w-scheduler : {:}'.format(w_scheduler))
|
|
||||||
logger.log('criterion : {:}'.format(criterion))
|
|
||||||
flop, param = get_model_infos(search_model, xshape)
|
|
||||||
#logger.log('{:}'.format(search_model))
|
|
||||||
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
|
|
||||||
if xargs.arch_nas_dataset is None:
|
|
||||||
api = None
|
|
||||||
else:
|
|
||||||
api = API(xargs.arch_nas_dataset)
|
|
||||||
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
|
||||||
|
|
||||||
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
|
||||||
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
|
|
||||||
|
|
||||||
logger.close()
|
|
||||||
|
|
||||||
|
|
||||||
def check_unique_arch(meta_file):
|
def check_unique_arch(meta_file):
|
||||||
api = API(str(meta_file))
|
api = API(str(meta_file))
|
||||||
arch_strs = deepcopy(api.meta_archs)
|
arch_strs = deepcopy(api.meta_archs)
|
||||||
|
36
exps/NAS-Bench-201/test-weights.py
Normal file
36
exps/NAS-Bench-201/test-weights.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||||
|
########################################################
|
||||||
|
# python exps/NAS-Bench-201/test-weights.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
|
||||||
|
########################################################
|
||||||
|
import os, sys, time, glob, random, argparse
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||||
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||||
|
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||||
|
from nas_201_api import NASBench201API as API
|
||||||
|
from utils import weight_watcher
|
||||||
|
|
||||||
|
|
||||||
|
def main(meta_file, weight_dir, save_dir):
|
||||||
|
import pdb;
|
||||||
|
pdb.set_trace()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
|
||||||
|
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.')
|
||||||
|
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file.')
|
||||||
|
parser.add_argument('--weight_dir', type=str, default=None, help='The directory path to the weights of every NAS-Bench-201 architecture.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
save_dir = Path(args.save_dir)
|
||||||
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
meta_file = Path(args.api_path)
|
||||||
|
weight_dir = Path(args.weight_dir)
|
||||||
|
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
|
||||||
|
|
||||||
|
main(meta_file, weight_dir, save_dir)
|
||||||
|
|
@ -9,12 +9,23 @@ from utils import weight_watcher
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
model = models.vgg19_bn(pretrained=True)
|
# model = models.vgg19_bn(pretrained=True)
|
||||||
_, summary = weight_watcher.analyze(model, alphas=False)
|
# _, summary = weight_watcher.analyze(model, alphas=False)
|
||||||
# print(summary)
|
# for key, value in summary.items():
|
||||||
for key, value in summary.items():
|
# print('{:10s} : {:}'.format(key, value))
|
||||||
print('{:10s} : {:}'.format(key, value))
|
|
||||||
# import pdb; pdb.set_trace()
|
_, summary = weight_watcher.analyze(models.vgg13(pretrained=True), alphas=False)
|
||||||
|
print('vgg-13 : {:}'.format(summary['lognorm']))
|
||||||
|
_, summary = weight_watcher.analyze(models.vgg13_bn(pretrained=True), alphas=False)
|
||||||
|
print('vgg-13-BN : {:}'.format(summary['lognorm']))
|
||||||
|
_, summary = weight_watcher.analyze(models.vgg16(pretrained=True), alphas=False)
|
||||||
|
print('vgg-16 : {:}'.format(summary['lognorm']))
|
||||||
|
_, summary = weight_watcher.analyze(models.vgg16_bn(pretrained=True), alphas=False)
|
||||||
|
print('vgg-16-BN : {:}'.format(summary['lognorm']))
|
||||||
|
_, summary = weight_watcher.analyze(models.vgg19(pretrained=True), alphas=False)
|
||||||
|
print('vgg-19 : {:}'.format(summary['lognorm']))
|
||||||
|
_, summary = weight_watcher.analyze(models.vgg19_bn(pretrained=True), alphas=False)
|
||||||
|
print('vgg-19-BN : {:}'.format(summary['lognorm']))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -304,7 +304,7 @@ def analyze(model: nn.Module, min_size=50, max_size=0,
|
|||||||
if isinstance(module, available_module_types()):
|
if isinstance(module, available_module_types()):
|
||||||
names.append(name)
|
names.append(name)
|
||||||
modules.append(module)
|
modules.append(module)
|
||||||
print('There are {:} layers to be analyzed in this model.'.format(len(modules)))
|
# print('There are {:} layers to be analyzed in this model.'.format(len(modules)))
|
||||||
all_results = OrderedDict()
|
all_results = OrderedDict()
|
||||||
for index, module in enumerate(modules):
|
for index, module in enumerate(modules):
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
|
Loading…
Reference in New Issue
Block a user