add autodl
This commit is contained in:
		
							
								
								
									
										38
									
								
								AutoDL-Projects/xautodl/procedures/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								AutoDL-Projects/xautodl/procedures/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,38 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ###################################################################### | ||||
| # This folder is deprecated, which is re-organized in "xalgorithms". # | ||||
| ###################################################################### | ||||
| from .starts import prepare_seed | ||||
| from .starts import prepare_logger | ||||
| from .starts import get_machine_info | ||||
| from .starts import save_checkpoint | ||||
| from .starts import copy_checkpoint | ||||
| from .optimizers import get_optim_scheduler | ||||
| from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed | ||||
| from .funcs_nasbench import pure_evaluate as bench_pure_evaluate | ||||
| from .funcs_nasbench import get_nas_bench_loaders | ||||
|  | ||||
|  | ||||
| def get_procedures(procedure): | ||||
|     from .basic_main import basic_train, basic_valid | ||||
|     from .search_main import search_train, search_valid | ||||
|     from .search_main_v2 import search_train_v2 | ||||
|     from .simple_KD_main import simple_KD_train, simple_KD_valid | ||||
|  | ||||
|     train_funcs = { | ||||
|         "basic": basic_train, | ||||
|         "search": search_train, | ||||
|         "Simple-KD": simple_KD_train, | ||||
|         "search-v2": search_train_v2, | ||||
|     } | ||||
|     valid_funcs = { | ||||
|         "basic": basic_valid, | ||||
|         "search": search_valid, | ||||
|         "Simple-KD": simple_KD_valid, | ||||
|         "search-v2": search_valid, | ||||
|     } | ||||
|  | ||||
|     train_func = train_funcs[procedure] | ||||
|     valid_func = valid_funcs[procedure] | ||||
|     return train_func, valid_func | ||||
							
								
								
									
										99
									
								
								AutoDL-Projects/xautodl/procedures/advanced_main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								AutoDL-Projects/xautodl/procedures/advanced_main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,99 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 # | ||||
