Update TuNAS
This commit is contained in:
		| @@ -8,6 +8,10 @@ | |||||||
| # python ./exps/algos-v2/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 | # python ./exps/algos-v2/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 | ||||||
| # python ./exps/algos-v2/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 | # python ./exps/algos-v2/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 | ||||||
| # python ./exps/algos-v2/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777 | # python ./exps/algos-v2/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777 | ||||||
|  | #### | ||||||
|  | # python ./exps/algos-v2/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0 | ||||||
|  | # python ./exps/algos-v2/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 | ||||||
|  | # python ./exps/algos-v2/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed 777 | ||||||
| ###################################################################################### | ###################################################################################### | ||||||
| import os, sys, time, random, argparse | import os, sys, time, random, argparse | ||||||
| import numpy as np | import numpy as np | ||||||
| @@ -26,7 +30,28 @@ from models       import get_cell_based_tiny_net, get_search_spaces | |||||||
| from nas_201_api  import NASBench301API as API | from nas_201_api  import NASBench301API as API | ||||||
|  |  | ||||||
|  |  | ||||||
| def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): | # Ad-hoc for TuNAS | ||||||
|  | class ExponentialMovingAverage(object): | ||||||
|  |   """Class that maintains an exponential moving average.""" | ||||||
|  |  | ||||||
|  |   def __init__(self, momentum): | ||||||
|  |     self._numerator   = 0 | ||||||
|  |     self._denominator = 0 | ||||||
|  |     self._momentum    = momentum | ||||||
|  |  | ||||||
|  |   def update(self, value): | ||||||
|  |     self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value | ||||||
|  |     self._denominator = self._momentum * self._denominator + (1 - self._momentum) | ||||||
|  |  | ||||||
|  |   @property | ||||||
|  |   def value(self): | ||||||
|  |     """Return the current value of the moving average""" | ||||||
|  |     return self._numerator / self._denominator | ||||||
|  |  | ||||||
|  | RL_BASELINE_EMA = ExponentialMovingAverage(0.95) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, algo, epoch_str, print_freq, logger): | ||||||
|   data_time, batch_time = AverageMeter(), AverageMeter() |   data_time, batch_time = AverageMeter(), AverageMeter() | ||||||
|   base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() |   base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||||
|   arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() |   arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||||
| @@ -43,7 +68,7 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer | |||||||
|      |      | ||||||
|     # Update the weights |     # Update the weights | ||||||
|     network.zero_grad() |     network.zero_grad() | ||||||
|     _, logits = network(base_inputs) |     _, logits, _ = network(base_inputs) | ||||||
|     base_loss = criterion(logits, base_targets) |     base_loss = criterion(logits, base_targets) | ||||||
|     base_loss.backward() |     base_loss.backward() | ||||||
|     w_optimizer.step() |     w_optimizer.step() | ||||||
| @@ -55,12 +80,21 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer | |||||||
|  |  | ||||||
|     # update the architecture-weight |     # update the architecture-weight | ||||||
|     network.zero_grad() |     network.zero_grad() | ||||||
|     _, logits = network(arch_inputs) |     _, logits, log_probs = network(arch_inputs) | ||||||
|  |     arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) | ||||||
|  |     if algo == 'tunas': | ||||||
|  |       with torch.no_grad(): | ||||||
|  |         RL_BASELINE_EMA.update(arch_prec1.item()) | ||||||
|  |         rl_advantage = arch_prec1 - RL_BASELINE_EMA.value | ||||||
|  |       rl_log_prob = sum(log_probs) | ||||||
|  |       arch_loss = - rl_advantage * rl_log_prob | ||||||
|  |     elif algo == 'tas' or algo == 'fbv2': | ||||||
|       arch_loss = criterion(logits, arch_targets) |       arch_loss = criterion(logits, arch_targets) | ||||||
|  |     else: | ||||||
|  |       raise ValueError('invalid algorightm name: {:}'.format(algo)) | ||||||
|     arch_loss.backward() |     arch_loss.backward() | ||||||
|     a_optimizer.step() |     a_optimizer.step() | ||||||
|     # record |     # record | ||||||
|     arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) |  | ||||||
|     arch_losses.update(arch_loss.item(),  arch_inputs.size(0)) |     arch_losses.update(arch_loss.item(),  arch_inputs.size(0)) | ||||||
|     arch_top1.update  (arch_prec1.item(), arch_inputs.size(0)) |     arch_top1.update  (arch_prec1.item(), arch_inputs.size(0)) | ||||||
|     arch_top5.update  (arch_prec5.item(), arch_inputs.size(0)) |     arch_top5.update  (arch_prec5.item(), arch_inputs.size(0)) | ||||||
| @@ -78,76 +112,6 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer | |||||||
|   return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg |   return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||||
|  |  | ||||||
|  |  | ||||||
| def train_controller(xloader, network, criterion, optimizer, prev_baseline, epoch_str, print_freq, logger): |  | ||||||
|   # config. (containing some necessary arg) |  | ||||||
|   #   baseline: The baseline score (i.e. average val_acc) from the previous epoch |  | ||||||
|   data_time, batch_time = AverageMeter(), AverageMeter() |  | ||||||
|   GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), time.time() |  | ||||||
|    |  | ||||||
|   controller_num_aggregate = 20 |  | ||||||
|   controller_train_steps = 50 |  | ||||||
|   controller_bl_dec = 0.99 |  | ||||||
|   controller_entropy_weight = 0.0001 |  | ||||||
|  |  | ||||||
|   network.eval() |  | ||||||
|   network.controller.train() |  | ||||||
|   network.controller.zero_grad() |  | ||||||
|   loader_iter = iter(xloader) |  | ||||||
|   for step in range(controller_train_steps * controller_num_aggregate): |  | ||||||
|     try: |  | ||||||
|       inputs, targets = next(loader_iter) |  | ||||||
|     except: |  | ||||||
|       loader_iter = iter(xloader) |  | ||||||
|       inputs, targets = next(loader_iter) |  | ||||||
|     inputs  = inputs.cuda(non_blocking=True) |  | ||||||
|     targets = targets.cuda(non_blocking=True) |  | ||||||
|     # measure data loading time |  | ||||||
|     data_time.update(time.time() - xend) |  | ||||||
|      |  | ||||||
|     log_prob, entropy, sampled_arch = network.controller() |  | ||||||
|     with torch.no_grad(): |  | ||||||
|       network.set_cal_mode('dynamic', sampled_arch) |  | ||||||
|       _, logits = network(inputs) |  | ||||||
|       val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) |  | ||||||
|       val_top1  = val_top1.view(-1) / 100 |  | ||||||
|     reward = val_top1 + controller_entropy_weight * entropy |  | ||||||
|     if prev_baseline is None: |  | ||||||
|       baseline = val_top1 |  | ||||||
|     else: |  | ||||||
|       baseline = prev_baseline - (1 - controller_bl_dec) * (prev_baseline - reward) |  | ||||||
|     |  | ||||||
|     loss = -1 * log_prob * (reward - baseline) |  | ||||||
|      |  | ||||||
|     # account |  | ||||||
|     RewardMeter.update(reward.item()) |  | ||||||
|     BaselineMeter.update(baseline.item()) |  | ||||||
|     ValAccMeter.update(val_top1.item()*100) |  | ||||||
|     LossMeter.update(loss.item()) |  | ||||||
|     EntropyMeter.update(entropy.item()) |  | ||||||
|    |  | ||||||
|     # Average gradient over controller_num_aggregate samples |  | ||||||
|     loss = loss / controller_num_aggregate |  | ||||||
|     loss.backward(retain_graph=True) |  | ||||||
|  |  | ||||||
|     # measure elapsed time |  | ||||||
|     batch_time.update(time.time() - xend) |  | ||||||
|     xend = time.time() |  | ||||||
|     if (step+1) % controller_num_aggregate == 0: |  | ||||||
|       grad_norm = torch.nn.utils.clip_grad_norm_(network.controller.parameters(), 5.0) |  | ||||||
|       GradnormMeter.update(grad_norm) |  | ||||||
|       optimizer.step() |  | ||||||
|       network.controller.zero_grad() |  | ||||||
|  |  | ||||||
|     if step % print_freq == 0: |  | ||||||
|       Sstr = '*Train-Controller* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, controller_train_steps * controller_num_aggregate) |  | ||||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) |  | ||||||
|       Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})'.format(loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter) |  | ||||||
|       Estr = 'Entropy={:.4f} ({:.4f})'.format(EntropyMeter.val, EntropyMeter.avg) |  | ||||||
|       logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Estr) |  | ||||||
|  |  | ||||||
|   return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def valid_func(xloader, network, criterion, logger): | def valid_func(xloader, network, criterion, logger): | ||||||
|   data_time, batch_time = AverageMeter(), AverageMeter() |   data_time, batch_time = AverageMeter(), AverageMeter() | ||||||
|   arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() |   arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||||
| @@ -159,7 +123,7 @@ def valid_func(xloader, network, criterion, logger): | |||||||
|       # measure data loading time |       # measure data loading time | ||||||
|       data_time.update(time.time() - end) |       data_time.update(time.time() - end) | ||||||
|       # prediction |       # prediction | ||||||
|       _, logits = network(arch_inputs.cuda(non_blocking=True)) |       _, logits, _ = network(arch_inputs.cuda(non_blocking=True)) | ||||||
|       arch_loss = criterion(logits, arch_targets) |       arch_loss = criterion(logits, arch_targets) | ||||||
|       # record |       # record | ||||||
|       arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) |       arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) | ||||||
| @@ -211,9 +175,9 @@ def main(xargs): | |||||||
|   params = count_parameters_in_MB(search_model) |   params = count_parameters_in_MB(search_model) | ||||||
|   logger.log('The parameters of the search model = {:.2f} MB'.format(params)) |   logger.log('The parameters of the search model = {:.2f} MB'.format(params)) | ||||||
|   logger.log('search-space : {:}'.format(search_space)) |   logger.log('search-space : {:}'.format(search_space)) | ||||||
|   try: |   if bool(xargs.use_api): | ||||||
|     api = API(verbose=False) |     api = API(verbose=False) | ||||||
|   except: |   else: | ||||||
|     api = None |     api = None | ||||||
|   logger.log('{:} create API = {:} done'.format(time_string(), api)) |   logger.log('{:} create API = {:} done'.format(time_string(), api)) | ||||||
|  |  | ||||||
| @@ -250,7 +214,7 @@ def main(xargs): | |||||||
|       network.set_tau( xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1) ) |       network.set_tau( xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1) ) | ||||||
|       logger.log('[RESET tau as : {:}]'.format(network.tau)) |       logger.log('[RESET tau as : {:}]'.format(network.tau)) | ||||||
|     search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ |     search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ | ||||||
|                 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) |                 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, xargs.algo, epoch_str, xargs.print_freq, logger) | ||||||
|     search_time.update(time.time() - start_time) |     search_time.update(time.time() - start_time) | ||||||
|     logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) |     logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) | ||||||
|     logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) |     logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) | ||||||
| @@ -305,8 +269,9 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--data_path'   ,       type=str,   help='Path to dataset') |   parser.add_argument('--data_path'   ,       type=str,   help='Path to dataset') | ||||||
|   parser.add_argument('--dataset'     ,       type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') |   parser.add_argument('--dataset'     ,       type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') | ||||||
|   parser.add_argument('--search_space',       type=str,   default='sss', choices=['sss'], help='The search space name.') |   parser.add_argument('--search_space',       type=str,   default='sss', choices=['sss'], help='The search space name.') | ||||||
|   parser.add_argument('--algo'        ,       type=str,   choices=['tas', 'fbv2', 'enas'], help='The search space name.') |   parser.add_argument('--algo'        ,       type=str,   choices=['tas', 'fbv2', 'tunas'], help='The search space name.') | ||||||
|   parser.add_argument('--genotype'    ,       type=str,   default='|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', help='The genotype.') |   parser.add_argument('--genotype'    ,       type=str,   default='|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', help='The genotype.') | ||||||
|  |   parser.add_argument('--use_api'     ,       type=int,   default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).') | ||||||
|   # FOR GDAS |   # FOR GDAS | ||||||
|   parser.add_argument('--tau_min',            type=float, default=0.1,  help='The minimum tau for Gumbel Softmax.') |   parser.add_argument('--tau_min',            type=float, default=0.1,  help='The minimum tau for Gumbel Softmax.') | ||||||
|   parser.add_argument('--tau_max',            type=float, default=10,   help='The maximum tau for Gumbel Softmax.') |   parser.add_argument('--tau_max',            type=float, default=10,   help='The maximum tau for Gumbel Softmax.') | ||||||
|   | |||||||
| @@ -29,8 +29,8 @@ from log_utils import time_string | |||||||
| def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): | def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): | ||||||
|   ss_dir = '{:}-{:}'.format(root_dir, search_space) |   ss_dir = '{:}-{:}'.format(root_dir, search_space) | ||||||
|   alg2name, alg2path = OrderedDict(), OrderedDict() |   alg2name, alg2path = OrderedDict(), OrderedDict() | ||||||
|   seeds = [777] |  | ||||||
|   if search_space == 'tss': |   if search_space == 'tss': | ||||||
|  |     seeds = [777] | ||||||
|     alg2name['GDAS'] = 'gdas-affine0_BN0-None' |     alg2name['GDAS'] = 'gdas-affine0_BN0-None' | ||||||
|     alg2name['RSPS'] = 'random-affine0_BN0-None' |     alg2name['RSPS'] = 'random-affine0_BN0-None' | ||||||
|     alg2name['DARTS (1st)'] = 'darts-v1-affine0_BN0-None' |     alg2name['DARTS (1st)'] = 'darts-v1-affine0_BN0-None' | ||||||
| @@ -38,8 +38,10 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): | |||||||
|     alg2name['ENAS'] = 'enas-affine0_BN0-None' |     alg2name['ENAS'] = 'enas-affine0_BN0-None' | ||||||
|     alg2name['SETN'] = 'setn-affine0_BN0-None' |     alg2name['SETN'] = 'setn-affine0_BN0-None' | ||||||
|   else: |   else: | ||||||
|  |     seeds = [777, 888, 999] | ||||||
|     alg2name['TAS'] = 'tas-affine0_BN0' |     alg2name['TAS'] = 'tas-affine0_BN0' | ||||||
|     alg2name['FBNetV2'] = 'fbv2-affine0_BN0' |     alg2name['FBNetV2'] = 'fbv2-affine0_BN0' | ||||||
|  |     alg2name['TuNAS'] = 'tunas-affine0_BN0' | ||||||
|   for alg, name in alg2name.items(): |   for alg, name in alg2name.items(): | ||||||
|     alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth') |     alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth') | ||||||
|   alg2data = OrderedDict() |   alg2data = OrderedDict() | ||||||
| @@ -84,7 +86,7 @@ def visualize_curve(api, vis_save_dir, search_space): | |||||||
|     alg2data = fetch_data(search_space=search_space, dataset=dataset) |     alg2data = fetch_data(search_space=search_space, dataset=dataset) | ||||||
|     alg2accuracies = OrderedDict() |     alg2accuracies = OrderedDict() | ||||||
|     epochs = 100 |     epochs = 100 | ||||||
|     colors = ['b', 'g', 'c', 'm', 'y'] |     colors = ['b', 'g', 'c', 'm', 'y', 'r'] | ||||||
|     ax.set_xlim(0, epochs) |     ax.set_xlim(0, epochs) | ||||||
|     # ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) |     # ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) | ||||||
|     for idx, (alg, data) in enumerate(alg2data.items()): |     for idx, (alg, data) in enumerate(alg2data.items()): | ||||||
|   | |||||||
| @@ -47,10 +47,10 @@ class GenericNAS301Model(nn.Module): | |||||||
|   def set_algo(self, algo: Text): |   def set_algo(self, algo: Text): | ||||||
|     # used for searching |     # used for searching | ||||||
|     assert self._algo is None, 'This functioin can only be called once.' |     assert self._algo is None, 'This functioin can only be called once.' | ||||||
|     assert algo in ['fbv2', 'enas', 'tas'], 'invalid algo : {:}'.format(algo) |     assert algo in ['fbv2', 'tunas', 'tas'], 'invalid algo : {:}'.format(algo) | ||||||
|     self._algo = algo |     self._algo = algo | ||||||
|     self._arch_parameters = nn.Parameter(1e-3*torch.randn(self._max_num_Cs, len(self._candidate_Cs))) |     self._arch_parameters = nn.Parameter(1e-3*torch.randn(self._max_num_Cs, len(self._candidate_Cs))) | ||||||
|     if algo == 'fbv2' or algo == 'enas': |     if algo == 'fbv2' or algo == 'tunas': | ||||||
|       self.register_buffer('_masks', torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs))) |       self.register_buffer('_masks', torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs))) | ||||||
|       for i in range(len(self._candidate_Cs)): |       for i in range(len(self._candidate_Cs)): | ||||||
|         self._masks.data[i, :self._candidate_Cs[i]] = 1 |         self._masks.data[i, :self._candidate_Cs[i]] = 1 | ||||||
| @@ -106,15 +106,17 @@ class GenericNAS301Model(nn.Module): | |||||||
|  |  | ||||||
|   def forward(self, inputs): |   def forward(self, inputs): | ||||||
|     feature = inputs |     feature = inputs | ||||||
|  |  | ||||||
|  |     log_probs = [] | ||||||
|     for i, cell in enumerate(self._cells): |     for i, cell in enumerate(self._cells): | ||||||
|       feature = cell(feature) |       feature = cell(feature) | ||||||
|       if self._algo == 'fbv2': |       # apply different searching algorithms | ||||||
|       idx = max(0, i-1) |       idx = max(0, i-1) | ||||||
|  |       if self._algo == 'fbv2': | ||||||
|         weights = nn.functional.gumbel_softmax(self._arch_parameters[idx:idx+1], tau=self.tau, dim=-1) |         weights = nn.functional.gumbel_softmax(self._arch_parameters[idx:idx+1], tau=self.tau, dim=-1) | ||||||
|         mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1) |         mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1) | ||||||
|         feature = feature * mask |         feature = feature * mask | ||||||
|       elif self._algo == 'tas': |       elif self._algo == 'tas': | ||||||
|         idx = max(0, i-1) |  | ||||||
|         selected_cs, selected_probs = select2withP(self._arch_parameters[idx:idx+1], self.tau, num=2) |         selected_cs, selected_probs = select2withP(self._arch_parameters[idx:idx+1], self.tau, num=2) | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|           i1, i2 = selected_cs.cpu().view(-1).tolist() |           i1, i2 = selected_cs.cpu().view(-1).tolist() | ||||||
| @@ -128,6 +130,13 @@ class GenericNAS301Model(nn.Module): | |||||||
|         else: |         else: | ||||||
|           miss = torch.zeros(feature.shape[0], feature.shape[1]-out.shape[1], feature.shape[2], feature.shape[3], device=feature.device) |           miss = torch.zeros(feature.shape[0], feature.shape[1]-out.shape[1], feature.shape[2], feature.shape[3], device=feature.device) | ||||||
|           feature = torch.cat((out, miss), dim=1) |           feature = torch.cat((out, miss), dim=1) | ||||||
|  |       elif self._algo == 'tunas': | ||||||
|  |         prob = nn.functional.softmax(self._arch_parameters[idx:idx+1], dim=-1) | ||||||
|  |         dist = torch.distributions.Categorical(prob) | ||||||
|  |         action = dist.sample() | ||||||
|  |         log_probs.append(dist.log_prob(action)) | ||||||
|  |         mask = self._masks[action.item()].view(1, -1, 1, 1) | ||||||
|  |         feature = feature * mask | ||||||
|       else: |       else: | ||||||
|         raise ValueError('invalid algorithm : {:}'.format(self._algo)) |         raise ValueError('invalid algorithm : {:}'.format(self._algo)) | ||||||
|  |  | ||||||
| @@ -136,4 +145,4 @@ class GenericNAS301Model(nn.Module): | |||||||
|     out = out.view(out.size(0), -1) |     out = out.view(out.size(0), -1) | ||||||
|     logits = self.classifier(out) |     logits = self.classifier(out) | ||||||
|  |  | ||||||
|     return out, logits |     return out, logits, log_probs | ||||||
|   | |||||||
| @@ -60,6 +60,7 @@ class NASBench301API(NASBenchMetaAPI): | |||||||
|     self.reset_time() |     self.reset_time() | ||||||
|     if file_path_or_dict is None: |     if file_path_or_dict is None: | ||||||
|       file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1]) |       file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1]) | ||||||
|  |       print ('Try to use the default NAS-Bench-301 path from {:}.'.format(file_path_or_dict)) | ||||||
|     if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path): |     if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path): | ||||||
|       file_path_or_dict = str(file_path_or_dict) |       file_path_or_dict = str(file_path_or_dict) | ||||||
|       if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict)) |       if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user