Update q-config and black for procedures/utils
This commit is contained in:
		| @@ -1,25 +1,36 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from .starts     import prepare_seed, prepare_logger, get_machine_info, save_checkpoint, copy_checkpoint | ||||
| from .starts import prepare_seed | ||||
| from .starts import prepare_logger | ||||
| from .starts import get_machine_info | ||||
| from .starts import save_checkpoint | ||||
| from .starts import copy_checkpoint | ||||
| from .optimizers import get_optim_scheduler | ||||
| from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed | ||||
| from .funcs_nasbench import pure_evaluate as bench_pure_evaluate | ||||
| from .funcs_nasbench import get_nas_bench_loaders | ||||
|  | ||||
| def get_procedures(procedure): | ||||
|   from .basic_main     import basic_train, basic_valid | ||||
|   from .search_main    import search_train, search_valid | ||||
|   from .search_main_v2 import search_train_v2 | ||||
|   from .simple_KD_main import simple_KD_train, simple_KD_valid | ||||
|  | ||||
|   train_funcs = {'basic' : basic_train, \ | ||||
|                  'search': search_train,'Simple-KD': simple_KD_train, \ | ||||
|                  'search-v2': search_train_v2} | ||||
|   valid_funcs = {'basic' : basic_valid, \ | ||||
|                  'search': search_valid,'Simple-KD': simple_KD_valid, \ | ||||
|                  'search-v2': search_valid} | ||||
|    | ||||
|   train_func  = train_funcs[procedure] | ||||
|   valid_func  = valid_funcs[procedure] | ||||
|   return train_func, valid_func | ||||
| def get_procedures(procedure): | ||||
|     from .basic_main import basic_train, basic_valid | ||||
|     from .search_main import search_train, search_valid | ||||
|     from .search_main_v2 import search_train_v2 | ||||
|     from .simple_KD_main import simple_KD_train, simple_KD_valid | ||||
|  | ||||
|     train_funcs = { | ||||
|         "basic": basic_train, | ||||
|         "search": search_train, | ||||
|         "Simple-KD": simple_KD_train, | ||||
|         "search-v2": search_train_v2, | ||||
|     } | ||||
|     valid_funcs = { | ||||
|         "basic": basic_valid, | ||||
|         "search": search_valid, | ||||
|         "Simple-KD": simple_KD_valid, | ||||
|         "search-v2": search_valid, | ||||
|     } | ||||
|  | ||||
|     train_func = train_funcs[procedure] | ||||
|     valid_func = valid_funcs[procedure] | ||||
|     return train_func, valid_func | ||||
|   | ||||
| @@ -3,73 +3,100 @@ | ||||
| ################################################## | ||||
| import os, sys, time, torch | ||||
| from log_utils import AverageMeter, time_string | ||||
| from utils     import obtain_accuracy | ||||
| from utils import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def basic_train(xloader, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger): | ||||
|   loss, acc1, acc5 = procedure(xloader, network, criterion, scheduler, optimizer, 'train', optim_config, extra_info, print_freq, logger) | ||||
|   return loss, acc1, acc5 | ||||
|     loss, acc1, acc5 = procedure( | ||||
|         xloader, network, criterion, scheduler, optimizer, "train", optim_config, extra_info, print_freq, logger | ||||
|     ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def basic_valid(xloader, network, criterion, optim_config, extra_info, print_freq, logger): | ||||
|   with torch.no_grad(): | ||||
|     loss, acc1, acc5 = procedure(xloader, network, criterion, None, None, 'valid', None, extra_info, print_freq, logger) | ||||
|   return loss, acc1, acc5 | ||||
|     with torch.no_grad(): | ||||
|         loss, acc1, acc5 = procedure( | ||||
|             xloader, network, criterion, None, None, "valid", None, extra_info, print_freq, logger | ||||
|         ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger): | ||||
|   data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   if mode == 'train': | ||||
|     network.train() | ||||
|   elif mode == 'valid': | ||||
|     network.eval() | ||||
|   else: raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|    | ||||
|   #logger.log('[{:5s}] config ::  auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message())) | ||||
|   logger.log('[{:5s}] config ::  auxiliary={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1)) | ||||
|   end = time.time() | ||||
|   for i, (inputs, targets) in enumerate(xloader): | ||||
|     if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|     # measure data loading time | ||||
|     data_time.update(time.time() - end) | ||||
|     # calculate prediction and loss | ||||
|     targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|     if mode == 'train': optimizer.zero_grad() | ||||
|  | ||||
|     features, logits = network(inputs) | ||||
|     if isinstance(logits, list): | ||||
|       assert len(logits) == 2, 'logits must has {:} items instead of {:}'.format(2, len(logits)) | ||||
|       logits, logits_aux = logits | ||||
|     data_time, batch_time, losses, top1, top5 = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|     ) | ||||
|     if mode == "train": | ||||
|         network.train() | ||||
|     elif mode == "valid": | ||||
|         network.eval() | ||||
|     else: | ||||
|       logits, logits_aux = logits, None | ||||
|     loss             = criterion(logits, targets) | ||||
|     if config is not None and hasattr(config, 'auxiliary') and config.auxiliary > 0: | ||||
|       loss_aux = criterion(logits_aux, targets) | ||||
|       loss += config.auxiliary * loss_aux | ||||
|      | ||||
|     if mode == 'train': | ||||
|       loss.backward() | ||||
|       optimizer.step() | ||||
|         raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|  | ||||
|     # record | ||||
|     prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|     losses.update(loss.item(),  inputs.size(0)) | ||||
|     top1.update  (prec1.item(), inputs.size(0)) | ||||
|     top5.update  (prec5.item(), inputs.size(0)) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - end) | ||||
|     # logger.log('[{:5s}] config ::  auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message())) | ||||
|     logger.log( | ||||
|         "[{:5s}] config ::  auxiliary={:}".format(mode, config.auxiliary if hasattr(config, "auxiliary") else -1) | ||||
|     ) | ||||
|     end = time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|         if mode == "train": | ||||
|             scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|         # calculate prediction and loss | ||||
|         targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|     if i % print_freq == 0 or (i+1) == len(xloader): | ||||
|       Sstr = ' {:5s} '.format(mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader)) | ||||
|       if scheduler is not None: | ||||
|         Sstr += ' {:}'.format(scheduler.get_min_info()) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5) | ||||
|       Istr = 'Size={:}'.format(list(inputs.size())) | ||||
|       logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr) | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|  | ||||
|   logger.log(' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(mode=mode.upper(), top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg)) | ||||
|   return losses.avg, top1.avg, top5.avg | ||||
|         features, logits = network(inputs) | ||||
|         if isinstance(logits, list): | ||||
|             assert len(logits) == 2, "logits must has {:} items instead of {:}".format(2, len(logits)) | ||||
|             logits, logits_aux = logits | ||||
|         else: | ||||
|             logits, logits_aux = logits, None | ||||
|         loss = criterion(logits, targets) | ||||
|         if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0: | ||||
|             loss_aux = criterion(logits_aux, targets) | ||||
|             loss += config.auxiliary * loss_aux | ||||
|  | ||||
|         if mode == "train": | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|         losses.update(loss.item(), inputs.size(0)) | ||||
|         top1.update(prec1.item(), inputs.size(0)) | ||||
|         top5.update(prec5.item(), inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|  | ||||
|         if i % print_freq == 0 or (i + 1) == len(xloader): | ||||
|             Sstr = ( | ||||
|                 " {:5s} ".format(mode.upper()) | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) | ||||
|             ) | ||||
|             if scheduler is not None: | ||||
|                 Sstr += " {:}".format(scheduler.get_min_info()) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                 loss=losses, top1=top1, top5=top5 | ||||
|             ) | ||||
|             Istr = "Size={:}".format(list(inputs.size())) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr) | ||||
|  | ||||
|     logger.log( | ||||
|         " **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( | ||||
|             mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg | ||||
|         ) | ||||
|     ) | ||||
|     return losses.avg, top1.avg, top5.avg | ||||
|   | ||||
| @@ -5,199 +5,348 @@ import os, time, copy, torch, pathlib | ||||
|  | ||||
| import datasets | ||||
| from config_utils import load_config | ||||
| from procedures   import prepare_seed, get_optim_scheduler | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net | ||||
| from procedures import prepare_seed, get_optim_scheduler | ||||
| from utils import get_model_infos, obtain_accuracy | ||||
| from log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from models import get_cell_based_tiny_net | ||||
|  | ||||
|  | ||||
| __all__ = ['evaluate_for_seed', 'pure_evaluate', 'get_nas_bench_loaders'] | ||||
| __all__ = ["evaluate_for_seed", "pure_evaluate", "get_nas_bench_loaders"] | ||||
|  | ||||
|  | ||||
| def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | ||||
|   data_time, batch_time, batch = AverageMeter(), AverageMeter(), None | ||||
|   losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   latencies, device = [], torch.cuda.current_device() | ||||
|   network.eval() | ||||
|   with torch.no_grad(): | ||||
|     end = time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|       targets = targets.cuda(device=device, non_blocking=True) | ||||
|       inputs  = inputs.cuda(device=device, non_blocking=True) | ||||
|       data_time.update(time.time() - end) | ||||
|       # forward | ||||
|       features, logits = network(inputs) | ||||
|       loss             = criterion(logits, targets) | ||||
|       batch_time.update(time.time() - end) | ||||
|       if batch is None or batch == inputs.size(0): | ||||
|         batch = inputs.size(0) | ||||
|         latencies.append( batch_time.val - data_time.val ) | ||||
|       # record loss and accuracy | ||||
|       prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|       losses.update(loss.item(),  inputs.size(0)) | ||||
|       top1.update  (prec1.item(), inputs.size(0)) | ||||
|       top5.update  (prec5.item(), inputs.size(0)) | ||||
|       end = time.time() | ||||
|   if len(latencies) > 2: latencies = latencies[1:] | ||||
|   return losses.avg, top1.avg, top5.avg, latencies | ||||
|  | ||||
|     data_time, batch_time, batch = AverageMeter(), AverageMeter(), None | ||||
|     losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     latencies, device = [], torch.cuda.current_device() | ||||
|     network.eval() | ||||
|     with torch.no_grad(): | ||||
|         end = time.time() | ||||
|         for i, (inputs, targets) in enumerate(xloader): | ||||
|             targets = targets.cuda(device=device, non_blocking=True) | ||||
|             inputs = inputs.cuda(device=device, non_blocking=True) | ||||
|             data_time.update(time.time() - end) | ||||
|             # forward | ||||
|             features, logits = network(inputs) | ||||
|             loss = criterion(logits, targets) | ||||
|             batch_time.update(time.time() - end) | ||||
|             if batch is None or batch == inputs.size(0): | ||||
|                 batch = inputs.size(0) | ||||
|                 latencies.append(batch_time.val - data_time.val) | ||||
|             # record loss and accuracy | ||||
|             prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|             losses.update(loss.item(), inputs.size(0)) | ||||
|             top1.update(prec1.item(), inputs.size(0)) | ||||
|             top5.update(prec5.item(), inputs.size(0)) | ||||
|             end = time.time() | ||||
|     if len(latencies) > 2: | ||||
|         latencies = latencies[1:] | ||||
|     return losses.avg, top1.avg, top5.avg, latencies | ||||
|  | ||||
|  | ||||
| def procedure(xloader, network, criterion, scheduler, optimizer, mode: str): | ||||
|   losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   if mode == 'train'  : network.train() | ||||
|   elif mode == 'valid': network.eval() | ||||
|   else: raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|   device = torch.cuda.current_device() | ||||
|   data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() | ||||
|   for i, (inputs, targets) in enumerate(xloader): | ||||
|     if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|     losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     if mode == "train": | ||||
|         network.train() | ||||
|     elif mode == "valid": | ||||
|         network.eval() | ||||
|     else: | ||||
|         raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|     device = torch.cuda.current_device() | ||||
|     data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|         if mode == "train": | ||||
|             scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|  | ||||
|     targets = targets.cuda(device=device, non_blocking=True) | ||||
|     if mode == 'train': optimizer.zero_grad() | ||||
|     # forward | ||||
|     features, logits = network(inputs) | ||||
|     loss             = criterion(logits, targets) | ||||
|     # backward | ||||
|     if mode == 'train': | ||||
|       loss.backward() | ||||
|       optimizer.step() | ||||
|     # record loss and accuracy | ||||
|     prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|     losses.update(loss.item(),  inputs.size(0)) | ||||
|     top1.update  (prec1.item(), inputs.size(0)) | ||||
|     top5.update  (prec5.item(), inputs.size(0)) | ||||
|     # count time | ||||
|     batch_time.update(time.time() - end) | ||||
|     end = time.time() | ||||
|   return losses.avg, top1.avg, top5.avg, batch_time.sum | ||||
|         targets = targets.cuda(device=device, non_blocking=True) | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|         # forward | ||||
|         features, logits = network(inputs) | ||||
|         loss = criterion(logits, targets) | ||||
|         # backward | ||||
|         if mode == "train": | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|         # record loss and accuracy | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|         losses.update(loss.item(), inputs.size(0)) | ||||
|         top1.update(prec1.item(), inputs.size(0)) | ||||
|         top5.update(prec5.item(), inputs.size(0)) | ||||
|         # count time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|     return losses.avg, top1.avg, top5.avg, batch_time.sum | ||||
|  | ||||
|  | ||||
| def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed: int, logger): | ||||
|  | ||||
|   prepare_seed(seed) # random seed | ||||
|   net = get_cell_based_tiny_net(arch_config) | ||||
|   #net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) | ||||
|   flop, param  = get_model_infos(net, opt_config.xshape) | ||||
|   logger.log('Network : {:}'.format(net.get_message()), False) | ||||
|   logger.log('{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed)) | ||||
|   logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param)) | ||||
|   # train and valid | ||||
|   optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config) | ||||
|   default_device = torch.cuda.current_device() | ||||
|   network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(device=default_device) | ||||
|   criterion = criterion.cuda(device=default_device) | ||||
|   # start training | ||||
|   start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup | ||||
|   train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} | ||||
|   train_times , valid_times, lrs = {}, {}, {} | ||||
|   for epoch in range(total_epoch): | ||||
|     scheduler.update(epoch, 0.0) | ||||
|     lr = min(scheduler.get_lr()) | ||||
|     train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train') | ||||
|     train_losses[epoch] = train_loss | ||||
|     train_acc1es[epoch] = train_acc1  | ||||
|     train_acc5es[epoch] = train_acc5 | ||||
|     train_times [epoch] = train_tm | ||||
|     lrs[epoch] = lr | ||||
|     with torch.no_grad(): | ||||
|       for key, xloder in valid_loaders.items(): | ||||
|         valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(xloder  , network, criterion,      None,      None, 'valid') | ||||
|         valid_losses['{:}@{:}'.format(key,epoch)] = valid_loss | ||||
|         valid_acc1es['{:}@{:}'.format(key,epoch)] = valid_acc1  | ||||
|         valid_acc5es['{:}@{:}'.format(key,epoch)] = valid_acc5 | ||||
|         valid_times ['{:}@{:}'.format(key,epoch)] = valid_tm | ||||
|     prepare_seed(seed)  # random seed | ||||
|     net = get_cell_based_tiny_net(arch_config) | ||||
|     # net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) | ||||
|     flop, param = get_model_infos(net, opt_config.xshape) | ||||
|     logger.log("Network : {:}".format(net.get_message()), False) | ||||
|     logger.log("{:} Seed-------------------------- {:} --------------------------".format(time_string(), seed)) | ||||
|     logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param)) | ||||
|     # train and valid | ||||
|     optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config) | ||||
|     default_device = torch.cuda.current_device() | ||||
|     network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(device=default_device) | ||||
|     criterion = criterion.cuda(device=default_device) | ||||
|     # start training | ||||
|     start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup | ||||
|     train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} | ||||
|     train_times, valid_times, lrs = {}, {}, {} | ||||
|     for epoch in range(total_epoch): | ||||
|         scheduler.update(epoch, 0.0) | ||||
|         lr = min(scheduler.get_lr()) | ||||
|         train_loss, train_acc1, train_acc5, train_tm = procedure( | ||||
|             train_loader, network, criterion, scheduler, optimizer, "train" | ||||
|         ) | ||||
|         train_losses[epoch] = train_loss | ||||
|         train_acc1es[epoch] = train_acc1 | ||||
|         train_acc5es[epoch] = train_acc5 | ||||
|         train_times[epoch] = train_tm | ||||
|         lrs[epoch] = lr | ||||
|         with torch.no_grad(): | ||||
|             for key, xloder in valid_loaders.items(): | ||||
|                 valid_loss, valid_acc1, valid_acc5, valid_tm = procedure( | ||||
|                     xloder, network, criterion, None, None, "valid" | ||||
|                 ) | ||||
|                 valid_losses["{:}@{:}".format(key, epoch)] = valid_loss | ||||
|                 valid_acc1es["{:}@{:}".format(key, epoch)] = valid_acc1 | ||||
|                 valid_acc5es["{:}@{:}".format(key, epoch)] = valid_acc5 | ||||
|                 valid_times["{:}@{:}".format(key, epoch)] = valid_tm | ||||
|  | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch-1), True) ) | ||||
|     logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}'.format(time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5, lr)) | ||||
|   info_seed = {'flop' : flop, | ||||
|                'param': param, | ||||
|                'arch_config' : arch_config._asdict(), | ||||
|                'opt_config'  : opt_config._asdict(), | ||||
|                'total_epoch' : total_epoch , | ||||
|                'train_losses': train_losses, | ||||
|                'train_acc1es': train_acc1es, | ||||
|                'train_acc5es': train_acc5es, | ||||
|                'train_times' : train_times, | ||||
|                'valid_losses': valid_losses, | ||||
|                'valid_acc1es': valid_acc1es, | ||||
|                'valid_acc5es': valid_acc5es, | ||||
|                'valid_times' : valid_times, | ||||
|                'learning_rates': lrs, | ||||
|                'net_state_dict': net.state_dict(), | ||||
|                'net_string'  : '{:}'.format(net), | ||||
|                'finish-train': True | ||||
|               } | ||||
|   return info_seed | ||||
|         # measure elapsed time | ||||
|         epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|         need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True)) | ||||
|         logger.log( | ||||
|             "{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}".format( | ||||
|                 time_string(), | ||||
|                 need_time, | ||||
|                 epoch, | ||||
|                 total_epoch, | ||||
|                 train_loss, | ||||
|                 train_acc1, | ||||
|                 train_acc5, | ||||
|                 valid_loss, | ||||
|                 valid_acc1, | ||||
|                 valid_acc5, | ||||
|                 lr, | ||||
|             ) | ||||
|         ) | ||||
|     info_seed = { | ||||
|         "flop": flop, | ||||
|         "param": param, | ||||
|         "arch_config": arch_config._asdict(), | ||||
|         "opt_config": opt_config._asdict(), | ||||
|         "total_epoch": total_epoch, | ||||
|         "train_losses": train_losses, | ||||
|         "train_acc1es": train_acc1es, | ||||
|         "train_acc5es": train_acc5es, | ||||
|         "train_times": train_times, | ||||
|         "valid_losses": valid_losses, | ||||
|         "valid_acc1es": valid_acc1es, | ||||
|         "valid_acc5es": valid_acc5es, | ||||
|         "valid_times": valid_times, | ||||
|         "learning_rates": lrs, | ||||
|         "net_state_dict": net.state_dict(), | ||||
|         "net_string": "{:}".format(net), | ||||
|         "finish-train": True, | ||||
|     } | ||||
|     return info_seed | ||||
|  | ||||
|  | ||||
| def get_nas_bench_loaders(workers): | ||||
|  | ||||
|   torch.set_num_threads(workers) | ||||
|     torch.set_num_threads(workers) | ||||
|  | ||||
|   root_dir  = (pathlib.Path(__file__).parent / '..' / '..').resolve() | ||||
|   torch_dir = pathlib.Path(os.environ['TORCH_HOME']) | ||||
|   # cifar | ||||
|   cifar_config_path = root_dir / 'configs' / 'nas-benchmark' / 'CIFAR.config' | ||||
|   cifar_config = load_config(cifar_config_path, None, None) | ||||
|   get_datasets = datasets.get_datasets  # a function to return the dataset | ||||
|   break_line = '-' * 150 | ||||
|   print ('{:} Create data-loader for all datasets'.format(time_string())) | ||||
|   print (break_line) | ||||
|   TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets('cifar10', str(torch_dir/'cifar.python'), -1) | ||||
|   print ('original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num)) | ||||
|   cifar10_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar-split.txt', None, None) | ||||
|   assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [1, 2, 3, 4, 6, 8, 9, 10, 12, 14] | ||||
|   temp_dataset = copy.deepcopy(TRAIN_CIFAR10) | ||||
|   temp_dataset.transform = VALID_CIFAR10.transform | ||||
|   # data loader | ||||
|   trainval_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True , num_workers=workers, pin_memory=True) | ||||
|   train_cifar10_loader    = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train), num_workers=workers, pin_memory=True) | ||||
|   valid_cifar10_loader    = torch.utils.data.DataLoader(temp_dataset , batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid), num_workers=workers, pin_memory=True) | ||||
|   test__cifar10_loader    = torch.utils.data.DataLoader(VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True) | ||||
|   print ('CIFAR-10  : trval-loader has {:3d} batch with {:} per batch'.format(len(trainval_cifar10_loader), cifar_config.batch_size)) | ||||
|   print ('CIFAR-10  : train-loader has {:3d} batch with {:} per batch'.format(len(train_cifar10_loader), cifar_config.batch_size)) | ||||
|   print ('CIFAR-10  : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_cifar10_loader), cifar_config.batch_size)) | ||||
|   print ('CIFAR-10  : test--loader has {:3d} batch with {:} per batch'.format(len(test__cifar10_loader), cifar_config.batch_size)) | ||||
|   print (break_line) | ||||
|   # CIFAR-100 | ||||
|   TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets('cifar100', str(torch_dir/'cifar.python'), -1) | ||||
|   print ('original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num)) | ||||
|   cifar100_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar100-test-split.txt', None, None) | ||||
|   assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [0, 2, 6, 7, 9, 11, 12, 17, 20, 24] | ||||
|   train_cifar100_loader = torch.utils.data.DataLoader(TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True) | ||||
|   valid_cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True) | ||||
|   test__cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest) , num_workers=workers, pin_memory=True) | ||||
|   print ('CIFAR-100  : train-loader has {:3d} batch'.format(len(train_cifar100_loader))) | ||||
|   print ('CIFAR-100  : valid-loader has {:3d} batch'.format(len(valid_cifar100_loader))) | ||||
|   print ('CIFAR-100  : test--loader has {:3d} batch'.format(len(test__cifar100_loader))) | ||||
|   print (break_line) | ||||
|     root_dir = (pathlib.Path(__file__).parent / ".." / "..").resolve() | ||||
|     torch_dir = pathlib.Path(os.environ["TORCH_HOME"]) | ||||
|     # cifar | ||||
|     cifar_config_path = root_dir / "configs" / "nas-benchmark" / "CIFAR.config" | ||||
|     cifar_config = load_config(cifar_config_path, None, None) | ||||
|     get_datasets = datasets.get_datasets  # a function to return the dataset | ||||
|     break_line = "-" * 150 | ||||
|     print("{:} Create data-loader for all datasets".format(time_string())) | ||||
|     print(break_line) | ||||
|     TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets("cifar10", str(torch_dir / "cifar.python"), -1) | ||||
|     print( | ||||
|         "original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes".format( | ||||
|             len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num | ||||
|         ) | ||||
|     ) | ||||
|     cifar10_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None) | ||||
|     assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [ | ||||
|         1, | ||||
|         2, | ||||
|         3, | ||||
|         4, | ||||
|         6, | ||||
|         8, | ||||
|         9, | ||||
|         10, | ||||
|         12, | ||||
|         14, | ||||
|     ] | ||||
|     temp_dataset = copy.deepcopy(TRAIN_CIFAR10) | ||||
|     temp_dataset.transform = VALID_CIFAR10.transform | ||||
|     # data loader | ||||
|     trainval_cifar10_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True | ||||
|     ) | ||||
|     train_cifar10_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_CIFAR10, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     valid_cifar10_loader = torch.utils.data.DataLoader( | ||||
|         temp_dataset, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     test__cifar10_loader = torch.utils.data.DataLoader( | ||||
|         VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-10  : trval-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(trainval_cifar10_loader), cifar_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-10  : train-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(train_cifar10_loader), cifar_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-10  : valid-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(valid_cifar10_loader), cifar_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-10  : test--loader has {:3d} batch with {:} per batch".format( | ||||
|             len(test__cifar10_loader), cifar_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print(break_line) | ||||
|     # CIFAR-100 | ||||
|     TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets("cifar100", str(torch_dir / "cifar.python"), -1) | ||||
|     print( | ||||
|         "original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes".format( | ||||
|             len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num | ||||
|         ) | ||||
|     ) | ||||
|     cifar100_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None) | ||||
|     assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [ | ||||
|         0, | ||||
|         2, | ||||
|         6, | ||||
|         7, | ||||
|         9, | ||||
|         11, | ||||
|         12, | ||||
|         17, | ||||
|         20, | ||||
|         24, | ||||
|     ] | ||||
|     train_cifar100_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True | ||||
|     ) | ||||
|     valid_cifar100_loader = torch.utils.data.DataLoader( | ||||
|         VALID_CIFAR100, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     test__cifar100_loader = torch.utils.data.DataLoader( | ||||
|         VALID_CIFAR100, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     print("CIFAR-100  : train-loader has {:3d} batch".format(len(train_cifar100_loader))) | ||||
|     print("CIFAR-100  : valid-loader has {:3d} batch".format(len(valid_cifar100_loader))) | ||||
|     print("CIFAR-100  : test--loader has {:3d} batch".format(len(test__cifar100_loader))) | ||||
|     print(break_line) | ||||
|  | ||||
|   imagenet16_config_path = 'configs/nas-benchmark/ImageNet-16.config' | ||||
|   imagenet16_config = load_config(imagenet16_config_path, None, None) | ||||
|   TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets('ImageNet16-120', str(torch_dir/'cifar.python'/'ImageNet16'), -1) | ||||
|   print ('original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num)) | ||||
|   imagenet_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'imagenet-16-120-test-split.txt', None, None) | ||||
|   assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [0, 4, 5, 10, 11, 13, 14, 15, 17, 20] | ||||
|   train_imagenet_loader = torch.utils.data.DataLoader(TRAIN_ImageNet16_120, batch_size=imagenet16_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True) | ||||
|   valid_imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid), num_workers=workers, pin_memory=True) | ||||
|   test__imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest) , num_workers=workers, pin_memory=True) | ||||
|   print ('ImageNet-16-120  : train-loader has {:3d} batch with {:} per batch'.format(len(train_imagenet_loader), imagenet16_config.batch_size)) | ||||
|   print ('ImageNet-16-120  : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_imagenet_loader), imagenet16_config.batch_size)) | ||||
|   print ('ImageNet-16-120  : test--loader has {:3d} batch with {:} per batch'.format(len(test__imagenet_loader), imagenet16_config.batch_size)) | ||||
|     imagenet16_config_path = "configs/nas-benchmark/ImageNet-16.config" | ||||
|     imagenet16_config = load_config(imagenet16_config_path, None, None) | ||||
|     TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets( | ||||
|         "ImageNet16-120", str(torch_dir / "cifar.python" / "ImageNet16"), -1 | ||||
|     ) | ||||
|     print( | ||||
|         "original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes".format( | ||||
|             len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num | ||||
|         ) | ||||
|     ) | ||||
|     imagenet_splits = load_config(root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt", None, None) | ||||
|     assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [ | ||||
|         0, | ||||
|         4, | ||||
|         5, | ||||
|         10, | ||||
|         11, | ||||
|         13, | ||||
|         14, | ||||
|         15, | ||||
|         17, | ||||
|         20, | ||||
|     ] | ||||
|     train_imagenet_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_ImageNet16_120, | ||||
|         batch_size=imagenet16_config.batch_size, | ||||
|         shuffle=True, | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     valid_imagenet_loader = torch.utils.data.DataLoader( | ||||
|         VALID_ImageNet16_120, | ||||
|         batch_size=imagenet16_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     test__imagenet_loader = torch.utils.data.DataLoader( | ||||
|         VALID_ImageNet16_120, | ||||
|         batch_size=imagenet16_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     print( | ||||
|         "ImageNet-16-120  : train-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(train_imagenet_loader), imagenet16_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "ImageNet-16-120  : valid-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(valid_imagenet_loader), imagenet16_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "ImageNet-16-120  : test--loader has {:3d} batch with {:} per batch".format( | ||||
|             len(test__imagenet_loader), imagenet16_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|   # 'cifar10', 'cifar100', 'ImageNet16-120' | ||||
|   loaders = {'cifar10@trainval': trainval_cifar10_loader, | ||||
|              'cifar10@train'   : train_cifar10_loader, | ||||
|              'cifar10@valid'   : valid_cifar10_loader, | ||||
|              'cifar10@test'    : test__cifar10_loader, | ||||
|              'cifar100@train'  : train_cifar100_loader, | ||||
|              'cifar100@valid'  : valid_cifar100_loader, | ||||
|              'cifar100@test'   : test__cifar100_loader, | ||||
|              'ImageNet16-120@train': train_imagenet_loader, | ||||
|              'ImageNet16-120@valid': valid_imagenet_loader, | ||||
|              'ImageNet16-120@test' : test__imagenet_loader} | ||||
|   return loaders | ||||
|     # 'cifar10', 'cifar100', 'ImageNet16-120' | ||||
|     loaders = { | ||||
|         "cifar10@trainval": trainval_cifar10_loader, | ||||
|         "cifar10@train": train_cifar10_loader, | ||||
|         "cifar10@valid": valid_cifar10_loader, | ||||
|         "cifar10@test": test__cifar10_loader, | ||||
|         "cifar100@train": train_cifar100_loader, | ||||
|         "cifar100@valid": valid_cifar100_loader, | ||||
|         "cifar100@test": test__cifar100_loader, | ||||
|         "ImageNet16-120@train": train_imagenet_loader, | ||||
|         "ImageNet16-120@valid": valid_imagenet_loader, | ||||
|         "ImageNet16-120@test": test__imagenet_loader, | ||||
|     } | ||||
|     return loaders | ||||
|   | ||||
| @@ -8,197 +8,201 @@ from torch.optim import Optimizer | ||||
|  | ||||
|  | ||||
| class _LRScheduler(object): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs): | ||||
|         if not isinstance(optimizer, Optimizer): | ||||
|             raise TypeError("{:} is not an Optimizer".format(type(optimizer).__name__)) | ||||
|         self.optimizer = optimizer | ||||
|         for group in optimizer.param_groups: | ||||
|             group.setdefault("initial_lr", group["lr"]) | ||||
|         self.base_lrs = list(map(lambda group: group["initial_lr"], optimizer.param_groups)) | ||||
|         self.max_epochs = epochs | ||||
|         self.warmup_epochs = warmup_epochs | ||||
|         self.current_epoch = 0 | ||||
|         self.current_iter = 0 | ||||
|  | ||||
|   def __init__(self, optimizer, warmup_epochs, epochs): | ||||
|     if not isinstance(optimizer, Optimizer): | ||||
|       raise TypeError('{:} is not an Optimizer'.format(type(optimizer).__name__)) | ||||
|     self.optimizer = optimizer | ||||
|     for group in optimizer.param_groups: | ||||
|       group.setdefault('initial_lr', group['lr']) | ||||
|     self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) | ||||
|     self.max_epochs = epochs | ||||
|     self.warmup_epochs  = warmup_epochs | ||||
|     self.current_epoch  = 0 | ||||
|     self.current_iter   = 0 | ||||
|     def extra_repr(self): | ||||
|         return "" | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return '' | ||||
|     def __repr__(self): | ||||
|         return "{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) + ", {:})".format( | ||||
|             self.extra_repr() | ||||
|         ) | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}'.format(name=self.__class__.__name__, **self.__dict__) | ||||
|               + ', {:})'.format(self.extra_repr())) | ||||
|     def state_dict(self): | ||||
|         return {key: value for key, value in self.__dict__.items() if key != "optimizer"} | ||||
|  | ||||
|   def state_dict(self): | ||||
|     return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} | ||||
|     def load_state_dict(self, state_dict): | ||||
|         self.__dict__.update(state_dict) | ||||
|  | ||||
|   def load_state_dict(self, state_dict): | ||||
|     self.__dict__.update(state_dict) | ||||
|     def get_lr(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|   def get_lr(self): | ||||
|     raise NotImplementedError | ||||
|     def get_min_info(self): | ||||
|         lrs = self.get_lr() | ||||
|         return "#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#".format( | ||||
|             min(lrs), max(lrs), self.current_epoch, self.current_iter | ||||
|         ) | ||||
|  | ||||
|   def get_min_info(self): | ||||
|     lrs = self.get_lr() | ||||
|     return '#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#'.format(min(lrs), max(lrs), self.current_epoch, self.current_iter) | ||||
|  | ||||
|   def get_min_lr(self): | ||||
|     return min( self.get_lr() ) | ||||
|  | ||||
|   def update(self, cur_epoch, cur_iter): | ||||
|     if cur_epoch is not None: | ||||
|       assert isinstance(cur_epoch, int) and cur_epoch>=0, 'invalid cur-epoch : {:}'.format(cur_epoch) | ||||
|       self.current_epoch = cur_epoch | ||||
|     if cur_iter is not None: | ||||
|       assert isinstance(cur_iter, float) and cur_iter>=0, 'invalid cur-iter : {:}'.format(cur_iter) | ||||
|       self.current_iter  = cur_iter | ||||
|     for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): | ||||
|       param_group['lr'] = lr | ||||
|     def get_min_lr(self): | ||||
|         return min(self.get_lr()) | ||||
|  | ||||
|     def update(self, cur_epoch, cur_iter): | ||||
|         if cur_epoch is not None: | ||||
|             assert isinstance(cur_epoch, int) and cur_epoch >= 0, "invalid cur-epoch : {:}".format(cur_epoch) | ||||
|             self.current_epoch = cur_epoch | ||||
|         if cur_iter is not None: | ||||
|             assert isinstance(cur_iter, float) and cur_iter >= 0, "invalid cur-iter : {:}".format(cur_iter) | ||||
|             self.current_iter = cur_iter | ||||
|         for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): | ||||
|             param_group["lr"] = lr | ||||
|  | ||||
|  | ||||
| class CosineAnnealingLR(_LRScheduler): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min): | ||||
|         self.T_max = T_max | ||||
|         self.eta_min = eta_min | ||||
|         super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|   def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min): | ||||
|     self.T_max = T_max | ||||
|     self.eta_min = eta_min | ||||
|     super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return 'type={:}, T-max={:}, eta-min={:}'.format('cosine', self.T_max, self.eta_min) | ||||
|  | ||||
|   def get_lr(self): | ||||
|     lrs = [] | ||||
|     for base_lr in self.base_lrs: | ||||
|       if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs: | ||||
|         last_epoch = self.current_epoch - self.warmup_epochs | ||||
|         #if last_epoch < self.T_max: | ||||
|         #if last_epoch < self.max_epochs: | ||||
|         lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2 | ||||
|         #else: | ||||
|         #  lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2 | ||||
|       elif self.current_epoch >= self.max_epochs: | ||||
|         lr = self.eta_min | ||||
|       else: | ||||
|         lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr | ||||
|       lrs.append( lr ) | ||||
|     return lrs | ||||
|     def extra_repr(self): | ||||
|         return "type={:}, T-max={:}, eta-min={:}".format("cosine", self.T_max, self.eta_min) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         lrs = [] | ||||
|         for base_lr in self.base_lrs: | ||||
|             if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs: | ||||
|                 last_epoch = self.current_epoch - self.warmup_epochs | ||||
|                 # if last_epoch < self.T_max: | ||||
|                 # if last_epoch < self.max_epochs: | ||||
|                 lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2 | ||||
|                 # else: | ||||
|                 #  lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2 | ||||
|             elif self.current_epoch >= self.max_epochs: | ||||
|                 lr = self.eta_min | ||||
|             else: | ||||
|                 lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
|  | ||||
| class MultiStepLR(_LRScheduler): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas): | ||||
|         assert len(milestones) == len(gammas), "invalid {:} vs {:}".format(len(milestones), len(gammas)) | ||||
|         self.milestones = milestones | ||||
|         self.gammas = gammas | ||||
|         super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|   def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas): | ||||
|     assert len(milestones) == len(gammas), 'invalid {:} vs {:}'.format(len(milestones), len(gammas)) | ||||
|     self.milestones = milestones | ||||
|     self.gammas     = gammas | ||||
|     super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|     def extra_repr(self): | ||||
|         return "type={:}, milestones={:}, gammas={:}, base-lrs={:}".format( | ||||
|             "multistep", self.milestones, self.gammas, self.base_lrs | ||||
|         ) | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return 'type={:}, milestones={:}, gammas={:}, base-lrs={:}'.format('multistep', self.milestones, self.gammas, self.base_lrs) | ||||
|  | ||||
|   def get_lr(self): | ||||
|     lrs = [] | ||||
|     for base_lr in self.base_lrs: | ||||
|       if self.current_epoch >= self.warmup_epochs: | ||||
|         last_epoch = self.current_epoch - self.warmup_epochs | ||||
|         idx = bisect_right(self.milestones, last_epoch) | ||||
|         lr = base_lr | ||||
|         for x in self.gammas[:idx]: lr *= x | ||||
|       else: | ||||
|         lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr | ||||
|       lrs.append( lr ) | ||||
|     return lrs | ||||
|     def get_lr(self): | ||||
|         lrs = [] | ||||
|         for base_lr in self.base_lrs: | ||||
|             if self.current_epoch >= self.warmup_epochs: | ||||
|                 last_epoch = self.current_epoch - self.warmup_epochs | ||||
|                 idx = bisect_right(self.milestones, last_epoch) | ||||
|                 lr = base_lr | ||||
|                 for x in self.gammas[:idx]: | ||||
|                     lr *= x | ||||
|             else: | ||||
|                 lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
|  | ||||
| class ExponentialLR(_LRScheduler): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs, gamma): | ||||
|         self.gamma = gamma | ||||
|         super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|   def __init__(self, optimizer, warmup_epochs, epochs, gamma): | ||||
|     self.gamma      = gamma | ||||
|     super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|     def extra_repr(self): | ||||
|         return "type={:}, gamma={:}, base-lrs={:}".format("exponential", self.gamma, self.base_lrs) | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return 'type={:}, gamma={:}, base-lrs={:}'.format('exponential', self.gamma, self.base_lrs) | ||||
|  | ||||
|   def get_lr(self): | ||||
|     lrs = [] | ||||
|     for base_lr in self.base_lrs: | ||||
|       if self.current_epoch >= self.warmup_epochs: | ||||
|         last_epoch = self.current_epoch - self.warmup_epochs | ||||
|         assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch) | ||||
|         lr = base_lr * (self.gamma ** last_epoch) | ||||
|       else: | ||||
|         lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr | ||||
|       lrs.append( lr ) | ||||
|     return lrs | ||||
|     def get_lr(self): | ||||
|         lrs = [] | ||||
|         for base_lr in self.base_lrs: | ||||
|             if self.current_epoch >= self.warmup_epochs: | ||||
|                 last_epoch = self.current_epoch - self.warmup_epochs | ||||
|                 assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch) | ||||
|                 lr = base_lr * (self.gamma ** last_epoch) | ||||
|             else: | ||||
|                 lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
|  | ||||
| class LinearLR(_LRScheduler): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR): | ||||
|         self.max_LR = max_LR | ||||
|         self.min_LR = min_LR | ||||
|         super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|   def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR): | ||||
|     self.max_LR = max_LR | ||||
|     self.min_LR = min_LR | ||||
|     super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return 'type={:}, max_LR={:}, min_LR={:}, base-lrs={:}'.format('LinearLR', self.max_LR, self.min_LR, self.base_lrs) | ||||
|  | ||||
|   def get_lr(self): | ||||
|     lrs = [] | ||||
|     for base_lr in self.base_lrs: | ||||
|       if self.current_epoch >= self.warmup_epochs: | ||||
|         last_epoch = self.current_epoch - self.warmup_epochs | ||||
|         assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch) | ||||
|         ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR | ||||
|         lr = base_lr * (1-ratio) | ||||
|       else: | ||||
|         lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr | ||||
|       lrs.append( lr ) | ||||
|     return lrs | ||||
|     def extra_repr(self): | ||||
|         return "type={:}, max_LR={:}, min_LR={:}, base-lrs={:}".format( | ||||
|             "LinearLR", self.max_LR, self.min_LR, self.base_lrs | ||||
|         ) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         lrs = [] | ||||
|         for base_lr in self.base_lrs: | ||||
|             if self.current_epoch >= self.warmup_epochs: | ||||
|                 last_epoch = self.current_epoch - self.warmup_epochs | ||||
|                 assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch) | ||||
|                 ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR | ||||
|                 lr = base_lr * (1 - ratio) | ||||
|             else: | ||||
|                 lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
|  | ||||
| class CrossEntropyLabelSmooth(nn.Module): | ||||
|     def __init__(self, num_classes, epsilon): | ||||
|         super(CrossEntropyLabelSmooth, self).__init__() | ||||
|         self.num_classes = num_classes | ||||
|         self.epsilon = epsilon | ||||
|         self.logsoftmax = nn.LogSoftmax(dim=1) | ||||
|  | ||||
|   def __init__(self, num_classes, epsilon): | ||||
|     super(CrossEntropyLabelSmooth, self).__init__() | ||||
|     self.num_classes = num_classes | ||||
|     self.epsilon = epsilon | ||||
|     self.logsoftmax = nn.LogSoftmax(dim=1) | ||||
|  | ||||
|   def forward(self, inputs, targets): | ||||
|     log_probs = self.logsoftmax(inputs) | ||||
|     targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) | ||||
|     targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes | ||||
|     loss = (-targets * log_probs).mean(0).sum() | ||||
|     return loss | ||||
|  | ||||
|     def forward(self, inputs, targets): | ||||
|         log_probs = self.logsoftmax(inputs) | ||||
|         targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) | ||||
|         targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes | ||||
|         loss = (-targets * log_probs).mean(0).sum() | ||||
|         return loss | ||||
|  | ||||
|  | ||||
| def get_optim_scheduler(parameters, config): | ||||
|   assert hasattr(config, 'optim') and hasattr(config, 'scheduler') and hasattr(config, 'criterion'), 'config must have optim / scheduler / criterion keys instead of {:}'.format(config) | ||||
|   if config.optim == 'SGD': | ||||
|     optim = torch.optim.SGD(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov) | ||||
|   elif config.optim == 'RMSprop': | ||||
|     optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay) | ||||
|   else: | ||||
|     raise ValueError('invalid optim : {:}'.format(config.optim)) | ||||
|     assert ( | ||||
|         hasattr(config, "optim") and hasattr(config, "scheduler") and hasattr(config, "criterion") | ||||
|     ), "config must have optim / scheduler / criterion keys instead of {:}".format(config) | ||||
|     if config.optim == "SGD": | ||||
|         optim = torch.optim.SGD( | ||||
|             parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov | ||||
|         ) | ||||
|     elif config.optim == "RMSprop": | ||||
|         optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay) | ||||
|     else: | ||||
|         raise ValueError("invalid optim : {:}".format(config.optim)) | ||||
|  | ||||
|   if config.scheduler == 'cos': | ||||
|     T_max = getattr(config, 'T_max', config.epochs) | ||||
|     scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min) | ||||
|   elif config.scheduler == 'multistep': | ||||
|     scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas) | ||||
|   elif config.scheduler == 'exponential': | ||||
|     scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma) | ||||
|   elif config.scheduler == 'linear': | ||||
|     scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min) | ||||
|   else: | ||||
|     raise ValueError('invalid scheduler : {:}'.format(config.scheduler)) | ||||
|     if config.scheduler == "cos": | ||||
|         T_max = getattr(config, "T_max", config.epochs) | ||||
|         scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min) | ||||
|     elif config.scheduler == "multistep": | ||||
|         scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas) | ||||
|     elif config.scheduler == "exponential": | ||||
|         scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma) | ||||
|     elif config.scheduler == "linear": | ||||
|         scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min) | ||||
|     else: | ||||
|         raise ValueError("invalid scheduler : {:}".format(config.scheduler)) | ||||
|  | ||||
|   if config.criterion == 'Softmax': | ||||
|     criterion = torch.nn.CrossEntropyLoss() | ||||
|   elif config.criterion == 'SmoothSoftmax': | ||||
|     criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth) | ||||
|   else: | ||||
|     raise ValueError('invalid criterion : {:}'.format(config.criterion)) | ||||
|   return optim, scheduler, criterion | ||||
|     if config.criterion == "Softmax": | ||||
|         criterion = torch.nn.CrossEntropyLoss() | ||||
|     elif config.criterion == "SmoothSoftmax": | ||||
|         criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth) | ||||
|     else: | ||||
|         raise ValueError("invalid criterion : {:}".format(config.criterion)) | ||||
|     return optim, scheduler, criterion | ||||
|   | ||||
| @@ -7,11 +7,12 @@ from qlib.utils import init_instance_by_config | ||||
| from qlib.workflow import R | ||||
| from qlib.utils import flatten_dict | ||||
| from qlib.log import set_log_basic_config | ||||
| from qlib.log import get_module_logger | ||||
|  | ||||
|  | ||||
| def update_gpu(config, gpu): | ||||
|     config = config.copy() | ||||
|     if "task" in config and "GPU" in config["task"]["model"]: | ||||
|     if "task" in config and "moodel" in config["task"] and "GPU" in config["task"]["model"]: | ||||
|         config["task"]["model"]["GPU"] = gpu | ||||
|     elif "model" in config and "GPU" in config["model"]: | ||||
|         config["model"]["GPU"] = gpu | ||||
| @@ -29,11 +30,6 @@ def update_market(config, market): | ||||
|  | ||||
| def run_exp(task_config, dataset, experiment_name, recorder_name, uri): | ||||
|  | ||||
|     # model initiaiton | ||||
|     print("") | ||||
|     print("[{:}] - [{:}]: {:}".format(experiment_name, recorder_name, uri)) | ||||
|     print("dataset={:}".format(dataset)) | ||||
|  | ||||
|     model = init_instance_by_config(task_config["model"]) | ||||
|  | ||||
|     # start exp | ||||
| @@ -41,6 +37,10 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri): | ||||
|  | ||||
|         log_file = R.get_recorder().root_uri / "{:}.log".format(experiment_name) | ||||
|         set_log_basic_config(log_file) | ||||
|         logger = get_module_logger("q.run_exp") | ||||
|         logger.info("task_config={:}".format(task_config)) | ||||
|         logger.info("[{:}] - [{:}]: {:}".format(experiment_name, recorder_name, uri)) | ||||
|         logger.info("dataset={:}".format(dataset)) | ||||
|  | ||||
|         # train model | ||||
|         R.log_params(**flatten_dict(task_config)) | ||||
|   | ||||
| @@ -3,124 +3,170 @@ | ||||
| ################################################## | ||||
| import os, sys, time, torch | ||||
| from log_utils import AverageMeter, time_string | ||||
| from utils     import obtain_accuracy | ||||
| from models    import change_key | ||||
| from utils import obtain_accuracy | ||||
| from models import change_key | ||||
|  | ||||
|  | ||||
| def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant): | ||||
|   expected_flop = torch.mean( expected_flop ) | ||||
|     expected_flop = torch.mean(expected_flop) | ||||
|  | ||||
|   if flop_cur < flop_need - flop_tolerant:   # Too Small FLOP | ||||
|     loss = - torch.log( expected_flop ) | ||||
|   #elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP | ||||
|   elif flop_cur > flop_need: # Too Large FLOP | ||||
|     loss = torch.log( expected_flop ) | ||||
|   else: # Required FLOP | ||||
|     loss = None | ||||
|   if loss is None: return 0, 0 | ||||
|   else           : return loss, loss.item() | ||||
|     if flop_cur < flop_need - flop_tolerant:  # Too Small FLOP | ||||
|         loss = -torch.log(expected_flop) | ||||
|     # elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP | ||||
|     elif flop_cur > flop_need:  # Too Large FLOP | ||||
|         loss = torch.log(expected_flop) | ||||
|     else:  # Required FLOP | ||||
|         loss = None | ||||
|     if loss is None: | ||||
|         return 0, 0 | ||||
|     else: | ||||
|         return loss, loss.item() | ||||
|  | ||||
|  | ||||
| def search_train(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, extra_info, print_freq, logger): | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter() | ||||
|   epoch_str, flop_need, flop_weight, flop_tolerant = extra_info['epoch-str'], extra_info['FLOP-exp'], extra_info['FLOP-weight'], extra_info['FLOP-tolerant'] | ||||
| def search_train( | ||||
|     search_loader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     base_optimizer, | ||||
|     arch_optimizer, | ||||
|     optim_config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter() | ||||
|     epoch_str, flop_need, flop_weight, flop_tolerant = ( | ||||
|         extra_info["epoch-str"], | ||||
|         extra_info["FLOP-exp"], | ||||
|         extra_info["FLOP-weight"], | ||||
|         extra_info["FLOP-tolerant"], | ||||
|     ) | ||||
|  | ||||
|   network.train() | ||||
|   logger.log('[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(epoch_str, flop_need, flop_weight)) | ||||
|   end = time.time() | ||||
|   network.apply( change_key('search_mode', 'search') ) | ||||
|   for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader): | ||||
|     scheduler.update(None, 1.0 * step / len(search_loader)) | ||||
|     # calculate prediction and loss | ||||
|     base_targets = base_targets.cuda(non_blocking=True) | ||||
|     arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|     # measure data loading time | ||||
|     data_time.update(time.time() - end) | ||||
|      | ||||
|     # update the weights | ||||
|     base_optimizer.zero_grad() | ||||
|     logits, expected_flop = network(base_inputs) | ||||
|     #network.apply( change_key('search_mode', 'basic') ) | ||||
|     #features, logits = network(base_inputs) | ||||
|     base_loss = criterion(logits, base_targets) | ||||
|     base_loss.backward() | ||||
|     base_optimizer.step() | ||||
|     # record | ||||
|     prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) | ||||
|     base_losses.update(base_loss.item(), base_inputs.size(0)) | ||||
|     top1.update       (prec1.item(), base_inputs.size(0)) | ||||
|     top5.update       (prec5.item(), base_inputs.size(0)) | ||||
|  | ||||
|     # update the architecture | ||||
|     arch_optimizer.zero_grad() | ||||
|     logits, expected_flop = network(arch_inputs) | ||||
|     flop_cur  = network.module.get_flop('genotype', None, None) | ||||
|     flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant) | ||||
|     acls_loss = criterion(logits, arch_targets) | ||||
|     arch_loss = acls_loss + flop_loss * flop_weight | ||||
|     arch_loss.backward() | ||||
|     arch_optimizer.step() | ||||
|    | ||||
|     # record | ||||
|     arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|     arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0)) | ||||
|     arch_cls_losses.update (acls_loss.item(), arch_inputs.size(0)) | ||||
|      | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - end) | ||||
|     network.train() | ||||
|     logger.log("[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(epoch_str, flop_need, flop_weight)) | ||||
|     end = time.time() | ||||
|     if step % print_freq == 0 or (step+1) == len(search_loader): | ||||
|       Sstr = '**TRAIN** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(search_loader)) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=base_losses, top1=top1, top5=top5) | ||||
|       Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses) | ||||
|       logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr) | ||||
|       #Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size())) | ||||
|       #logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr) | ||||
|       #print(network.module.get_arch_info()) | ||||
|       #print(network.module.width_attentions[0]) | ||||
|       #print(network.module.width_attentions[1]) | ||||
|     network.apply(change_key("search_mode", "search")) | ||||
|     for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader): | ||||
|         scheduler.update(None, 1.0 * step / len(search_loader)) | ||||
|         # calculate prediction and loss | ||||
|         base_targets = base_targets.cuda(non_blocking=True) | ||||
|         arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|  | ||||
|   logger.log(' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, baseloss=base_losses.avg, archloss=arch_losses.avg)) | ||||
|   return base_losses.avg, arch_losses.avg, top1.avg, top5.avg | ||||
|         # update the weights | ||||
|         base_optimizer.zero_grad() | ||||
|         logits, expected_flop = network(base_inputs) | ||||
|         # network.apply( change_key('search_mode', 'basic') ) | ||||
|         # features, logits = network(base_inputs) | ||||
|         base_loss = criterion(logits, base_targets) | ||||
|         base_loss.backward() | ||||
|         base_optimizer.step() | ||||
|         # record | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) | ||||
|         base_losses.update(base_loss.item(), base_inputs.size(0)) | ||||
|         top1.update(prec1.item(), base_inputs.size(0)) | ||||
|         top5.update(prec5.item(), base_inputs.size(0)) | ||||
|  | ||||
|         # update the architecture | ||||
|         arch_optimizer.zero_grad() | ||||
|         logits, expected_flop = network(arch_inputs) | ||||
|         flop_cur = network.module.get_flop("genotype", None, None) | ||||
|         flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant) | ||||
|         acls_loss = criterion(logits, arch_targets) | ||||
|         arch_loss = acls_loss + flop_loss * flop_weight | ||||
|         arch_loss.backward() | ||||
|         arch_optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|         arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0)) | ||||
|         arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|         if step % print_freq == 0 or (step + 1) == len(search_loader): | ||||
|             Sstr = "**TRAIN** " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader)) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Lstr = "Base-Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                 loss=base_losses, top1=top1, top5=top5 | ||||
|             ) | ||||
|             Vstr = "Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})".format( | ||||
|                 aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses | ||||
|             ) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Vstr) | ||||
|             # Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size())) | ||||
|             # logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr) | ||||
|             # print(network.module.get_arch_info()) | ||||
|             # print(network.module.width_attentions[0]) | ||||
|             # print(network.module.width_attentions[1]) | ||||
|  | ||||
|     logger.log( | ||||
|         " **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}".format( | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             baseloss=base_losses.avg, | ||||
|             archloss=arch_losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|     return base_losses.avg, arch_losses.avg, top1.avg, top5.avg | ||||
|  | ||||
|  | ||||
| def search_valid(xloader, network, criterion, extra_info, print_freq, logger): | ||||
|   data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     data_time, batch_time, losses, top1, top5 = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|     ) | ||||
|  | ||||
|   network.eval() | ||||
|   network.apply( change_key('search_mode', 'search') ) | ||||
|   end = time.time() | ||||
|   #logger.log('Starting evaluating {:}'.format(epoch_info)) | ||||
|   with torch.no_grad(): | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|       # measure data loading time | ||||
|       data_time.update(time.time() - end) | ||||
|       # calculate prediction and loss | ||||
|       targets = targets.cuda(non_blocking=True) | ||||
|     network.eval() | ||||
|     network.apply(change_key("search_mode", "search")) | ||||
|     end = time.time() | ||||
|     # logger.log('Starting evaluating {:}'.format(epoch_info)) | ||||
|     with torch.no_grad(): | ||||
|         for i, (inputs, targets) in enumerate(xloader): | ||||
|             # measure data loading time | ||||
|             data_time.update(time.time() - end) | ||||
|             # calculate prediction and loss | ||||
|             targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|       logits, expected_flop = network(inputs) | ||||
|       loss             = criterion(logits, targets) | ||||
|       # record | ||||
|       prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|       losses.update(loss.item(),  inputs.size(0)) | ||||
|       top1.update  (prec1.item(), inputs.size(0)) | ||||
|       top5.update  (prec5.item(), inputs.size(0)) | ||||
|             logits, expected_flop = network(inputs) | ||||
|             loss = criterion(logits, targets) | ||||
|             # record | ||||
|             prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|             losses.update(loss.item(), inputs.size(0)) | ||||
|             top1.update(prec1.item(), inputs.size(0)) | ||||
|             top5.update(prec5.item(), inputs.size(0)) | ||||
|  | ||||
|       # measure elapsed time | ||||
|       batch_time.update(time.time() - end) | ||||
|       end = time.time() | ||||
|             # measure elapsed time | ||||
|             batch_time.update(time.time() - end) | ||||
|             end = time.time() | ||||
|  | ||||
|       if i % print_freq == 0 or (i+1) == len(xloader): | ||||
|         Sstr = '**VALID** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader)) | ||||
|         Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|         Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5) | ||||
|         Istr = 'Size={:}'.format(list(inputs.size())) | ||||
|         logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr) | ||||
|             if i % print_freq == 0 or (i + 1) == len(xloader): | ||||
|                 Sstr = "**VALID** " + time_string() + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) | ||||
|                 Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                     batch_time=batch_time, data_time=data_time | ||||
|                 ) | ||||
|                 Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                     loss=losses, top1=top1, top5=top5 | ||||
|                 ) | ||||
|                 Istr = "Size={:}".format(list(inputs.size())) | ||||
|                 logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr) | ||||
|  | ||||
|   logger.log(' **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg)) | ||||
|   | ||||
|   return losses.avg, top1.avg, top5.avg | ||||
|     logger.log( | ||||
|         " **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( | ||||
|             top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     return losses.avg, top1.avg, top5.avg | ||||
|   | ||||
| @@ -3,85 +3,118 @@ | ||||
| ################################################## | ||||
| import os, sys, time, torch | ||||
| from log_utils import AverageMeter, time_string | ||||
| from utils     import obtain_accuracy | ||||
| from models    import change_key | ||||
| from utils import obtain_accuracy | ||||
| from models import change_key | ||||
|  | ||||
|  | ||||
| def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant): | ||||
|   expected_flop = torch.mean( expected_flop ) | ||||
|     expected_flop = torch.mean(expected_flop) | ||||
|  | ||||
|   if flop_cur < flop_need - flop_tolerant:   # Too Small FLOP | ||||
|     loss = - torch.log( expected_flop ) | ||||
|   #elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP | ||||
|   elif flop_cur > flop_need: # Too Large FLOP | ||||
|     loss = torch.log( expected_flop ) | ||||
|   else: # Required FLOP | ||||
|     loss = None | ||||
|   if loss is None: return 0, 0 | ||||
|   else           : return loss, loss.item() | ||||
|     if flop_cur < flop_need - flop_tolerant:  # Too Small FLOP | ||||
|         loss = -torch.log(expected_flop) | ||||
|     # elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP | ||||
|     elif flop_cur > flop_need:  # Too Large FLOP | ||||
|         loss = torch.log(expected_flop) | ||||
|     else:  # Required FLOP | ||||
|         loss = None | ||||
|     if loss is None: | ||||
|         return 0, 0 | ||||
|     else: | ||||
|         return loss, loss.item() | ||||
|  | ||||
|  | ||||
| def search_train_v2(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, extra_info, print_freq, logger): | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter() | ||||
|   epoch_str, flop_need, flop_weight, flop_tolerant = extra_info['epoch-str'], extra_info['FLOP-exp'], extra_info['FLOP-weight'], extra_info['FLOP-tolerant'] | ||||
| def search_train_v2( | ||||
|     search_loader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     base_optimizer, | ||||
|     arch_optimizer, | ||||
|     optim_config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter() | ||||
|     epoch_str, flop_need, flop_weight, flop_tolerant = ( | ||||
|         extra_info["epoch-str"], | ||||
|         extra_info["FLOP-exp"], | ||||
|         extra_info["FLOP-weight"], | ||||
|         extra_info["FLOP-tolerant"], | ||||
|     ) | ||||
|  | ||||
|   network.train() | ||||
|   logger.log('[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(epoch_str, flop_need, flop_weight)) | ||||
|   end = time.time() | ||||
|   network.apply( change_key('search_mode', 'search') ) | ||||
|   for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader): | ||||
|     scheduler.update(None, 1.0 * step / len(search_loader)) | ||||
|     # calculate prediction and loss | ||||
|     base_targets = base_targets.cuda(non_blocking=True) | ||||
|     arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|     # measure data loading time | ||||
|     data_time.update(time.time() - end) | ||||
|      | ||||
|     # update the weights | ||||
|     base_optimizer.zero_grad() | ||||
|     logits, expected_flop = network(base_inputs) | ||||
|     base_loss = criterion(logits, base_targets) | ||||
|     base_loss.backward() | ||||
|     base_optimizer.step() | ||||
|     # record | ||||
|     prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) | ||||
|     base_losses.update(base_loss.item(), base_inputs.size(0)) | ||||
|     top1.update       (prec1.item(), base_inputs.size(0)) | ||||
|     top5.update       (prec5.item(), base_inputs.size(0)) | ||||
|  | ||||
|     # update the architecture | ||||
|     arch_optimizer.zero_grad() | ||||
|     logits, expected_flop = network(arch_inputs) | ||||
|     flop_cur  = network.module.get_flop('genotype', None, None) | ||||
|     flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant) | ||||
|     acls_loss = criterion(logits, arch_targets) | ||||
|     arch_loss = acls_loss + flop_loss * flop_weight | ||||
|     arch_loss.backward() | ||||
|     arch_optimizer.step() | ||||
|    | ||||
|     # record | ||||
|     arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|     arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0)) | ||||
|     arch_cls_losses.update (acls_loss.item(), arch_inputs.size(0)) | ||||
|      | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - end) | ||||
|     network.train() | ||||
|     logger.log("[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(epoch_str, flop_need, flop_weight)) | ||||
|     end = time.time() | ||||
|     if step % print_freq == 0 or (step+1) == len(search_loader): | ||||
|       Sstr = '**TRAIN** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(search_loader)) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=base_losses, top1=top1, top5=top5) | ||||
|       Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses) | ||||
|       logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr) | ||||
|       #num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0 | ||||
|       #logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' GPU={:.2f}MB'.format(num_bytes/1e6)) | ||||
|       #Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size())) | ||||
|       #logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr) | ||||
|       #print(network.module.get_arch_info()) | ||||
|       #print(network.module.width_attentions[0]) | ||||
|       #print(network.module.width_attentions[1]) | ||||
|     network.apply(change_key("search_mode", "search")) | ||||
|     for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader): | ||||
|         scheduler.update(None, 1.0 * step / len(search_loader)) | ||||
|         # calculate prediction and loss | ||||
|         base_targets = base_targets.cuda(non_blocking=True) | ||||
|         arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|  | ||||
|   logger.log(' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, baseloss=base_losses.avg, archloss=arch_losses.avg)) | ||||
|   return base_losses.avg, arch_losses.avg, top1.avg, top5.avg | ||||
|         # update the weights | ||||
|         base_optimizer.zero_grad() | ||||
|         logits, expected_flop = network(base_inputs) | ||||
|         base_loss = criterion(logits, base_targets) | ||||
|         base_loss.backward() | ||||
|         base_optimizer.step() | ||||
|         # record | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) | ||||
|         base_losses.update(base_loss.item(), base_inputs.size(0)) | ||||
|         top1.update(prec1.item(), base_inputs.size(0)) | ||||
|         top5.update(prec5.item(), base_inputs.size(0)) | ||||
|  | ||||
|         # update the architecture | ||||
|         arch_optimizer.zero_grad() | ||||
|         logits, expected_flop = network(arch_inputs) | ||||
|         flop_cur = network.module.get_flop("genotype", None, None) | ||||
|         flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant) | ||||
|         acls_loss = criterion(logits, arch_targets) | ||||
|         arch_loss = acls_loss + flop_loss * flop_weight | ||||
|         arch_loss.backward() | ||||
|         arch_optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|         arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0)) | ||||
|         arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|         if step % print_freq == 0 or (step + 1) == len(search_loader): | ||||
|             Sstr = "**TRAIN** " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader)) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Lstr = "Base-Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                 loss=base_losses, top1=top1, top5=top5 | ||||
|             ) | ||||
|             Vstr = "Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})".format( | ||||
|                 aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses | ||||
|             ) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Vstr) | ||||
|             # num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0 | ||||
|             # logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' GPU={:.2f}MB'.format(num_bytes/1e6)) | ||||
|             # Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size())) | ||||
|             # logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr) | ||||
|             # print(network.module.get_arch_info()) | ||||
|             # print(network.module.width_attentions[0]) | ||||
|             # print(network.module.width_attentions[1]) | ||||
|  | ||||
|     logger.log( | ||||
|         " **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}".format( | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             baseloss=base_losses.avg, | ||||
|             archloss=arch_losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|     return base_losses.avg, arch_losses.avg, top1.avg, top5.avg | ||||
|   | ||||
| @@ -3,92 +3,143 @@ | ||||
| ##################################################### | ||||
| import os, sys, time, torch | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| # our modules | ||||
| from log_utils import AverageMeter, time_string | ||||
| from utils     import obtain_accuracy | ||||
| from utils import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def simple_KD_train(xloader, teacher, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger): | ||||
|   loss, acc1, acc5 = procedure(xloader, teacher, network, criterion, scheduler, optimizer, 'train', optim_config, extra_info, print_freq, logger) | ||||
|   return loss, acc1, acc5 | ||||
| def simple_KD_train( | ||||
|     xloader, teacher, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger | ||||
| ): | ||||
|     loss, acc1, acc5 = procedure( | ||||
|         xloader, | ||||
|         teacher, | ||||
|         network, | ||||
|         criterion, | ||||
|         scheduler, | ||||
|         optimizer, | ||||
|         "train", | ||||
|         optim_config, | ||||
|         extra_info, | ||||
|         print_freq, | ||||
|         logger, | ||||
|     ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def simple_KD_valid(xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger): | ||||
|   with torch.no_grad(): | ||||
|     loss, acc1, acc5 = procedure(xloader, teacher, network, criterion, None, None, 'valid', optim_config, extra_info, print_freq, logger) | ||||
|   return loss, acc1, acc5 | ||||
|     with torch.no_grad(): | ||||
|         loss, acc1, acc5 = procedure( | ||||
|             xloader, teacher, network, criterion, None, None, "valid", optim_config, extra_info, print_freq, logger | ||||
|         ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def loss_KD_fn(criterion, student_logits, teacher_logits, studentFeatures, teacherFeatures, targets, alpha, temperature): | ||||
|   basic_loss = criterion(student_logits, targets) * (1. - alpha) | ||||
|   log_student= F.log_softmax(student_logits / temperature, dim=1) | ||||
|   sof_teacher= F.softmax    (teacher_logits / temperature, dim=1) | ||||
|   KD_loss    = F.kl_div(log_student, sof_teacher, reduction='batchmean') * (alpha * temperature * temperature) | ||||
|   return basic_loss + KD_loss | ||||
| def loss_KD_fn( | ||||
|     criterion, student_logits, teacher_logits, studentFeatures, teacherFeatures, targets, alpha, temperature | ||||
| ): | ||||
|     basic_loss = criterion(student_logits, targets) * (1.0 - alpha) | ||||
|     log_student = F.log_softmax(student_logits / temperature, dim=1) | ||||
|     sof_teacher = F.softmax(teacher_logits / temperature, dim=1) | ||||
|     KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * (alpha * temperature * temperature) | ||||
|     return basic_loss + KD_loss | ||||
|  | ||||
|  | ||||
| def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger): | ||||
|   data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   Ttop1, Ttop5 = AverageMeter(), AverageMeter() | ||||
|   if mode == 'train': | ||||
|     network.train() | ||||
|   elif mode == 'valid': | ||||
|     network.eval() | ||||
|   else: raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|   teacher.eval() | ||||
|    | ||||
|   logger.log('[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, config.KD_alpha, config.KD_temperature)) | ||||
|   end = time.time() | ||||
|   for i, (inputs, targets) in enumerate(xloader): | ||||
|     if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|     # measure data loading time | ||||
|     data_time.update(time.time() - end) | ||||
|     # calculate prediction and loss | ||||
|     targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|     if mode == 'train': optimizer.zero_grad() | ||||
|  | ||||
|     student_f, logits = network(inputs) | ||||
|     if isinstance(logits, list): | ||||
|       assert len(logits) == 2, 'logits must has {:} items instead of {:}'.format(2, len(logits)) | ||||
|       logits, logits_aux = logits | ||||
|     data_time, batch_time, losses, top1, top5 = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|     ) | ||||
|     Ttop1, Ttop5 = AverageMeter(), AverageMeter() | ||||
|     if mode == "train": | ||||
|         network.train() | ||||
|     elif mode == "valid": | ||||
|         network.eval() | ||||
|     else: | ||||
|       logits, logits_aux = logits, None | ||||
|     with torch.no_grad(): | ||||
|       teacher_f, teacher_logits = teacher(inputs) | ||||
|         raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|     teacher.eval() | ||||
|  | ||||
|     loss             = loss_KD_fn(criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature) | ||||
|     if config is not None and hasattr(config, 'auxiliary') and config.auxiliary > 0: | ||||
|       loss_aux = criterion(logits_aux, targets) | ||||
|       loss += config.auxiliary * loss_aux | ||||
|      | ||||
|     if mode == 'train': | ||||
|       loss.backward() | ||||
|       optimizer.step() | ||||
|  | ||||
|     # record | ||||
|     sprec1, sprec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|     losses.update(loss.item(),   inputs.size(0)) | ||||
|     top1.update  (sprec1.item(), inputs.size(0)) | ||||
|     top5.update  (sprec5.item(), inputs.size(0)) | ||||
|     # teacher | ||||
|     tprec1, tprec5 = obtain_accuracy(teacher_logits.data, targets.data, topk=(1, 5)) | ||||
|     Ttop1.update (tprec1.item(), inputs.size(0)) | ||||
|     Ttop5.update (tprec5.item(), inputs.size(0)) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - end) | ||||
|     logger.log( | ||||
|         "[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]".format( | ||||
|             mode, config.auxiliary if hasattr(config, "auxiliary") else -1, config.KD_alpha, config.KD_temperature | ||||
|         ) | ||||
|     ) | ||||
|     end = time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|         if mode == "train": | ||||
|             scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|         # calculate prediction and loss | ||||
|         targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|     if i % print_freq == 0 or (i+1) == len(xloader): | ||||
|       Sstr = ' {:5s} '.format(mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader)) | ||||
|       if scheduler is not None: | ||||
|         Sstr += ' {:}'.format(scheduler.get_min_info()) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5) | ||||
|       Lstr+= ' Teacher : acc@1={:.2f}, acc@5={:.2f}'.format(Ttop1.avg, Ttop5.avg) | ||||
|       Istr = 'Size={:}'.format(list(inputs.size())) | ||||
|       logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr) | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|  | ||||
|   logger.log(' **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}'.format(mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg)) | ||||
|   logger.log(' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(mode=mode.upper(), top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg)) | ||||
|   return losses.avg, top1.avg, top5.avg | ||||
|         student_f, logits = network(inputs) | ||||
|         if isinstance(logits, list): | ||||
|             assert len(logits) == 2, "logits must has {:} items instead of {:}".format(2, len(logits)) | ||||
|             logits, logits_aux = logits | ||||
|         else: | ||||
|             logits, logits_aux = logits, None | ||||
|         with torch.no_grad(): | ||||
|             teacher_f, teacher_logits = teacher(inputs) | ||||
|  | ||||
|         loss = loss_KD_fn( | ||||
|             criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature | ||||
|         ) | ||||
|         if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0: | ||||
|             loss_aux = criterion(logits_aux, targets) | ||||
|             loss += config.auxiliary * loss_aux | ||||
|  | ||||
|         if mode == "train": | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         sprec1, sprec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|         losses.update(loss.item(), inputs.size(0)) | ||||
|         top1.update(sprec1.item(), inputs.size(0)) | ||||
|         top5.update(sprec5.item(), inputs.size(0)) | ||||
|         # teacher | ||||
|         tprec1, tprec5 = obtain_accuracy(teacher_logits.data, targets.data, topk=(1, 5)) | ||||
|         Ttop1.update(tprec1.item(), inputs.size(0)) | ||||
|         Ttop5.update(tprec5.item(), inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|  | ||||
|         if i % print_freq == 0 or (i + 1) == len(xloader): | ||||
|             Sstr = ( | ||||
|                 " {:5s} ".format(mode.upper()) | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) | ||||
|             ) | ||||
|             if scheduler is not None: | ||||
|                 Sstr += " {:}".format(scheduler.get_min_info()) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                 loss=losses, top1=top1, top5=top5 | ||||
|             ) | ||||
|             Lstr += " Teacher : acc@1={:.2f}, acc@5={:.2f}".format(Ttop1.avg, Ttop5.avg) | ||||
|             Istr = "Size={:}".format(list(inputs.size())) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr) | ||||
|  | ||||
|     logger.log( | ||||
|         " **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}".format( | ||||
|             mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg | ||||
|         ) | ||||
|     ) | ||||
|     logger.log( | ||||
|         " **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( | ||||
|             mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg | ||||
|         ) | ||||
|     ) | ||||
|     return losses.avg, top1.avg, top5.avg | ||||
|   | ||||
| @@ -3,62 +3,71 @@ | ||||
| ################################################## | ||||
| import os, sys, torch, random, PIL, copy, numpy as np | ||||
| from os import path as osp | ||||
| from shutil  import copyfile | ||||
| from shutil import copyfile | ||||
|  | ||||
|  | ||||
| def prepare_seed(rand_seed): | ||||
|   random.seed(rand_seed) | ||||
|   np.random.seed(rand_seed) | ||||
|   torch.manual_seed(rand_seed) | ||||
|   torch.cuda.manual_seed(rand_seed) | ||||
|   torch.cuda.manual_seed_all(rand_seed) | ||||
|     random.seed(rand_seed) | ||||
|     np.random.seed(rand_seed) | ||||
|     torch.manual_seed(rand_seed) | ||||
|     torch.cuda.manual_seed(rand_seed) | ||||
|     torch.cuda.manual_seed_all(rand_seed) | ||||
|  | ||||
|  | ||||
| def prepare_logger(xargs): | ||||
|   args = copy.deepcopy( xargs ) | ||||
|   from log_utils import Logger | ||||
|   logger = Logger(args.save_dir, args.rand_seed) | ||||
|   logger.log('Main Function with logger : {:}'.format(logger)) | ||||
|   logger.log('Arguments : -------------------------------') | ||||
|   for name, value in args._get_kwargs(): | ||||
|     logger.log('{:16} : {:}'.format(name, value)) | ||||
|   logger.log("Python  Version  : {:}".format(sys.version.replace('\n', ' '))) | ||||
|   logger.log("Pillow  Version  : {:}".format(PIL.__version__)) | ||||
|   logger.log("PyTorch Version  : {:}".format(torch.__version__)) | ||||
|   logger.log("cuDNN   Version  : {:}".format(torch.backends.cudnn.version())) | ||||
|   logger.log("CUDA available   : {:}".format(torch.cuda.is_available())) | ||||
|   logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count())) | ||||
|   logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None')) | ||||
|   return logger | ||||
|     args = copy.deepcopy(xargs) | ||||
|     from log_utils import Logger | ||||
|  | ||||
|     logger = Logger(args.save_dir, args.rand_seed) | ||||
|     logger.log("Main Function with logger : {:}".format(logger)) | ||||
|     logger.log("Arguments : -------------------------------") | ||||
|     for name, value in args._get_kwargs(): | ||||
|         logger.log("{:16} : {:}".format(name, value)) | ||||
|     logger.log("Python  Version  : {:}".format(sys.version.replace("\n", " "))) | ||||
|     logger.log("Pillow  Version  : {:}".format(PIL.__version__)) | ||||
|     logger.log("PyTorch Version  : {:}".format(torch.__version__)) | ||||
|     logger.log("cuDNN   Version  : {:}".format(torch.backends.cudnn.version())) | ||||
|     logger.log("CUDA available   : {:}".format(torch.cuda.is_available())) | ||||
|     logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count())) | ||||
|     logger.log( | ||||
|         "CUDA_VISIBLE_DEVICES : {:}".format( | ||||
|             os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else "None" | ||||
|         ) | ||||
|     ) | ||||
|     return logger | ||||
|  | ||||
|  | ||||
| def get_machine_info(): | ||||
|   info = "Python  Version  : {:}".format(sys.version.replace('\n', ' ')) | ||||
|   info+= "\nPillow  Version  : {:}".format(PIL.__version__) | ||||
|   info+= "\nPyTorch Version  : {:}".format(torch.__version__) | ||||
|   info+= "\ncuDNN   Version  : {:}".format(torch.backends.cudnn.version()) | ||||
|   info+= "\nCUDA available   : {:}".format(torch.cuda.is_available()) | ||||
|   info+= "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count()) | ||||
|   if 'CUDA_VISIBLE_DEVICES' in os.environ: | ||||
|     info+= "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ['CUDA_VISIBLE_DEVICES']) | ||||
|   else: | ||||
|     info+= "\nDoes not set CUDA_VISIBLE_DEVICES" | ||||
|   return info | ||||
|     info = "Python  Version  : {:}".format(sys.version.replace("\n", " ")) | ||||
|     info += "\nPillow  Version  : {:}".format(PIL.__version__) | ||||
|     info += "\nPyTorch Version  : {:}".format(torch.__version__) | ||||
|     info += "\ncuDNN   Version  : {:}".format(torch.backends.cudnn.version()) | ||||
|     info += "\nCUDA available   : {:}".format(torch.cuda.is_available()) | ||||
|     info += "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count()) | ||||
|     if "CUDA_VISIBLE_DEVICES" in os.environ: | ||||
|         info += "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ["CUDA_VISIBLE_DEVICES"]) | ||||
|     else: | ||||
|         info += "\nDoes not set CUDA_VISIBLE_DEVICES" | ||||
|     return info | ||||
|  | ||||
|  | ||||
| def save_checkpoint(state, filename, logger): | ||||
|   if osp.isfile(filename): | ||||
|     if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(filename)) | ||||
|     os.remove(filename) | ||||
|   torch.save(state, filename) | ||||
|   assert osp.isfile(filename), 'save filename : {:} failed, which is not found.'.format(filename) | ||||
|   if hasattr(logger, 'log'): logger.log('save checkpoint into {:}'.format(filename)) | ||||
|   return filename | ||||
|     if osp.isfile(filename): | ||||
|         if hasattr(logger, "log"): | ||||
|             logger.log("Find {:} exist, delete is at first before saving".format(filename)) | ||||
|         os.remove(filename) | ||||
|     torch.save(state, filename) | ||||
|     assert osp.isfile(filename), "save filename : {:} failed, which is not found.".format(filename) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log("save checkpoint into {:}".format(filename)) | ||||
|     return filename | ||||
|  | ||||
|  | ||||
| def copy_checkpoint(src, dst, logger): | ||||
|   if osp.isfile(dst): | ||||
|     if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(dst)) | ||||
|     os.remove(dst) | ||||
|   copyfile(src, dst) | ||||
|   if hasattr(logger, 'log'): logger.log('copy the file from {:} into {:}'.format(src, dst)) | ||||
|     if osp.isfile(dst): | ||||
|         if hasattr(logger, "log"): | ||||
|             logger.log("Find {:} exist, delete is at first before saving".format(dst)) | ||||
|         os.remove(dst) | ||||
|     copyfile(src, dst) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log("copy the file from {:} into {:}".format(src, dst)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user