| ##################################################### | ||||
| # To be finished. | ||||
| # | ||||
| import os, sys, time, torch | ||||
| from typing import Optional, Text, Callable | ||||
|  | ||||
| # modules in AutoDL | ||||
| from xautodl.log_utils import AverageMeter, time_string | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def get_device(tensors): | ||||
|     if isinstance(tensors, (list, tuple)): | ||||
|         return get_device(tensors[0]) | ||||
|     elif isinstance(tensors, dict): | ||||
|         for key, value in tensors.items(): | ||||
|             return get_device(value) | ||||
|     else: | ||||
|         return tensors.device | ||||
|  | ||||
|  | ||||
| def basic_train_fn( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     optimizer, | ||||
|     metric, | ||||
|     logger, | ||||
| ): | ||||
|     results = procedure( | ||||
|         xloader, | ||||
|         network, | ||||
|         criterion, | ||||
|         optimizer, | ||||
|         metric, | ||||
|         "train", | ||||
|         logger, | ||||
|     ) | ||||
|     return results | ||||
|  | ||||
|  | ||||
| def basic_eval_fn(xloader, network, metric, logger): | ||||
|     with torch.no_grad(): | ||||
|         results = procedure( | ||||
|             xloader, | ||||
|             network, | ||||
|             None, | ||||
|             None, | ||||
|             metric, | ||||
|             "valid", | ||||
|             logger, | ||||
|         ) | ||||
|     return results | ||||
|  | ||||
|  | ||||
| def procedure( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     optimizer, | ||||
|     metric, | ||||
|     mode: Text, | ||||
|     logger_fn: Callable = None, | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     if mode.lower() == "train": | ||||
|         network.train() | ||||
|     elif mode.lower() == "valid": | ||||
|         network.eval() | ||||
|     else: | ||||
|         raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|  | ||||
|     end = time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|         # calculate prediction and loss | ||||
|  | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|  | ||||
|         outputs = network(inputs) | ||||
|         targets = targets.to(get_device(outputs)) | ||||
|  | ||||
|         if mode == "train": | ||||
|             loss = criterion(outputs, targets) | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         with torch.no_grad(): | ||||
|             results = metric(outputs, targets) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|     return metric.get_info() | ||||
							
								
								
									
										154
									
								
								AutoDL-Projects/xautodl/procedures/basic_main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								AutoDL-Projects/xautodl/procedures/basic_main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,154 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, time, torch | ||||
|  | ||||
| # modules in AutoDL | ||||
| from xautodl.log_utils import AverageMeter, time_string | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def basic_train( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     optimizer, | ||||
|     optim_config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     loss, acc1, acc5 = procedure( | ||||
|         xloader, | ||||
|         network, | ||||
|         criterion, | ||||
|         scheduler, | ||||
|         optimizer, | ||||
|         "train", | ||||
|         optim_config, | ||||
|         extra_info, | ||||
|         print_freq, | ||||
|         logger, | ||||
|     ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def basic_valid( | ||||
|     xloader, network, criterion, optim_config, extra_info, print_freq, logger | ||||
| ): | ||||
|     with torch.no_grad(): | ||||
|         loss, acc1, acc5 = procedure( | ||||
|             xloader, | ||||
|             network, | ||||
|             criterion, | ||||
|             None, | ||||
|             None, | ||||
|             "valid", | ||||
|             None, | ||||
|             extra_info, | ||||
|             print_freq, | ||||
|             logger, | ||||
|         ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def procedure( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     optimizer, | ||||
|     mode, | ||||
|     config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time, losses, top1, top5 = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|     ) | ||||
|     if mode == "train": | ||||
|         network.train() | ||||
|     elif mode == "valid": | ||||
|         network.eval() | ||||
|     else: | ||||
|         raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|  | ||||
|     # logger.log('[{:5s}] config ::  auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message())) | ||||
|     logger.log( | ||||
|         "[{:5s}] config ::  auxiliary={:}".format( | ||||
|             mode, config.auxiliary if hasattr(config, "auxiliary") else -1 | ||||
|         ) | ||||
|     ) | ||||
|     end = time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|         if mode == "train": | ||||
|             scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|         # calculate prediction and loss | ||||
|         targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|  | ||||
|         features, logits = network(inputs) | ||||
|         if isinstance(logits, list): | ||||
|             assert len(logits) == 2, "logits must has {:} items instead of {:}".format( | ||||
|                 2, len(logits) | ||||
|             ) | ||||
|             logits, logits_aux = logits | ||||
|         else: | ||||
|             logits, logits_aux = logits, None | ||||
|         loss = criterion(logits, targets) | ||||
|         if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0: | ||||
|             loss_aux = criterion(logits_aux, targets) | ||||
|             loss += config.auxiliary * loss_aux | ||||
|  | ||||
|         if mode == "train": | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|         losses.update(loss.item(), inputs.size(0)) | ||||
|         top1.update(prec1.item(), inputs.size(0)) | ||||
|         top5.update(prec5.item(), inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|  | ||||
|         if i % print_freq == 0 or (i + 1) == len(xloader): | ||||
|             Sstr = ( | ||||
|                 " {:5s} ".format(mode.upper()) | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) | ||||
|             ) | ||||
|             if scheduler is not None: | ||||
|                 Sstr += " {:}".format(scheduler.get_min_info()) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                 loss=losses, top1=top1, top5=top5 | ||||
|             ) | ||||
|             Istr = "Size={:}".format(list(inputs.size())) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr) | ||||
|  | ||||
|     logger.log( | ||||
|         " **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( | ||||
|             mode=mode.upper(), | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             loss=losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|     return losses.avg, top1.avg, top5.avg | ||||
							
								
								
									
										20
									
								
								AutoDL-Projects/xautodl/procedures/eval_funcs.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								AutoDL-Projects/xautodl/procedures/eval_funcs.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 # | ||||
| ##################################################### | ||||
| import abc | ||||
|  | ||||
|  | ||||
| def obtain_accuracy(output, target, topk=(1,)): | ||||
|     """Computes the precision@k for the specified values of k""" | ||||
|     maxk = max(topk) | ||||
|     batch_size = target.size(0) | ||||
|  | ||||
|     _, pred = output.topk(maxk, 1, True, True) | ||||
|     pred = pred.t() | ||||
|     correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||||
|  | ||||
|     res = [] | ||||
|     for k in topk: | ||||
|         correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) | ||||
|         res.append(correct_k.mul_(100.0 / batch_size)) | ||||
|     return res | ||||
							
								
								
									
										437
									
								
								AutoDL-Projects/xautodl/procedures/funcs_nasbench.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										437
									
								
								AutoDL-Projects/xautodl/procedures/funcs_nasbench.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,437 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||
| ##################################################### | ||||
| import os, time, copy, torch, pathlib | ||||
|  | ||||
| from xautodl import datasets | ||||
| from xautodl.config_utils import load_config | ||||
| from xautodl.procedures import prepare_seed, get_optim_scheduler | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import get_cell_based_tiny_net | ||||
| from xautodl.utils import get_model_infos | ||||
| from xautodl.procedures.eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| __all__ = ["evaluate_for_seed", "pure_evaluate", "get_nas_bench_loaders"] | ||||
|  | ||||
|  | ||||
| def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | ||||
|     data_time, batch_time, batch = AverageMeter(), AverageMeter(), None | ||||
|     losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     latencies, device = [], torch.cuda.current_device() | ||||
|     network.eval() | ||||
|     with torch.no_grad(): | ||||
|         end = time.time() | ||||
|         for i, (inputs, targets) in enumerate(xloader): | ||||
|             targets = targets.cuda(device=device, non_blocking=True) | ||||
|             inputs = inputs.cuda(device=device, non_blocking=True) | ||||
|             data_time.update(time.time() - end) | ||||
|             # forward | ||||
|             features, logits = network(inputs) | ||||
|             loss = criterion(logits, targets) | ||||
|             batch_time.update(time.time() - end) | ||||
|             if batch is None or batch == inputs.size(0): | ||||
|                 batch = inputs.size(0) | ||||
|                 latencies.append(batch_time.val - data_time.val) | ||||
|             # record loss and accuracy | ||||
|             prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|             losses.update(loss.item(), inputs.size(0)) | ||||
|             top1.update(prec1.item(), inputs.size(0)) | ||||
|             top5.update(prec5.item(), inputs.size(0)) | ||||
|             end = time.time() | ||||
|     if len(latencies) > 2: | ||||
|         latencies = latencies[1:] | ||||
|     return losses.avg, top1.avg, top5.avg, latencies | ||||
|  | ||||
|  | ||||
| def procedure(xloader, network, criterion, scheduler, optimizer, mode: str): | ||||
|     losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     if mode == "train": | ||||
|         network.train() | ||||
|     elif mode == "valid": | ||||
|         network.eval() | ||||
|     else: | ||||
|         raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|     device = torch.cuda.current_device() | ||||
|     data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|         if mode == "train": | ||||
|             scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|  | ||||
|         targets = targets.cuda(device=device, non_blocking=True) | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|         # forward | ||||
|         features, logits = network(inputs) | ||||
|         loss = criterion(logits, targets) | ||||
|         # backward | ||||
|         if mode == "train": | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|         # record loss and accuracy | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|         losses.update(loss.item(), inputs.size(0)) | ||||
|         top1.update(prec1.item(), inputs.size(0)) | ||||
|         top5.update(prec5.item(), inputs.size(0)) | ||||
|         # count time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|     return losses.avg, top1.avg, top5.avg, batch_time.sum | ||||
|  | ||||
|  | ||||
| def evaluate_for_seed( | ||||
|     arch_config, opt_config, train_loader, valid_loaders, seed: int, logger | ||||
| ): | ||||
|     """A modular function to train and evaluate a single network, using the given random seed and optimization config with the provided loaders.""" | ||||
|     prepare_seed(seed)  # random seed | ||||
|     net = get_cell_based_tiny_net(arch_config) | ||||
|     # net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) | ||||
|     flop, param = get_model_infos(net, opt_config.xshape) | ||||
|     logger.log("Network : {:}".format(net.get_message()), False) | ||||
|     logger.log( | ||||
|         "{:} Seed-------------------------- {:} --------------------------".format( | ||||
|             time_string(), seed | ||||
|         ) | ||||
|     ) | ||||
|     logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param)) | ||||
|     # train and valid | ||||
|     optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config) | ||||
|     default_device = torch.cuda.current_device() | ||||
|     network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda( | ||||
|         device=default_device | ||||
|     ) | ||||
|     criterion = criterion.cuda(device=default_device) | ||||
|     # start training | ||||
|     start_time, epoch_time, total_epoch = ( | ||||
|         time.time(), | ||||
|         AverageMeter(), | ||||
|         opt_config.epochs + opt_config.warmup, | ||||
|     ) | ||||
|     ( | ||||
|         train_losses, | ||||
|         train_acc1es, | ||||
|         train_acc5es, | ||||
|         valid_losses, | ||||
|         valid_acc1es, | ||||
|         valid_acc5es, | ||||
|     ) = ({}, {}, {}, {}, {}, {}) | ||||
|     train_times, valid_times, lrs = {}, {}, {} | ||||
|     for epoch in range(total_epoch): | ||||
|         scheduler.update(epoch, 0.0) | ||||
|         lr = min(scheduler.get_lr()) | ||||
|         train_loss, train_acc1, train_acc5, train_tm = procedure( | ||||
|             train_loader, network, criterion, scheduler, optimizer, "train" | ||||
|         ) | ||||
|         train_losses[epoch] = train_loss | ||||
|         train_acc1es[epoch] = train_acc1 | ||||
|         train_acc5es[epoch] = train_acc5 | ||||
|         train_times[epoch] = train_tm | ||||
|         lrs[epoch] = lr | ||||
|         with torch.no_grad(): | ||||
|             for key, xloder in valid_loaders.items(): | ||||
|                 valid_loss, valid_acc1, valid_acc5, valid_tm = procedure( | ||||
|                     xloder, network, criterion, None, None, "valid" | ||||
|                 ) | ||||
|                 valid_losses["{:}@{:}".format(key, epoch)] = valid_loss | ||||
|                 valid_acc1es["{:}@{:}".format(key, epoch)] = valid_acc1 | ||||
|                 valid_acc5es["{:}@{:}".format(key, epoch)] = valid_acc5 | ||||
|                 valid_times["{:}@{:}".format(key, epoch)] = valid_tm | ||||
|  | ||||
|         # measure elapsed time | ||||
|         epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}".format( | ||||
|                 time_string(), | ||||
|                 need_time, | ||||
|                 epoch, | ||||
|                 total_epoch, | ||||
|                 train_loss, | ||||
|                 train_acc1, | ||||
|                 train_acc5, | ||||
|                 valid_loss, | ||||
|                 valid_acc1, | ||||
|                 valid_acc5, | ||||
|                 lr, | ||||
|             ) | ||||
|         ) | ||||
|     info_seed = { | ||||
|         "flop": flop, | ||||
|         "param": param, | ||||
|         "arch_config": arch_config._asdict(), | ||||
|         "opt_config": opt_config._asdict(), | ||||
|         "total_epoch": total_epoch, | ||||
|         "train_losses": train_losses, | ||||
|         "train_acc1es": train_acc1es, | ||||
|         "train_acc5es": train_acc5es, | ||||
|         "train_times": train_times, | ||||
|         "valid_losses": valid_losses, | ||||
|         "valid_acc1es": valid_acc1es, | ||||
|         "valid_acc5es": valid_acc5es, | ||||
|         "valid_times": valid_times, | ||||
|         "learning_rates": lrs, | ||||
|         "net_state_dict": net.state_dict(), | ||||
|         "net_string": "{:}".format(net), | ||||
|         "finish-train": True, | ||||
|     } | ||||
|     return info_seed | ||||
|  | ||||
|  | ||||
| def get_nas_bench_loaders(workers): | ||||
|  | ||||
|     torch.set_num_threads(workers) | ||||
|  | ||||
|     root_dir = (pathlib.Path(__file__).parent / ".." / "..").resolve() | ||||
|     torch_dir = pathlib.Path(os.environ["TORCH_HOME"]) | ||||
|     # cifar | ||||
|     cifar_config_path = root_dir / "configs" / "nas-benchmark" / "CIFAR.config" | ||||
|     cifar_config = load_config(cifar_config_path, None, None) | ||||
|     get_datasets = datasets.get_datasets  # a function to return the dataset | ||||
|     break_line = "-" * 150 | ||||
|     print("{:} Create data-loader for all datasets".format(time_string())) | ||||
|     print(break_line) | ||||
|     TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets( | ||||
|         "cifar10", str(torch_dir / "cifar.python"), -1 | ||||
|     ) | ||||
|     print( | ||||
|         "original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes".format( | ||||
|             len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num | ||||
|         ) | ||||
|     ) | ||||
|     cifar10_splits = load_config( | ||||
|         root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None | ||||
|     ) | ||||
|     assert cifar10_splits.train[:10] == [ | ||||
|         0, | ||||
|         5, | ||||
|         7, | ||||
|         11, | ||||
|         13, | ||||
|         15, | ||||
|         16, | ||||
|         17, | ||||
|         20, | ||||
|         24, | ||||
|     ] and cifar10_splits.valid[:10] == [ | ||||
|         1, | ||||
|         2, | ||||
|         3, | ||||
|         4, | ||||
|         6, | ||||
|         8, | ||||
|         9, | ||||
|         10, | ||||
|         12, | ||||
|         14, | ||||
|     ] | ||||
|     temp_dataset = copy.deepcopy(TRAIN_CIFAR10) | ||||
|     temp_dataset.transform = VALID_CIFAR10.transform | ||||
|     # data loader | ||||
|     trainval_cifar10_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_CIFAR10, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         shuffle=True, | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     train_cifar10_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_CIFAR10, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     valid_cifar10_loader = torch.utils.data.DataLoader( | ||||
|         temp_dataset, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     test__cifar10_loader = torch.utils.data.DataLoader( | ||||
|         VALID_CIFAR10, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         shuffle=False, | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-10  : trval-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(trainval_cifar10_loader), cifar_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-10  : train-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(train_cifar10_loader), cifar_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-10  : valid-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(valid_cifar10_loader), cifar_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-10  : test--loader has {:3d} batch with {:} per batch".format( | ||||
|             len(test__cifar10_loader), cifar_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print(break_line) | ||||
|     # CIFAR-100 | ||||
|     TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets( | ||||
|         "cifar100", str(torch_dir / "cifar.python"), -1 | ||||
|     ) | ||||
|     print( | ||||
|         "original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes".format( | ||||
|             len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num | ||||
|         ) | ||||
|     ) | ||||
|     cifar100_splits = load_config( | ||||
|         root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None | ||||
|     ) | ||||
|     assert cifar100_splits.xvalid[:10] == [ | ||||
|         1, | ||||
|         3, | ||||
|         4, | ||||
|         5, | ||||
|         8, | ||||
|         10, | ||||
|         13, | ||||
|         14, | ||||
|         15, | ||||
|         16, | ||||
|     ] and cifar100_splits.xtest[:10] == [ | ||||
|         0, | ||||
|         2, | ||||
|         6, | ||||
|         7, | ||||
|         9, | ||||
|         11, | ||||
|         12, | ||||
|         17, | ||||
|         20, | ||||
|         24, | ||||
|     ] | ||||
|     train_cifar100_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_CIFAR100, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         shuffle=True, | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     valid_cifar100_loader = torch.utils.data.DataLoader( | ||||
|         VALID_CIFAR100, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     test__cifar100_loader = torch.utils.data.DataLoader( | ||||
|         VALID_CIFAR100, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-100  : train-loader has {:3d} batch".format(len(train_cifar100_loader)) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-100  : valid-loader has {:3d} batch".format(len(valid_cifar100_loader)) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-100  : test--loader has {:3d} batch".format(len(test__cifar100_loader)) | ||||
|     ) | ||||
|     print(break_line) | ||||
|  | ||||
|     imagenet16_config_path = "configs/nas-benchmark/ImageNet-16.config" | ||||
|     imagenet16_config = load_config(imagenet16_config_path, None, None) | ||||
|     TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets( | ||||
|         "ImageNet16-120", str(torch_dir / "cifar.python" / "ImageNet16"), -1 | ||||
|     ) | ||||
|     print( | ||||
|         "original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes".format( | ||||
|             len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num | ||||
|         ) | ||||
|     ) | ||||
|     imagenet_splits = load_config( | ||||
|         root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt", | ||||
|         None, | ||||
|         None, | ||||
|     ) | ||||
|     assert imagenet_splits.xvalid[:10] == [ | ||||
|         1, | ||||
|         2, | ||||
|         3, | ||||
|         6, | ||||
|         7, | ||||
|         8, | ||||
|         9, | ||||
|         12, | ||||
|         16, | ||||
|         18, | ||||
|     ] and imagenet_splits.xtest[:10] == [ | ||||
|         0, | ||||
|         4, | ||||
|         5, | ||||
|         10, | ||||
|         11, | ||||
|         13, | ||||
|         14, | ||||
|         15, | ||||
|         17, | ||||
|         20, | ||||
|     ] | ||||
|     train_imagenet_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_ImageNet16_120, | ||||
|         batch_size=imagenet16_config.batch_size, | ||||
|         shuffle=True, | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     valid_imagenet_loader = torch.utils.data.DataLoader( | ||||
|         VALID_ImageNet16_120, | ||||
|         batch_size=imagenet16_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     test__imagenet_loader = torch.utils.data.DataLoader( | ||||
|         VALID_ImageNet16_120, | ||||
|         batch_size=imagenet16_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     print( | ||||
|         "ImageNet-16-120  : train-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(train_imagenet_loader), imagenet16_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "ImageNet-16-120  : valid-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(valid_imagenet_loader), imagenet16_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "ImageNet-16-120  : test--loader has {:3d} batch with {:} per batch".format( | ||||
|             len(test__imagenet_loader), imagenet16_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     # 'cifar10', 'cifar100', 'ImageNet16-120' | ||||
|     loaders = { | ||||
|         "cifar10@trainval": trainval_cifar10_loader, | ||||
|         "cifar10@train": train_cifar10_loader, | ||||
|         "cifar10@valid": valid_cifar10_loader, | ||||
|         "cifar10@test": test__cifar10_loader, | ||||
|         "cifar100@train": train_cifar100_loader, | ||||
|         "cifar100@valid": valid_cifar100_loader, | ||||
|         "cifar100@test": test__cifar100_loader, | ||||
|         "ImageNet16-120@train": train_imagenet_loader, | ||||
|         "ImageNet16-120@valid": valid_imagenet_loader, | ||||
|         "ImageNet16-120@test": test__imagenet_loader, | ||||
|     } | ||||
|     return loaders | ||||
							
								
								
									
										166
									
								
								AutoDL-Projects/xautodl/procedures/metric_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										166
									
								
								AutoDL-Projects/xautodl/procedures/metric_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,166 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 # | ||||
| ##################################################### | ||||
| import abc | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
|  | ||||
| class AverageMeter(object): | ||||
|     """Computes and stores the average and current value""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.reset() | ||||
|  | ||||
|     def reset(self): | ||||
|         self.val = 0.0 | ||||
|         self.avg = 0.0 | ||||
|         self.sum = 0.0 | ||||
|         self.count = 0.0 | ||||
|  | ||||
|     def update(self, val, n=1): | ||||
|         self.val = val | ||||
|         self.sum += val * n | ||||
|         self.count += n | ||||
|         self.avg = self.sum / self.count | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(val={val}, avg={avg}, count={count})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Metric(abc.ABC): | ||||
|     """The default meta metric class.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.reset() | ||||
|  | ||||
|     def reset(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({inner})".format( | ||||
|             name=self.__class__.__name__, inner=self.inner_repr() | ||||
|         ) | ||||
|  | ||||
|     def inner_repr(self): | ||||
|         return "" | ||||
|  | ||||
|  | ||||
| class ComposeMetric(Metric): | ||||
|     """The composed metric class.""" | ||||
|  | ||||
|     def __init__(self, *metric_list): | ||||
|         self.reset() | ||||
|         for metric in metric_list: | ||||
|             self.append(metric) | ||||
|  | ||||
|     def reset(self): | ||||
|         self._metric_list = [] | ||||
|  | ||||
|     def append(self, metric): | ||||
|         if not isinstance(metric, Metric): | ||||
|             raise ValueError( | ||||
|                 "The input metric is not correct: {:}".format(type(metric)) | ||||
|             ) | ||||
|         self._metric_list.append(metric) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._metric_list) | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         results = list() | ||||
|         for metric in self._metric_list: | ||||
|             results.append(metric(predictions, targets)) | ||||
|         return results | ||||
|  | ||||
|     def get_info(self): | ||||
|         results = dict() | ||||
|         for metric in self._metric_list: | ||||
|             for key, value in metric.get_info().items(): | ||||
|                 results[key] = value | ||||
|         return results | ||||
|  | ||||
|     def inner_repr(self): | ||||
|         xlist = [] | ||||
|         for metric in self._metric_list: | ||||
|             xlist.append(str(metric)) | ||||
|         return ",".join(xlist) | ||||
|  | ||||
|  | ||||
| class MSEMetric(Metric): | ||||
|     """The metric for mse.""" | ||||
|  | ||||
|     def __init__(self, ignore_batch): | ||||
|         super(MSEMetric, self).__init__() | ||||
|         self._ignore_batch = ignore_batch | ||||
|  | ||||
|     def reset(self): | ||||
|         self._mse = AverageMeter() | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): | ||||
|             loss = torch.nn.functional.mse_loss(predictions.data, targets.data).item() | ||||
|             if self._ignore_batch: | ||||
|                 self._mse.update(loss, 1) | ||||
|             else: | ||||
|                 self._mse.update(loss, predictions.shape[0]) | ||||
|             return loss | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         return {"mse": self._mse.avg, "score": self._mse.avg} | ||||
|  | ||||
|  | ||||
| class Top1AccMetric(Metric): | ||||
|     """The metric for the top-1 accuracy.""" | ||||
|  | ||||
|     def __init__(self, ignore_batch): | ||||
|         super(Top1AccMetric, self).__init__() | ||||
|         self._ignore_batch = ignore_batch | ||||
|  | ||||
|     def reset(self): | ||||
|         self._accuracy = AverageMeter() | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): | ||||
|             max_prob_indexes = torch.argmax(predictions, dim=-1) | ||||
|             corrects = torch.eq(max_prob_indexes, targets) | ||||
|             accuracy = corrects.float().mean().float() | ||||
|             if self._ignore_batch: | ||||
|                 self._accuracy.update(accuracy, 1) | ||||
|             else:  # [TODO] for 3-d tensor | ||||
|                 self._accuracy.update(accuracy, predictions.shape[0]) | ||||
|             return accuracy | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         return {"accuracy": self._accuracy.avg, "score": self._accuracy.avg * 100} | ||||
|  | ||||
|  | ||||
| class SaveMetric(Metric): | ||||
|     """The metric for mse.""" | ||||
|  | ||||
|     def reset(self): | ||||
|         self._predicts = [] | ||||
|  | ||||
|     def __call__(self, predictions, targets=None): | ||||
|         if isinstance(predictions, torch.Tensor): | ||||
|             predicts = predictions.cpu().numpy() | ||||
|             self._predicts.append(predicts) | ||||
|             return predicts | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         all_predicts = np.concatenate(self._predicts) | ||||
|         return {"predictions": all_predicts} | ||||
							
								
								
									
										263
									
								
								AutoDL-Projects/xautodl/procedures/optimizers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										263
									
								
								AutoDL-Projects/xautodl/procedures/optimizers.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,263 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import math, torch | ||||
| import torch.nn as nn | ||||
| from bisect import bisect_right | ||||
| from torch.optim import Optimizer | ||||
|  | ||||
|  | ||||
| class _LRScheduler(object): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs): | ||||
|         if not isinstance(optimizer, Optimizer): | ||||
|             raise TypeError("{:} is not an Optimizer".format(type(optimizer).__name__)) | ||||
|         self.optimizer = optimizer | ||||
|         for group in optimizer.param_groups: | ||||
|             group.setdefault("initial_lr", group["lr"]) | ||||
|         self.base_lrs = list( | ||||
|             map(lambda group: group["initial_lr"], optimizer.param_groups) | ||||
|         ) | ||||
|         self.max_epochs = epochs | ||||
|         self.warmup_epochs = warmup_epochs | ||||
|         self.current_epoch = 0 | ||||
|         self.current_iter = 0 | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "" | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) + ", {:})".format( | ||||
|             self.extra_repr() | ||||
|         ) | ||||
|  | ||||
|     def state_dict(self): | ||||
|         return { | ||||
|             key: value for key, value in self.__dict__.items() if key != "optimizer" | ||||
|         } | ||||
|  | ||||
|     def load_state_dict(self, state_dict): | ||||
|         self.__dict__.update(state_dict) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def get_min_info(self): | ||||
|         lrs = self.get_lr() | ||||
|         return "#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#".format( | ||||
|             min(lrs), max(lrs), self.current_epoch, self.current_iter | ||||
|         ) | ||||
|  | ||||
|     def get_min_lr(self): | ||||
|         return min(self.get_lr()) | ||||
|  | ||||
|     def update(self, cur_epoch, cur_iter): | ||||
|         if cur_epoch is not None: | ||||
|             assert ( | ||||
|                 isinstance(cur_epoch, int) and cur_epoch >= 0 | ||||
|             ), "invalid cur-epoch : {:}".format(cur_epoch) | ||||
|             self.current_epoch = cur_epoch | ||||
|         if cur_iter is not None: | ||||
|             assert ( | ||||
|                 isinstance(cur_iter, float) and cur_iter >= 0 | ||||
|             ), "invalid cur-iter : {:}".format(cur_iter) | ||||
|             self.current_iter = cur_iter | ||||
|         for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): | ||||
|             param_group["lr"] = lr | ||||
|  | ||||
|  | ||||
| class CosineAnnealingLR(_LRScheduler): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min): | ||||
|         self.T_max = T_max | ||||
|         self.eta_min = eta_min | ||||
|         super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "type={:}, T-max={:}, eta-min={:}".format( | ||||
|             "cosine", self.T_max, self.eta_min | ||||
|         ) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         lrs = [] | ||||
|         for base_lr in self.base_lrs: | ||||
|             if ( | ||||
|                 self.current_epoch >= self.warmup_epochs | ||||
|                 and self.current_epoch < self.max_epochs | ||||
|             ): | ||||
|                 last_epoch = self.current_epoch - self.warmup_epochs | ||||
|                 # if last_epoch < self.T_max: | ||||
|                 # if last_epoch < self.max_epochs: | ||||
|                 lr = ( | ||||
|                     self.eta_min | ||||
|                     + (base_lr - self.eta_min) | ||||
|                     * (1 + math.cos(math.pi * last_epoch / self.T_max)) | ||||
|                     / 2 | ||||
|                 ) | ||||
|                 # else: | ||||
|                 #  lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2 | ||||
|             elif self.current_epoch >= self.max_epochs: | ||||
|                 lr = self.eta_min | ||||
|             else: | ||||
|                 lr = ( | ||||
|                     self.current_epoch / self.warmup_epochs | ||||
|                     + self.current_iter / self.warmup_epochs | ||||
|                 ) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
|  | ||||
| class MultiStepLR(_LRScheduler): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas): | ||||
|         assert len(milestones) == len(gammas), "invalid {:} vs {:}".format( | ||||
|             len(milestones), len(gammas) | ||||
|         ) | ||||
|         self.milestones = milestones | ||||
|         self.gammas = gammas | ||||
|         super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "type={:}, milestones={:}, gammas={:}, base-lrs={:}".format( | ||||
|             "multistep", self.milestones, self.gammas, self.base_lrs | ||||
|         ) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         lrs = [] | ||||
|         for base_lr in self.base_lrs: | ||||
|             if self.current_epoch >= self.warmup_epochs: | ||||
|                 last_epoch = self.current_epoch - self.warmup_epochs | ||||
|                 idx = bisect_right(self.milestones, last_epoch) | ||||
|                 lr = base_lr | ||||
|                 for x in self.gammas[:idx]: | ||||
|                     lr *= x | ||||
|             else: | ||||
|                 lr = ( | ||||
|                     self.current_epoch / self.warmup_epochs | ||||
|                     + self.current_iter / self.warmup_epochs | ||||
|                 ) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
|  | ||||
| class ExponentialLR(_LRScheduler): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs, gamma): | ||||
|         self.gamma = gamma | ||||
|         super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "type={:}, gamma={:}, base-lrs={:}".format( | ||||
|             "exponential", self.gamma, self.base_lrs | ||||
|         ) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         lrs = [] | ||||
|         for base_lr in self.base_lrs: | ||||
|             if self.current_epoch >= self.warmup_epochs: | ||||
|                 last_epoch = self.current_epoch - self.warmup_epochs | ||||
|                 assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch) | ||||
|                 lr = base_lr * (self.gamma**last_epoch) | ||||
|             else: | ||||
|                 lr = ( | ||||
|                     self.current_epoch / self.warmup_epochs | ||||
|                     + self.current_iter / self.warmup_epochs | ||||
|                 ) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
|  | ||||
| class LinearLR(_LRScheduler): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR): | ||||
|         self.max_LR = max_LR | ||||
|         self.min_LR = min_LR | ||||
|         super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "type={:}, max_LR={:}, min_LR={:}, base-lrs={:}".format( | ||||
|             "LinearLR", self.max_LR, self.min_LR, self.base_lrs | ||||
|         ) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         lrs = [] | ||||
|         for base_lr in self.base_lrs: | ||||
|             if self.current_epoch >= self.warmup_epochs: | ||||
|                 last_epoch = self.current_epoch - self.warmup_epochs | ||||
|                 assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch) | ||||
|                 ratio = ( | ||||
|                     (self.max_LR - self.min_LR) | ||||
|                     * last_epoch | ||||
|                     / self.max_epochs | ||||
|                     / self.max_LR | ||||
|                 ) | ||||
|                 lr = base_lr * (1 - ratio) | ||||
|             else: | ||||
|                 lr = ( | ||||
|                     self.current_epoch / self.warmup_epochs | ||||
|                     + self.current_iter / self.warmup_epochs | ||||
|                 ) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
|  | ||||
| class CrossEntropyLabelSmooth(nn.Module): | ||||
|     def __init__(self, num_classes, epsilon): | ||||
|         super(CrossEntropyLabelSmooth, self).__init__() | ||||
|         self.num_classes = num_classes | ||||
|         self.epsilon = epsilon | ||||
|         self.logsoftmax = nn.LogSoftmax(dim=1) | ||||
|  | ||||
|     def forward(self, inputs, targets): | ||||
|         log_probs = self.logsoftmax(inputs) | ||||
|         targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) | ||||
|         targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes | ||||
|         loss = (-targets * log_probs).mean(0).sum() | ||||
|         return loss | ||||
|  | ||||
|  | ||||
| def get_optim_scheduler(parameters, config): | ||||
|     assert ( | ||||
|         hasattr(config, "optim") | ||||
|         and hasattr(config, "scheduler") | ||||
|         and hasattr(config, "criterion") | ||||
|     ), "config must have optim / scheduler / criterion keys instead of {:}".format( | ||||
|         config | ||||
|     ) | ||||
|     if config.optim == "SGD": | ||||
|         optim = torch.optim.SGD( | ||||
|             parameters, | ||||
|             config.LR, | ||||
|             momentum=config.momentum, | ||||
|             weight_decay=config.decay, | ||||
|             nesterov=config.nesterov, | ||||
|         ) | ||||
|     elif config.optim == "RMSprop": | ||||
|         optim = torch.optim.RMSprop( | ||||
|             parameters, config.LR, momentum=config.momentum, weight_decay=config.decay | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("invalid optim : {:}".format(config.optim)) | ||||
|  | ||||
|     if config.scheduler == "cos": | ||||
|         T_max = getattr(config, "T_max", config.epochs) | ||||
|         scheduler = CosineAnnealingLR( | ||||
|             optim, config.warmup, config.epochs, T_max, config.eta_min | ||||
|         ) | ||||
|     elif config.scheduler == "multistep": | ||||
|         scheduler = MultiStepLR( | ||||
|             optim, config.warmup, config.epochs, config.milestones, config.gammas | ||||
|         ) | ||||
|     elif config.scheduler == "exponential": | ||||
|         scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma) | ||||
|     elif config.scheduler == "linear": | ||||
|         scheduler = LinearLR( | ||||
|             optim, config.warmup, config.epochs, config.LR, config.LR_min | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("invalid scheduler : {:}".format(config.scheduler)) | ||||
|  | ||||
|     if config.criterion == "Softmax": | ||||
|         criterion = torch.nn.CrossEntropyLoss() | ||||
|     elif config.criterion == "SmoothSoftmax": | ||||
|         criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth) | ||||
|     else: | ||||
|         raise ValueError("invalid criterion : {:}".format(config.criterion)) | ||||
|     return optim, scheduler, criterion | ||||
							
								
								
									
										150
									
								
								AutoDL-Projects/xautodl/procedures/q_exps.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										150
									
								
								AutoDL-Projects/xautodl/procedures/q_exps.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,150 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||
