Update NATS (sss) algorithms -- warmup
This commit is contained in:
		| @@ -1,6 +1,11 @@ | |||||||
| ################################################## | ################################################## | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||||
| ###################################################################################### | ###################################################################################### | ||||||
|  | # In this file, we aims to evaluate three kinds of channel searching strategies: | ||||||
|  | # -  | ||||||
|  | #### | ||||||
|  | # python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio 0.25 | ||||||
|  | #### | ||||||
| # python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777 | # python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777 | ||||||
| # python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777 | # python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777 | ||||||
| # python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777 | # python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777 | ||||||
| @@ -51,7 +56,7 @@ class ExponentialMovingAverage(object): | |||||||
| RL_BASELINE_EMA = ExponentialMovingAverage(0.95) | RL_BASELINE_EMA = ExponentialMovingAverage(0.95) | ||||||
|  |  | ||||||
|  |  | ||||||
| def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, algo, epoch_str, print_freq, logger): | def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, enable_controller, algo, epoch_str, print_freq, logger): | ||||||
|   data_time, batch_time = AverageMeter(), AverageMeter() |   data_time, batch_time = AverageMeter(), AverageMeter() | ||||||
|   base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() |   base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||||
|   arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() |   arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||||
| @@ -80,6 +85,7 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer | |||||||
|  |  | ||||||
|     # update the architecture-weight |     # update the architecture-weight | ||||||
|     network.zero_grad() |     network.zero_grad() | ||||||
|  |     a_optimizer.zero_grad() | ||||||
|     _, logits, log_probs = network(arch_inputs) |     _, logits, log_probs = network(arch_inputs) | ||||||
|     arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) |     arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) | ||||||
|     if algo == 'tunas': |     if algo == 'tunas': | ||||||
| @@ -92,6 +98,7 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer | |||||||
|       arch_loss = criterion(logits, arch_targets) |       arch_loss = criterion(logits, arch_targets) | ||||||
|     else: |     else: | ||||||
|       raise ValueError('invalid algorightm name: {:}'.format(algo)) |       raise ValueError('invalid algorightm name: {:}'.format(algo)) | ||||||
|  |     if enable_controller: | ||||||
|       arch_loss.backward() |       arch_loss.backward() | ||||||
|       a_optimizer.step() |       a_optimizer.step() | ||||||
|     # record |     # record | ||||||
| @@ -208,13 +215,22 @@ def main(xargs): | |||||||
|     w_scheduler.update(epoch, 0.0) |     w_scheduler.update(epoch, 0.0) | ||||||
|     need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.val * (total_epoch-epoch), True)) |     need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.val * (total_epoch-epoch), True)) | ||||||
|     epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) |     epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) | ||||||
|     logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) |  | ||||||
|  |     if xargs.warmup_ratio is None or xargs.warmup_ratio <= float(epoch) / total_epoch: | ||||||
|  |       enable_controller = True | ||||||
|  |       network.set_warmup_ratio(None) | ||||||
|  |     else: | ||||||
|  |       enable_controller = False | ||||||
|  |       network.set_warmup_ratio(1.0 - float(epoch) / total_epoch / xargs.warmup_ratio) | ||||||
|  |  | ||||||
|  |     logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller)) | ||||||
|  |  | ||||||
|     if xargs.algo == 'fbv2' or xargs.algo == 'tas': |     if xargs.algo == 'fbv2' or xargs.algo == 'tas': | ||||||
|       network.set_tau(xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1)) |       network.set_tau(xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1)) | ||||||
|       logger.log('[RESET tau as : {:}]'.format(network.tau)) |       logger.log('[RESET tau as : {:}]'.format(network.tau)) | ||||||
|     search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ |     search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ | ||||||
|                 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, xargs.algo, epoch_str, xargs.print_freq, logger) |                 = search_func(search_loader, network, criterion, w_scheduler, | ||||||
|  |                               w_optimizer, a_optimizer, enable_controller, xargs.algo, epoch_str, xargs.print_freq, logger) | ||||||
|     search_time.update(time.time() - start_time) |     search_time.update(time.time() - start_time) | ||||||
|     logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) |     logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) | ||||||
|     logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) |     logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) | ||||||
| @@ -275,6 +291,8 @@ if __name__ == '__main__': | |||||||
|   # FOR GDAS |   # FOR GDAS | ||||||
|   parser.add_argument('--tau_min',            type=float, default=0.1,  help='The minimum tau for Gumbel Softmax.') |   parser.add_argument('--tau_min',            type=float, default=0.1,  help='The minimum tau for Gumbel Softmax.') | ||||||
|   parser.add_argument('--tau_max',            type=float, default=10,   help='The maximum tau for Gumbel Softmax.') |   parser.add_argument('--tau_max',            type=float, default=10,   help='The maximum tau for Gumbel Softmax.') | ||||||
|  |   # FOR ALL | ||||||
|  |   parser.add_argument('--warmup_ratio',       type=float,               help='The warmup ratio, if None, not use warmup.') | ||||||
|   # |   # | ||||||
|   parser.add_argument('--track_running_stats',type=int,   default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') |   parser.add_argument('--track_running_stats',type=int,   default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') | ||||||
|   parser.add_argument('--affine'      ,       type=int,   default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.') |   parser.add_argument('--affine'      ,       type=int,   default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.') | ||||||
| @@ -291,7 +309,7 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--rand_seed',          type=int,   help='manual seed') |   parser.add_argument('--rand_seed',          type=int,   help='manual seed') | ||||||
|   args = parser.parse_args() |   args = parser.parse_args() | ||||||
|   if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) |   if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) | ||||||
|   dirname = '{:}-affine{:}_BN{:}-AWD{:}'.format(args.algo, args.affine, args.track_running_stats, args.arch_weight_decay) |   dirname = '{:}-affine{:}_BN{:}-AWD{:}-WARM{:}'.format(args.algo, args.affine, args.track_running_stats, args.arch_weight_decay, args.warmup_ratio) | ||||||
|   if args.overwite_epochs is not None: |   if args.overwite_epochs is not None: | ||||||
|     dirname = dirname + '-E{:}'.format(args.overwite_epochs) |     dirname = dirname + '-E{:}'.format(args.overwite_epochs) | ||||||
|   args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, dirname) |   args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, dirname) | ||||||
|   | |||||||
| @@ -26,7 +26,7 @@ from nats_bench import create | |||||||
| from log_utils import time_string | from log_utils import time_string | ||||||
|  |  | ||||||
|  |  | ||||||
| def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): | def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-AWD0.0-WARMNone'): | ||||||
|   ss_dir = '{:}-{:}'.format(root_dir, search_space) |   ss_dir = '{:}-{:}'.format(root_dir, search_space) | ||||||
|   alg2name, alg2path = OrderedDict(), OrderedDict() |   alg2name, alg2path = OrderedDict(), OrderedDict() | ||||||
|   seeds = [777, 888, 999] |   seeds = [777, 888, 999] | ||||||
| @@ -39,9 +39,9 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): | |||||||
|     alg2name['ENAS'] = 'enas-affine0_BN0-None' |     alg2name['ENAS'] = 'enas-affine0_BN0-None' | ||||||
|     alg2name['SETN'] = 'setn-affine0_BN0-None' |     alg2name['SETN'] = 'setn-affine0_BN0-None' | ||||||
|   else: |   else: | ||||||
|     alg2name['TAS'] = 'tas-affine0_BN0' |     alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix) | ||||||
|     alg2name['FBNetV2'] = 'fbv2-affine0_BN0' |     alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix) | ||||||
|     alg2name['TuNAS'] = 'tunas-affine0_BN0' |     alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix) | ||||||
|   for alg, name in alg2name.items(): |   for alg, name in alg2name.items(): | ||||||
|     alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth') |     alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth') | ||||||
|   alg2data = OrderedDict() |   alg2data = OrderedDict() | ||||||
|   | |||||||
| @@ -1,6 +1,10 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||||
| ##################################################### | ##################################################### | ||||||
|  | # Here, we utilized three techniques to search for the number of channels: | ||||||
|  | # - feature interpaltion from "Network Pruning via Transformable Architecture Search, NeurIPS 2019" | ||||||
|  | # - masking + GumbelSoftmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020" | ||||||
|  | # - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020" | ||||||
| from typing import List, Text, Any | from typing import List, Text, Any | ||||||
| import random, torch | import random, torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| @@ -43,6 +47,7 @@ class GenericNAS301Model(nn.Module): | |||||||
|     # algorithm related |     # algorithm related | ||||||
|     self.register_buffer('_tau', torch.zeros(1)) |     self.register_buffer('_tau', torch.zeros(1)) | ||||||
|     self._algo        = None |     self._algo        = None | ||||||
|  |     self._warmup_ratio = None | ||||||
|  |  | ||||||
|   def set_algo(self, algo: Text): |   def set_algo(self, algo: Text): | ||||||
|     # used for searching |     # used for searching | ||||||
| @@ -62,6 +67,13 @@ class GenericNAS301Model(nn.Module): | |||||||
|   def set_tau(self, tau): |   def set_tau(self, tau): | ||||||
|     self._tau.data[:] = tau |     self._tau.data[:] = tau | ||||||
|  |  | ||||||
|  |   @property | ||||||
|  |   def warmup_ratio(self): | ||||||
|  |     return self._warmup_ratio | ||||||
|  |  | ||||||
|  |   def set_warmup_ratio(self, ratio: float): | ||||||
|  |     self._warmup_ratio = ratio | ||||||
|  |  | ||||||
|   @property |   @property | ||||||
|   def weights(self): |   def weights(self): | ||||||
|     xlist = list(self._cells.parameters()) |     xlist = list(self._cells.parameters()) | ||||||
| @@ -112,7 +124,13 @@ class GenericNAS301Model(nn.Module): | |||||||
|       feature = cell(feature) |       feature = cell(feature) | ||||||
|       # apply different searching algorithms |       # apply different searching algorithms | ||||||
|       idx = max(0, i-1) |       idx = max(0, i-1) | ||||||
|       if self._algo == 'fbv2': |       if self._warmup_ratio is not None: | ||||||
|  |         if random.random() < self._warmup_ratio: | ||||||
|  |           mask = self._masks[-1] | ||||||
|  |         else: | ||||||
|  |           mask = self._masks[random.randint(0, len(self._masks)-1)] | ||||||
|  |         feature = feature * mask.view(1, -1, 1, 1) | ||||||
|  |       elif self._algo == 'fbv2': | ||||||
|         weights = nn.functional.gumbel_softmax(self._arch_parameters[idx:idx+1], tau=self.tau, dim=-1) |         weights = nn.functional.gumbel_softmax(self._arch_parameters[idx:idx+1], tau=self.tau, dim=-1) | ||||||
|         mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1) |         mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1) | ||||||
|         feature = feature * mask |         feature = feature * mask | ||||||
|   | |||||||
							
								
								
									
										30
									
								
								scripts-search/NATS/search-size.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								scripts-search/NATS/search-size.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | |||||||
