Add int search space
This commit is contained in:
		| @@ -76,7 +76,9 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode): | ||||
|     return losses.avg, top1.avg, top5.avg, batch_time.sum | ||||
|  | ||||
|  | ||||
| def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, seed, logger): | ||||
| def evaluate_for_seed( | ||||
|     arch_config, config, arch, train_loader, valid_loaders, seed, logger | ||||
| ): | ||||
|  | ||||
|     prepare_seed(seed)  # random seed | ||||
|     net = get_cell_based_tiny_net( | ||||
| @@ -94,14 +96,29 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, se | ||||
|     # net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) | ||||
|     flop, param = get_model_infos(net, config.xshape) | ||||
|     logger.log("Network : {:}".format(net.get_message()), False) | ||||
|     logger.log("{:} Seed-------------------------- {:} --------------------------".format(time_string(), seed)) | ||||
|     logger.log( | ||||
|         "{:} Seed-------------------------- {:} --------------------------".format( | ||||
|             time_string(), seed | ||||
|         ) | ||||
|     ) | ||||
|     logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param)) | ||||
|     # train and valid | ||||
|     optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config) | ||||
|     network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda() | ||||
|     # start training | ||||
|     start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup | ||||
|     train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} | ||||
|     start_time, epoch_time, total_epoch = ( | ||||
|         time.time(), | ||||
|         AverageMeter(), | ||||
|         config.epochs + config.warmup, | ||||
|     ) | ||||
|     ( | ||||
|         train_losses, | ||||
|         train_acc1es, | ||||
|         train_acc5es, | ||||
|         valid_losses, | ||||
|         valid_acc1es, | ||||
|         valid_acc5es, | ||||
|     ) = ({}, {}, {}, {}, {}, {}) | ||||
|     train_times, valid_times = {}, {} | ||||
|     for epoch in range(total_epoch): | ||||
|         scheduler.update(epoch, 0.0) | ||||
| @@ -126,7 +143,9 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, se | ||||
|         # measure elapsed time | ||||
|         epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|         need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True)) | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]".format( | ||||
|                 time_string(), | ||||
|   | ||||
		Reference in New Issue
	
	Block a user