| ##################################################### | ||||
|  | ||||
| import inspect | ||||
| import os | ||||
| import pprint | ||||
| import logging | ||||
| from copy import deepcopy | ||||
|  | ||||
| import qlib | ||||
| from qlib.utils import init_instance_by_config | ||||
| from qlib.workflow import R | ||||
| from qlib.utils import flatten_dict | ||||
| from qlib.log import get_module_logger | ||||
|  | ||||
|  | ||||
| def set_log_basic_config(filename=None, format=None, level=None): | ||||
|     """ | ||||
|     Set the basic configuration for the logging system. | ||||
|     See details at https://docs.python.org/3/library/logging.html#logging.basicConfig | ||||
|     :param filename: str or None | ||||
|         The path to save the logs. | ||||
|     :param format: the logging format | ||||
|     :param level: int | ||||
|     :return: Logger | ||||
|         Logger object. | ||||
|     """ | ||||
|     from qlib.config import C | ||||
|  | ||||
|     if level is None: | ||||
|         level = C.logging_level | ||||
|  | ||||
|     if format is None: | ||||
|         format = C.logging_config["formatters"]["logger_format"]["format"] | ||||
|  | ||||
|     # Remove all handlers associated with the root logger object. | ||||
|     for handler in logging.root.handlers[:]: | ||||
|         logging.root.removeHandler(handler) | ||||
|     logging.basicConfig(filename=filename, format=format, level=level) | ||||
|  | ||||
|  | ||||
| def update_gpu(config, gpu): | ||||
|     config = deepcopy(config) | ||||
|     if "task" in config and "model" in config["task"]: | ||||
|         if "GPU" in config["task"]["model"]: | ||||
|             config["task"]["model"]["GPU"] = gpu | ||||
|         elif ( | ||||
|             "kwargs" in config["task"]["model"] | ||||
|             and "GPU" in config["task"]["model"]["kwargs"] | ||||
|         ): | ||||
|             config["task"]["model"]["kwargs"]["GPU"] = gpu | ||||
|     elif "model" in config: | ||||
|         if "GPU" in config["model"]: | ||||
|             config["model"]["GPU"] = gpu | ||||
|         elif "kwargs" in config["model"] and "GPU" in config["model"]["kwargs"]: | ||||
|             config["model"]["kwargs"]["GPU"] = gpu | ||||
|     elif "kwargs" in config and "GPU" in config["kwargs"]: | ||||
|         config["kwargs"]["GPU"] = gpu | ||||
|     elif "GPU" in config: | ||||
|         config["GPU"] = gpu | ||||
|     return config | ||||
|  | ||||
|  | ||||
| def update_market(config, market): | ||||
|     config = deepcopy(config.copy()) | ||||
|     config["market"] = market | ||||
|     config["data_handler_config"]["instruments"] = market | ||||
|     return config | ||||
|  | ||||
|  | ||||
| def run_exp( | ||||
|     task_config, | ||||
|     dataset, | ||||
|     experiment_name, | ||||
|     recorder_name, | ||||
|     uri, | ||||
|     model_obj_name="model.pkl", | ||||
| ): | ||||
|  | ||||
|     model = init_instance_by_config(task_config["model"]) | ||||
|     model_fit_kwargs = dict(dataset=dataset) | ||||
|  | ||||
|     # Let's start the experiment. | ||||
|     with R.start( | ||||
|         experiment_name=experiment_name, | ||||
|         recorder_name=recorder_name, | ||||
|         uri=uri, | ||||
|         resume=True, | ||||
|     ): | ||||
|         # Setup log | ||||
|         recorder_root_dir = R.get_recorder().get_local_dir() | ||||
|         log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name)) | ||||
|  | ||||
|         set_log_basic_config(log_file) | ||||
|         logger = get_module_logger("q.run_exp") | ||||
|         logger.info("task_config::\n{:}".format(pprint.pformat(task_config, indent=2))) | ||||
|         logger.info("[{:}] - [{:}]: {:}".format(experiment_name, recorder_name, uri)) | ||||
|         logger.info("dataset={:}".format(dataset)) | ||||
|  | ||||
|         # Train model | ||||
|         try: | ||||
|             if hasattr(model, "to"):  # Recoverable model | ||||
|                 ori_device = model.device | ||||
|                 model = R.load_object(model_obj_name) | ||||
|                 model.to(ori_device) | ||||
|             else: | ||||
|                 model = R.load_object(model_obj_name) | ||||
|             logger.info("[Find existing object from {:}]".format(model_obj_name)) | ||||
|         except OSError: | ||||
|             R.log_params(**flatten_dict(update_gpu(task_config, None))) | ||||
|             if "save_path" in inspect.getfullargspec(model.fit).args: | ||||
|                 model_fit_kwargs["save_path"] = os.path.join( | ||||
|                     recorder_root_dir, "model.ckp" | ||||
|                 ) | ||||
|             elif "save_dir" in inspect.getfullargspec(model.fit).args: | ||||
|                 model_fit_kwargs["save_dir"] = os.path.join( | ||||
|                     recorder_root_dir, "model-ckps" | ||||
|                 ) | ||||
|             model.fit(**model_fit_kwargs) | ||||
|             # remove model to CPU for saving | ||||
|             if hasattr(model, "to"): | ||||
|                 old_device = model.device | ||||
|                 model.to("cpu") | ||||
|                 R.save_objects(**{model_obj_name: model}) | ||||
|                 model.to(old_device) | ||||
|             else: | ||||
|                 R.save_objects(**{model_obj_name: model}) | ||||
|         except Exception as e: | ||||
|             raise ValueError("Something wrong: {:}".format(e)) | ||||
|         # Get the recorder | ||||
|         recorder = R.get_recorder() | ||||
|  | ||||
|         # Generate records: prediction, backtest, and analysis | ||||
|         for record in task_config["record"]: | ||||
|             record = deepcopy(record) | ||||
|             if record["class"] == "MultiSegRecord": | ||||
|                 record["kwargs"] = dict(model=model, dataset=dataset, recorder=recorder) | ||||
|                 sr = init_instance_by_config(record) | ||||
|                 sr.generate(**record["generate_kwargs"]) | ||||
|             elif record["class"] == "SignalRecord": | ||||
|                 srconf = {"model": model, "dataset": dataset, "recorder": recorder} | ||||
|                 record["kwargs"].update(srconf) | ||||
|                 sr = init_instance_by_config(record) | ||||
|                 sr.generate() | ||||
|             else: | ||||
|                 rconf = {"recorder": recorder} | ||||
|                 record["kwargs"].update(rconf) | ||||
|                 ar = init_instance_by_config(record) | ||||
|                 ar.generate() | ||||
							
								
								
									
										199
									
								
								AutoDL-Projects/xautodl/procedures/search_main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										199
									
								
								AutoDL-Projects/xautodl/procedures/search_main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,199 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, time, torch | ||||
