Update WW
This commit is contained in:
		| @@ -3,110 +3,18 @@ | ||||
| ######################################################## | ||||
| # python exps/NAS-Bench-201/test-correlation.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth | ||||
| ######################################################## | ||||
| import os, sys, time, glob, random, argparse | ||||
| import sys, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| from tqdm import tqdm | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import load_config, dict2config, configure2str | ||||
| from datasets     import get_datasets, SearchDataset | ||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net, get_search_spaces, CellStructure | ||||
| from log_utils    import time_string | ||||
| from models       import CellStructure | ||||
| from nas_201_api  import NASBench201API as API | ||||
|  | ||||
|    | ||||
| def valid_func(xloader, network, criterion): | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   network.eval() | ||||
|   end = time.time() | ||||
|   with torch.no_grad(): | ||||
|     for step, (arch_inputs, arch_targets) in enumerate(xloader): | ||||
|       arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|       # measure data loading time | ||||
|       data_time.update(time.time() - end) | ||||
|       # prediction | ||||
|       _, logits = network(arch_inputs) | ||||
|       arch_loss = criterion(logits, arch_targets) | ||||
|       # record | ||||
|       arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) | ||||
|       arch_losses.update(arch_loss.item(),  arch_inputs.size(0)) | ||||
|       arch_top1.update  (arch_prec1.item(), arch_inputs.size(0)) | ||||
|       arch_top5.update  (arch_prec5.item(), arch_inputs.size(0)) | ||||
|       # measure elapsed time | ||||
|       batch_time.update(time.time() - end) | ||||
|       end = time.time() | ||||
|   return arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||
|  | ||||
|  | ||||
| def main(xargs): | ||||
|   assert torch.cuda.is_available(), 'CUDA is not available.' | ||||
|   torch.backends.cudnn.enabled   = True | ||||
|   torch.backends.cudnn.benchmark = False | ||||
|   torch.backends.cudnn.deterministic = True | ||||
|   torch.set_num_threads( xargs.workers ) | ||||
|   prepare_seed(xargs.rand_seed) | ||||
|   logger = prepare_logger(args) | ||||
|  | ||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|   if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': | ||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|     cifar_split = load_config(split_Fpath, None, None) | ||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|   elif xargs.dataset.startswith('ImageNet16'): | ||||
|     split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset) | ||||
|     imagenet16_split = load_config(split_Fpath, None, None) | ||||
|     train_split, valid_split = imagenet16_split.train, imagenet16_split.valid | ||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|   else: | ||||
|     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) | ||||
|   config_path = 'configs/nas-benchmark/algos/DARTS.config' | ||||
|   config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|   # To split data | ||||
|   train_data_v2 = deepcopy(train_data) | ||||
|   train_data_v2.transform = valid_data.transform | ||||
|   valid_data    = train_data_v2 | ||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|   # data loader | ||||
|   search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) | ||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) | ||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|  | ||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||
|   model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells, | ||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||
|                               'space'    : search_space}, None) | ||||
|   search_model = get_cell_based_tiny_net(model_config) | ||||
|   logger.log('search-model :\n{:}'.format(search_model)) | ||||
|    | ||||
|   w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) | ||||
|   a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) | ||||
|   logger.log('w-optimizer : {:}'.format(w_optimizer)) | ||||
|   logger.log('a-optimizer : {:}'.format(a_optimizer)) | ||||
|   logger.log('w-scheduler : {:}'.format(w_scheduler)) | ||||
|   logger.log('criterion   : {:}'.format(criterion)) | ||||
|   flop, param  = get_model_infos(search_model, xshape) | ||||
|   #logger.log('{:}'.format(search_model)) | ||||
|   logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) | ||||
|   if xargs.arch_nas_dataset is None: | ||||
|     api = None | ||||
|   else: | ||||
|     api = API(xargs.arch_nas_dataset) | ||||
|   logger.log('{:} create API = {:} done'.format(time_string(), api)) | ||||
|  | ||||
|   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') | ||||
|   network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() | ||||
|  | ||||
|   logger.close() | ||||
|    | ||||
|  | ||||
| def check_unique_arch(meta_file): | ||||
|   api = API(str(meta_file)) | ||||
|   | ||||
							
								
								
									
										36
									
								
								exps/NAS-Bench-201/test-weights.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								exps/NAS-Bench-201/test-weights.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,36 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||