|  | #!/bin/bash | ||||||
|  | # bash ./NATS/search-size.sh 0 777 | ||||||
|  | echo script name: $0 | ||||||
|  | echo $# arguments | ||||||
|  | if [ "$#" -ne 2 ] ;then | ||||||
|  |   echo "Input illegal number of parameters " $# | ||||||
|  |   echo "Need 2 parameters for GPU-device and seed" | ||||||
|  |   exit 1 | ||||||
|  | fi | ||||||
|  | if [ "$TORCH_HOME" = "" ]; then | ||||||
|  |   echo "Must set TORCH_HOME envoriment variable for data dir saving" | ||||||
|  |   exit 1 | ||||||
|  | else | ||||||
|  |   echo "TORCH_HOME : $TORCH_HOME" | ||||||
|  | fi | ||||||
|  |  | ||||||
|  | device=$1 | ||||||
|  | seed=$2 | ||||||
|  |  | ||||||
|  | CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed ${seed} | ||||||
|  | CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed ${seed} | ||||||
|  | CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed ${seed} | ||||||
|  |  | ||||||
|  | CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed ${seed} | ||||||
|  | CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed ${seed} | ||||||
|  | CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed ${seed} | ||||||
|  |  | ||||||
|  | CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed ${seed} | ||||||
|  | CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed ${seed} | ||||||
|  | CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed ${seed} | ||||||
		Reference in New Issue
	
	Block a user