|  | ||||
| from xautodl.log_utils import AverageMeter, time_string | ||||
| from xautodl.models import change_key | ||||
|  | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant): | ||||
|     expected_flop = torch.mean(expected_flop) | ||||
|  | ||||
|     if flop_cur < flop_need - flop_tolerant:  # Too Small FLOP | ||||
|         loss = -torch.log(expected_flop) | ||||
|     # elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP | ||||
|     elif flop_cur > flop_need:  # Too Large FLOP | ||||
|         loss = torch.log(expected_flop) | ||||
|     else:  # Required FLOP | ||||
|         loss = None | ||||
|     if loss is None: | ||||
|         return 0, 0 | ||||
|     else: | ||||
|         return loss, loss.item() | ||||
|  | ||||
|  | ||||
| def search_train( | ||||
|     search_loader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     base_optimizer, | ||||
|     arch_optimizer, | ||||
|     optim_config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     base_losses, arch_losses, top1, top5 = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|     ) | ||||
|     arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter() | ||||
|     epoch_str, flop_need, flop_weight, flop_tolerant = ( | ||||
|         extra_info["epoch-str"], | ||||
|         extra_info["FLOP-exp"], | ||||
|         extra_info["FLOP-weight"], | ||||
|         extra_info["FLOP-tolerant"], | ||||
|     ) | ||||
|  | ||||
|     network.train() | ||||
|     logger.log( | ||||
|         "[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format( | ||||
|             epoch_str, flop_need, flop_weight | ||||
|         ) | ||||
|     ) | ||||
|     end = time.time() | ||||
|     network.apply(change_key("search_mode", "search")) | ||||
|     for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( | ||||
|         search_loader | ||||
|     ): | ||||
|         scheduler.update(None, 1.0 * step / len(search_loader)) | ||||
|         # calculate prediction and loss | ||||
|         base_targets = base_targets.cuda(non_blocking=True) | ||||
|         arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|  | ||||
|         # update the weights | ||||
|         base_optimizer.zero_grad() | ||||
|         logits, expected_flop = network(base_inputs) | ||||
|         # network.apply( change_key('search_mode', 'basic') ) | ||||
|         # features, logits = network(base_inputs) | ||||
|         base_loss = criterion(logits, base_targets) | ||||
|         base_loss.backward() | ||||
|         base_optimizer.step() | ||||
|         # record | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) | ||||
|         base_losses.update(base_loss.item(), base_inputs.size(0)) | ||||
|         top1.update(prec1.item(), base_inputs.size(0)) | ||||
|         top5.update(prec5.item(), base_inputs.size(0)) | ||||
|  | ||||
|         # update the architecture | ||||
|         arch_optimizer.zero_grad() | ||||
|         logits, expected_flop = network(arch_inputs) | ||||
|         flop_cur = network.module.get_flop("genotype", None, None) | ||||
|         flop_loss, flop_loss_scale = get_flop_loss( | ||||
|             expected_flop, flop_cur, flop_need, flop_tolerant | ||||
|         ) | ||||
|         acls_loss = criterion(logits, arch_targets) | ||||
|         arch_loss = acls_loss + flop_loss * flop_weight | ||||
|         arch_loss.backward() | ||||
|         arch_optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|         arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0)) | ||||
|         arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|         if step % print_freq == 0 or (step + 1) == len(search_loader): | ||||
|             Sstr = ( | ||||
|                 "**TRAIN** " | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader)) | ||||
|             ) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Lstr = "Base-Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                 loss=base_losses, top1=top1, top5=top5 | ||||
|             ) | ||||
|             Vstr = "Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})".format( | ||||
|                 aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses | ||||
|             ) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Vstr) | ||||
|             # Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size())) | ||||
|             # logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr) | ||||
|             # print(network.module.get_arch_info()) | ||||
|             # print(network.module.width_attentions[0]) | ||||
|             # print(network.module.width_attentions[1]) | ||||
|  | ||||
|     logger.log( | ||||
|         " **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}".format( | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             baseloss=base_losses.avg, | ||||
|             archloss=arch_losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|     return base_losses.avg, arch_losses.avg, top1.avg, top5.avg | ||||
|  | ||||
|  | ||||
| def search_valid(xloader, network, criterion, extra_info, print_freq, logger): | ||||
|     data_time, batch_time, losses, top1, top5 = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|     ) | ||||
|  | ||||
|     network.eval() | ||||
|     network.apply(change_key("search_mode", "search")) | ||||
|     end = time.time() | ||||
|     # logger.log('Starting evaluating {:}'.format(epoch_info)) | ||||
|     with torch.no_grad(): | ||||
|         for i, (inputs, targets) in enumerate(xloader): | ||||
|             # measure data loading time | ||||
|             data_time.update(time.time() - end) | ||||
|             # calculate prediction and loss | ||||
|             targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|             logits, expected_flop = network(inputs) | ||||
|             loss = criterion(logits, targets) | ||||
|             # record | ||||
|             prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|             losses.update(loss.item(), inputs.size(0)) | ||||
|             top1.update(prec1.item(), inputs.size(0)) | ||||
|             top5.update(prec5.item(), inputs.size(0)) | ||||
|  | ||||
|             # measure elapsed time | ||||
|             batch_time.update(time.time() - end) | ||||
|             end = time.time() | ||||
|  | ||||
|             if i % print_freq == 0 or (i + 1) == len(xloader): | ||||
|                 Sstr = ( | ||||
|                     "**VALID** " | ||||
|                     + time_string() | ||||
|                     + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) | ||||
|                 ) | ||||
|                 Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                     batch_time=batch_time, data_time=data_time | ||||
|                 ) | ||||
|                 Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                     loss=losses, top1=top1, top5=top5 | ||||
|                 ) | ||||
|                 Istr = "Size={:}".format(list(inputs.size())) | ||||
|                 logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr) | ||||
|  | ||||
|     logger.log( | ||||
|         " **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             loss=losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     return losses.avg, top1.avg, top5.avg | ||||
							
								
								
									
										139
									
								
								AutoDL-Projects/xautodl/procedures/search_main_v2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								AutoDL-Projects/xautodl/procedures/search_main_v2.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,139 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, time, torch | ||||
|  | ||||
| # modules in AutoDL | ||||
| from xautodl.log_utils import AverageMeter, time_string | ||||
| from xautodl.models import change_key | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant): | ||||
|     expected_flop = torch.mean(expected_flop) | ||||
|  | ||||
|     if flop_cur < flop_need - flop_tolerant:  # Too Small FLOP | ||||
|         loss = -torch.log(expected_flop) | ||||
|     # elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP | ||||
|     elif flop_cur > flop_need:  # Too Large FLOP | ||||
|         loss = torch.log(expected_flop) | ||||
|     else:  # Required FLOP | ||||
|         loss = None | ||||
|     if loss is None: | ||||
|         return 0, 0 | ||||
|     else: | ||||
|         return loss, loss.item() | ||||
|  | ||||
|  | ||||
| def search_train_v2( | ||||
|     search_loader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     base_optimizer, | ||||
|     arch_optimizer, | ||||
|     optim_config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     base_losses, arch_losses, top1, top5 = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|     ) | ||||
|     arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter() | ||||
|     epoch_str, flop_need, flop_weight, flop_tolerant = ( | ||||
|         extra_info["epoch-str"], | ||||
|         extra_info["FLOP-exp"], | ||||
|         extra_info["FLOP-weight"], | ||||
|         extra_info["FLOP-tolerant"], | ||||
|     ) | ||||
|  | ||||
|     network.train() | ||||
|     logger.log( | ||||
|         "[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format( | ||||
|             epoch_str, flop_need, flop_weight | ||||
|         ) | ||||
|     ) | ||||
|     end = time.time() | ||||
|     network.apply(change_key("search_mode", "search")) | ||||
|     for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( | ||||
|         search_loader | ||||
|     ): | ||||
|         scheduler.update(None, 1.0 * step / len(search_loader)) | ||||
|         # calculate prediction and loss | ||||
|         base_targets = base_targets.cuda(non_blocking=True) | ||||
|         arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|  | ||||
|         # update the weights | ||||
|         base_optimizer.zero_grad() | ||||
|         logits, expected_flop = network(base_inputs) | ||||
|         base_loss = criterion(logits, base_targets) | ||||
|         base_loss.backward() | ||||
|         base_optimizer.step() | ||||
|         # record | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) | ||||
|         base_losses.update(base_loss.item(), base_inputs.size(0)) | ||||
|         top1.update(prec1.item(), base_inputs.size(0)) | ||||
|         top5.update(prec5.item(), base_inputs.size(0)) | ||||
|  | ||||
|         # update the architecture | ||||
|         arch_optimizer.zero_grad() | ||||
|         logits, expected_flop = network(arch_inputs) | ||||
|         flop_cur = network.module.get_flop("genotype", None, None) | ||||
|         flop_loss, flop_loss_scale = get_flop_loss( | ||||
|             expected_flop, flop_cur, flop_need, flop_tolerant | ||||
|         ) | ||||
|         acls_loss = criterion(logits, arch_targets) | ||||
|         arch_loss = acls_loss + flop_loss * flop_weight | ||||
|         arch_loss.backward() | ||||
|         arch_optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|         arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0)) | ||||
|         arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|         if step % print_freq == 0 or (step + 1) == len(search_loader): | ||||
|             Sstr = ( | ||||
|                 "**TRAIN** " | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader)) | ||||
|             ) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Lstr = "Base-Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                 loss=base_losses, top1=top1, top5=top5 | ||||
|             ) | ||||
|             Vstr = "Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})".format( | ||||
|                 aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses | ||||
|             ) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Vstr) | ||||
|             # num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0 | ||||
|             # logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' GPU={:.2f}MB'.format(num_bytes/1e6)) | ||||
|             # Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size())) | ||||
|             # logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr) | ||||
|             # print(network.module.get_arch_info()) | ||||
|             # print(network.module.width_attentions[0]) | ||||
|             # print(network.module.width_attentions[1]) | ||||
|  | ||||
|     logger.log( | ||||
|         " **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}".format( | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             baseloss=base_losses.avg, | ||||
|             archloss=arch_losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|     return base_losses.avg, arch_losses.avg, top1.avg, top5.avg | ||||
							
								
								
									
										204
									
								
								AutoDL-Projects/xautodl/procedures/simple_KD_main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										204
									
								
								AutoDL-Projects/xautodl/procedures/simple_KD_main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,204 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import os, sys, time, torch | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| # modules in AutoDL | ||||
| from xautodl.log_utils import AverageMeter, time_string | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def simple_KD_train( | ||||
|     xloader, | ||||
|     teacher, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     optimizer, | ||||
|     optim_config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     loss, acc1, acc5 = procedure( | ||||
|         xloader, | ||||
|         teacher, | ||||
|         network, | ||||
|         criterion, | ||||
|         scheduler, | ||||
|         optimizer, | ||||
|         "train", | ||||
|         optim_config, | ||||
|         extra_info, | ||||
|         print_freq, | ||||
|         logger, | ||||
|     ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def simple_KD_valid( | ||||
|     xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger | ||||
| ): | ||||
|     with torch.no_grad(): | ||||
|         loss, acc1, acc5 = procedure( | ||||
|             xloader, | ||||
|             teacher, | ||||
|             network, | ||||
|             criterion, | ||||
|             None, | ||||
|             None, | ||||
|             "valid", | ||||
|             optim_config, | ||||
|             extra_info, | ||||
|             print_freq, | ||||
|             logger, | ||||
|         ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def loss_KD_fn( | ||||
|     criterion, | ||||
|     student_logits, | ||||
|     teacher_logits, | ||||
|     studentFeatures, | ||||
|     teacherFeatures, | ||||
|     targets, | ||||
|     alpha, | ||||
|     temperature, | ||||
| ): | ||||
|     basic_loss = criterion(student_logits, targets) * (1.0 - alpha) | ||||
|     log_student = F.log_softmax(student_logits / temperature, dim=1) | ||||
|     sof_teacher = F.softmax(teacher_logits / temperature, dim=1) | ||||
|     KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * ( | ||||
|         alpha * temperature * temperature | ||||
|     ) | ||||
|     return basic_loss + KD_loss | ||||
|  | ||||
|  | ||||
| def procedure( | ||||
|     xloader, | ||||
|     teacher, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     optimizer, | ||||
|     mode, | ||||
|     config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time, losses, top1, top5 = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|     ) | ||||
|     Ttop1, Ttop5 = AverageMeter(), AverageMeter() | ||||
|     if mode == "train": | ||||
|         network.train() | ||||
|     elif mode == "valid": | ||||
|         network.eval() | ||||
|     else: | ||||
|         raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|     teacher.eval() | ||||
|  | ||||
|     logger.log( | ||||
|         "[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]".format( | ||||
|             mode, | ||||
|             config.auxiliary if hasattr(config, "auxiliary") else -1, | ||||
|             config.KD_alpha, | ||||
|             config.KD_temperature, | ||||
|         ) | ||||
|     ) | ||||
|     end = time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|         if mode == "train": | ||||
|             scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|         # calculate prediction and loss | ||||
|         targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|  | ||||
|         student_f, logits = network(inputs) | ||||
|         if isinstance(logits, list): | ||||
|             assert len(logits) == 2, "logits must has {:} items instead of {:}".format( | ||||
|                 2, len(logits) | ||||
|             ) | ||||
|             logits, logits_aux = logits | ||||
|         else: | ||||
|             logits, logits_aux = logits, None | ||||
|         with torch.no_grad(): | ||||
|             teacher_f, teacher_logits = teacher(inputs) | ||||
|  | ||||
|         loss = loss_KD_fn( | ||||
|             criterion, | ||||
|             logits, | ||||
|             teacher_logits, | ||||
|             student_f, | ||||
|             teacher_f, | ||||
|             targets, | ||||
|             config.KD_alpha, | ||||
|             config.KD_temperature, | ||||
|         ) | ||||
|         if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0: | ||||
|             loss_aux = criterion(logits_aux, targets) | ||||
|             loss += config.auxiliary * loss_aux | ||||
|  | ||||
|         if mode == "train": | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         sprec1, sprec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|         losses.update(loss.item(), inputs.size(0)) | ||||
|         top1.update(sprec1.item(), inputs.size(0)) | ||||
|         top5.update(sprec5.item(), inputs.size(0)) | ||||
|         # teacher | ||||
|         tprec1, tprec5 = obtain_accuracy(teacher_logits.data, targets.data, topk=(1, 5)) | ||||
|         Ttop1.update(tprec1.item(), inputs.size(0)) | ||||
|         Ttop5.update(tprec5.item(), inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|  | ||||
|         if i % print_freq == 0 or (i + 1) == len(xloader): | ||||
|             Sstr = ( | ||||
|                 " {:5s} ".format(mode.upper()) | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) | ||||
|             ) | ||||
|             if scheduler is not None: | ||||
|                 Sstr += " {:}".format(scheduler.get_min_info()) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                 loss=losses, top1=top1, top5=top5 | ||||
|             ) | ||||
|             Lstr += " Teacher : acc@1={:.2f}, acc@5={:.2f}".format(Ttop1.avg, Ttop5.avg) | ||||
|             Istr = "Size={:}".format(list(inputs.size())) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr) | ||||
|  | ||||
|     logger.log( | ||||
|         " **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}".format( | ||||
|             mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg | ||||
|         ) | ||||
|     ) | ||||
|     logger.log( | ||||
|         " **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( | ||||
|             mode=mode.upper(), | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             loss=losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|     return losses.avg, top1.avg, top5.avg | ||||
							
								
								
									
										79
									
								
								AutoDL-Projects/xautodl/procedures/starts.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								AutoDL-Projects/xautodl/procedures/starts.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,79 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, torch, random, PIL, copy, numpy as np | ||||
| from os import path as osp | ||||
| from shutil import copyfile | ||||
|  | ||||
|  | ||||
| def prepare_seed(rand_seed): | ||||
|     random.seed(rand_seed) | ||||
|     np.random.seed(rand_seed) | ||||
|     torch.manual_seed(rand_seed) | ||||
|     torch.cuda.manual_seed(rand_seed) | ||||
|     torch.cuda.manual_seed_all(rand_seed) | ||||
|  | ||||
|  | ||||
| def prepare_logger(xargs): | ||||
|     args = copy.deepcopy(xargs) | ||||
|     from xautodl.log_utils import Logger | ||||
|  | ||||
|     logger = Logger(args.save_dir, args.rand_seed) | ||||
|     logger.log("Main Function with logger : {:}".format(logger)) | ||||
|     logger.log("Arguments : -------------------------------") | ||||
|     for name, value in args._get_kwargs(): | ||||
|         logger.log("{:16} : {:}".format(name, value)) | ||||
|     logger.log("Python  Version  : {:}".format(sys.version.replace("\n", " "))) | ||||
|     logger.log("Pillow  Version  : {:}".format(PIL.__version__)) | ||||
|     logger.log("PyTorch Version  : {:}".format(torch.__version__)) | ||||
|     logger.log("cuDNN   Version  : {:}".format(torch.backends.cudnn.version())) | ||||
|     logger.log("CUDA available   : {:}".format(torch.cuda.is_available())) | ||||
|     logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count())) | ||||
|     logger.log( | ||||
|         "CUDA_VISIBLE_DEVICES : {:}".format( | ||||
|             os.environ["CUDA_VISIBLE_DEVICES"] | ||||
|             if "CUDA_VISIBLE_DEVICES" in os.environ | ||||
|             else "None" | ||||
|         ) | ||||
|     ) | ||||
|     return logger | ||||
|  | ||||
|  | ||||
| def get_machine_info(): | ||||
|     info = "Python  Version  : {:}".format(sys.version.replace("\n", " ")) | ||||
|     info += "\nPillow  Version  : {:}".format(PIL.__version__) | ||||
|     info += "\nPyTorch Version  : {:}".format(torch.__version__) | ||||
|     info += "\ncuDNN   Version  : {:}".format(torch.backends.cudnn.version()) | ||||
|     info += "\nCUDA available   : {:}".format(torch.cuda.is_available()) | ||||
|     info += "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count()) | ||||
|     if "CUDA_VISIBLE_DEVICES" in os.environ: | ||||
|         info += "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ["CUDA_VISIBLE_DEVICES"]) | ||||
|     else: | ||||
|         info += "\nDoes not set CUDA_VISIBLE_DEVICES" | ||||
|     return info | ||||
|  | ||||
|  | ||||
| def save_checkpoint(state, filename, logger): | ||||
|     if osp.isfile(filename): | ||||
|         if hasattr(logger, "log"): | ||||
|             logger.log( | ||||
|                 "Find {:} exist, delete is at first before saving".format(filename) | ||||
|             ) | ||||
|         os.remove(filename) | ||||
|     torch.save(state, filename) | ||||
|     assert osp.isfile( | ||||
|         filename | ||||
|     ), "save filename : {:} failed, which is not found.".format(filename) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log("save checkpoint into {:}".format(filename)) | ||||
|     return filename | ||||
|  | ||||
|  | ||||
| def copy_checkpoint(src, dst, logger): | ||||
|     if osp.isfile(dst): | ||||
|         if hasattr(logger, "log"): | ||||
|             logger.log("Find {:} exist, delete is at first before saving".format(dst)) | ||||
|         os.remove(dst) | ||||
|     copyfile(src, dst) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log("copy the file from {:} into {:}".format(src, dst)) | ||||
		Reference in New Issue
	
	Block a user