| ######################################################## | ||||
| # python exps/NAS-Bench-201/test-weights.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth | ||||
| ######################################################## | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| import torch | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||
| from nas_201_api  import NASBench201API as API | ||||
| from utils import weight_watcher | ||||
|  | ||||
|  | ||||
| def main(meta_file, weight_dir, save_dir): | ||||
|   import pdb; | ||||
|   pdb.set_trace() | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") | ||||
|   parser.add_argument('--save_dir',   type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.') | ||||
|   parser.add_argument('--api_path',   type=str, default=None, help='The path to the NAS-Bench-201 benchmark file.') | ||||
|   parser.add_argument('--weight_dir', type=str, default=None, help='The directory path to the weights of every NAS-Bench-201 architecture.') | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   save_dir = Path(args.save_dir) | ||||
|   save_dir.mkdir(parents=True, exist_ok=True) | ||||
|   meta_file = Path(args.api_path) | ||||
|   weight_dir = Path(args.weight_dir) | ||||
|   assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) | ||||
|  | ||||
|   main(meta_file, weight_dir, save_dir) | ||||
|  | ||||
| @@ -9,12 +9,23 @@ from utils import weight_watcher | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|   model = models.vgg19_bn(pretrained=True) | ||||
|   _, summary = weight_watcher.analyze(model, alphas=False) | ||||
|   # print(summary) | ||||
|   for key, value in summary.items(): | ||||
|     print('{:10s} : {:}'.format(key, value)) | ||||
|   # import pdb; pdb.set_trace() | ||||
|   # model = models.vgg19_bn(pretrained=True) | ||||
|   # _, summary = weight_watcher.analyze(model, alphas=False) | ||||
|   # for key, value in summary.items(): | ||||
|   #   print('{:10s} : {:}'.format(key, value)) | ||||
|  | ||||
|   _, summary = weight_watcher.analyze(models.vgg13(pretrained=True), alphas=False) | ||||
|   print('vgg-13 : {:}'.format(summary['lognorm'])) | ||||
|   _, summary = weight_watcher.analyze(models.vgg13_bn(pretrained=True), alphas=False) | ||||
|   print('vgg-13-BN : {:}'.format(summary['lognorm'])) | ||||
|   _, summary = weight_watcher.analyze(models.vgg16(pretrained=True), alphas=False) | ||||
|   print('vgg-16 : {:}'.format(summary['lognorm'])) | ||||
|   _, summary = weight_watcher.analyze(models.vgg16_bn(pretrained=True), alphas=False) | ||||
|   print('vgg-16-BN : {:}'.format(summary['lognorm'])) | ||||
|   _, summary = weight_watcher.analyze(models.vgg19(pretrained=True), alphas=False) | ||||
|   print('vgg-19 : {:}'.format(summary['lognorm'])) | ||||
|   _, summary = weight_watcher.analyze(models.vgg19_bn(pretrained=True), alphas=False) | ||||
|   print('vgg-19-BN : {:}'.format(summary['lognorm'])) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   | ||||
| @@ -304,7 +304,7 @@ def analyze(model: nn.Module, min_size=50, max_size=0, | ||||
|     if isinstance(module, available_module_types()): | ||||
|       names.append(name) | ||||
|       modules.append(module) | ||||
|   print('There are {:} layers to be analyzed in this model.'.format(len(modules))) | ||||
|   # print('There are {:} layers to be analyzed in this model.'.format(len(modules))) | ||||
|   all_results = OrderedDict() | ||||
|   for index, module in enumerate(modules): | ||||
|     if isinstance(module, nn.Linear): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user