update NAS-Bench-102 baselines
This commit is contained in:
		| @@ -1,10 +1,10 @@ | |||||||
| { | { | ||||||
|   "scheduler": ["str",   "cos"], |   "scheduler": ["str",   "cos"], | ||||||
|  |   "LR"       : ["float", "0.025"], | ||||||
|   "eta_min"  : ["float", "0.001"], |   "eta_min"  : ["float", "0.001"], | ||||||
|   "epochs"   : ["int",   "50"], |   "epochs"   : ["int",   "50"], | ||||||
|   "warmup"   : ["int",   "0"], |   "warmup"   : ["int",   "0"], | ||||||
|   "optim"    : ["str",   "SGD"], |   "optim"    : ["str",   "SGD"], | ||||||
|   "LR"       : ["float", "0.025"], |  | ||||||
|   "decay"    : ["float", "0.0005"], |   "decay"    : ["float", "0.0005"], | ||||||
|   "momentum" : ["float", "0.9"], |   "momentum" : ["float", "0.9"], | ||||||
|   "nesterov" : ["bool",  "1"], |   "nesterov" : ["bool",  "1"], | ||||||
|   | |||||||
| @@ -2,7 +2,7 @@ | |||||||
|   "scheduler": ["str",   "cos"], |   "scheduler": ["str",   "cos"], | ||||||
|   "LR"       : ["float", "0.05"], |   "LR"       : ["float", "0.05"], | ||||||
|   "eta_min"  : ["float", "0.0005"], |   "eta_min"  : ["float", "0.0005"], | ||||||
|   "epochs"   : ["int",   "310"], |   "epochs"   : ["int",   "250"], | ||||||
|   "T_max"    : ["int",   "10"], |   "T_max"    : ["int",   "10"], | ||||||
|   "warmup"   : ["int",   "0"], |   "warmup"   : ["int",   "0"], | ||||||
|   "optim"    : ["str",   "SGD"], |   "optim"    : ["str",   "SGD"], | ||||||
|   | |||||||
| @@ -1,10 +1,10 @@ | |||||||
| { | { | ||||||
|   "scheduler": ["str",   "cos"], |   "scheduler": ["str",   "cos"], | ||||||
|  |   "LR"       : ["float", "0.025"], | ||||||
|   "eta_min"  : ["float", "0.001"], |   "eta_min"  : ["float", "0.001"], | ||||||
|   "epochs"   : ["int",   "250"], |   "epochs"   : ["int",   "250"], | ||||||
|   "warmup"   : ["int",   "0"], |   "warmup"   : ["int",   "0"], | ||||||
|   "optim"    : ["str",   "SGD"], |   "optim"    : ["str",   "SGD"], | ||||||
|   "LR"       : ["float", "0.025"], |  | ||||||
|   "decay"    : ["float", "0.0005"], |   "decay"    : ["float", "0.0005"], | ||||||
|   "momentum" : ["float", "0.9"], |   "momentum" : ["float", "0.9"], | ||||||
|   "nesterov" : ["bool",  "1"], |   "nesterov" : ["bool",  "1"], | ||||||
|   | |||||||
| @@ -1,10 +1,10 @@ | |||||||
| { | { | ||||||
|   "scheduler": ["str",   "cos"], |   "scheduler": ["str",   "cos"], | ||||||
|  |   "LR"       : ["float", "0.025"], | ||||||
|   "eta_min"  : ["float", "0.001"], |   "eta_min"  : ["float", "0.001"], | ||||||
|   "epochs"   : ["int",   "250"], |   "epochs"   : ["int",   "250"], | ||||||
|   "warmup"   : ["int",   "0"], |   "warmup"   : ["int",   "0"], | ||||||
|   "optim"    : ["str",   "SGD"], |   "optim"    : ["str",   "SGD"], | ||||||
|   "LR"       : ["float", "0.025"], |  | ||||||
|   "decay"    : ["float", "0.0005"], |   "decay"    : ["float", "0.0005"], | ||||||
|   "momentum" : ["float", "0.9"], |   "momentum" : ["float", "0.9"], | ||||||
|   "nesterov" : ["bool",  "1"], |   "nesterov" : ["bool",  "1"], | ||||||
|   | |||||||
| @@ -1,10 +1,10 @@ | |||||||
| { | { | ||||||
|   "scheduler": ["str",   "cos"], |   "scheduler": ["str",   "cos"], | ||||||
|  |   "LR"       : ["float", "0.025"], | ||||||
|   "eta_min"  : ["float", "0.001"], |   "eta_min"  : ["float", "0.001"], | ||||||
|   "epochs"   : ["int",   "250"], |   "epochs"   : ["int",   "250"], | ||||||
|   "warmup"   : ["int",   "0"], |   "warmup"   : ["int",   "0"], | ||||||
|   "optim"    : ["str",   "SGD"], |   "optim"    : ["str",   "SGD"], | ||||||
|   "LR"       : ["float", "0.025"], |  | ||||||
|   "decay"    : ["float", "0.0005"], |   "decay"    : ["float", "0.0005"], | ||||||
|   "momentum" : ["float", "0.9"], |   "momentum" : ["float", "0.9"], | ||||||
|   "nesterov" : ["bool",  "1"], |   "nesterov" : ["bool",  "1"], | ||||||
|   | |||||||
| @@ -15,6 +15,7 @@ from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_che | |||||||
| from utils        import get_model_infos, obtain_accuracy | from utils        import get_model_infos, obtain_accuracy | ||||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | from log_utils    import AverageMeter, time_string, convert_secs2time | ||||||
| from models       import get_cell_based_tiny_net, get_search_spaces | from models       import get_cell_based_tiny_net, get_search_spaces | ||||||
|  | from nas_102_api  import NASBench102API as API | ||||||
|  |  | ||||||
|  |  | ||||||
| def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler, optimizer, epoch_str, print_freq, logger): | def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler, optimizer, epoch_str, print_freq, logger): | ||||||
| @@ -224,6 +225,12 @@ def main(xargs): | |||||||
|   #flop, param  = get_model_infos(shared_cnn, xshape) |   #flop, param  = get_model_infos(shared_cnn, xshape) | ||||||
|   #logger.log('{:}'.format(shared_cnn)) |   #logger.log('{:}'.format(shared_cnn)) | ||||||
|   #logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) |   #logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) | ||||||
|  |   logger.log('search-space : {:}'.format(search_space)) | ||||||
|  |   if xargs.arch_nas_dataset is None: | ||||||
|  |     api = None | ||||||
|  |   else: | ||||||
|  |     api = API(xargs.arch_nas_dataset) | ||||||
|  |   logger.log('{:} create API = {:} done'.format(time_string(), api)) | ||||||
|   shared_cnn, controller, criterion = torch.nn.DataParallel(shared_cnn).cuda(), controller.cuda(), criterion.cuda() |   shared_cnn, controller, criterion = torch.nn.DataParallel(shared_cnn).cuda(), controller.cuda(), criterion.cuda() | ||||||
|  |  | ||||||
|   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') |   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') | ||||||
| @@ -247,7 +254,7 @@ def main(xargs): | |||||||
|     start_epoch, valid_accuracies, genotypes, baseline = 0, {'best': -1}, {}, None |     start_epoch, valid_accuracies, genotypes, baseline = 0, {'best': -1}, {}, None | ||||||
|  |  | ||||||
|   # start training |   # start training | ||||||
|   start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup |   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup | ||||||
|   for epoch in range(start_epoch, total_epoch): |   for epoch in range(start_epoch, total_epoch): | ||||||
|     w_scheduler.update(epoch, 0.0) |     w_scheduler.update(epoch, 0.0) | ||||||
|     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) |     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) | ||||||
| @@ -263,7 +270,8 @@ def main(xargs): | |||||||
|                                                                      'ctl_entropy_w': xargs.controller_entropy_weight,  |                                                                      'ctl_entropy_w': xargs.controller_entropy_weight,  | ||||||
|                                                                      'ctl_bl_dec'   : xargs.controller_bl_dec}, None), \ |                                                                      'ctl_bl_dec'   : xargs.controller_bl_dec}, None), \ | ||||||
|                                                         epoch_str, xargs.print_freq, logger) |                                                         epoch_str, xargs.print_freq, logger) | ||||||
|     logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline)) |     search_time.update(time.time() - start_time) | ||||||
|  |     logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline, search_time.sum)) | ||||||
|     best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader) |     best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader) | ||||||
|     shared_cnn.module.update_arch(best_arch) |     shared_cnn.module.update_arch(best_arch) | ||||||
|     _, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion) |     _, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion) | ||||||
| @@ -298,6 +306,7 @@ def main(xargs): | |||||||
|     if find_best: |     if find_best: | ||||||
|       logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, best_valid_acc)) |       logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, best_valid_acc)) | ||||||
|       copy_checkpoint(model_base_path, model_best_path, logger) |       copy_checkpoint(model_base_path, model_best_path, logger) | ||||||
|  |     if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) | ||||||
|     # measure elapsed time |     # measure elapsed time | ||||||
|     epoch_time.update(time.time() - start_time) |     epoch_time.update(time.time() - start_time) | ||||||
|     start_time = time.time() |     start_time = time.time() | ||||||
| @@ -306,27 +315,15 @@ def main(xargs): | |||||||
|   logger.log('During searching, the best architecture is {:}'.format(genotypes['best'])) |   logger.log('During searching, the best architecture is {:}'.format(genotypes['best'])) | ||||||
|   logger.log('Its accuracy is {:.2f}%'.format(valid_accuracies['best'])) |   logger.log('Its accuracy is {:.2f}%'.format(valid_accuracies['best'])) | ||||||
|   logger.log('Randomly select {:} architectures and select the best.'.format(xargs.controller_num_samples)) |   logger.log('Randomly select {:} architectures and select the best.'.format(xargs.controller_num_samples)) | ||||||
|  |   start_time = time.time() | ||||||
|   final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples) |   final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples) | ||||||
|  |   search_time.update(time.time() - start_time) | ||||||
|   shared_cnn.module.update_arch(final_arch) |   shared_cnn.module.update_arch(final_arch) | ||||||
|   final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion) |   final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion) | ||||||
|   logger.log('The Selected Final Architecture : {:}'.format(final_arch)) |   logger.log('The Selected Final Architecture : {:}'.format(final_arch)) | ||||||
|   logger.log('Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%'.format(final_loss, final_top1, final_top5)) |   logger.log('Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%'.format(final_loss, final_top1, final_top5)) | ||||||
|   # check the performance from the architecture dataset |   logger.log('ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, final_arch)) | ||||||
|   #if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): |   if api is not None: logger.log('{:}'.format( api.query_by_arch(final_arch) )) | ||||||
|   #  logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) |  | ||||||
|   #else: |  | ||||||
|   #  nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset) |  | ||||||
|   #  geno = genotypes[total_epoch-1] |  | ||||||
|   #  logger.log('The last model is {:}'.format(geno)) |  | ||||||
|   #  info = nas_bench.query_by_arch( geno ) |  | ||||||
|   #  if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) |  | ||||||
|   #  else           : logger.log('{:}'.format(info)) |  | ||||||
|   #  logger.log('-'*100) |  | ||||||
|   #  geno = genotypes['best'] |  | ||||||
|   #  logger.log('The best model is {:}'.format(geno)) |  | ||||||
|   #  info = nas_bench.query_by_arch( geno ) |  | ||||||
|   #  if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) |  | ||||||
|   #  else           : logger.log('{:}'.format(info)) |  | ||||||
|   logger.close() |   logger.close() | ||||||
|    |    | ||||||
|  |  | ||||||
|   | |||||||
| @@ -93,8 +93,8 @@ def main(xargs): | |||||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) |     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||||
|   else: |   else: | ||||||
|     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) |     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) | ||||||
|   config_path = 'configs/nas-benchmark/algos/GDAS.config' |   #config_path = 'configs/nas-benchmark/algos/GDAS.config' | ||||||
|   config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) |   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) |   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||||
|   # data loader |   # data loader | ||||||
|   search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) |   search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) | ||||||
| @@ -105,7 +105,7 @@ def main(xargs): | |||||||
|   model_config = dict2config({'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells, |   model_config = dict2config({'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells, | ||||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, |                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||||
|                               'space'    : search_space, |                               'space'    : search_space, | ||||||
|                               'affine'   : False, 'track_running_stats': True}, None) |                               'affine'   : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) | ||||||
|   search_model = get_cell_based_tiny_net(model_config) |   search_model = get_cell_based_tiny_net(model_config) | ||||||
|   logger.log('search-model :\n{:}'.format(search_model)) |   logger.log('search-model :\n{:}'.format(search_model)) | ||||||
|    |    | ||||||
| @@ -156,7 +156,7 @@ def main(xargs): | |||||||
|     search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \ |     search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \ | ||||||
|               = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) |               = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) | ||||||
|     search_time.update(time.time() - start_time) |     search_time.update(time.time() - start_time) | ||||||
|     logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5)) |     logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) | ||||||
|     logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss , valid_a_top1 , valid_a_top5 )) |     logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss , valid_a_top1 , valid_a_top5 )) | ||||||
|     # check the best accuracy |     # check the best accuracy | ||||||
|     valid_accuracies[epoch] = valid_a_top1 |     valid_accuracies[epoch] = valid_a_top1 | ||||||
| @@ -210,6 +210,8 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') |   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') | ||||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') |   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||||
|   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') |   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') | ||||||
|  |   parser.add_argument('--track_running_stats',type=int,   choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') | ||||||
|  |   parser.add_argument('--config_path',        type=str,   help='The path of the configuration.') | ||||||
|   # architecture leraning rate |   # architecture leraning rate | ||||||
|   parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') |   parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') | ||||||
|   parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') |   parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') | ||||||
|   | |||||||
| @@ -15,6 +15,7 @@ from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_che | |||||||
| from utils        import get_model_infos, obtain_accuracy | from utils        import get_model_infos, obtain_accuracy | ||||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | from log_utils    import AverageMeter, time_string, convert_secs2time | ||||||
| from models       import get_cell_based_tiny_net, get_search_spaces | from models       import get_cell_based_tiny_net, get_search_spaces | ||||||
|  | from nas_102_api  import NASBench102API as API | ||||||
|  |  | ||||||
|  |  | ||||||
| def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger): | def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger): | ||||||
| @@ -130,6 +131,9 @@ def main(xargs): | |||||||
|   logger.log('w-optimizer : {:}'.format(w_optimizer)) |   logger.log('w-optimizer : {:}'.format(w_optimizer)) | ||||||
|   logger.log('w-scheduler : {:}'.format(w_scheduler)) |   logger.log('w-scheduler : {:}'.format(w_scheduler)) | ||||||
|   logger.log('criterion   : {:}'.format(criterion)) |   logger.log('criterion   : {:}'.format(criterion)) | ||||||
|  |   if xargs.arch_nas_dataset is None: api = None | ||||||
|  |   else                             : api = API(xargs.arch_nas_dataset) | ||||||
|  |   logger.log('{:} create API = {:} done'.format(time_string(), api)) | ||||||
|  |  | ||||||
|   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') |   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') | ||||||
|   network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() |   network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() | ||||||
| @@ -149,7 +153,7 @@ def main(xargs): | |||||||
|     start_epoch, valid_accuracies = 0, {'best': -1} |     start_epoch, valid_accuracies = 0, {'best': -1} | ||||||
|  |  | ||||||
|   # start training |   # start training | ||||||
|   start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup |   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup | ||||||
|   for epoch in range(start_epoch, total_epoch): |   for epoch in range(start_epoch, total_epoch): | ||||||
|     w_scheduler.update(epoch, 0.0) |     w_scheduler.update(epoch, 0.0) | ||||||
|     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) |     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) | ||||||
| @@ -157,7 +161,8 @@ def main(xargs): | |||||||
|     logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) |     logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) | ||||||
|  |  | ||||||
|     search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger) |     search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger) | ||||||
|     logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5)) |     search_time.update(time.time() - start_time) | ||||||
|  |     logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) | ||||||
|     valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) |     valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||||
|     logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) |     logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||||
|     # check the best accuracy |     # check the best accuracy | ||||||
| @@ -188,7 +193,8 @@ def main(xargs): | |||||||
|     start_time = time.time() |     start_time = time.time() | ||||||
|  |  | ||||||
|   logger.log('\n' + '-'*200) |   logger.log('\n' + '-'*200) | ||||||
|  |   logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum)) | ||||||
|  |   start_time = time.time() | ||||||
|   best_arch, best_acc = None, -1 |   best_arch, best_acc = None, -1 | ||||||
|   for iarch in range(xargs.select_num): |   for iarch in range(xargs.select_num): | ||||||
|     arch = search_model.random_genotype( True ) |     arch = search_model.random_genotype( True ) | ||||||
| @@ -197,24 +203,10 @@ def main(xargs): | |||||||
|     if best_arch is None or best_acc < valid_a_top1: |     if best_arch is None or best_acc < valid_a_top1: | ||||||
|       best_arch, best_acc = arch, valid_a_top1 |       best_arch, best_acc = arch, valid_a_top1 | ||||||
|  |  | ||||||
|   logger.log('Find the best one : {:} with accuracy={:.2f}%'.format(best_arch, best_acc)) |   search_time.update(time.time() - start_time) | ||||||
|  |   logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum)) | ||||||
|   logger.log('\n' + '-'*100) |   if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) )) | ||||||
|   """ |  | ||||||
|   # check the performance from the architecture dataset |  | ||||||
|   if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): |  | ||||||
|     logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) |  | ||||||
|   else: |  | ||||||
|     nas_bench = TinyNASBenchmarkAPI(xargs.arch_nas_dataset) |  | ||||||
|     geno      = best_arch |  | ||||||
|     logger.log('The last model is {:}'.format(geno)) |  | ||||||
|     info = nas_bench.query_by_arch( geno ) |  | ||||||
|     if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) |  | ||||||
|     else           : logger.log('{:}'.format(info)) |  | ||||||
|     logger.log('-'*100) |  | ||||||
|   logger.close() |   logger.close() | ||||||
|   """ |  | ||||||
|    |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   | |||||||
| @@ -52,14 +52,18 @@ def main(xargs, nas_bench): | |||||||
|   random_arch = random_architecture_func(xargs.max_nodes, search_space) |   random_arch = random_architecture_func(xargs.max_nodes, search_space) | ||||||
|   #x =random_arch() ; y = mutate_arch(x) |   #x =random_arch() ; y = mutate_arch(x) | ||||||
|   logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) |   logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) | ||||||
|   best_arch, best_acc = None, -1 |   best_arch, best_acc, total_time_cost, history = None, -1, 0, [] | ||||||
|   for idx in range(xargs.random_num): |   #for idx in range(xargs.random_num): | ||||||
|  |   while total_time_cost < xargs.time_budget: | ||||||
|     arch = random_arch() |     arch = random_arch() | ||||||
|     accuracy = train_and_eval(arch, nas_bench, extra_info) |     accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info) | ||||||
|  |     if total_time_cost + cost_time > xargs.time_budget: break | ||||||
|  |     else: total_time_cost += cost_time | ||||||
|  |     history.append(arch) | ||||||
|     if best_arch is None or best_acc < accuracy: |     if best_arch is None or best_acc < accuracy: | ||||||
|       best_acc, best_arch = accuracy, arch |       best_acc, best_arch = accuracy, arch | ||||||
|     logger.log('[{:03d}/{:03d}] : {:} : accuracy = {:.2f}%'.format(idx, xargs.random_num, arch, accuracy)) |     logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy)) | ||||||
|   logger.log('{:} best arch is {:}, accuracy = {:.2f}%'.format(time_string(), best_arch, best_acc)) |   logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.'.format(time_string(), best_arch, best_acc, len(history), total_time_cost)) | ||||||
|    |    | ||||||
|   info = nas_bench.query_by_arch( best_arch ) |   info = nas_bench.query_by_arch( best_arch ) | ||||||
|   if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) |   if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) | ||||||
| @@ -79,7 +83,8 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') |   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') | ||||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') |   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||||
|   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') |   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') | ||||||
|   parser.add_argument('--random_num',         type=int,   help='The number of random selected architectures.') |   #parser.add_argument('--random_num',         type=int,   help='The number of random selected architectures.') | ||||||
|  |   parser.add_argument('--time_budget',        type=int,   help='The total time cost budge for searching (in seconds).') | ||||||
|   # log |   # log | ||||||
|   parser.add_argument('--workers',            type=int,   default=2,    help='number of data loading workers (default: 2)') |   parser.add_argument('--workers',            type=int,   default=2,    help='number of data loading workers (default: 2)') | ||||||
|   parser.add_argument('--save_dir',           type=str,   help='Folder to save checkpoints and log.') |   parser.add_argument('--save_dir',           type=str,   help='Folder to save checkpoints and log.') | ||||||
|   | |||||||
| @@ -60,12 +60,12 @@ def train_and_eval(arch, nas_bench, extra_info): | |||||||
|     arch_index = nas_bench.query_index_by_arch( arch ) |     arch_index = nas_bench.query_index_by_arch( arch ) | ||||||
|     assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) |     assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) | ||||||
|     info = nas_bench.get_more_info(arch_index, 'cifar10-valid', True) |     info = nas_bench.get_more_info(arch_index, 'cifar10-valid', True) | ||||||
|     import pdb; pdb.set_trace() |     valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] | ||||||
|     #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs |     #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs | ||||||
|   else: |   else: | ||||||
|     # train a model from scratch. |     # train a model from scratch. | ||||||
|     raise ValueError('NOT IMPLEMENT YET') |     raise ValueError('NOT IMPLEMENT YET') | ||||||
|   return valid_acc |   return valid_acc, time_cost | ||||||
|  |  | ||||||
|  |  | ||||||
| def random_architecture_func(max_nodes, op_names): | def random_architecture_func(max_nodes, op_names): | ||||||
| @@ -101,7 +101,7 @@ def mutate_arch_func(op_names): | |||||||
|   return mutate_arch_func |   return mutate_arch_func | ||||||
|  |  | ||||||
|  |  | ||||||
| def regularized_evolution(cycles, population_size, sample_size, random_arch, mutate_arch, nas_bench, extra_info): | def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, nas_bench, extra_info): | ||||||
|   """Algorithm for regularized evolution (i.e. aging evolution). |   """Algorithm for regularized evolution (i.e. aging evolution). | ||||||
|    |    | ||||||
|   Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image |   Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image | ||||||
| @@ -111,27 +111,30 @@ def regularized_evolution(cycles, population_size, sample_size, random_arch, mut | |||||||
|     cycles: the number of cycles the algorithm should run for. |     cycles: the number of cycles the algorithm should run for. | ||||||
|     population_size: the number of individuals to keep in the population. |     population_size: the number of individuals to keep in the population. | ||||||
|     sample_size: the number of individuals that should participate in each tournament. |     sample_size: the number of individuals that should participate in each tournament. | ||||||
|  |     time_budget: the upper bound of searching cost | ||||||
|  |  | ||||||
|   Returns: |   Returns: | ||||||
|     history: a list of `Model` instances, representing all the models computed |     history: a list of `Model` instances, representing all the models computed | ||||||
|         during the evolution experiment. |         during the evolution experiment. | ||||||
|   """ |   """ | ||||||
|   population = collections.deque() |   population = collections.deque() | ||||||
|   history = []  # Not used by the algorithm, only used to report results. |   history, total_time_cost = [], 0  # Not used by the algorithm, only used to report results. | ||||||
|  |  | ||||||
|   # Initialize the population with random models. |   # Initialize the population with random models. | ||||||
|   while len(population) < population_size: |   while len(population) < population_size: | ||||||
|     model = Model() |     model = Model() | ||||||
|     model.arch = random_arch() |     model.arch = random_arch() | ||||||
|     model.accuracy = train_and_eval(model.arch, nas_bench, extra_info) |     model.accuracy, time_cost = train_and_eval(model.arch, nas_bench, extra_info) | ||||||
|     population.append(model) |     population.append(model) | ||||||
|     history.append(model) |     history.append(model) | ||||||
|  |     total_time_cost += time_cost | ||||||
|  |  | ||||||
|   # Carry out evolution in cycles. Each cycle produces a model and removes |   # Carry out evolution in cycles. Each cycle produces a model and removes | ||||||
|   # another. |   # another. | ||||||
|   while len(history) < cycles: |   #while len(history) < cycles: | ||||||
|  |   while total_time_cost < time_budget: | ||||||
|     # Sample randomly chosen models from the current population. |     # Sample randomly chosen models from the current population. | ||||||
|     sample = [] |     start_time, sample = time.time(), [] | ||||||
|     while len(sample) < sample_size: |     while len(sample) < sample_size: | ||||||
|       # Inefficient, but written this way for clarity. In the case of neural |       # Inefficient, but written this way for clarity. In the case of neural | ||||||
|       # nets, the efficiency of this line is irrelevant because training neural |       # nets, the efficiency of this line is irrelevant because training neural | ||||||
| @@ -145,13 +148,18 @@ def regularized_evolution(cycles, population_size, sample_size, random_arch, mut | |||||||
|     # Create the child model and store it. |     # Create the child model and store it. | ||||||
|     child = Model() |     child = Model() | ||||||
|     child.arch = mutate_arch(parent.arch) |     child.arch = mutate_arch(parent.arch) | ||||||
|     child.accuracy = train_and_eval(child.arch, nas_bench, extra_info) |     total_time_cost += time.time() - start_time | ||||||
|  |     child.accuracy, time_cost = train_and_eval(child.arch, nas_bench, extra_info) | ||||||
|  |     if total_time_cost + time_cost > time_budget: # return | ||||||
|  |       return history, total_time_cost | ||||||
|  |     else: | ||||||
|  |       total_time_cost += time_cost | ||||||
|     population.append(child) |     population.append(child) | ||||||
|     history.append(child) |     history.append(child) | ||||||
|  |  | ||||||
|     # Remove the oldest model. |     # Remove the oldest model. | ||||||
|     population.popleft() |     population.popleft() | ||||||
|   return history |   return history, total_time_cost | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(xargs, nas_bench): | def main(xargs, nas_bench): | ||||||
| @@ -188,8 +196,9 @@ def main(xargs, nas_bench): | |||||||
|   mutate_arch = mutate_arch_func(search_space) |   mutate_arch = mutate_arch_func(search_space) | ||||||
|   #x =random_arch() ; y = mutate_arch(x) |   #x =random_arch() ; y = mutate_arch(x) | ||||||
|   logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) |   logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) | ||||||
|   history = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info) |   logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget)) | ||||||
|   logger.log('{:} regularized_evolution finish with history of {:} arch.'.format(time_string(), len(history))) |   history, total_cost = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info) | ||||||
|  |   logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s.'.format(time_string(), len(history), total_cost)) | ||||||
|   best_arch = max(history, key=lambda i: i.accuracy) |   best_arch = max(history, key=lambda i: i.accuracy) | ||||||
|   best_arch = best_arch.arch |   best_arch = best_arch.arch | ||||||
|   logger.log('{:} best arch is {:}'.format(time_string(), best_arch)) |   logger.log('{:} best arch is {:}'.format(time_string(), best_arch)) | ||||||
| @@ -216,6 +225,7 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--ea_population',      type=int,   help='The population size in EA.') |   parser.add_argument('--ea_population',      type=int,   help='The population size in EA.') | ||||||
|   parser.add_argument('--ea_sample_size',     type=int,   help='The sample size in EA.') |   parser.add_argument('--ea_sample_size',     type=int,   help='The sample size in EA.') | ||||||
|   parser.add_argument('--ea_fast_by_api',     type=int,   help='Use our API to speed up the experiments or not.') |   parser.add_argument('--ea_fast_by_api',     type=int,   help='Use our API to speed up the experiments or not.') | ||||||
|  |   parser.add_argument('--time_budget',        type=int,   help='The total time cost budge for searching (in seconds).') | ||||||
|   # log |   # log | ||||||
|   parser.add_argument('--workers',            type=int,   default=2,    help='number of data loading workers (default: 2)') |   parser.add_argument('--workers',            type=int,   default=2,    help='number of data loading workers (default: 2)') | ||||||
|   parser.add_argument('--save_dir',           type=str,   help='Folder to save checkpoints and log.') |   parser.add_argument('--save_dir',           type=str,   help='Folder to save checkpoints and log.') | ||||||
|   | |||||||
| @@ -17,6 +17,7 @@ from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_che | |||||||
| from utils        import get_model_infos, obtain_accuracy | from utils        import get_model_infos, obtain_accuracy | ||||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | from log_utils    import AverageMeter, time_string, convert_secs2time | ||||||
| from models       import get_cell_based_tiny_net, get_search_spaces | from models       import get_cell_based_tiny_net, get_search_spaces | ||||||
|  | from nas_102_api  import NASBench102API as API | ||||||
|  |  | ||||||
|  |  | ||||||
| def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): | def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): | ||||||
| @@ -162,7 +163,8 @@ def main(xargs): | |||||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) |   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||||
|   model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells, |   model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells, | ||||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, |                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||||
|                               'space'    : search_space}, None) |                               'space'    : search_space, | ||||||
|  |                               'affine'   : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) | ||||||
|   logger.log('search space : {:}'.format(search_space)) |   logger.log('search space : {:}'.format(search_space)) | ||||||
|   search_model = get_cell_based_tiny_net(model_config) |   search_model = get_cell_based_tiny_net(model_config) | ||||||
|    |    | ||||||
| @@ -175,6 +177,12 @@ def main(xargs): | |||||||
|   flop, param  = get_model_infos(search_model, xshape) |   flop, param  = get_model_infos(search_model, xshape) | ||||||
|   #logger.log('{:}'.format(search_model)) |   #logger.log('{:}'.format(search_model)) | ||||||
|   logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) |   logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) | ||||||
|  |   logger.log('search-space : {:}'.format(search_space)) | ||||||
|  |   if xargs.arch_nas_dataset is None: | ||||||
|  |     api = None | ||||||
|  |   else: | ||||||
|  |     api = API(xargs.arch_nas_dataset) | ||||||
|  |   logger.log('{:} create API = {:} done'.format(time_string(), api)) | ||||||
|  |  | ||||||
|   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') |   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') | ||||||
|   network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() |   network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() | ||||||
| @@ -196,7 +204,7 @@ def main(xargs): | |||||||
|     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} |     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} | ||||||
|  |  | ||||||
|   # start training |   # start training | ||||||
|   start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup |   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup | ||||||
|   for epoch in range(start_epoch, total_epoch): |   for epoch in range(start_epoch, total_epoch): | ||||||
|     w_scheduler.update(epoch, 0.0) |     w_scheduler.update(epoch, 0.0) | ||||||
|     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) |     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) | ||||||
| @@ -205,7 +213,8 @@ def main(xargs): | |||||||
|  |  | ||||||
|     search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ |     search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ | ||||||
|                 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) |                 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) | ||||||
|     logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5)) |     search_time.update(time.time() - start_time) | ||||||
|  |     logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) | ||||||
|     logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) |     logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) | ||||||
|  |  | ||||||
|     genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) |     genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) | ||||||
| @@ -243,52 +252,23 @@ def main(xargs): | |||||||
|           }, logger.path('info'), logger) |           }, logger.path('info'), logger) | ||||||
|     with torch.no_grad(): |     with torch.no_grad(): | ||||||
|       logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) |       logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) | ||||||
|  |     if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) | ||||||
|     # measure elapsed time |     # measure elapsed time | ||||||
|     epoch_time.update(time.time() - start_time) |     epoch_time.update(time.time() - start_time) | ||||||
|     start_time = time.time() |     start_time = time.time() | ||||||
|  |  | ||||||
|   #logger.log('During searching, the best gentotype is : {:} , with the validation accuracy of {:.3f}%.'.format(genotypes['best'], valid_accuracies['best'])) |   # the final post procedure : count the time | ||||||
|  |   start_time = time.time() | ||||||
|   genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) |   genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) | ||||||
|  |   search_time.update(time.time() - start_time) | ||||||
|   network.module.set_cal_mode('dynamic', genotype) |   network.module.set_cal_mode('dynamic', genotype) | ||||||
|   valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) |   valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) | ||||||
|   logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1)) |   logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1)) | ||||||
|   # sampling |  | ||||||
|   """ |  | ||||||
|   with torch.no_grad(): |  | ||||||
|     logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) |  | ||||||
|   selected_archs = set() |  | ||||||
|   while len(selected_archs) < xargs.select_num: |  | ||||||
|     architecture = search_model.dync_genotype() |  | ||||||
|     selected_archs.add( architecture ) |  | ||||||
|   logger.log('select {:} architectures based on the learned arch-parameters'.format( len(selected_archs) )) |  | ||||||
|  |  | ||||||
|   best_arch, best_acc = None, -1 |  | ||||||
|   state_dict = deepcopy( network.state_dict() ) |  | ||||||
|   for index, arch in enumerate(selected_archs): |  | ||||||
|     with torch.no_grad(): |  | ||||||
|       search_model.set_cal_mode('dynamic', arch) |  | ||||||
|       network.load_state_dict( deepcopy(state_dict) ) |  | ||||||
|       valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) |  | ||||||
|     logger.log('{:} [{:03d}/{:03d}] : {:125s}, loss={:.3f}, accuracy={:.3f}%'.format(time_string(), index, len(selected_archs), str(arch), valid_a_loss , valid_a_top1)) |  | ||||||
|     if best_arch is None or best_acc < valid_a_top1: |  | ||||||
|       best_arch, best_acc = arch, valid_a_top1 |  | ||||||
|   logger.log('Find the best one : {:} with accuracy={:.2f}%'.format(best_arch, best_acc)) |  | ||||||
|   """ |  | ||||||
|  |  | ||||||
|   logger.log('\n' + '-'*100) |   logger.log('\n' + '-'*100) | ||||||
|   # check the performance from the architecture dataset |   # check the performance from the architecture dataset | ||||||
|   """ |   logger.log('SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotype)) | ||||||
|   if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): |   if api is not None: logger.log('{:}'.format( api.query_by_arch(genotype) )) | ||||||
|     logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) |  | ||||||
|   else: |  | ||||||
|     nas_bench = TinyNASBenchmarkAPI(xargs.arch_nas_dataset) |  | ||||||
|     geno      = best_arch |  | ||||||
|     logger.log('The last model is {:}'.format(geno)) |  | ||||||
|     info = nas_bench.query_by_arch( geno ) |  | ||||||
|     if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) |  | ||||||
|     else           : logger.log('{:}'.format(info)) |  | ||||||
|     logger.log('-'*100) |  | ||||||
|   """ |  | ||||||
|   logger.close() |   logger.close() | ||||||
|    |    | ||||||
|  |  | ||||||
| @@ -303,7 +283,8 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') |   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||||
|   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') |   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') | ||||||
|   parser.add_argument('--select_num',         type=int,   help='The number of selected architectures to evaluate.') |   parser.add_argument('--select_num',         type=int,   help='The number of selected architectures to evaluate.') | ||||||
|   parser.add_argument('--config_path',        type=str,   help='.') |   parser.add_argument('--track_running_stats',type=int,   choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') | ||||||
|  |   parser.add_argument('--config_path',        type=str,   help='The path of the configuration.') | ||||||
|   # architecture leraning rate |   # architecture leraning rate | ||||||
|   parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') |   parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') | ||||||
|   parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') |   parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') | ||||||
|   | |||||||
| @@ -20,6 +20,9 @@ def get_cell_based_tiny_net(config): | |||||||
|   group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM'] |   group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM'] | ||||||
|   if super_type == 'basic' and config.name in group_names: |   if super_type == 'basic' and config.name in group_names: | ||||||
|     from .cell_searchs import nas_super_nets |     from .cell_searchs import nas_super_nets | ||||||
|  |     try: | ||||||
|  |       return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats) | ||||||
|  |     except: | ||||||
|       return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space) |       return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space) | ||||||
|   elif super_type == 'l2s-base' and config.name in group_names: |   elif super_type == 'l2s-base' and config.name in group_names: | ||||||
|     from .l2s_cell_searchs import nas_super_nets |     from .l2s_cell_searchs import nas_super_nets | ||||||
|   | |||||||
| @@ -11,7 +11,8 @@ from .genotypes        import Structure | |||||||
|  |  | ||||||
| class TinyNetworkGDAS(nn.Module): | class TinyNetworkGDAS(nn.Module): | ||||||
|  |  | ||||||
|   def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True): |   #def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True): | ||||||
|  |   def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): | ||||||
|     super(TinyNetworkGDAS, self).__init__() |     super(TinyNetworkGDAS, self).__init__() | ||||||
|     self._C        = C |     self._C        = C | ||||||
|     self._layerN   = N |     self._layerN   = N | ||||||
|   | |||||||
| @@ -13,7 +13,7 @@ from .genotypes        import Structure | |||||||
|  |  | ||||||
| class TinyNetworkSETN(nn.Module): | class TinyNetworkSETN(nn.Module): | ||||||
|  |  | ||||||
|   def __init__(self, C, N, max_nodes, num_classes, search_space): |   def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): | ||||||
|     super(TinyNetworkSETN, self).__init__() |     super(TinyNetworkSETN, self).__init__() | ||||||
|     self._C        = C |     self._C        = C | ||||||
|     self._layerN   = N |     self._layerN   = N | ||||||
| @@ -31,7 +31,7 @@ class TinyNetworkSETN(nn.Module): | |||||||
|       if reduction: |       if reduction: | ||||||
|         cell = ResNetBasicblock(C_prev, C_curr, 2) |         cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||||
|       else: |       else: | ||||||
|         cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) |         cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) | ||||||
|         if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index |         if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index | ||||||
|         else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) |         else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) | ||||||
|       self.cells.append( cell ) |       self.cells.append( cell ) | ||||||
|   | |||||||
| @@ -34,6 +34,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/GDAS.py \ | |||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} --data_path ${data_path} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||||
| 	--tau_max 10 --tau_min 0.1 \ | 	--config_path configs/nas-benchmark/algos/GDAS.config \ | ||||||
|  | 	--tau_max 10 --tau_min 0.1 --track_running_stats 1 \ | ||||||
| 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
|   | |||||||
| @@ -35,5 +35,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \ | |||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} --data_path ${data_path} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||||
| 	--ea_cycles 30 --ea_population 10 --ea_sample_size 3 --ea_fast_by_api 1 \ | 	--time_budget 12000 \ | ||||||
|  | 	--ea_cycles 100 --ea_population 10 --ea_sample_size 3 --ea_fast_by_api 1 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
|   | |||||||
| @@ -34,5 +34,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/RANDOM.py \ | |||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} --data_path ${data_path} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||||
| 	--random_num 100 \ | 	--time_budget 12000 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
|  | #	--random_num 100 \ | ||||||
|   | |||||||
| @@ -36,6 +36,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/SETN.py \ | |||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||||
| 	--config_path configs/nas-benchmark/algos/SETN.config \ | 	--config_path configs/nas-benchmark/algos/SETN.config \ | ||||||
|  | 	--track_running_stats 1 \ | ||||||
| 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | ||||||
| 	--select_num 100 \ | 	--select_num 100 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user