add autodl
This commit is contained in:
		
							
								
								
									
										12
									
								
								AutoDL-Projects/xautodl/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								AutoDL-Projects/xautodl/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,12 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | ||||
| ##################################################### | ||||
| # An Automated Deep Learning Package to support     # | ||||
| # research activities.                              # | ||||
| ##################################################### | ||||
|  | ||||
|  | ||||
| def version(): | ||||
|     versions = ["0.9.9"]  # 2021.06.01 | ||||
|     versions = ["1.0.0"]  # 2021.08.14 | ||||
|     return versions[-1] | ||||
							
								
								
									
										20
									
								
								AutoDL-Projects/xautodl/config_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								AutoDL-Projects/xautodl/config_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| # general config related functions | ||||
| from .config_utils import load_config, dict2config, configure2str | ||||
|  | ||||
| # the args setting for different experiments | ||||
| from .basic_args import obtain_basic_args | ||||
| from .attention_args import obtain_attention_args | ||||
| from .random_baseline import obtain_RandomSearch_args | ||||
| from .cls_kd_args import obtain_cls_kd_args | ||||
| from .cls_init_args import obtain_cls_init_args | ||||
| from .search_single_args import obtain_search_single_args | ||||
| from .search_args import obtain_search_args | ||||
|  | ||||
| # for network pruning | ||||
| from .pruning_args import obtain_pruning_args | ||||
|  | ||||
| # utils for args | ||||
| from .args_utils import arg_str2bool | ||||
							
								
								
									
										12
									
								
								AutoDL-Projects/xautodl/config_utils/args_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								AutoDL-Projects/xautodl/config_utils/args_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,12 @@ | ||||
| import argparse | ||||
|  | ||||
|  | ||||
| def arg_str2bool(v): | ||||
|     if isinstance(v, bool): | ||||
|         return v | ||||
|     elif v.lower() in ("yes", "true", "t", "y", "1"): | ||||
|         return True | ||||
|     elif v.lower() in ("no", "false", "f", "n", "0"): | ||||
|         return False | ||||
|     else: | ||||
|         raise argparse.ArgumentTypeError("Boolean value expected.") | ||||
							
								
								
									
										32
									
								
								AutoDL-Projects/xautodl/config_utils/attention_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								AutoDL-Projects/xautodl/config_utils/attention_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_attention_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument("--att_channel", type=int, help=".") | ||||
|     parser.add_argument("--att_spatial", type=str, help=".") | ||||
|     parser.add_argument("--att_active", type=str, help=".") | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     return args | ||||
							
								
								
									
										44
									
								
								AutoDL-Projects/xautodl/config_utils/basic_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								AutoDL-Projects/xautodl/config_utils/basic_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ################################################## | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_basic_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument( | ||||
|         "--model_source", | ||||
|         type=str, | ||||
|         default="normal", | ||||
|         help="The source of model defination.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--extra_model_path", | ||||
|         type=str, | ||||
|         default=None, | ||||
|         help="The extra model ckp file (help to indicate the searched architecture).", | ||||
|     ) | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     return args | ||||
							
								
								
									
										32
									
								
								AutoDL-Projects/xautodl/config_utils/cls_init_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								AutoDL-Projects/xautodl/config_utils/cls_init_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_cls_init_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument( | ||||
|         "--init_checkpoint", type=str, help="The checkpoint path to the initial model." | ||||
|     ) | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     return args | ||||
							
								
								
									
										43
									
								
								AutoDL-Projects/xautodl/config_utils/cls_kd_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								AutoDL-Projects/xautodl/config_utils/cls_kd_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,43 @@ | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_cls_kd_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument( | ||||
|         "--KD_checkpoint", | ||||
|         type=str, | ||||
|         help="The teacher checkpoint in knowledge distillation.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--KD_alpha", type=float, help="The alpha parameter in knowledge distillation." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--KD_temperature", | ||||
|         type=float, | ||||
|         help="The temperature parameter in knowledge distillation.", | ||||
|     ) | ||||
|     # parser.add_argument('--KD_feature',       type=float,                 help='Knowledge distillation at the feature level.') | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     return args | ||||
							
								
								
									
										135
									
								
								AutoDL-Projects/xautodl/config_utils/config_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								AutoDL-Projects/xautodl/config_utils/config_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,135 @@ | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
| # | ||||
| import os, json | ||||
| from os import path as osp | ||||
| from pathlib import Path | ||||
| from collections import namedtuple | ||||
|  | ||||
| support_types = ("str", "int", "bool", "float", "none") | ||||
|  | ||||
|  | ||||
| def convert_param(original_lists): | ||||
|     assert isinstance(original_lists, list), "The type is not right : {:}".format( | ||||
|         original_lists | ||||
|     ) | ||||
|     ctype, value = original_lists[0], original_lists[1] | ||||
|     assert ctype in support_types, "Ctype={:}, support={:}".format(ctype, support_types) | ||||
|     is_list = isinstance(value, list) | ||||
|     if not is_list: | ||||
|         value = [value] | ||||
|     outs = [] | ||||
|     for x in value: | ||||
|         if ctype == "int": | ||||
|             x = int(x) | ||||
|         elif ctype == "str": | ||||
|             x = str(x) | ||||
|         elif ctype == "bool": | ||||
|             x = bool(int(x)) | ||||
|         elif ctype == "float": | ||||
|             x = float(x) | ||||
|         elif ctype == "none": | ||||
|             if x.lower() != "none": | ||||
|                 raise ValueError( | ||||
|                     "For the none type, the value must be none instead of {:}".format(x) | ||||
|                 ) | ||||
|             x = None | ||||
|         else: | ||||
|             raise TypeError("Does not know this type : {:}".format(ctype)) | ||||
|         outs.append(x) | ||||
|     if not is_list: | ||||
|         outs = outs[0] | ||||
|     return outs | ||||
|  | ||||
|  | ||||
| def load_config(path, extra, logger): | ||||
|     path = str(path) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log(path) | ||||
|     assert os.path.exists(path), "Can not find {:}".format(path) | ||||
|     # Reading data back | ||||
|     with open(path, "r") as f: | ||||
|         data = json.load(f) | ||||
|     content = {k: convert_param(v) for k, v in data.items()} | ||||
|     assert extra is None or isinstance( | ||||
|         extra, dict | ||||
|     ), "invalid type of extra : {:}".format(extra) | ||||
|     if isinstance(extra, dict): | ||||
|         content = {**content, **extra} | ||||
|     Arguments = namedtuple("Configure", " ".join(content.keys())) | ||||
|     content = Arguments(**content) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log("{:}".format(content)) | ||||
|     return content | ||||
|  | ||||
|  | ||||
| def configure2str(config, xpath=None): | ||||
|     if not isinstance(config, dict): | ||||
|         config = config._asdict() | ||||
|  | ||||
|     def cstring(x): | ||||
|         return '"{:}"'.format(x) | ||||
|  | ||||
|     def gtype(x): | ||||
|         if isinstance(x, list): | ||||
|             x = x[0] | ||||
|         if isinstance(x, str): | ||||
|             return "str" | ||||
|         elif isinstance(x, bool): | ||||
|             return "bool" | ||||
|         elif isinstance(x, int): | ||||
|             return "int" | ||||
|         elif isinstance(x, float): | ||||
|             return "float" | ||||
|         elif x is None: | ||||
|             return "none" | ||||
|         else: | ||||
|             raise ValueError("invalid : {:}".format(x)) | ||||
|  | ||||
|     def cvalue(x, xtype): | ||||
|         if isinstance(x, list): | ||||
|             is_list = True | ||||
|         else: | ||||
|             is_list, x = False, [x] | ||||
|         temps = [] | ||||
|         for temp in x: | ||||
|             if xtype == "bool": | ||||
|                 temp = cstring(int(temp)) | ||||
|             elif xtype == "none": | ||||
|                 temp = cstring("None") | ||||
|             else: | ||||
|                 temp = cstring(temp) | ||||
|             temps.append(temp) | ||||
|         if is_list: | ||||
|             return "[{:}]".format(", ".join(temps)) | ||||
|         else: | ||||
|             return temps[0] | ||||
|  | ||||
|     xstrings = [] | ||||
|     for key, value in config.items(): | ||||
|         xtype = gtype(value) | ||||
|         string = "  {:20s} : [{:8s}, {:}]".format( | ||||
|             cstring(key), cstring(xtype), cvalue(value, xtype) | ||||
|         ) | ||||
|         xstrings.append(string) | ||||
|     Fstring = "{\n" + ",\n".join(xstrings) + "\n}" | ||||
|     if xpath is not None: | ||||
|         parent = Path(xpath).resolve().parent | ||||
|         parent.mkdir(parents=True, exist_ok=True) | ||||
|         if osp.isfile(xpath): | ||||
|             os.remove(xpath) | ||||
|         with open(xpath, "w") as text_file: | ||||
|             text_file.write("{:}".format(Fstring)) | ||||
|     return Fstring | ||||
|  | ||||
|  | ||||
| def dict2config(xdict, logger): | ||||
|     assert isinstance(xdict, dict), "invalid type : {:}".format(type(xdict)) | ||||
|     Arguments = namedtuple("Configure", " ".join(xdict.keys())) | ||||
|     content = Arguments(**xdict) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log("{:}".format(content)) | ||||
|     return content | ||||
							
								
								
									
										48
									
								
								AutoDL-Projects/xautodl/config_utils/pruning_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								AutoDL-Projects/xautodl/config_utils/pruning_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | ||||
| import os, sys, time, random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_pruning_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument( | ||||
|         "--keep_ratio", | ||||
|         type=float, | ||||
|         help="The left channel ratio compared to the original network.", | ||||
|     ) | ||||
|     parser.add_argument("--model_version", type=str, help="The network version.") | ||||
|     parser.add_argument( | ||||
|         "--KD_alpha", type=float, help="The alpha parameter in knowledge distillation." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--KD_temperature", | ||||
|         type=float, | ||||
|         help="The temperature parameter in knowledge distillation.", | ||||
|     ) | ||||
|     parser.add_argument("--Regular_W_feat", type=float, help="The .") | ||||
|     parser.add_argument("--Regular_W_conv", type=float, help="The .") | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     assert ( | ||||
|         args.keep_ratio > 0 and args.keep_ratio <= 1 | ||||
|     ), "invalid keep ratio : {:}".format(args.keep_ratio) | ||||
|     return args | ||||
							
								
								
									
										44
									
								
								AutoDL-Projects/xautodl/config_utils/random_baseline.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								AutoDL-Projects/xautodl/config_utils/random_baseline.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | ||||
| import os, sys, time, random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_RandomSearch_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--expect_flop", type=float, help="The expected flop keep ratio." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_nums", | ||||
|         type=int, | ||||
|         help="The maximum number of running random arch generating..", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--random_mode", | ||||
|         type=str, | ||||
|         choices=["random", "fix"], | ||||
|         help="The path to the optimizer configuration", | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     # assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max) | ||||
|     return args | ||||
							
								
								
									
										53
									
								
								AutoDL-Projects/xautodl/config_utils/search_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								AutoDL-Projects/xautodl/config_utils/search_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | ||||
| import os, sys, time, random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_search_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--split_path", type=str, help="The split file path.") | ||||
|     # parser.add_argument('--arch_para_pure',   type=int,                   help='The architecture-parameter pure or not.') | ||||
|     parser.add_argument( | ||||
|         "--gumbel_tau_max", type=float, help="The maximum tau for Gumbel." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--gumbel_tau_min", type=float, help="The minimum tau for Gumbel." | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.") | ||||
|     parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.") | ||||
|     parser.add_argument( | ||||
|         "--FLOP_tolerant", type=float, help="The tolerant range for FLOP." | ||||
|     ) | ||||
|     # ablation studies | ||||
|     parser.add_argument( | ||||
|         "--ablation_num_select", | ||||
|         type=int, | ||||
|         help="The number of randomly selected channels.", | ||||
|     ) | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None | ||||
|     assert ( | ||||
|         args.FLOP_tolerant is not None and args.FLOP_tolerant > 0 | ||||
|     ), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant) | ||||
|     # assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) | ||||
|     # args.arch_para_pure = bool(args.arch_para_pure) | ||||
|     return args | ||||
							
								
								
									
										48
									
								
								AutoDL-Projects/xautodl/config_utils/search_single_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								AutoDL-Projects/xautodl/config_utils/search_single_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | ||||
| import os, sys, time, random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_search_single_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--split_path", type=str, help="The split file path.") | ||||
|     parser.add_argument("--search_shape", type=str, help="The shape to be searched.") | ||||
|     # parser.add_argument('--arch_para_pure',   type=int,                   help='The architecture-parameter pure or not.') | ||||
|     parser.add_argument( | ||||
|         "--gumbel_tau_max", type=float, help="The maximum tau for Gumbel." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--gumbel_tau_min", type=float, help="The minimum tau for Gumbel." | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.") | ||||
|     parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.") | ||||
|     parser.add_argument( | ||||
|         "--FLOP_tolerant", type=float, help="The tolerant range for FLOP." | ||||
|     ) | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None | ||||
|     assert ( | ||||
|         args.FLOP_tolerant is not None and args.FLOP_tolerant > 0 | ||||
|     ), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant) | ||||
|     # assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) | ||||
|     # args.arch_para_pure = bool(args.arch_para_pure) | ||||
|     return args | ||||
							
								
								
									
										39
									
								
								AutoDL-Projects/xautodl/config_utils/share_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								AutoDL-Projects/xautodl/config_utils/share_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,39 @@ | ||||
| import os, sys, time, random, argparse | ||||
|  | ||||
|  | ||||
| def add_shared_args(parser): | ||||
|     # Data Generation | ||||
|     parser.add_argument("--dataset", type=str, help="The dataset name.") | ||||
|     parser.add_argument("--data_path", type=str, help="The dataset name.") | ||||
|     parser.add_argument( | ||||
|         "--cutout_length", type=int, help="The cutout length, negative means not use." | ||||
|     ) | ||||
|     # Printing | ||||
|     parser.add_argument( | ||||
|         "--print_freq", type=int, default=100, help="print frequency (default: 200)" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--print_freq_eval", | ||||
|         type=int, | ||||
|         default=100, | ||||
|         help="print frequency (default: 200)", | ||||
|     ) | ||||
|     # Checkpoints | ||||
|     parser.add_argument( | ||||
|         "--eval_frequency", | ||||
|         type=int, | ||||
|         default=1, | ||||
|         help="evaluation frequency (default: 200)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, help="Folder to save checkpoints and log." | ||||
|     ) | ||||
|     # Acceleration | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=8, | ||||
|         help="number of data loading workers (default: 8)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
							
								
								
									
										16
									
								
								AutoDL-Projects/xautodl/log_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								AutoDL-Projects/xautodl/log_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| # every package does not rely on pytorch or tensorflow | ||||
| # I tried to list all dependency here: os, sys, time, numpy, (possibly) matplotlib | ||||
| ################################################## | ||||
| from .logger import Logger, PrintLogger | ||||
| from .meter import AverageMeter | ||||
| from .time_utils import ( | ||||
|     time_for_file, | ||||
|     time_string, | ||||
|     time_string_short, | ||||
|     time_print, | ||||
|     convert_secs2time, | ||||
| ) | ||||
| from .pickle_wrap import pickle_save, pickle_load | ||||
							
								
								
									
										173
									
								
								AutoDL-Projects/xautodl/log_utils/logger.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								AutoDL-Projects/xautodl/log_utils/logger.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,173 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from pathlib import Path | ||||
| import importlib, warnings | ||||
| import os, sys, time, numpy as np | ||||
|  | ||||
| if sys.version_info.major == 2:  # Python 2.x | ||||
|     from StringIO import StringIO as BIO | ||||
| else:  # Python 3.x | ||||
|     from io import BytesIO as BIO | ||||
|  | ||||
| if importlib.util.find_spec("tensorflow"): | ||||
|     import tensorflow as tf | ||||
|  | ||||
|  | ||||
| class PrintLogger(object): | ||||
|     def __init__(self): | ||||
|         """Create a summary writer logging to log_dir.""" | ||||
|         self.name = "PrintLogger" | ||||
|  | ||||
|     def log(self, string): | ||||
|         print(string) | ||||
|  | ||||
|     def close(self): | ||||
|         print("-" * 30 + " close printer " + "-" * 30) | ||||
|  | ||||
|  | ||||
| class Logger(object): | ||||
|     def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False): | ||||
|         """Create a summary writer logging to log_dir.""" | ||||
|         self.seed = int(seed) | ||||
|         self.log_dir = Path(log_dir) | ||||
|         self.model_dir = Path(log_dir) / "checkpoint" | ||||
|         self.log_dir.mkdir(parents=True, exist_ok=True) | ||||
|         if create_model_dir: | ||||
|             self.model_dir.mkdir(parents=True, exist_ok=True) | ||||
|         # self.meta_dir.mkdir(mode=0o775, parents=True, exist_ok=True) | ||||
|  | ||||
|         self.use_tf = bool(use_tf) | ||||
|         self.tensorboard_dir = self.log_dir / ( | ||||
|             "tensorboard-{:}".format(time.strftime("%d-%h", time.gmtime(time.time()))) | ||||
|         ) | ||||
|         # self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h-at-%H:%M:%S', time.gmtime(time.time()) ))) | ||||
|         self.logger_path = self.log_dir / "seed-{:}-T-{:}.log".format( | ||||
|             self.seed, time.strftime("%d-%h-at-%H-%M-%S", time.gmtime(time.time())) | ||||
|         ) | ||||
|         self.logger_file = open(self.logger_path, "w") | ||||
|  | ||||
|         if self.use_tf: | ||||
|             self.tensorboard_dir.mkdir(mode=0o775, parents=True, exist_ok=True) | ||||
|             self.writer = tf.summary.FileWriter(str(self.tensorboard_dir)) | ||||
|         else: | ||||
|             self.writer = None | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(dir={log_dir}, use-tf={use_tf}, writer={writer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def path(self, mode): | ||||
|         valids = ("model", "best", "info", "log", None) | ||||
|         if mode is None: | ||||
|             return self.log_dir | ||||
|         elif mode == "model": | ||||
|             return self.model_dir / "seed-{:}-basic.pth".format(self.seed) | ||||
|         elif mode == "best": | ||||
|             return self.model_dir / "seed-{:}-best.pth".format(self.seed) | ||||
|         elif mode == "info": | ||||
|             return self.log_dir / "seed-{:}-last-info.pth".format(self.seed) | ||||
|         elif mode == "log": | ||||
|             return self.log_dir | ||||
|         else: | ||||
|             raise TypeError("Unknow mode = {:}, valid modes = {:}".format(mode, valids)) | ||||
|  | ||||
|     def extract_log(self): | ||||
|         return self.logger_file | ||||
|  | ||||
|     def close(self): | ||||
|         self.logger_file.close() | ||||
|         if self.writer is not None: | ||||
|             self.writer.close() | ||||
|  | ||||
|     def log(self, string, save=True, stdout=False): | ||||
|         if stdout: | ||||
|             sys.stdout.write(string) | ||||
|             sys.stdout.flush() | ||||
|         else: | ||||
|             print(string) | ||||
|         if save: | ||||
|             self.logger_file.write("{:}\n".format(string)) | ||||
|             self.logger_file.flush() | ||||
|  | ||||
|     def scalar_summary(self, tags, values, step): | ||||
|         """Log a scalar variable.""" | ||||
|         if not self.use_tf: | ||||
|             warnings.warn("Do set use-tensorflow installed but call scalar_summary") | ||||
|         else: | ||||
|             assert isinstance(tags, list) == isinstance( | ||||
|                 values, list | ||||
|             ), "Type : {:} vs {:}".format(type(tags), type(values)) | ||||
|             if not isinstance(tags, list): | ||||
|                 tags, values = [tags], [values] | ||||
|             for tag, value in zip(tags, values): | ||||
|                 summary = tf.Summary( | ||||
|                     value=[tf.Summary.Value(tag=tag, simple_value=value)] | ||||
|                 ) | ||||
|                 self.writer.add_summary(summary, step) | ||||
|                 self.writer.flush() | ||||
|  | ||||
|     def image_summary(self, tag, images, step): | ||||
|         """Log a list of images.""" | ||||
|         import scipy | ||||
|  | ||||
|         if not self.use_tf: | ||||
|             warnings.warn("Do set use-tensorflow installed but call scalar_summary") | ||||
|             return | ||||
|  | ||||
|         img_summaries = [] | ||||
|         for i, img in enumerate(images): | ||||
|             # Write the image to a string | ||||
|             try: | ||||
|                 s = StringIO() | ||||
|             except: | ||||
|                 s = BytesIO() | ||||
|             scipy.misc.toimage(img).save(s, format="png") | ||||
|  | ||||
|             # Create an Image object | ||||
|             img_sum = tf.Summary.Image( | ||||
|                 encoded_image_string=s.getvalue(), | ||||
|                 height=img.shape[0], | ||||
|                 width=img.shape[1], | ||||
|             ) | ||||
|             # Create a Summary value | ||||
|             img_summaries.append( | ||||
|                 tf.Summary.Value(tag="{}/{}".format(tag, i), image=img_sum) | ||||
|             ) | ||||
|  | ||||
|         # Create and write Summary | ||||
|         summary = tf.Summary(value=img_summaries) | ||||
|         self.writer.add_summary(summary, step) | ||||
|         self.writer.flush() | ||||
|  | ||||
|     def histo_summary(self, tag, values, step, bins=1000): | ||||
|         """Log a histogram of the tensor of values.""" | ||||
|         if not self.use_tf: | ||||
|             raise ValueError("Do not have tensorflow") | ||||
|         import tensorflow as tf | ||||
|  | ||||
|         # Create a histogram using numpy | ||||
|         counts, bin_edges = np.histogram(values, bins=bins) | ||||
|  | ||||
|         # Fill the fields of the histogram proto | ||||
|         hist = tf.HistogramProto() | ||||
|         hist.min = float(np.min(values)) | ||||
|         hist.max = float(np.max(values)) | ||||
|         hist.num = int(np.prod(values.shape)) | ||||
|         hist.sum = float(np.sum(values)) | ||||
|         hist.sum_squares = float(np.sum(values**2)) | ||||
|  | ||||
|         # Drop the start of the first bin | ||||
|         bin_edges = bin_edges[1:] | ||||
|  | ||||
|         # Add bin edges and counts | ||||
|         for edge in bin_edges: | ||||
|             hist.bucket_limit.append(edge) | ||||
|         for c in counts: | ||||
|             hist.bucket.append(c) | ||||
|  | ||||
|         # Create and write Summary | ||||
|         summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) | ||||
|         self.writer.add_summary(summary, step) | ||||
|         self.writer.flush() | ||||
							
								
								
									
										120
									
								
								AutoDL-Projects/xautodl/log_utils/meter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								AutoDL-Projects/xautodl/log_utils/meter.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,120 @@ | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| class AverageMeter(object): | ||||
|     """Computes and stores the average and current value""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.reset() | ||||
|  | ||||
|     def reset(self): | ||||
|         self.val = 0.0 | ||||
|         self.avg = 0.0 | ||||
|         self.sum = 0.0 | ||||
|         self.count = 0.0 | ||||
|  | ||||
|     def update(self, val, n=1): | ||||
|         self.val = val | ||||
|         self.sum += val * n | ||||
|         self.count += n | ||||
|         self.avg = self.sum / self.count | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(val={val}, avg={avg}, count={count})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class RecorderMeter(object): | ||||
|     """Computes and stores the minimum loss value and its epoch index""" | ||||
|  | ||||
|     def __init__(self, total_epoch): | ||||
|         self.reset(total_epoch) | ||||
|  | ||||
|     def reset(self, total_epoch): | ||||
|         assert total_epoch > 0, "total_epoch should be greater than 0 vs {:}".format( | ||||
|             total_epoch | ||||
|         ) | ||||
|         self.total_epoch = total_epoch | ||||
|         self.current_epoch = 0 | ||||
|         self.epoch_losses = np.zeros( | ||||
|             (self.total_epoch, 2), dtype=np.float32 | ||||
|         )  # [epoch, train/val] | ||||
|         self.epoch_losses = self.epoch_losses - 1 | ||||
|         self.epoch_accuracy = np.zeros( | ||||
|             (self.total_epoch, 2), dtype=np.float32 | ||||
|         )  # [epoch, train/val] | ||||
|         self.epoch_accuracy = self.epoch_accuracy | ||||
|  | ||||
|     def update(self, idx, train_loss, train_acc, val_loss, val_acc): | ||||
|         assert ( | ||||
|             idx >= 0 and idx < self.total_epoch | ||||
|         ), "total_epoch : {} , but update with the {} index".format( | ||||
|             self.total_epoch, idx | ||||
|         ) | ||||
|         self.epoch_losses[idx, 0] = train_loss | ||||
|         self.epoch_losses[idx, 1] = val_loss | ||||
|         self.epoch_accuracy[idx, 0] = train_acc | ||||
|         self.epoch_accuracy[idx, 1] = val_acc | ||||
|         self.current_epoch = idx + 1 | ||||
|         return self.max_accuracy(False) == self.epoch_accuracy[idx, 1] | ||||
|  | ||||
|     def max_accuracy(self, istrain): | ||||
|         if self.current_epoch <= 0: | ||||
|             return 0 | ||||
|         if istrain: | ||||
|             return self.epoch_accuracy[: self.current_epoch, 0].max() | ||||
|         else: | ||||
|             return self.epoch_accuracy[: self.current_epoch, 1].max() | ||||
|  | ||||
|     def plot_curve(self, save_path): | ||||
|         import matplotlib | ||||
|  | ||||
|         matplotlib.use("agg") | ||||
|         import matplotlib.pyplot as plt | ||||
|  | ||||
|         title = "the accuracy/loss curve of train/val" | ||||
|         dpi = 100 | ||||
|         width, height = 1600, 1000 | ||||
|         legend_fontsize = 10 | ||||
|         figsize = width / float(dpi), height / float(dpi) | ||||
|  | ||||
|         fig = plt.figure(figsize=figsize) | ||||
|         x_axis = np.array([i for i in range(self.total_epoch)])  # epochs | ||||
|         y_axis = np.zeros(self.total_epoch) | ||||
|  | ||||
|         plt.xlim(0, self.total_epoch) | ||||
|         plt.ylim(0, 100) | ||||
|         interval_y = 5 | ||||
|         interval_x = 5 | ||||
|         plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x)) | ||||
|         plt.yticks(np.arange(0, 100 + interval_y, interval_y)) | ||||
|         plt.grid() | ||||
|         plt.title(title, fontsize=20) | ||||
|         plt.xlabel("the training epoch", fontsize=16) | ||||
|         plt.ylabel("accuracy", fontsize=16) | ||||
|  | ||||
|         y_axis[:] = self.epoch_accuracy[:, 0] | ||||
|         plt.plot(x_axis, y_axis, color="g", linestyle="-", label="train-accuracy", lw=2) | ||||
|         plt.legend(loc=4, fontsize=legend_fontsize) | ||||
|  | ||||
|         y_axis[:] = self.epoch_accuracy[:, 1] | ||||
|         plt.plot(x_axis, y_axis, color="y", linestyle="-", label="valid-accuracy", lw=2) | ||||
|         plt.legend(loc=4, fontsize=legend_fontsize) | ||||
|  | ||||
|         y_axis[:] = self.epoch_losses[:, 0] | ||||
|         plt.plot( | ||||
|             x_axis, y_axis * 50, color="g", linestyle=":", label="train-loss-x50", lw=2 | ||||
|         ) | ||||
|         plt.legend(loc=4, fontsize=legend_fontsize) | ||||
|  | ||||
|         y_axis[:] = self.epoch_losses[:, 1] | ||||
|         plt.plot( | ||||
|             x_axis, y_axis * 50, color="y", linestyle=":", label="valid-loss-x50", lw=2 | ||||
|         ) | ||||
|         plt.legend(loc=4, fontsize=legend_fontsize) | ||||
|  | ||||
|         if save_path is not None: | ||||
|             fig.savefig(save_path, dpi=dpi, bbox_inches="tight") | ||||
|             print("---- save figure {} into {}".format(title, save_path)) | ||||
|         plt.close(fig) | ||||
							
								
								
									
										21
									
								
								AutoDL-Projects/xautodl/log_utils/pickle_wrap.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								AutoDL-Projects/xautodl/log_utils/pickle_wrap.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import pickle | ||||
| from pathlib import Path | ||||
|  | ||||
|  | ||||
| def pickle_save(obj, path): | ||||
|     file_path = Path(path) | ||||
|     file_dir = file_path.parent | ||||
|     file_dir.mkdir(parents=True, exist_ok=True) | ||||
|     with file_path.open("wb") as f: | ||||
|         pickle.dump(obj, f) | ||||
|  | ||||
|  | ||||
| def pickle_load(path): | ||||
|     if not Path(path).exists(): | ||||
|         raise ValueError("{:} does not exists".format(path)) | ||||
|     with Path(path).open("rb") as f: | ||||
|         data = pickle.load(f) | ||||
|     return data | ||||
							
								
								
									
										49
									
								
								AutoDL-Projects/xautodl/log_utils/time_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								AutoDL-Projects/xautodl/log_utils/time_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import time, sys | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| def time_for_file(): | ||||
|     ISOTIMEFORMAT = "%d-%h-at-%H-%M-%S" | ||||
|     return "{:}".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) | ||||
|  | ||||
|  | ||||
| def time_string(): | ||||
|     ISOTIMEFORMAT = "%Y-%m-%d %X" | ||||
|     string = "[{:}]".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) | ||||
|     return string | ||||
|  | ||||
|  | ||||
| def time_string_short(): | ||||
|     ISOTIMEFORMAT = "%Y%m%d" | ||||
|     string = "{:}".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) | ||||
|     return string | ||||
|  | ||||
|  | ||||
| def time_print(string, is_print=True): | ||||
|     if is_print: | ||||
|         print("{} : {}".format(time_string(), string)) | ||||
|  | ||||
|  | ||||
| def convert_secs2time(epoch_time, return_str=False): | ||||
|     need_hour = int(epoch_time / 3600) | ||||
|     need_mins = int((epoch_time - 3600 * need_hour) / 60) | ||||
|     need_secs = int(epoch_time - 3600 * need_hour - 60 * need_mins) | ||||
|     if return_str: | ||||
|         str = "[{:02d}:{:02d}:{:02d}]".format(need_hour, need_mins, need_secs) | ||||
|         return str | ||||
|     else: | ||||
|         return need_hour, need_mins, need_secs | ||||
|  | ||||
|  | ||||
| def print_log(print_string, log): | ||||
|     # if isinstance(log, Logger): log.log('{:}'.format(print_string)) | ||||
|     if hasattr(log, "log"): | ||||
|         log.log("{:}".format(print_string)) | ||||
|     else: | ||||
|         print("{:}".format(print_string)) | ||||
|         if log is not None: | ||||
|             log.write("{:}\n".format(print_string)) | ||||
|             log.flush() | ||||
							
								
								
									
										117
									
								
								AutoDL-Projects/xautodl/models/CifarDenseNet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								AutoDL-Projects/xautodl/models/CifarDenseNet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,117 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import math, torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from .initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class Bottleneck(nn.Module): | ||||
|     def __init__(self, nChannels, growthRate): | ||||
|         super(Bottleneck, self).__init__() | ||||
|         interChannels = 4 * growthRate | ||||
|         self.bn1 = nn.BatchNorm2d(nChannels) | ||||
|         self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) | ||||
|         self.bn2 = nn.BatchNorm2d(interChannels) | ||||
|         self.conv2 = nn.Conv2d( | ||||
|             interChannels, growthRate, kernel_size=3, padding=1, bias=False | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         out = self.conv1(F.relu(self.bn1(x))) | ||||
|         out = self.conv2(F.relu(self.bn2(out))) | ||||
|         out = torch.cat((x, out), 1) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class SingleLayer(nn.Module): | ||||
|     def __init__(self, nChannels, growthRate): | ||||
|         super(SingleLayer, self).__init__() | ||||
|         self.bn1 = nn.BatchNorm2d(nChannels) | ||||
|         self.conv1 = nn.Conv2d( | ||||
|             nChannels, growthRate, kernel_size=3, padding=1, bias=False | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         out = self.conv1(F.relu(self.bn1(x))) | ||||
|         out = torch.cat((x, out), 1) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class Transition(nn.Module): | ||||
|     def __init__(self, nChannels, nOutChannels): | ||||
|         super(Transition, self).__init__() | ||||
|         self.bn1 = nn.BatchNorm2d(nChannels) | ||||
|         self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         out = self.conv1(F.relu(self.bn1(x))) | ||||
|         out = F.avg_pool2d(out, 2) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class DenseNet(nn.Module): | ||||
|     def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): | ||||
|         super(DenseNet, self).__init__() | ||||
|  | ||||
|         if bottleneck: | ||||
|             nDenseBlocks = int((depth - 4) / 6) | ||||
|         else: | ||||
|             nDenseBlocks = int((depth - 4) / 3) | ||||
|  | ||||
|         self.message = "CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}".format( | ||||
|             "bottleneck" if bottleneck else "basic", | ||||
|             depth, | ||||
|             reduction, | ||||
|             growthRate, | ||||
|             nClasses, | ||||
|         ) | ||||
|  | ||||
|         nChannels = 2 * growthRate | ||||
|         self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False) | ||||
|  | ||||
|         self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||
|         nChannels += nDenseBlocks * growthRate | ||||
|         nOutChannels = int(math.floor(nChannels * reduction)) | ||||
|         self.trans1 = Transition(nChannels, nOutChannels) | ||||
|  | ||||
|         nChannels = nOutChannels | ||||
|         self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||
|         nChannels += nDenseBlocks * growthRate | ||||
|         nOutChannels = int(math.floor(nChannels * reduction)) | ||||
|         self.trans2 = Transition(nChannels, nOutChannels) | ||||
|  | ||||
|         nChannels = nOutChannels | ||||
|         self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||
|         nChannels += nDenseBlocks * growthRate | ||||
|  | ||||
|         self.act = nn.Sequential( | ||||
|             nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), nn.AvgPool2d(8) | ||||
|         ) | ||||
|         self.fc = nn.Linear(nChannels, nClasses) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): | ||||
|         layers = [] | ||||
|         for i in range(int(nDenseBlocks)): | ||||
|             if bottleneck: | ||||
|                 layers.append(Bottleneck(nChannels, growthRate)) | ||||
|             else: | ||||
|                 layers.append(SingleLayer(nChannels, growthRate)) | ||||
|             nChannels += growthRate | ||||
|         return nn.Sequential(*layers) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         out = self.conv1(inputs) | ||||
|         out = self.trans1(self.dense1(out)) | ||||
|         out = self.trans2(self.dense2(out)) | ||||
|         out = self.dense3(out) | ||||
|         features = self.act(out) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         out = self.fc(features) | ||||
|         return features, out | ||||
							
								
								
									
										180
									
								
								AutoDL-Projects/xautodl/models/CifarResNet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										180
									
								
								AutoDL-Projects/xautodl/models/CifarResNet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,180 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from .initialization import initialize_resnet | ||||
| from .SharedUtils import additive_func | ||||
|  | ||||
|  | ||||
| class Downsample(nn.Module): | ||||
|     def __init__(self, nIn, nOut, stride): | ||||
|         super(Downsample, self).__init__() | ||||
|         assert stride == 2 and nOut == 2 * nIn, "stride:{} IO:{},{}".format( | ||||
|             stride, nIn, nOut | ||||
|         ) | ||||
|         self.in_dim = nIn | ||||
|         self.out_dim = nOut | ||||
|         self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=1, padding=0, bias=False) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.avg(x) | ||||
|         out = self.conv(x) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     def __init__(self, nIn, nOut, kernel, stride, padding, bias, relu): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, bias=bias | ||||
|         ) | ||||
|         self.bn = nn.BatchNorm2d(nOut) | ||||
|         if relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|         self.out_dim = nOut | ||||
|         self.num_conv = 1 | ||||
|  | ||||
|     def forward(self, x): | ||||
|         conv = self.conv(x) | ||||
|         bn = self.bn(conv) | ||||
|         if self.relu: | ||||
|             return self.relu(bn) | ||||
|         else: | ||||
|             return bn | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     expansion = 1 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, True) | ||||
|         self.conv_b = ConvBNReLU(planes, planes, 3, 1, 1, False, False) | ||||
|         if stride == 2: | ||||
|             self.downsample = Downsample(inplanes, planes, stride) | ||||
|         elif inplanes != planes: | ||||
|             self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, False) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes | ||||
|         self.num_conv = 2 | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, basicblock) | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, True) | ||||
|         self.conv_3x3 = ConvBNReLU(planes, planes, 3, stride, 1, False, True) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             planes, planes * self.expansion, 1, 1, 0, False, False | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = Downsample(inplanes, planes * self.expansion, stride) | ||||
|         elif inplanes != planes * self.expansion: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, planes * self.expansion, 1, 1, 0, False, False | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes * self.expansion | ||||
|         self.num_conv = 3 | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, bottleneck) | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class CifarResNet(nn.Module): | ||||
|     def __init__(self, block_name, depth, num_classes, zero_init_residual): | ||||
|         super(CifarResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|  | ||||
|         self.message = "CifarResNet : Block : {:}, Depth : {:}, Layers for each block : {:}".format( | ||||
|             block_name, depth, layer_blocks | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.channels = [16] | ||||
|         self.layers = nn.ModuleList([ConvBNReLU(3, 16, 3, 1, 1, False, True)]) | ||||
|         for stage in range(3): | ||||
|             for iL in range(layer_blocks): | ||||
|                 iC = self.channels[-1] | ||||
|                 planes = 16 * (2 ** stage) | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iC, planes, stride) | ||||
|                 self.channels.append(module.out_dim) | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iC, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|  | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(module.out_dim, num_classes) | ||||
|         assert ( | ||||
|             sum(x.num_conv for x in self.layers) + 1 == depth | ||||
|         ), "invalid depth check {:} vs {:}".format( | ||||
|             sum(x.num_conv for x in self.layers) + 1, depth | ||||
|         ) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|         if zero_init_residual: | ||||
|             for m in self.modules(): | ||||
|                 if isinstance(m, ResNetBasicblock): | ||||
|                     nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|                 elif isinstance(m, ResNetBottleneck): | ||||
|                     nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
							
								
								
									
										115
									
								
								AutoDL-Projects/xautodl/models/CifarWideResNet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										115
									
								
								AutoDL-Projects/xautodl/models/CifarWideResNet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,115 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from .initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class WideBasicblock(nn.Module): | ||||
|     def __init__(self, inplanes, planes, stride, dropout=False): | ||||
|         super(WideBasicblock, self).__init__() | ||||
|  | ||||
|         self.bn_a = nn.BatchNorm2d(inplanes) | ||||
|         self.conv_a = nn.Conv2d( | ||||
|             inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False | ||||
|         ) | ||||
|  | ||||
|         self.bn_b = nn.BatchNorm2d(planes) | ||||
|         if dropout: | ||||
|             self.dropout = nn.Dropout2d(p=0.5, inplace=True) | ||||
|         else: | ||||
|             self.dropout = None | ||||
|         self.conv_b = nn.Conv2d( | ||||
|             planes, planes, kernel_size=3, stride=1, padding=1, bias=False | ||||
|         ) | ||||
|  | ||||
|         if inplanes != planes: | ||||
|             self.downsample = nn.Conv2d( | ||||
|                 inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|  | ||||
|     def forward(self, x): | ||||
|  | ||||
|         basicblock = self.bn_a(x) | ||||
|         basicblock = F.relu(basicblock) | ||||
|         basicblock = self.conv_a(basicblock) | ||||
|  | ||||
|         basicblock = self.bn_b(basicblock) | ||||
|         basicblock = F.relu(basicblock) | ||||
|         if self.dropout is not None: | ||||
|             basicblock = self.dropout(basicblock) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             x = self.downsample(x) | ||||
|  | ||||
|         return x + basicblock | ||||
|  | ||||
|  | ||||
| class CifarWideResNet(nn.Module): | ||||
|     """ | ||||
|     ResNet optimized for the Cifar dataset, as specified in | ||||
|     https://arxiv.org/abs/1512.03385.pdf | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, depth, widen_factor, num_classes, dropout): | ||||
|         super(CifarWideResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         assert (depth - 4) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|         layer_blocks = (depth - 4) // 6 | ||||
|         print( | ||||
|             "CifarPreResNet : Depth : {} , Layers for each block : {}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         self.num_classes = num_classes | ||||
|         self.dropout = dropout | ||||
|         self.conv_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) | ||||
|  | ||||
|         self.message = "Wide ResNet : depth={:}, widen_factor={:}, class={:}".format( | ||||
|             depth, widen_factor, num_classes | ||||
|         ) | ||||
|         self.inplanes = 16 | ||||
|         self.stage_1 = self._make_layer( | ||||
|             WideBasicblock, 16 * widen_factor, layer_blocks, 1 | ||||
|         ) | ||||
|         self.stage_2 = self._make_layer( | ||||
|             WideBasicblock, 32 * widen_factor, layer_blocks, 2 | ||||
|         ) | ||||
|         self.stage_3 = self._make_layer( | ||||
|             WideBasicblock, 64 * widen_factor, layer_blocks, 2 | ||||
|         ) | ||||
|         self.lastact = nn.Sequential( | ||||
|             nn.BatchNorm2d(64 * widen_factor), nn.ReLU(inplace=True) | ||||
|         ) | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(64 * widen_factor, num_classes) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def _make_layer(self, block, planes, blocks, stride): | ||||
|  | ||||
|         layers = [] | ||||
|         layers.append(block(self.inplanes, planes, stride, self.dropout)) | ||||
|         self.inplanes = planes | ||||
|         for i in range(1, blocks): | ||||
|             layers.append(block(self.inplanes, planes, 1, self.dropout)) | ||||
|  | ||||
|         return nn.Sequential(*layers) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.conv_3x3(x) | ||||
|         x = self.stage_1(x) | ||||
|         x = self.stage_2(x) | ||||
|         x = self.stage_3(x) | ||||
|         x = self.lastact(x) | ||||
|         x = self.avgpool(x) | ||||
|         features = x.view(x.size(0), -1) | ||||
|         outs = self.classifier(features) | ||||
|         return features, outs | ||||
							
								
								
									
										117
									
								
								AutoDL-Projects/xautodl/models/ImageNet_MobileNetV2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								AutoDL-Projects/xautodl/models/ImageNet_MobileNetV2.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,117 @@ | ||||
| # MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018 | ||||
| from torch import nn | ||||
| from .initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         padding = (kernel_size - 1) // 2 | ||||
|         self.conv = nn.Conv2d( | ||||
|             in_planes, | ||||
|             out_planes, | ||||
|             kernel_size, | ||||
|             stride, | ||||
|             padding, | ||||
|             groups=groups, | ||||
|             bias=False, | ||||
|         ) | ||||
|         self.bn = nn.BatchNorm2d(out_planes) | ||||
|         self.relu = nn.ReLU6(inplace=True) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         out = self.conv(x) | ||||
|         out = self.bn(out) | ||||
|         out = self.relu(out) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class InvertedResidual(nn.Module): | ||||
|     def __init__(self, inp, oup, stride, expand_ratio): | ||||
|         super(InvertedResidual, self).__init__() | ||||
|         self.stride = stride | ||||
|         assert stride in [1, 2] | ||||
|  | ||||
|         hidden_dim = int(round(inp * expand_ratio)) | ||||
|         self.use_res_connect = self.stride == 1 and inp == oup | ||||
|  | ||||
|         layers = [] | ||||
|         if expand_ratio != 1: | ||||
|             # pw | ||||
|             layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) | ||||
|         layers.extend( | ||||
|             [ | ||||
|                 # dw | ||||
|                 ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), | ||||
|                 # pw-linear | ||||
|                 nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), | ||||
|                 nn.BatchNorm2d(oup), | ||||
|             ] | ||||
|         ) | ||||
|         self.conv = nn.Sequential(*layers) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         if self.use_res_connect: | ||||
|             return x + self.conv(x) | ||||
|         else: | ||||
|             return self.conv(x) | ||||
|  | ||||
|  | ||||
| class MobileNetV2(nn.Module): | ||||
|     def __init__( | ||||
|         self, num_classes, width_mult, input_channel, last_channel, block_name, dropout | ||||
|     ): | ||||
|         super(MobileNetV2, self).__init__() | ||||
|         if block_name == "InvertedResidual": | ||||
|             block = InvertedResidual | ||||
|         else: | ||||
|             raise ValueError("invalid block name : {:}".format(block_name)) | ||||
|         inverted_residual_setting = [ | ||||
|             # t, c,  n, s | ||||
|             [1, 16, 1, 1], | ||||
|             [6, 24, 2, 2], | ||||
|             [6, 32, 3, 2], | ||||
|             [6, 64, 4, 2], | ||||
|             [6, 96, 3, 1], | ||||
|             [6, 160, 3, 2], | ||||
|             [6, 320, 1, 1], | ||||
|         ] | ||||
|  | ||||
|         # building first layer | ||||
|         input_channel = int(input_channel * width_mult) | ||||
|         self.last_channel = int(last_channel * max(1.0, width_mult)) | ||||
|         features = [ConvBNReLU(3, input_channel, stride=2)] | ||||
|         # building inverted residual blocks | ||||
|         for t, c, n, s in inverted_residual_setting: | ||||
|             output_channel = int(c * width_mult) | ||||
|             for i in range(n): | ||||
|                 stride = s if i == 0 else 1 | ||||
|                 features.append( | ||||
|                     block(input_channel, output_channel, stride, expand_ratio=t) | ||||
|                 ) | ||||
|                 input_channel = output_channel | ||||
|         # building last several layers | ||||
|         features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) | ||||
|         # make it nn.Sequential | ||||
|         self.features = nn.Sequential(*features) | ||||
|  | ||||
|         # building classifier | ||||
|         self.classifier = nn.Sequential( | ||||
|             nn.Dropout(dropout), | ||||
|             nn.Linear(self.last_channel, num_classes), | ||||
|         ) | ||||
|         self.message = "MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}".format( | ||||
|             width_mult, input_channel, last_channel, block_name, dropout | ||||
|         ) | ||||
|  | ||||
|         # weight initialization | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         features = self.features(inputs) | ||||
|         vectors = features.mean([2, 3]) | ||||
|         predicts = self.classifier(vectors) | ||||
|         return features, predicts | ||||
							
								
								
									
										217
									
								
								AutoDL-Projects/xautodl/models/ImageNet_ResNet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										217
									
								
								AutoDL-Projects/xautodl/models/ImageNet_ResNet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,217 @@ | ||||
| # Deep Residual Learning for Image Recognition, CVPR 2016 | ||||
| import torch.nn as nn | ||||
| from .initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| def conv3x3(in_planes, out_planes, stride=1, groups=1): | ||||
|     return nn.Conv2d( | ||||
|         in_planes, | ||||
|         out_planes, | ||||
|         kernel_size=3, | ||||
|         stride=stride, | ||||
|         padding=1, | ||||
|         groups=groups, | ||||
|         bias=False, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def conv1x1(in_planes, out_planes, stride=1): | ||||
|     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | ||||
|  | ||||
|  | ||||
| class BasicBlock(nn.Module): | ||||
|     expansion = 1 | ||||
|  | ||||
|     def __init__( | ||||
|         self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64 | ||||
|     ): | ||||
|         super(BasicBlock, self).__init__() | ||||
|         if groups != 1 or base_width != 64: | ||||
|             raise ValueError("BasicBlock only supports groups=1 and base_width=64") | ||||
|         # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | ||||
|         self.conv1 = conv3x3(inplanes, planes, stride) | ||||
|         self.bn1 = nn.BatchNorm2d(planes) | ||||
|         self.relu = nn.ReLU(inplace=True) | ||||
|         self.conv2 = conv3x3(planes, planes) | ||||
|         self.bn2 = nn.BatchNorm2d(planes) | ||||
|         self.downsample = downsample | ||||
|         self.stride = stride | ||||
|  | ||||
|     def forward(self, x): | ||||
|         identity = x | ||||
|  | ||||
|         out = self.conv1(x) | ||||
|         out = self.bn1(out) | ||||
|         out = self.relu(out) | ||||
|  | ||||
|         out = self.conv2(out) | ||||
|         out = self.bn2(out) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             identity = self.downsample(x) | ||||
|  | ||||
|         out += identity | ||||
|         out = self.relu(out) | ||||
|  | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class Bottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|  | ||||
|     def __init__( | ||||
|         self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64 | ||||
|     ): | ||||
|         super(Bottleneck, self).__init__() | ||||
|         width = int(planes * (base_width / 64.0)) * groups | ||||
|         # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | ||||
|         self.conv1 = conv1x1(inplanes, width) | ||||
|         self.bn1 = nn.BatchNorm2d(width) | ||||
|         self.conv2 = conv3x3(width, width, stride, groups) | ||||
|         self.bn2 = nn.BatchNorm2d(width) | ||||
|         self.conv3 = conv1x1(width, planes * self.expansion) | ||||
|         self.bn3 = nn.BatchNorm2d(planes * self.expansion) | ||||
|         self.relu = nn.ReLU(inplace=True) | ||||
|         self.downsample = downsample | ||||
|         self.stride = stride | ||||
|  | ||||
|     def forward(self, x): | ||||
|         identity = x | ||||
|  | ||||
|         out = self.conv1(x) | ||||
|         out = self.bn1(out) | ||||
|         out = self.relu(out) | ||||
|  | ||||
|         out = self.conv2(out) | ||||
|         out = self.bn2(out) | ||||
|         out = self.relu(out) | ||||
|  | ||||
|         out = self.conv3(out) | ||||
|         out = self.bn3(out) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             identity = self.downsample(x) | ||||
|  | ||||
|         out += identity | ||||
|         out = self.relu(out) | ||||
|  | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNet(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         block_name, | ||||
|         layers, | ||||
|         deep_stem, | ||||
|         num_classes, | ||||
|         zero_init_residual, | ||||
|         groups, | ||||
|         width_per_group, | ||||
|     ): | ||||
|         super(ResNet, self).__init__() | ||||
|  | ||||
|         # planes = [int(width_per_group * groups * 2 ** i) for i in range(4)] | ||||
|         if block_name == "BasicBlock": | ||||
|             block = BasicBlock | ||||
|         elif block_name == "Bottleneck": | ||||
|             block = Bottleneck | ||||
|         else: | ||||
|             raise ValueError("invalid block-name : {:}".format(block_name)) | ||||
|  | ||||
|         if not deep_stem: | ||||
|             self.conv = nn.Sequential( | ||||
|                 nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), | ||||
|                 nn.BatchNorm2d(64), | ||||
|                 nn.ReLU(inplace=True), | ||||
|             ) | ||||
|         else: | ||||
|             self.conv = nn.Sequential( | ||||
|                 nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False), | ||||
|                 nn.BatchNorm2d(32), | ||||
|                 nn.ReLU(inplace=True), | ||||
|                 nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False), | ||||
|                 nn.BatchNorm2d(32), | ||||
|                 nn.ReLU(inplace=True), | ||||
|                 nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False), | ||||
|                 nn.BatchNorm2d(64), | ||||
|                 nn.ReLU(inplace=True), | ||||
|             ) | ||||
|         self.inplanes = 64 | ||||
|         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||||
|         self.layer1 = self._make_layer( | ||||
|             block, 64, layers[0], stride=1, groups=groups, base_width=width_per_group | ||||
|         ) | ||||
|         self.layer2 = self._make_layer( | ||||
|             block, 128, layers[1], stride=2, groups=groups, base_width=width_per_group | ||||
|         ) | ||||
|         self.layer3 = self._make_layer( | ||||
|             block, 256, layers[2], stride=2, groups=groups, base_width=width_per_group | ||||
|         ) | ||||
|         self.layer4 = self._make_layer( | ||||
|             block, 512, layers[3], stride=2, groups=groups, base_width=width_per_group | ||||
|         ) | ||||
|         self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||||
|         self.fc = nn.Linear(512 * block.expansion, num_classes) | ||||
|         self.message = ( | ||||
|             "block = {:}, layers = {:}, deep_stem = {:}, num_classes = {:}".format( | ||||
|                 block, layers, deep_stem, num_classes | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|         # Zero-initialize the last BN in each residual branch, | ||||
|         # so that the residual branch starts with zeros, and each residual block behaves like an identity. | ||||
|         # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | ||||
|         if zero_init_residual: | ||||
|             for m in self.modules(): | ||||
|                 if isinstance(m, Bottleneck): | ||||
|                     nn.init.constant_(m.bn3.weight, 0) | ||||
|                 elif isinstance(m, BasicBlock): | ||||
|                     nn.init.constant_(m.bn2.weight, 0) | ||||
|  | ||||
|     def _make_layer(self, block, planes, blocks, stride, groups, base_width): | ||||
|         downsample = None | ||||
|         if stride != 1 or self.inplanes != planes * block.expansion: | ||||
|             if stride == 2: | ||||
|                 downsample = nn.Sequential( | ||||
|                     nn.AvgPool2d(kernel_size=2, stride=2, padding=0), | ||||
|                     conv1x1(self.inplanes, planes * block.expansion, 1), | ||||
|                     nn.BatchNorm2d(planes * block.expansion), | ||||
|                 ) | ||||
|             elif stride == 1: | ||||
|                 downsample = nn.Sequential( | ||||
|                     conv1x1(self.inplanes, planes * block.expansion, stride), | ||||
|                     nn.BatchNorm2d(planes * block.expansion), | ||||
|                 ) | ||||
|             else: | ||||
|                 raise ValueError("invalid stride [{:}] for downsample".format(stride)) | ||||
|  | ||||
|         layers = [] | ||||
|         layers.append( | ||||
|             block(self.inplanes, planes, stride, downsample, groups, base_width) | ||||
|         ) | ||||
|         self.inplanes = planes * block.expansion | ||||
|         for _ in range(1, blocks): | ||||
|             layers.append(block(self.inplanes, planes, 1, None, groups, base_width)) | ||||
|  | ||||
|         return nn.Sequential(*layers) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.conv(x) | ||||
|         x = self.maxpool(x) | ||||
|  | ||||
|         x = self.layer1(x) | ||||
|         x = self.layer2(x) | ||||
|         x = self.layer3(x) | ||||
|         x = self.layer4(x) | ||||
|  | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.fc(features) | ||||
|  | ||||
|         return features, logits | ||||
							
								
								
									
										37
									
								
								AutoDL-Projects/xautodl/models/SharedUtils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								AutoDL-Projects/xautodl/models/SharedUtils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,37 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
|  | ||||
| def additive_func(A, B): | ||||
|     assert A.dim() == B.dim() and A.size(0) == B.size(0), "{:} vs {:}".format( | ||||
|         A.size(), B.size() | ||||
|     ) | ||||
|     C = min(A.size(1), B.size(1)) | ||||
|     if A.size(1) == B.size(1): | ||||
|         return A + B | ||||
|     elif A.size(1) < B.size(1): | ||||
|         out = B.clone() | ||||
|         out[:, :C] += A | ||||
|         return out | ||||
|     else: | ||||
|         out = A.clone() | ||||
|         out[:, :C] += B | ||||
|         return out | ||||
|  | ||||
|  | ||||
| def change_key(key, value): | ||||
|     def func(m): | ||||
|         if hasattr(m, key): | ||||
|             setattr(m, key, value) | ||||
|  | ||||
|     return func | ||||
|  | ||||
|  | ||||
| def parse_channel_info(xstring): | ||||
|     blocks = xstring.split(" ") | ||||
|     blocks = [x.split("-") for x in blocks] | ||||
|     blocks = [[int(_) for _ in x] for x in blocks] | ||||
|     return blocks | ||||
							
								
								
									
										326
									
								
								AutoDL-Projects/xautodl/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										326
									
								
								AutoDL-Projects/xautodl/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,326 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from os import path as osp | ||||
| from typing import List, Text | ||||
| import torch | ||||
|  | ||||
| __all__ = [ | ||||
|     "change_key", | ||||
|     "get_cell_based_tiny_net", | ||||
|     "get_search_spaces", | ||||
|     "get_cifar_models", | ||||
|     "get_imagenet_models", | ||||
|     "obtain_model", | ||||
|     "obtain_search_model", | ||||
|     "load_net_from_checkpoint", | ||||
|     "CellStructure", | ||||
|     "CellArchitectures", | ||||
| ] | ||||
|  | ||||
| # useful modules | ||||
| from xautodl.config_utils import dict2config | ||||
| from .SharedUtils import change_key | ||||
| from .cell_searchs import CellStructure, CellArchitectures | ||||
|  | ||||
|  | ||||
| # Cell-based NAS Models | ||||
| def get_cell_based_tiny_net(config): | ||||
|     if isinstance(config, dict): | ||||
|         config = dict2config(config, None)  # to support the argument being a dict | ||||
|     super_type = getattr(config, "super_type", "basic") | ||||
|     group_names = ["DARTS-V1", "DARTS-V2", "GDAS", "SETN", "ENAS", "RANDOM", "generic"] | ||||
|     if super_type == "basic" and config.name in group_names: | ||||
|         from .cell_searchs import nas201_super_nets as nas_super_nets | ||||
|  | ||||
|         try: | ||||
|             return nas_super_nets[config.name]( | ||||
|                 config.C, | ||||
|                 config.N, | ||||
|                 config.max_nodes, | ||||
|                 config.num_classes, | ||||
|                 config.space, | ||||
|                 config.affine, | ||||
|                 config.track_running_stats, | ||||
|             ) | ||||
|         except: | ||||
|             return nas_super_nets[config.name]( | ||||
|                 config.C, config.N, config.max_nodes, config.num_classes, config.space | ||||
|             ) | ||||
|     elif super_type == "search-shape": | ||||
|         from .shape_searchs import GenericNAS301Model | ||||
|  | ||||
|         genotype = CellStructure.str2structure(config.genotype) | ||||
|         return GenericNAS301Model( | ||||
|             config.candidate_Cs, | ||||
|             config.max_num_Cs, | ||||
|             genotype, | ||||
|             config.num_classes, | ||||
|             config.affine, | ||||
|             config.track_running_stats, | ||||
|         ) | ||||
|     elif super_type == "nasnet-super": | ||||
|         from .cell_searchs import nasnet_super_nets as nas_super_nets | ||||
|  | ||||
|         return nas_super_nets[config.name]( | ||||
|             config.C, | ||||
|             config.N, | ||||
|             config.steps, | ||||
|             config.multiplier, | ||||
|             config.stem_multiplier, | ||||
|             config.num_classes, | ||||
|             config.space, | ||||
|             config.affine, | ||||
|             config.track_running_stats, | ||||
|         ) | ||||
|     elif config.name == "infer.tiny": | ||||
|         from .cell_infers import TinyNetwork | ||||
|  | ||||
|         if hasattr(config, "genotype"): | ||||
|             genotype = config.genotype | ||||
|         elif hasattr(config, "arch_str"): | ||||
|             genotype = CellStructure.str2structure(config.arch_str) | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 "Can not find genotype from this config : {:}".format(config) | ||||
|             ) | ||||
|         return TinyNetwork(config.C, config.N, genotype, config.num_classes) | ||||
|     elif config.name == "infer.shape.tiny": | ||||
|         from .shape_infers import DynamicShapeTinyNet | ||||
|  | ||||
|         if isinstance(config.channels, str): | ||||
|             channels = tuple([int(x) for x in config.channels.split(":")]) | ||||
|         else: | ||||
|             channels = config.channels | ||||
|         genotype = CellStructure.str2structure(config.genotype) | ||||
|         return DynamicShapeTinyNet(channels, genotype, config.num_classes) | ||||
|     elif config.name == "infer.nasnet-cifar": | ||||
|         from .cell_infers import NASNetonCIFAR | ||||
|  | ||||
|         raise NotImplementedError | ||||
|     else: | ||||
|         raise ValueError("invalid network name : {:}".format(config.name)) | ||||
|  | ||||
|  | ||||
| # obtain the search space, i.e., a dict mapping the operation name into a python-function for this op | ||||
| def get_search_spaces(xtype, name) -> List[Text]: | ||||
|     if xtype == "cell" or xtype == "tss":  # The topology search space. | ||||
|         from .cell_operations import SearchSpaceNames | ||||
|  | ||||
|         assert name in SearchSpaceNames, "invalid name [{:}] in {:}".format( | ||||
|             name, SearchSpaceNames.keys() | ||||
|         ) | ||||
|         return SearchSpaceNames[name] | ||||
|     elif xtype == "sss":  # The size search space. | ||||
|         if name in ["nats-bench", "nats-bench-size"]: | ||||
|             return {"candidates": [8, 16, 24, 32, 40, 48, 56, 64], "numbers": 5} | ||||
|         else: | ||||
|             raise ValueError("Invalid name : {:}".format(name)) | ||||
|     else: | ||||
|         raise ValueError("invalid search-space type is {:}".format(xtype)) | ||||
|  | ||||
|  | ||||
| def get_cifar_models(config, extra_path=None): | ||||
|     super_type = getattr(config, "super_type", "basic") | ||||
|     if super_type == "basic": | ||||
|         from .CifarResNet import CifarResNet | ||||
|         from .CifarDenseNet import DenseNet | ||||
|         from .CifarWideResNet import CifarWideResNet | ||||
|  | ||||
|         if config.arch == "resnet": | ||||
|             return CifarResNet( | ||||
|                 config.module, config.depth, config.class_num, config.zero_init_residual | ||||
|             ) | ||||
|         elif config.arch == "densenet": | ||||
|             return DenseNet( | ||||
|                 config.growthRate, | ||||
|                 config.depth, | ||||
|                 config.reduction, | ||||
|                 config.class_num, | ||||
|                 config.bottleneck, | ||||
|             ) | ||||
|         elif config.arch == "wideresnet": | ||||
|             return CifarWideResNet( | ||||
|                 config.depth, config.wide_factor, config.class_num, config.dropout | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("invalid module type : {:}".format(config.arch)) | ||||
|     elif super_type.startswith("infer"): | ||||
|         from .shape_infers import InferWidthCifarResNet | ||||
|         from .shape_infers import InferDepthCifarResNet | ||||
|         from .shape_infers import InferCifarResNet | ||||
|         from .cell_infers import NASNetonCIFAR | ||||
|  | ||||
|         assert len(super_type.split("-")) == 2, "invalid super_type : {:}".format( | ||||
|             super_type | ||||
|         ) | ||||
|         infer_mode = super_type.split("-")[1] | ||||
|         if infer_mode == "width": | ||||
|             return InferWidthCifarResNet( | ||||
|                 config.module, | ||||
|                 config.depth, | ||||
|                 config.xchannels, | ||||
|                 config.class_num, | ||||
|                 config.zero_init_residual, | ||||
|             ) | ||||
|         elif infer_mode == "depth": | ||||
|             return InferDepthCifarResNet( | ||||
|                 config.module, | ||||
|                 config.depth, | ||||
|                 config.xblocks, | ||||
|                 config.class_num, | ||||
|                 config.zero_init_residual, | ||||
|             ) | ||||
|         elif infer_mode == "shape": | ||||
|             return InferCifarResNet( | ||||
|                 config.module, | ||||
|                 config.depth, | ||||
|                 config.xblocks, | ||||
|                 config.xchannels, | ||||
|                 config.class_num, | ||||
|                 config.zero_init_residual, | ||||
|             ) | ||||
|         elif infer_mode == "nasnet.cifar": | ||||
|             genotype = config.genotype | ||||
|             if extra_path is not None:  # reload genotype by extra_path | ||||
|                 if not osp.isfile(extra_path): | ||||
|                     raise ValueError("invalid extra_path : {:}".format(extra_path)) | ||||
|                 xdata = torch.load(extra_path) | ||||
|                 current_epoch = xdata["epoch"] | ||||
|                 genotype = xdata["genotypes"][current_epoch - 1] | ||||
|             C = config.C if hasattr(config, "C") else config.ichannel | ||||
|             N = config.N if hasattr(config, "N") else config.layers | ||||
|             return NASNetonCIFAR( | ||||
|                 C, N, config.stem_multi, config.class_num, genotype, config.auxiliary | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("invalid infer-mode : {:}".format(infer_mode)) | ||||
|     else: | ||||
|         raise ValueError("invalid super-type : {:}".format(super_type)) | ||||
|  | ||||
|  | ||||
| def get_imagenet_models(config): | ||||
|     super_type = getattr(config, "super_type", "basic") | ||||
|     if super_type == "basic": | ||||
|         from .ImageNet_ResNet import ResNet | ||||
|         from .ImageNet_MobileNetV2 import MobileNetV2 | ||||
|  | ||||
|         if config.arch == "resnet": | ||||
|             return ResNet( | ||||
|                 config.block_name, | ||||
|                 config.layers, | ||||
|                 config.deep_stem, | ||||
|                 config.class_num, | ||||
|                 config.zero_init_residual, | ||||
|                 config.groups, | ||||
|                 config.width_per_group, | ||||
|             ) | ||||
|         elif config.arch == "mobilenet_v2": | ||||
|             return MobileNetV2( | ||||
|                 config.class_num, | ||||
|                 config.width_multi, | ||||
|                 config.input_channel, | ||||
|                 config.last_channel, | ||||
|                 "InvertedResidual", | ||||
|                 config.dropout, | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("invalid arch : {:}".format(config.arch)) | ||||
|     elif super_type.startswith("infer"):  # NAS searched architecture | ||||
|         assert len(super_type.split("-")) == 2, "invalid super_type : {:}".format( | ||||
|             super_type | ||||
|         ) | ||||
|         infer_mode = super_type.split("-")[1] | ||||
|         if infer_mode == "shape": | ||||
|             from .shape_infers import InferImagenetResNet | ||||
|             from .shape_infers import InferMobileNetV2 | ||||
|  | ||||
|             if config.arch == "resnet": | ||||
|                 return InferImagenetResNet( | ||||
|                     config.block_name, | ||||
|                     config.layers, | ||||
|                     config.xblocks, | ||||
|                     config.xchannels, | ||||
|                     config.deep_stem, | ||||
|                     config.class_num, | ||||
|                     config.zero_init_residual, | ||||
|                 ) | ||||
|             elif config.arch == "MobileNetV2": | ||||
|                 return InferMobileNetV2( | ||||
|                     config.class_num, config.xchannels, config.xblocks, config.dropout | ||||
|                 ) | ||||
|             else: | ||||
|                 raise ValueError("invalid arch-mode : {:}".format(config.arch)) | ||||
|         else: | ||||
|             raise ValueError("invalid infer-mode : {:}".format(infer_mode)) | ||||
|     else: | ||||
|         raise ValueError("invalid super-type : {:}".format(super_type)) | ||||
|  | ||||
|  | ||||
| # Try to obtain the network by config. | ||||
| def obtain_model(config, extra_path=None): | ||||
|     if config.dataset == "cifar": | ||||
|         return get_cifar_models(config, extra_path) | ||||
|     elif config.dataset == "imagenet": | ||||
|         return get_imagenet_models(config) | ||||
|     else: | ||||
|         raise ValueError("invalid dataset in the model config : {:}".format(config)) | ||||
|  | ||||
|  | ||||
| def obtain_search_model(config): | ||||
|     if config.dataset == "cifar": | ||||
|         if config.arch == "resnet": | ||||
|             from .shape_searchs import SearchWidthCifarResNet | ||||
|             from .shape_searchs import SearchDepthCifarResNet | ||||
|             from .shape_searchs import SearchShapeCifarResNet | ||||
|  | ||||
|             if config.search_mode == "width": | ||||
|                 return SearchWidthCifarResNet( | ||||
|                     config.module, config.depth, config.class_num | ||||
|                 ) | ||||
|             elif config.search_mode == "depth": | ||||
|                 return SearchDepthCifarResNet( | ||||
|                     config.module, config.depth, config.class_num | ||||
|                 ) | ||||
|             elif config.search_mode == "shape": | ||||
|                 return SearchShapeCifarResNet( | ||||
|                     config.module, config.depth, config.class_num | ||||
|                 ) | ||||
|             else: | ||||
|                 raise ValueError("invalid search mode : {:}".format(config.search_mode)) | ||||
|         elif config.arch == "simres": | ||||
|             from .shape_searchs import SearchWidthSimResNet | ||||
|  | ||||
|             if config.search_mode == "width": | ||||
|                 return SearchWidthSimResNet(config.depth, config.class_num) | ||||
|             else: | ||||
|                 raise ValueError("invalid search mode : {:}".format(config.search_mode)) | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 "invalid arch : {:} for dataset [{:}]".format( | ||||
|                     config.arch, config.dataset | ||||
|                 ) | ||||
|             ) | ||||
|     elif config.dataset == "imagenet": | ||||
|         from .shape_searchs import SearchShapeImagenetResNet | ||||
|  | ||||
|         assert config.search_mode == "shape", "invalid search-mode : {:}".format( | ||||
|             config.search_mode | ||||
|         ) | ||||
|         if config.arch == "resnet": | ||||
|             return SearchShapeImagenetResNet( | ||||
|                 config.block_name, config.layers, config.deep_stem, config.class_num | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("invalid model config : {:}".format(config)) | ||||
|     else: | ||||
|         raise ValueError("invalid dataset in the model config : {:}".format(config)) | ||||
|  | ||||
|  | ||||
| def load_net_from_checkpoint(checkpoint): | ||||
|     assert osp.isfile(checkpoint), "checkpoint {:} does not exist".format(checkpoint) | ||||
|     checkpoint = torch.load(checkpoint) | ||||
|     model_config = dict2config(checkpoint["model-config"], None) | ||||
|     model = obtain_model(model_config) | ||||
|     model.load_state_dict(checkpoint["base-model"]) | ||||
|     return model | ||||
							
								
								
									
										5
									
								
								AutoDL-Projects/xautodl/models/cell_infers/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								AutoDL-Projects/xautodl/models/cell_infers/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| from .tiny_network import TinyNetwork | ||||
| from .nasnet_cifar import NASNetonCIFAR | ||||
							
								
								
									
										155
									
								
								AutoDL-Projects/xautodl/models/cell_infers/cells.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										155
									
								
								AutoDL-Projects/xautodl/models/cell_infers/cells.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,155 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
|  | ||||
| from xautodl.models.cell_operations import OPS | ||||
|  | ||||
|  | ||||
| # Cell for NAS-Bench-201 | ||||
| class InferCell(nn.Module): | ||||
|     def __init__( | ||||
|         self, genotype, C_in, C_out, stride, affine=True, track_running_stats=True | ||||
|     ): | ||||
|         super(InferCell, self).__init__() | ||||
|  | ||||
|         self.layers = nn.ModuleList() | ||||
|         self.node_IN = [] | ||||
|         self.node_IX = [] | ||||
|         self.genotype = deepcopy(genotype) | ||||
|         for i in range(1, len(genotype)): | ||||
|             node_info = genotype[i - 1] | ||||
|             cur_index = [] | ||||
|             cur_innod = [] | ||||
|             for (op_name, op_in) in node_info: | ||||
|                 if op_in == 0: | ||||
|                     layer = OPS[op_name]( | ||||
|                         C_in, C_out, stride, affine, track_running_stats | ||||
|                     ) | ||||
|                 else: | ||||
|                     layer = OPS[op_name](C_out, C_out, 1, affine, track_running_stats) | ||||
|                 cur_index.append(len(self.layers)) | ||||
|                 cur_innod.append(op_in) | ||||
|                 self.layers.append(layer) | ||||
|             self.node_IX.append(cur_index) | ||||
|             self.node_IN.append(cur_innod) | ||||
|         self.nodes = len(genotype) | ||||
|         self.in_dim = C_in | ||||
|         self.out_dim = C_out | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         string = "info :: nodes={nodes}, inC={in_dim}, outC={out_dim}".format( | ||||
|             **self.__dict__ | ||||
|         ) | ||||
|         laystr = [] | ||||
|         for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)): | ||||
|             y = [ | ||||
|                 "I{:}-L{:}".format(_ii, _il) | ||||
|                 for _il, _ii in zip(node_layers, node_innods) | ||||
|             ] | ||||
|             x = "{:}<-({:})".format(i + 1, ",".join(y)) | ||||
|             laystr.append(x) | ||||
|         return ( | ||||
|             string | ||||
|             + ", [{:}]".format(" | ".join(laystr)) | ||||
|             + ", {:}".format(self.genotype.tostr()) | ||||
|         ) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         nodes = [inputs] | ||||
|         for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)): | ||||
|             node_feature = sum( | ||||
|                 self.layers[_il](nodes[_ii]) | ||||
|                 for _il, _ii in zip(node_layers, node_innods) | ||||
|             ) | ||||
|             nodes.append(node_feature) | ||||
|         return nodes[-1] | ||||
|  | ||||
|  | ||||
| # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 | ||||
| class NASNetInferCell(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         genotype, | ||||
|         C_prev_prev, | ||||
|         C_prev, | ||||
|         C, | ||||
|         reduction, | ||||
|         reduction_prev, | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ): | ||||
|         super(NASNetInferCell, self).__init__() | ||||
|         self.reduction = reduction | ||||
|         if reduction_prev: | ||||
|             self.preprocess0 = OPS["skip_connect"]( | ||||
|                 C_prev_prev, C, 2, affine, track_running_stats | ||||
|             ) | ||||
|         else: | ||||
|             self.preprocess0 = OPS["nor_conv_1x1"]( | ||||
|                 C_prev_prev, C, 1, affine, track_running_stats | ||||
|             ) | ||||
|         self.preprocess1 = OPS["nor_conv_1x1"]( | ||||
|             C_prev, C, 1, affine, track_running_stats | ||||
|         ) | ||||
|  | ||||
|         if not reduction: | ||||
|             nodes, concats = genotype["normal"], genotype["normal_concat"] | ||||
|         else: | ||||
|             nodes, concats = genotype["reduce"], genotype["reduce_concat"] | ||||
|         self._multiplier = len(concats) | ||||
|         self._concats = concats | ||||
|         self._steps = len(nodes) | ||||
|         self._nodes = nodes | ||||
|         self.edges = nn.ModuleDict() | ||||
|         for i, node in enumerate(nodes): | ||||
|             for in_node in node: | ||||
|                 name, j = in_node[0], in_node[1] | ||||
|                 stride = 2 if reduction and j < 2 else 1 | ||||
|                 node_str = "{:}<-{:}".format(i + 2, j) | ||||
|                 self.edges[node_str] = OPS[name]( | ||||
|                     C, C, stride, affine, track_running_stats | ||||
|                 ) | ||||
|  | ||||
|     # [TODO] to support drop_prob in this function.. | ||||
|     def forward(self, s0, s1, unused_drop_prob): | ||||
|         s0 = self.preprocess0(s0) | ||||
|         s1 = self.preprocess1(s1) | ||||
|  | ||||
|         states = [s0, s1] | ||||
|         for i, node in enumerate(self._nodes): | ||||
|             clist = [] | ||||
|             for in_node in node: | ||||
|                 name, j = in_node[0], in_node[1] | ||||
|                 node_str = "{:}<-{:}".format(i + 2, j) | ||||
|                 op = self.edges[node_str] | ||||
|                 clist.append(op(states[j])) | ||||
|             states.append(sum(clist)) | ||||
|         return torch.cat([states[x] for x in self._concats], dim=1) | ||||
|  | ||||
|  | ||||
| class AuxiliaryHeadCIFAR(nn.Module): | ||||
|     def __init__(self, C, num_classes): | ||||
|         """assuming input size 8x8""" | ||||
|         super(AuxiliaryHeadCIFAR, self).__init__() | ||||
|         self.features = nn.Sequential( | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.AvgPool2d( | ||||
|                 5, stride=3, padding=0, count_include_pad=False | ||||
|             ),  # image size = 2 x 2 | ||||
|             nn.Conv2d(C, 128, 1, bias=False), | ||||
|             nn.BatchNorm2d(128), | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.Conv2d(128, 768, 2, bias=False), | ||||
|             nn.BatchNorm2d(768), | ||||
|             nn.ReLU(inplace=True), | ||||
|         ) | ||||
|         self.classifier = nn.Linear(768, num_classes) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.features(x) | ||||
|         x = self.classifier(x.view(x.size(0), -1)) | ||||
|         return x | ||||
							
								
								
									
										118
									
								
								AutoDL-Projects/xautodl/models/cell_infers/nasnet_cifar.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										118
									
								
								AutoDL-Projects/xautodl/models/cell_infers/nasnet_cifar.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,118 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
|  | ||||
| from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR | ||||
|  | ||||
|  | ||||
| # The macro structure is based on NASNet | ||||
| class NASNetonCIFAR(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         C, | ||||
|         N, | ||||
|         stem_multiplier, | ||||
|         num_classes, | ||||
|         genotype, | ||||
|         auxiliary, | ||||
|         affine=True, | ||||
|         track_running_stats=True, | ||||
|     ): | ||||
|         super(NASNetonCIFAR, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C * stem_multiplier), | ||||
|         ) | ||||
|  | ||||
|         # config for each layer | ||||
|         layer_channels = ( | ||||
|             [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1) | ||||
|         ) | ||||
|         layer_reductions = ( | ||||
|             [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1) | ||||
|         ) | ||||
|  | ||||
|         C_prev_prev, C_prev, C_curr, reduction_prev = ( | ||||
|             C * stem_multiplier, | ||||
|             C * stem_multiplier, | ||||
|             C, | ||||
|             False, | ||||
|         ) | ||||
|         self.auxiliary_index = None | ||||
|         self.auxiliary_head = None | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             cell = InferCell( | ||||
|                 genotype, | ||||
|                 C_prev_prev, | ||||
|                 C_prev, | ||||
|                 C_curr, | ||||
|                 reduction, | ||||
|                 reduction_prev, | ||||
|                 affine, | ||||
|                 track_running_stats, | ||||
|             ) | ||||
|             self.cells.append(cell) | ||||
|             C_prev_prev, C_prev, reduction_prev = ( | ||||
|                 C_prev, | ||||
|                 cell._multiplier * C_curr, | ||||
|                 reduction, | ||||
|             ) | ||||
|             if reduction and C_curr == C * 4 and auxiliary: | ||||
|                 self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes) | ||||
|                 self.auxiliary_index = index | ||||
|         self._Layer = len(self.cells) | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self.drop_path_prob = -1 | ||||
|  | ||||
|     def update_drop_path(self, drop_path_prob): | ||||
|         self.drop_path_prob = drop_path_prob | ||||
|  | ||||
|     def auxiliary_param(self): | ||||
|         if self.auxiliary_head is None: | ||||
|             return [] | ||||
|         else: | ||||
|             return list(self.auxiliary_head.parameters()) | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, N={_layerN}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         stem_feature, logits_aux = self.stem(inputs), None | ||||
|         cell_results = [stem_feature, stem_feature] | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob) | ||||
|             cell_results.append(cell_feature) | ||||
|             if ( | ||||
|                 self.auxiliary_index is not None | ||||
|                 and i == self.auxiliary_index | ||||
|                 and self.training | ||||
|             ): | ||||
|                 logits_aux = self.auxiliary_head(cell_results[-1]) | ||||
|         out = self.lastact(cell_results[-1]) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|         if logits_aux is None: | ||||
|             return out, logits | ||||
|         else: | ||||
|             return out, [logits, logits_aux] | ||||
							
								
								
									
										63
									
								
								AutoDL-Projects/xautodl/models/cell_infers/tiny_network.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								AutoDL-Projects/xautodl/models/cell_infers/tiny_network.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch.nn as nn | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .cells import InferCell | ||||
|  | ||||
|  | ||||
| # The macro structure for architectures in NAS-Bench-201 | ||||
| class TinyNetwork(nn.Module): | ||||
|     def __init__(self, C, N, genotype, num_classes): | ||||
|         super(TinyNetwork, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|  | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) | ||||
|         ) | ||||
|  | ||||
|         layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|         C_prev = C | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(C_prev, C_curr, 2, True) | ||||
|             else: | ||||
|                 cell = InferCell(genotype, C_prev, C_curr, 1) | ||||
|             self.cells.append(cell) | ||||
|             C_prev = cell.out_dim | ||||
|         self._Layer = len(self.cells) | ||||
|  | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, N={_layerN}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         feature = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             feature = cell(feature) | ||||
|  | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
							
								
								
									
										553
									
								
								AutoDL-Projects/xautodl/models/cell_operations.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										553
									
								
								AutoDL-Projects/xautodl/models/cell_operations.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,553 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| __all__ = ["OPS", "RAW_OP_CLASSES", "ResNetBasicblock", "SearchSpaceNames"] | ||||
|  | ||||
| OPS = { | ||||
|     "none": lambda C_in, C_out, stride, affine, track_running_stats: Zero( | ||||
|         C_in, C_out, stride | ||||
|     ), | ||||
|     "avg_pool_3x3": lambda C_in, C_out, stride, affine, track_running_stats: POOLING( | ||||
|         C_in, C_out, stride, "avg", affine, track_running_stats | ||||
|     ), | ||||
|     "max_pool_3x3": lambda C_in, C_out, stride, affine, track_running_stats: POOLING( | ||||
|         C_in, C_out, stride, "max", affine, track_running_stats | ||||
|     ), | ||||
|     "nor_conv_7x7": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN( | ||||
|         C_in, | ||||
|         C_out, | ||||
|         (7, 7), | ||||
|         (stride, stride), | ||||
|         (3, 3), | ||||
|         (1, 1), | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ), | ||||
|     "nor_conv_3x3": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN( | ||||
|         C_in, | ||||
|         C_out, | ||||
|         (3, 3), | ||||
|         (stride, stride), | ||||
|         (1, 1), | ||||
|         (1, 1), | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ), | ||||
|     "nor_conv_1x1": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN( | ||||
|         C_in, | ||||
|         C_out, | ||||
|         (1, 1), | ||||
|         (stride, stride), | ||||
|         (0, 0), | ||||
|         (1, 1), | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ), | ||||
|     "dua_sepc_3x3": lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv( | ||||
|         C_in, | ||||
|         C_out, | ||||
|         (3, 3), | ||||
|         (stride, stride), | ||||
|         (1, 1), | ||||
|         (1, 1), | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ), | ||||
|     "dua_sepc_5x5": lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv( | ||||
|         C_in, | ||||
|         C_out, | ||||
|         (5, 5), | ||||
|         (stride, stride), | ||||
|         (2, 2), | ||||
|         (1, 1), | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ), | ||||
|     "dil_sepc_3x3": lambda C_in, C_out, stride, affine, track_running_stats: SepConv( | ||||
|         C_in, | ||||
|         C_out, | ||||
|         (3, 3), | ||||
|         (stride, stride), | ||||
|         (2, 2), | ||||
|         (2, 2), | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ), | ||||
|     "dil_sepc_5x5": lambda C_in, C_out, stride, affine, track_running_stats: SepConv( | ||||
|         C_in, | ||||
|         C_out, | ||||
|         (5, 5), | ||||
|         (stride, stride), | ||||
|         (4, 4), | ||||
|         (2, 2), | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ), | ||||
|     "skip_connect": lambda C_in, C_out, stride, affine, track_running_stats: Identity() | ||||
|     if stride == 1 and C_in == C_out | ||||
|     else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats), | ||||
| } | ||||
|  | ||||
| CONNECT_NAS_BENCHMARK = ["none", "skip_connect", "nor_conv_3x3"] | ||||
| NAS_BENCH_201 = ["none", "skip_connect", "nor_conv_1x1", "nor_conv_3x3", "avg_pool_3x3"] | ||||
| DARTS_SPACE = [ | ||||
|     "none", | ||||
|     "skip_connect", | ||||
|     "dua_sepc_3x3", | ||||
|     "dua_sepc_5x5", | ||||
|     "dil_sepc_3x3", | ||||
|     "dil_sepc_5x5", | ||||
|     "avg_pool_3x3", | ||||
|     "max_pool_3x3", | ||||
| ] | ||||
|  | ||||
| SearchSpaceNames = { | ||||
|     "connect-nas": CONNECT_NAS_BENCHMARK, | ||||
|     "nats-bench": NAS_BENCH_201, | ||||
|     "nas-bench-201": NAS_BENCH_201, | ||||
|     "darts": DARTS_SPACE, | ||||
| } | ||||
|  | ||||
|  | ||||
| class ReLUConvBN(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         C_in, | ||||
|         C_out, | ||||
|         kernel_size, | ||||
|         stride, | ||||
|         padding, | ||||
|         dilation, | ||||
|         affine, | ||||
|         track_running_stats=True, | ||||
|     ): | ||||
|         super(ReLUConvBN, self).__init__() | ||||
|         self.op = nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d( | ||||
|                 C_in, | ||||
|                 C_out, | ||||
|                 kernel_size, | ||||
|                 stride=stride, | ||||
|                 padding=padding, | ||||
|                 dilation=dilation, | ||||
|                 bias=not affine, | ||||
|             ), | ||||
|             nn.BatchNorm2d( | ||||
|                 C_out, affine=affine, track_running_stats=track_running_stats | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.op(x) | ||||
|  | ||||
|  | ||||
| class SepConv(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         C_in, | ||||
|         C_out, | ||||
|         kernel_size, | ||||
|         stride, | ||||
|         padding, | ||||
|         dilation, | ||||
|         affine, | ||||
|         track_running_stats=True, | ||||
|     ): | ||||
|         super(SepConv, self).__init__() | ||||
|         self.op = nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d( | ||||
|                 C_in, | ||||
|                 C_in, | ||||
|                 kernel_size=kernel_size, | ||||
|                 stride=stride, | ||||
|                 padding=padding, | ||||
|                 dilation=dilation, | ||||
|                 groups=C_in, | ||||
|                 bias=False, | ||||
|             ), | ||||
|             nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=not affine), | ||||
|             nn.BatchNorm2d( | ||||
|                 C_out, affine=affine, track_running_stats=track_running_stats | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.op(x) | ||||
|  | ||||
|  | ||||
| class DualSepConv(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         C_in, | ||||
|         C_out, | ||||
|         kernel_size, | ||||
|         stride, | ||||
|         padding, | ||||
|         dilation, | ||||
|         affine, | ||||
|         track_running_stats=True, | ||||
|     ): | ||||
|         super(DualSepConv, self).__init__() | ||||
|         self.op_a = SepConv( | ||||
|             C_in, | ||||
|             C_in, | ||||
|             kernel_size, | ||||
|             stride, | ||||
|             padding, | ||||
|             dilation, | ||||
|             affine, | ||||
|             track_running_stats, | ||||
|         ) | ||||
|         self.op_b = SepConv( | ||||
|             C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.op_a(x) | ||||
|         x = self.op_b(x) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     def __init__(self, inplanes, planes, stride, affine=True, track_running_stats=True): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_a = ReLUConvBN( | ||||
|             inplanes, planes, 3, stride, 1, 1, affine, track_running_stats | ||||
|         ) | ||||
|         self.conv_b = ReLUConvBN( | ||||
|             planes, planes, 3, 1, 1, 1, affine, track_running_stats | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = nn.Sequential( | ||||
|                 nn.AvgPool2d(kernel_size=2, stride=2, padding=0), | ||||
|                 nn.Conv2d( | ||||
|                     inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False | ||||
|                 ), | ||||
|             ) | ||||
|         elif inplanes != planes: | ||||
|             self.downsample = ReLUConvBN( | ||||
|                 inplanes, planes, 1, 1, 0, 1, affine, track_running_stats | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.in_dim = inplanes | ||||
|         self.out_dim = planes | ||||
|         self.stride = stride | ||||
|         self.num_conv = 2 | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         string = "{name}(inC={in_dim}, outC={out_dim}, stride={stride})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|         return string | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         return residual + basicblock | ||||
|  | ||||
|  | ||||
| class POOLING(nn.Module): | ||||
|     def __init__( | ||||
|         self, C_in, C_out, stride, mode, affine=True, track_running_stats=True | ||||
|     ): | ||||
|         super(POOLING, self).__init__() | ||||
|         if C_in == C_out: | ||||
|             self.preprocess = None | ||||
|         else: | ||||
|             self.preprocess = ReLUConvBN( | ||||
|                 C_in, C_out, 1, 1, 0, 1, affine, track_running_stats | ||||
|             ) | ||||
|         if mode == "avg": | ||||
|             self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) | ||||
|         elif mode == "max": | ||||
|             self.op = nn.MaxPool2d(3, stride=stride, padding=1) | ||||
|         else: | ||||
|             raise ValueError("Invalid mode={:} in POOLING".format(mode)) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.preprocess: | ||||
|             x = self.preprocess(inputs) | ||||
|         else: | ||||
|             x = inputs | ||||
|         return self.op(x) | ||||
|  | ||||
|  | ||||
| class Identity(nn.Module): | ||||
|     def __init__(self): | ||||
|         super(Identity, self).__init__() | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class Zero(nn.Module): | ||||
|     def __init__(self, C_in, C_out, stride): | ||||
|         super(Zero, self).__init__() | ||||
|         self.C_in = C_in | ||||
|         self.C_out = C_out | ||||
|         self.stride = stride | ||||
|         self.is_zero = True | ||||
|  | ||||
|     def forward(self, x): | ||||
|         if self.C_in == self.C_out: | ||||
|             if self.stride == 1: | ||||
|                 return x.mul(0.0) | ||||
|             else: | ||||
|                 return x[:, :, :: self.stride, :: self.stride].mul(0.0) | ||||
|         else: | ||||
|             shape = list(x.shape) | ||||
|             shape[1] = self.C_out | ||||
|             zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device) | ||||
|             return zeros | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "C_in={C_in}, C_out={C_out}, stride={stride}".format(**self.__dict__) | ||||
|  | ||||
|  | ||||
| class FactorizedReduce(nn.Module): | ||||
|     def __init__(self, C_in, C_out, stride, affine, track_running_stats): | ||||
|         super(FactorizedReduce, self).__init__() | ||||
|         self.stride = stride | ||||
|         self.C_in = C_in | ||||
|         self.C_out = C_out | ||||
|         self.relu = nn.ReLU(inplace=False) | ||||
|         if stride == 2: | ||||
|             # assert C_out % 2 == 0, 'C_out : {:}'.format(C_out) | ||||
|             C_outs = [C_out // 2, C_out - C_out // 2] | ||||
|             self.convs = nn.ModuleList() | ||||
|             for i in range(2): | ||||
|                 self.convs.append( | ||||
|                     nn.Conv2d( | ||||
|                         C_in, C_outs[i], 1, stride=stride, padding=0, bias=not affine | ||||
|                     ) | ||||
|                 ) | ||||
|             self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) | ||||
|         elif stride == 1: | ||||
|             self.conv = nn.Conv2d( | ||||
|                 C_in, C_out, 1, stride=stride, padding=0, bias=not affine | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("Invalid stride : {:}".format(stride)) | ||||
|         self.bn = nn.BatchNorm2d( | ||||
|             C_out, affine=affine, track_running_stats=track_running_stats | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         if self.stride == 2: | ||||
|             x = self.relu(x) | ||||
|             y = self.pad(x) | ||||
|             out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1) | ||||
|         else: | ||||
|             out = self.conv(x) | ||||
|         out = self.bn(out) | ||||
|         return out | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "C_in={C_in}, C_out={C_out}, stride={stride}".format(**self.__dict__) | ||||
|  | ||||
|  | ||||
| # Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019 | ||||
| class PartAwareOp(nn.Module): | ||||
|     def __init__(self, C_in, C_out, stride, part=4): | ||||
|         super().__init__() | ||||
|         self.part = 4 | ||||
|         self.hidden = C_in // 3 | ||||
|         self.avg_pool = nn.AdaptiveAvgPool2d(1) | ||||
|         self.local_conv_list = nn.ModuleList() | ||||
|         for i in range(self.part): | ||||
|             self.local_conv_list.append( | ||||
|                 nn.Sequential( | ||||
|                     nn.ReLU(), | ||||
|                     nn.Conv2d(C_in, self.hidden, 1), | ||||
|                     nn.BatchNorm2d(self.hidden, affine=True), | ||||
|                 ) | ||||
|             ) | ||||
|         self.W_K = nn.Linear(self.hidden, self.hidden) | ||||
|         self.W_Q = nn.Linear(self.hidden, self.hidden) | ||||
|  | ||||
|         if stride == 2: | ||||
|             self.last = FactorizedReduce(C_in + self.hidden, C_out, 2) | ||||
|         elif stride == 1: | ||||
|             self.last = FactorizedReduce(C_in + self.hidden, C_out, 1) | ||||
|         else: | ||||
|             raise ValueError("Invalid Stride : {:}".format(stride)) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         batch, C, H, W = x.size() | ||||
|         assert H >= self.part, "input size too small : {:} vs {:}".format( | ||||
|             x.shape, self.part | ||||
|         ) | ||||
|         IHs = [0] | ||||
|         for i in range(self.part): | ||||
|             IHs.append(min(H, int((i + 1) * (float(H) / self.part)))) | ||||
|         local_feat_list = [] | ||||
|         for i in range(self.part): | ||||
|             feature = x[:, :, IHs[i] : IHs[i + 1], :] | ||||
|             xfeax = self.avg_pool(feature) | ||||
|             xfea = self.local_conv_list[i](xfeax) | ||||
|             local_feat_list.append(xfea) | ||||
|         part_feature = torch.cat(local_feat_list, dim=2).view(batch, -1, self.part) | ||||
|         part_feature = part_feature.transpose(1, 2).contiguous() | ||||
|         part_K = self.W_K(part_feature) | ||||
|         part_Q = self.W_Q(part_feature).transpose(1, 2).contiguous() | ||||
|         weight_att = torch.bmm(part_K, part_Q) | ||||
|         attention = torch.softmax(weight_att, dim=2) | ||||
|         aggreateF = torch.bmm(attention, part_feature).transpose(1, 2).contiguous() | ||||
|         features = [] | ||||
|         for i in range(self.part): | ||||
|             feature = aggreateF[:, :, i : i + 1].expand( | ||||
|                 batch, self.hidden, IHs[i + 1] - IHs[i] | ||||
|             ) | ||||
|             feature = feature.view(batch, self.hidden, IHs[i + 1] - IHs[i], 1) | ||||
|             features.append(feature) | ||||
|         features = torch.cat(features, dim=2).expand(batch, self.hidden, H, W) | ||||
|         final_fea = torch.cat((x, features), dim=1) | ||||
|         outputs = self.last(final_fea) | ||||
|         return outputs | ||||
|  | ||||
|  | ||||
| def drop_path(x, drop_prob): | ||||
|     if drop_prob > 0.0: | ||||
|         keep_prob = 1.0 - drop_prob | ||||
|         mask = x.new_zeros(x.size(0), 1, 1, 1) | ||||
|         mask = mask.bernoulli_(keep_prob) | ||||
|         x = torch.div(x, keep_prob) | ||||
|         x.mul_(mask) | ||||
|     return x | ||||
|  | ||||
|  | ||||
| # Searching for A Robust Neural Architecture in Four GPU Hours | ||||
| class GDAS_Reduction_Cell(nn.Module): | ||||
|     def __init__( | ||||
|         self, C_prev_prev, C_prev, C, reduction_prev, affine, track_running_stats | ||||
|     ): | ||||
|         super(GDAS_Reduction_Cell, self).__init__() | ||||
|         if reduction_prev: | ||||
|             self.preprocess0 = FactorizedReduce( | ||||
|                 C_prev_prev, C, 2, affine, track_running_stats | ||||
|             ) | ||||
|         else: | ||||
|             self.preprocess0 = ReLUConvBN( | ||||
|                 C_prev_prev, C, 1, 1, 0, 1, affine, track_running_stats | ||||
|             ) | ||||
|         self.preprocess1 = ReLUConvBN( | ||||
|             C_prev, C, 1, 1, 0, 1, affine, track_running_stats | ||||
|         ) | ||||
|  | ||||
|         self.reduction = True | ||||
|         self.ops1 = nn.ModuleList( | ||||
|             [ | ||||
|                 nn.Sequential( | ||||
|                     nn.ReLU(inplace=False), | ||||
|                     nn.Conv2d( | ||||
|                         C, | ||||
|                         C, | ||||
|                         (1, 3), | ||||
|                         stride=(1, 2), | ||||
|                         padding=(0, 1), | ||||
|                         groups=8, | ||||
|                         bias=not affine, | ||||
|                     ), | ||||
|                     nn.Conv2d( | ||||
|                         C, | ||||
|                         C, | ||||
|                         (3, 1), | ||||
|                         stride=(2, 1), | ||||
|                         padding=(1, 0), | ||||
|                         groups=8, | ||||
|                         bias=not affine, | ||||
|                     ), | ||||
|                     nn.BatchNorm2d( | ||||
|                         C, affine=affine, track_running_stats=track_running_stats | ||||
|                     ), | ||||
|                     nn.ReLU(inplace=False), | ||||
|                     nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine), | ||||
|                     nn.BatchNorm2d( | ||||
|                         C, affine=affine, track_running_stats=track_running_stats | ||||
|                     ), | ||||
|                 ), | ||||
|                 nn.Sequential( | ||||
|                     nn.ReLU(inplace=False), | ||||
|                     nn.Conv2d( | ||||
|                         C, | ||||
|                         C, | ||||
|                         (1, 3), | ||||
|                         stride=(1, 2), | ||||
|                         padding=(0, 1), | ||||
|                         groups=8, | ||||
|                         bias=not affine, | ||||
|                     ), | ||||
|                     nn.Conv2d( | ||||
|                         C, | ||||
|                         C, | ||||
|                         (3, 1), | ||||
|                         stride=(2, 1), | ||||
|                         padding=(1, 0), | ||||
|                         groups=8, | ||||
|                         bias=not affine, | ||||
|                     ), | ||||
|                     nn.BatchNorm2d( | ||||
|                         C, affine=affine, track_running_stats=track_running_stats | ||||
|                     ), | ||||
|                     nn.ReLU(inplace=False), | ||||
|                     nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine), | ||||
|                     nn.BatchNorm2d( | ||||
|                         C, affine=affine, track_running_stats=track_running_stats | ||||
|                     ), | ||||
|                 ), | ||||
|             ] | ||||
|         ) | ||||
|  | ||||
|         self.ops2 = nn.ModuleList( | ||||
|             [ | ||||
|                 nn.Sequential( | ||||
|                     nn.MaxPool2d(3, stride=2, padding=1), | ||||
|                     nn.BatchNorm2d( | ||||
|                         C, affine=affine, track_running_stats=track_running_stats | ||||
|                     ), | ||||
|                 ), | ||||
|                 nn.Sequential( | ||||
|                     nn.MaxPool2d(3, stride=2, padding=1), | ||||
|                     nn.BatchNorm2d( | ||||
|                         C, affine=affine, track_running_stats=track_running_stats | ||||
|                     ), | ||||
|                 ), | ||||
|             ] | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def multiplier(self): | ||||
|         return 4 | ||||
|  | ||||
|     def forward(self, s0, s1, drop_prob=-1): | ||||
|         s0 = self.preprocess0(s0) | ||||
|         s1 = self.preprocess1(s1) | ||||
|  | ||||
|         X0 = self.ops1[0](s0) | ||||
|         X1 = self.ops1[1](s1) | ||||
|         if self.training and drop_prob > 0.0: | ||||
|             X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob) | ||||
|  | ||||
|         # X2 = self.ops2[0] (X0+X1) | ||||
|         X2 = self.ops2[0](s0) | ||||
|         X3 = self.ops2[1](s1) | ||||
|         if self.training and drop_prob > 0.0: | ||||
|             X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob) | ||||
|         return torch.cat([X0, X1, X2, X3], dim=1) | ||||
|  | ||||
|  | ||||
| # To manage the useful classes in this file. | ||||
| RAW_OP_CLASSES = {"gdas_reduction": GDAS_Reduction_Cell} | ||||
							
								
								
									
										33
									
								
								AutoDL-Projects/xautodl/models/cell_searchs/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								AutoDL-Projects/xautodl/models/cell_searchs/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,33 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| # The macro structure is defined in NAS-Bench-201 | ||||
| from .search_model_darts import TinyNetworkDarts | ||||
| from .search_model_gdas import TinyNetworkGDAS | ||||
| from .search_model_setn import TinyNetworkSETN | ||||
| from .search_model_enas import TinyNetworkENAS | ||||
| from .search_model_random import TinyNetworkRANDOM | ||||
| from .generic_model import GenericNAS201Model | ||||
| from .genotypes import Structure as CellStructure, architectures as CellArchitectures | ||||
|  | ||||
| # NASNet-based macro structure | ||||
| from .search_model_gdas_nasnet import NASNetworkGDAS | ||||
| from .search_model_gdas_frc_nasnet import NASNetworkGDAS_FRC | ||||
| from .search_model_darts_nasnet import NASNetworkDARTS | ||||
|  | ||||
|  | ||||
| nas201_super_nets = { | ||||
|     "DARTS-V1": TinyNetworkDarts, | ||||
|     "DARTS-V2": TinyNetworkDarts, | ||||
|     "GDAS": TinyNetworkGDAS, | ||||
|     "SETN": TinyNetworkSETN, | ||||
|     "ENAS": TinyNetworkENAS, | ||||
|     "RANDOM": TinyNetworkRANDOM, | ||||
|     "generic": GenericNAS201Model, | ||||
| } | ||||
|  | ||||
| nasnet_super_nets = { | ||||
|     "GDAS": NASNetworkGDAS, | ||||
|     "GDAS_FRC": NASNetworkGDAS_FRC, | ||||
|     "DARTS": NASNetworkDARTS, | ||||
| } | ||||
							
								
								
									
										14
									
								
								AutoDL-Projects/xautodl/models/cell_searchs/_test_module.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								AutoDL-Projects/xautodl/models/cell_searchs/_test_module.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import torch | ||||
| from search_model_enas_utils import Controller | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     controller = Controller(6, 4) | ||||
|     predictions = controller() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
							
								
								
									
										366
									
								
								AutoDL-Projects/xautodl/models/cell_searchs/generic_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										366
									
								
								AutoDL-Projects/xautodl/models/cell_searchs/generic_model.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,366 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 # | ||||
| ##################################################### | ||||
| import torch, random | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from typing import Text | ||||
| from torch.distributions.categorical import Categorical | ||||
|  | ||||
| from ..cell_operations import ResNetBasicblock, drop_path | ||||
| from .search_cells import NAS201SearchCell as SearchCell | ||||
| from .genotypes import Structure | ||||
|  | ||||
|  | ||||
| class Controller(nn.Module): | ||||
|     # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py | ||||
|     def __init__( | ||||
|         self, | ||||
|         edge2index, | ||||
|         op_names, | ||||
|         max_nodes, | ||||
|         lstm_size=32, | ||||
|         lstm_num_layers=2, | ||||
|         tanh_constant=2.5, | ||||
|         temperature=5.0, | ||||
|     ): | ||||
|         super(Controller, self).__init__() | ||||
|         # assign the attributes | ||||
|         self.max_nodes = max_nodes | ||||
|         self.num_edge = len(edge2index) | ||||
|         self.edge2index = edge2index | ||||
|         self.num_ops = len(op_names) | ||||
|         self.op_names = op_names | ||||
|         self.lstm_size = lstm_size | ||||
|         self.lstm_N = lstm_num_layers | ||||
|         self.tanh_constant = tanh_constant | ||||
|         self.temperature = temperature | ||||
|         # create parameters | ||||
|         self.register_parameter( | ||||
|             "input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size)) | ||||
|         ) | ||||
|         self.w_lstm = nn.LSTM( | ||||
|             input_size=self.lstm_size, | ||||
|             hidden_size=self.lstm_size, | ||||
|             num_layers=self.lstm_N, | ||||
|         ) | ||||
|         self.w_embd = nn.Embedding(self.num_ops, self.lstm_size) | ||||
|         self.w_pred = nn.Linear(self.lstm_size, self.num_ops) | ||||
|  | ||||
|         nn.init.uniform_(self.input_vars, -0.1, 0.1) | ||||
|         nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1) | ||||
|         nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1) | ||||
|         nn.init.uniform_(self.w_embd.weight, -0.1, 0.1) | ||||
|         nn.init.uniform_(self.w_pred.weight, -0.1, 0.1) | ||||
|  | ||||
|     def convert_structure(self, _arch): | ||||
|         genotypes = [] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 op_index = _arch[self.edge2index[node_str]] | ||||
|                 op_name = self.op_names[op_index] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     def forward(self): | ||||
|  | ||||
|         inputs, h0 = self.input_vars, None | ||||
|         log_probs, entropys, sampled_arch = [], [], [] | ||||
|         for iedge in range(self.num_edge): | ||||
|             outputs, h0 = self.w_lstm(inputs, h0) | ||||
|  | ||||
|             logits = self.w_pred(outputs) | ||||
|             logits = logits / self.temperature | ||||
|             logits = self.tanh_constant * torch.tanh(logits) | ||||
|             # distribution | ||||
|             op_distribution = Categorical(logits=logits) | ||||
|             op_index = op_distribution.sample() | ||||
|             sampled_arch.append(op_index.item()) | ||||
|  | ||||
|             op_log_prob = op_distribution.log_prob(op_index) | ||||
|             log_probs.append(op_log_prob.view(-1)) | ||||
|             op_entropy = op_distribution.entropy() | ||||
|             entropys.append(op_entropy.view(-1)) | ||||
|  | ||||
|             # obtain the input embedding for the next step | ||||
|             inputs = self.w_embd(op_index) | ||||
|         return ( | ||||
|             torch.sum(torch.cat(log_probs)), | ||||
|             torch.sum(torch.cat(entropys)), | ||||
|             self.convert_structure(sampled_arch), | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class GenericNAS201Model(nn.Module): | ||||
|     def __init__( | ||||
|         self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats | ||||
|     ): | ||||
|         super(GenericNAS201Model, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self._max_nodes = max_nodes | ||||
|         self._stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) | ||||
|         ) | ||||
|         layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|         C_prev, num_edge, edge2index = C, None, None | ||||
|         self._cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||
|             else: | ||||
|                 cell = SearchCell( | ||||
|                     C_prev, | ||||
|                     C_curr, | ||||
|                     1, | ||||
|                     max_nodes, | ||||
|                     search_space, | ||||
|                     affine, | ||||
|                     track_running_stats, | ||||
|                 ) | ||||
|                 if num_edge is None: | ||||
|                     num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|                 else: | ||||
|                     assert ( | ||||
|                         num_edge == cell.num_edges and edge2index == cell.edge2index | ||||
|                     ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self._cells.append(cell) | ||||
|             C_prev = cell.out_dim | ||||
|         self._op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self._cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential( | ||||
|             nn.BatchNorm2d( | ||||
|                 C_prev, affine=affine, track_running_stats=track_running_stats | ||||
|             ), | ||||
|             nn.ReLU(inplace=True), | ||||
|         ) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self._num_edge = num_edge | ||||
|         # algorithm related | ||||
|         self.arch_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|         self._mode = None | ||||
|         self.dynamic_cell = None | ||||
|         self._tau = None | ||||
|         self._algo = None | ||||
|         self._drop_path = None | ||||
|         self.verbose = False | ||||
|  | ||||
|     def set_algo(self, algo: Text): | ||||
|         # used for searching | ||||
|         assert self._algo is None, "This functioin can only be called once." | ||||
|         self._algo = algo | ||||
|         if algo == "enas": | ||||
|             self.controller = Controller( | ||||
|                 self.edge2index, self._op_names, self._max_nodes | ||||
|             ) | ||||
|         else: | ||||
|             self.arch_parameters = nn.Parameter( | ||||
|                 1e-3 * torch.randn(self._num_edge, len(self._op_names)) | ||||
|             ) | ||||
|             if algo == "gdas": | ||||
|                 self._tau = 10 | ||||
|  | ||||
|     def set_cal_mode(self, mode, dynamic_cell=None): | ||||
|         assert mode in ["gdas", "enas", "urs", "joint", "select", "dynamic"] | ||||
|         self._mode = mode | ||||
|         if mode == "dynamic": | ||||
|             self.dynamic_cell = deepcopy(dynamic_cell) | ||||
|         else: | ||||
|             self.dynamic_cell = None | ||||
|  | ||||
|     def set_drop_path(self, progress, drop_path_rate): | ||||
|         if drop_path_rate is None: | ||||
|             self._drop_path = None | ||||
|         elif progress is None: | ||||
|             self._drop_path = drop_path_rate | ||||
|         else: | ||||
|             self._drop_path = progress * drop_path_rate | ||||
|  | ||||
|     @property | ||||
|     def mode(self): | ||||
|         return self._mode | ||||
|  | ||||
|     @property | ||||
|     def drop_path(self): | ||||
|         return self._drop_path | ||||
|  | ||||
|     @property | ||||
|     def weights(self): | ||||
|         xlist = list(self._stem.parameters()) | ||||
|         xlist += list(self._cells.parameters()) | ||||
|         xlist += list(self.lastact.parameters()) | ||||
|         xlist += list(self.global_pooling.parameters()) | ||||
|         xlist += list(self.classifier.parameters()) | ||||
|         return xlist | ||||
|  | ||||
|     def set_tau(self, tau): | ||||
|         self._tau = tau | ||||
|  | ||||
|     @property | ||||
|     def tau(self): | ||||
|         return self._tau | ||||
|  | ||||
|     @property | ||||
|     def alphas(self): | ||||
|         if self._algo == "enas": | ||||
|             return list(self.controller.parameters()) | ||||
|         else: | ||||
|             return [self.arch_parameters] | ||||
|  | ||||
|     @property | ||||
|     def message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self._cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self._cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def show_alphas(self): | ||||
|         with torch.no_grad(): | ||||
|             if self._algo == "enas": | ||||
|                 return "w_pred :\n{:}".format(self.controller.w_pred.weight) | ||||
|             else: | ||||
|                 return "arch-parameters :\n{:}".format( | ||||
|                     nn.functional.softmax(self.arch_parameters, dim=-1).cpu() | ||||
|                 ) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, Max-Nodes={_max_nodes}, N={_layerN}, L={_Layer}, alg={_algo})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def genotype(self): | ||||
|         genotypes = [] | ||||
|         for i in range(1, self._max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 with torch.no_grad(): | ||||
|                     weights = self.arch_parameters[self.edge2index[node_str]] | ||||
|                     op_name = self._op_names[weights.argmax().item()] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     def dync_genotype(self, use_random=False): | ||||
|         genotypes = [] | ||||
|         with torch.no_grad(): | ||||
|             alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|         for i in range(1, self._max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 if use_random: | ||||
|                     op_name = random.choice(self._op_names) | ||||
|                 else: | ||||
|                     weights = alphas_cpu[self.edge2index[node_str]] | ||||
|                     op_index = torch.multinomial(weights, 1).item() | ||||
|                     op_name = self._op_names[op_index] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     def get_log_prob(self, arch): | ||||
|         with torch.no_grad(): | ||||
|             logits = nn.functional.log_softmax(self.arch_parameters, dim=-1) | ||||
|         select_logits = [] | ||||
|         for i, node_info in enumerate(arch.nodes): | ||||
|             for op, xin in node_info: | ||||
|                 node_str = "{:}<-{:}".format(i + 1, xin) | ||||
|                 op_index = self._op_names.index(op) | ||||
|                 select_logits.append(logits[self.edge2index[node_str], op_index]) | ||||
|         return sum(select_logits).item() | ||||
|  | ||||
|     def return_topK(self, K, use_random=False): | ||||
|         archs = Structure.gen_all(self._op_names, self._max_nodes, False) | ||||
|         pairs = [(self.get_log_prob(arch), arch) for arch in archs] | ||||
|         if K < 0 or K >= len(archs): | ||||
|             K = len(archs) | ||||
|         if use_random: | ||||
|             return random.sample(archs, K) | ||||
|         else: | ||||
|             sorted_pairs = sorted(pairs, key=lambda x: -x[0]) | ||||
|             return_pairs = [sorted_pairs[_][1] for _ in range(K)] | ||||
|             return return_pairs | ||||
|  | ||||
|     def normalize_archp(self): | ||||
|         if self.mode == "gdas": | ||||
|             while True: | ||||
|                 gumbels = -torch.empty_like(self.arch_parameters).exponential_().log() | ||||
|                 logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau | ||||
|                 probs = nn.functional.softmax(logits, dim=1) | ||||
|                 index = probs.max(-1, keepdim=True)[1] | ||||
|                 one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) | ||||
|                 hardwts = one_h - probs.detach() + probs | ||||
|                 if ( | ||||
|                     (torch.isinf(gumbels).any()) | ||||
|                     or (torch.isinf(probs).any()) | ||||
|                     or (torch.isnan(probs).any()) | ||||
|                 ): | ||||
|                     continue | ||||
|                 else: | ||||
|                     break | ||||
|             with torch.no_grad(): | ||||
|                 hardwts_cpu = hardwts.detach().cpu() | ||||
|             return hardwts, hardwts_cpu, index, "GUMBEL" | ||||
|         else: | ||||
|             alphas = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|             index = alphas.max(-1, keepdim=True)[1] | ||||
|             with torch.no_grad(): | ||||
|                 alphas_cpu = alphas.detach().cpu() | ||||
|             return alphas, alphas_cpu, index, "SOFTMAX" | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         alphas, alphas_cpu, index, verbose_str = self.normalize_archp() | ||||
|         feature = self._stem(inputs) | ||||
|         for i, cell in enumerate(self._cells): | ||||
|             if isinstance(cell, SearchCell): | ||||
|                 if self.mode == "urs": | ||||
|                     feature = cell.forward_urs(feature) | ||||
|                     if self.verbose: | ||||
|                         verbose_str += "-forward_urs" | ||||
|                 elif self.mode == "select": | ||||
|                     feature = cell.forward_select(feature, alphas_cpu) | ||||
|                     if self.verbose: | ||||
|                         verbose_str += "-forward_select" | ||||
|                 elif self.mode == "joint": | ||||
|                     feature = cell.forward_joint(feature, alphas) | ||||
|                     if self.verbose: | ||||
|                         verbose_str += "-forward_joint" | ||||
|                 elif self.mode == "dynamic": | ||||
|                     feature = cell.forward_dynamic(feature, self.dynamic_cell) | ||||
|                     if self.verbose: | ||||
|                         verbose_str += "-forward_dynamic" | ||||
|                 elif self.mode == "gdas": | ||||
|                     feature = cell.forward_gdas(feature, alphas, index) | ||||
|                     if self.verbose: | ||||
|                         verbose_str += "-forward_gdas" | ||||
|                 elif self.mode == "gdas_v1": | ||||
|                     feature = cell.forward_gdas_v1(feature, alphas, index) | ||||
|                     if self.verbose: | ||||
|                         verbose_str += "-forward_gdas_v1" | ||||
|                 else: | ||||
|                     raise ValueError("invalid mode={:}".format(self.mode)) | ||||
|             else: | ||||
|                 feature = cell(feature) | ||||
|             if self.drop_path is not None: | ||||
|                 feature = drop_path(feature, self.drop_path) | ||||
|         if self.verbose and random.random() < 0.001: | ||||
|             print(verbose_str) | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|         return out, logits | ||||
							
								
								
									
										274
									
								
								AutoDL-Projects/xautodl/models/cell_searchs/genotypes.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										274
									
								
								AutoDL-Projects/xautodl/models/cell_searchs/genotypes.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,274 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from copy import deepcopy | ||||
|  | ||||
|  | ||||
| def get_combination(space, num): | ||||
|     combs = [] | ||||
|     for i in range(num): | ||||
|         if i == 0: | ||||
|             for func in space: | ||||
|                 combs.append([(func, i)]) | ||||
|         else: | ||||
|             new_combs = [] | ||||
|             for string in combs: | ||||
|                 for func in space: | ||||
|                     xstring = string + [(func, i)] | ||||
|                     new_combs.append(xstring) | ||||
|             combs = new_combs | ||||
|     return combs | ||||
|  | ||||
|  | ||||
| class Structure: | ||||
|     def __init__(self, genotype): | ||||
|         assert isinstance(genotype, list) or isinstance( | ||||
|             genotype, tuple | ||||
|         ), "invalid class of genotype : {:}".format(type(genotype)) | ||||
|         self.node_num = len(genotype) + 1 | ||||
|         self.nodes = [] | ||||
|         self.node_N = [] | ||||
|         for idx, node_info in enumerate(genotype): | ||||
|             assert isinstance(node_info, list) or isinstance( | ||||
|                 node_info, tuple | ||||
|             ), "invalid class of node_info : {:}".format(type(node_info)) | ||||
|             assert len(node_info) >= 1, "invalid length : {:}".format(len(node_info)) | ||||
|             for node_in in node_info: | ||||
|                 assert isinstance(node_in, list) or isinstance( | ||||
|                     node_in, tuple | ||||
|                 ), "invalid class of in-node : {:}".format(type(node_in)) | ||||
|                 assert ( | ||||
|                     len(node_in) == 2 and node_in[1] <= idx | ||||
|                 ), "invalid in-node : {:}".format(node_in) | ||||
|             self.node_N.append(len(node_info)) | ||||
|             self.nodes.append(tuple(deepcopy(node_info))) | ||||
|  | ||||
|     def tolist(self, remove_str): | ||||
|         # convert this class to the list, if remove_str is 'none', then remove the 'none' operation. | ||||
|         # note that we re-order the input node in this function | ||||
|         # return the-genotype-list and success [if unsuccess, it is not a connectivity] | ||||
|         genotypes = [] | ||||
|         for node_info in self.nodes: | ||||
|             node_info = list(node_info) | ||||
|             node_info = sorted(node_info, key=lambda x: (x[1], x[0])) | ||||
|             node_info = tuple(filter(lambda x: x[0] != remove_str, node_info)) | ||||
|             if len(node_info) == 0: | ||||
|                 return None, False | ||||
|             genotypes.append(node_info) | ||||
|         return genotypes, True | ||||
|  | ||||
|     def node(self, index): | ||||
|         assert index > 0 and index <= len(self), "invalid index={:} < {:}".format( | ||||
|             index, len(self) | ||||
|         ) | ||||
|         return self.nodes[index] | ||||
|  | ||||
|     def tostr(self): | ||||
|         strings = [] | ||||
|         for node_info in self.nodes: | ||||
|             string = "|".join([x[0] + "~{:}".format(x[1]) for x in node_info]) | ||||
|             string = "|{:}|".format(string) | ||||
|             strings.append(string) | ||||
|         return "+".join(strings) | ||||
|  | ||||
|     def check_valid(self): | ||||
|         nodes = {0: True} | ||||
|         for i, node_info in enumerate(self.nodes): | ||||
|             sums = [] | ||||
|             for op, xin in node_info: | ||||
|                 if op == "none" or nodes[xin] is False: | ||||
|                     x = False | ||||
|                 else: | ||||
|                     x = True | ||||
|                 sums.append(x) | ||||
|             nodes[i + 1] = sum(sums) > 0 | ||||
|         return nodes[len(self.nodes)] | ||||
|  | ||||
|     def to_unique_str(self, consider_zero=False): | ||||
|         # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation | ||||
|         # two operations are special, i.e., none and skip_connect | ||||
|         nodes = {0: "0"} | ||||
|         for i_node, node_info in enumerate(self.nodes): | ||||
|             cur_node = [] | ||||
|             for op, xin in node_info: | ||||
|                 if consider_zero is None: | ||||
|                     x = "(" + nodes[xin] + ")" + "@{:}".format(op) | ||||
|                 elif consider_zero: | ||||
|                     if op == "none" or nodes[xin] == "#": | ||||
|                         x = "#"  # zero | ||||
|                     elif op == "skip_connect": | ||||
|                         x = nodes[xin] | ||||
|                     else: | ||||
|                         x = "(" + nodes[xin] + ")" + "@{:}".format(op) | ||||
|                 else: | ||||
|                     if op == "skip_connect": | ||||
|                         x = nodes[xin] | ||||
|                     else: | ||||
|                         x = "(" + nodes[xin] + ")" + "@{:}".format(op) | ||||
|                 cur_node.append(x) | ||||
|             nodes[i_node + 1] = "+".join(sorted(cur_node)) | ||||
|         return nodes[len(self.nodes)] | ||||
|  | ||||
|     def check_valid_op(self, op_names): | ||||
|         for node_info in self.nodes: | ||||
|             for inode_edge in node_info: | ||||
|                 # assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0]) | ||||
|                 if inode_edge[0] not in op_names: | ||||
|                     return False | ||||
|         return True | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({node_num} nodes with {node_info})".format( | ||||
|             name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.nodes) + 1 | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         return self.nodes[index] | ||||
|  | ||||
|     @staticmethod | ||||
|     def str2structure(xstr): | ||||
|         if isinstance(xstr, Structure): | ||||
|             return xstr | ||||
|         assert isinstance(xstr, str), "must take string (not {:}) as input".format( | ||||
|             type(xstr) | ||||
|         ) | ||||
|         nodestrs = xstr.split("+") | ||||
|         genotypes = [] | ||||
|         for i, node_str in enumerate(nodestrs): | ||||
|             inputs = list(filter(lambda x: x != "", node_str.split("|"))) | ||||
|             for xinput in inputs: | ||||
|                 assert len(xinput.split("~")) == 2, "invalid input length : {:}".format( | ||||
|                     xinput | ||||
|                 ) | ||||
|             inputs = (xi.split("~") for xi in inputs) | ||||
|             input_infos = tuple((op, int(IDX)) for (op, IDX) in inputs) | ||||
|             genotypes.append(input_infos) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     @staticmethod | ||||
|     def str2fullstructure(xstr, default_name="none"): | ||||
|         assert isinstance(xstr, str), "must take string (not {:}) as input".format( | ||||
|             type(xstr) | ||||
|         ) | ||||
|         nodestrs = xstr.split("+") | ||||
|         genotypes = [] | ||||
|         for i, node_str in enumerate(nodestrs): | ||||
|             inputs = list(filter(lambda x: x != "", node_str.split("|"))) | ||||
|             for xinput in inputs: | ||||
|                 assert len(xinput.split("~")) == 2, "invalid input length : {:}".format( | ||||
|                     xinput | ||||
|                 ) | ||||
|             inputs = (xi.split("~") for xi in inputs) | ||||
|             input_infos = list((op, int(IDX)) for (op, IDX) in inputs) | ||||
|             all_in_nodes = list(x[1] for x in input_infos) | ||||
|             for j in range(i): | ||||
|                 if j not in all_in_nodes: | ||||
|                     input_infos.append((default_name, j)) | ||||
|             node_info = sorted(input_infos, key=lambda x: (x[1], x[0])) | ||||
|             genotypes.append(tuple(node_info)) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     @staticmethod | ||||
|     def gen_all(search_space, num, return_ori): | ||||
|         assert isinstance(search_space, list) or isinstance( | ||||
|             search_space, tuple | ||||
|         ), "invalid class of search-space : {:}".format(type(search_space)) | ||||
|         assert ( | ||||
|             num >= 2 | ||||
|         ), "There should be at least two nodes in a neural cell instead of {:}".format( | ||||
|             num | ||||
|         ) | ||||
|         all_archs = get_combination(search_space, 1) | ||||
|         for i, arch in enumerate(all_archs): | ||||
|             all_archs[i] = [tuple(arch)] | ||||
|  | ||||
|         for inode in range(2, num): | ||||
|             cur_nodes = get_combination(search_space, inode) | ||||
|             new_all_archs = [] | ||||
|             for previous_arch in all_archs: | ||||
|                 for cur_node in cur_nodes: | ||||
|                     new_all_archs.append(previous_arch + [tuple(cur_node)]) | ||||
|             all_archs = new_all_archs | ||||
|         if return_ori: | ||||
|             return all_archs | ||||
|         else: | ||||
|             return [Structure(x) for x in all_archs] | ||||
|  | ||||
|  | ||||
| ResNet_CODE = Structure( | ||||
|     [ | ||||
|         (("nor_conv_3x3", 0),),  # node-1 | ||||
|         (("nor_conv_3x3", 1),),  # node-2 | ||||
|         (("skip_connect", 0), ("skip_connect", 2)), | ||||
|     ]  # node-3 | ||||
| ) | ||||
|  | ||||
| AllConv3x3_CODE = Structure( | ||||
|     [ | ||||
|         (("nor_conv_3x3", 0),),  # node-1 | ||||
|         (("nor_conv_3x3", 0), ("nor_conv_3x3", 1)),  # node-2 | ||||
|         (("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)), | ||||
|     ]  # node-3 | ||||
| ) | ||||
|  | ||||
| AllFull_CODE = Structure( | ||||
|     [ | ||||
|         ( | ||||
|             ("skip_connect", 0), | ||||
|             ("nor_conv_1x1", 0), | ||||
|             ("nor_conv_3x3", 0), | ||||
|             ("avg_pool_3x3", 0), | ||||
|         ),  # node-1 | ||||
|         ( | ||||
|             ("skip_connect", 0), | ||||
|             ("nor_conv_1x1", 0), | ||||
|             ("nor_conv_3x3", 0), | ||||
|             ("avg_pool_3x3", 0), | ||||
|             ("skip_connect", 1), | ||||
|             ("nor_conv_1x1", 1), | ||||
|             ("nor_conv_3x3", 1), | ||||
|             ("avg_pool_3x3", 1), | ||||
|         ),  # node-2 | ||||
|         ( | ||||
|             ("skip_connect", 0), | ||||
|             ("nor_conv_1x1", 0), | ||||
|             ("nor_conv_3x3", 0), | ||||
|             ("avg_pool_3x3", 0), | ||||
|             ("skip_connect", 1), | ||||
|             ("nor_conv_1x1", 1), | ||||
|             ("nor_conv_3x3", 1), | ||||
|             ("avg_pool_3x3", 1), | ||||
|             ("skip_connect", 2), | ||||
|             ("nor_conv_1x1", 2), | ||||
|             ("nor_conv_3x3", 2), | ||||
|             ("avg_pool_3x3", 2), | ||||
|         ), | ||||
|     ]  # node-3 | ||||
| ) | ||||
|  | ||||
| AllConv1x1_CODE = Structure( | ||||
|     [ | ||||
|         (("nor_conv_1x1", 0),),  # node-1 | ||||
|         (("nor_conv_1x1", 0), ("nor_conv_1x1", 1)),  # node-2 | ||||
|         (("nor_conv_1x1", 0), ("nor_conv_1x1", 1), ("nor_conv_1x1", 2)), | ||||
|     ]  # node-3 | ||||
| ) | ||||
|  | ||||
| AllIdentity_CODE = Structure( | ||||
|     [ | ||||
|         (("skip_connect", 0),),  # node-1 | ||||
|         (("skip_connect", 0), ("skip_connect", 1)),  # node-2 | ||||
|         (("skip_connect", 0), ("skip_connect", 1), ("skip_connect", 2)), | ||||
|     ]  # node-3 | ||||
| ) | ||||
|  | ||||
| architectures = { | ||||
|     "resnet": ResNet_CODE, | ||||
|     "all_c3x3": AllConv3x3_CODE, | ||||
|     "all_c1x1": AllConv1x1_CODE, | ||||
|     "all_idnt": AllIdentity_CODE, | ||||
|     "all_full": AllFull_CODE, | ||||
| } | ||||
							
								
								
									
										267
									
								
								AutoDL-Projects/xautodl/models/cell_searchs/search_cells.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										267
									
								
								AutoDL-Projects/xautodl/models/cell_searchs/search_cells.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,267 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import math, random, torch | ||||
| import warnings | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import OPS | ||||
|  | ||||
|  | ||||
| # This module is used for NAS-Bench-201, represents a small search space with a complete DAG | ||||
| class NAS201SearchCell(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         C_in, | ||||
|         C_out, | ||||
|         stride, | ||||
|         max_nodes, | ||||
|         op_names, | ||||
|         affine=False, | ||||
|         track_running_stats=True, | ||||
|     ): | ||||
|         super(NAS201SearchCell, self).__init__() | ||||
|  | ||||
|         self.op_names = deepcopy(op_names) | ||||
|         self.edges = nn.ModuleDict() | ||||
|         self.max_nodes = max_nodes | ||||
|         self.in_dim = C_in | ||||
|         self.out_dim = C_out | ||||
|         for i in range(1, max_nodes): | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 if j == 0: | ||||
|                     xlists = [ | ||||
|                         OPS[op_name](C_in, C_out, stride, affine, track_running_stats) | ||||
|                         for op_name in op_names | ||||
|                     ] | ||||
|                 else: | ||||
|                     xlists = [ | ||||
|                         OPS[op_name](C_in, C_out, 1, affine, track_running_stats) | ||||
|                         for op_name in op_names | ||||
|                     ] | ||||
|                 self.edges[node_str] = nn.ModuleList(xlists) | ||||
|         self.edge_keys = sorted(list(self.edges.keys())) | ||||
|         self.edge2index = {key: i for i, key in enumerate(self.edge_keys)} | ||||
|         self.num_edges = len(self.edges) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         string = "info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}".format( | ||||
|             **self.__dict__ | ||||
|         ) | ||||
|         return string | ||||
|  | ||||
|     def forward(self, inputs, weightss): | ||||
|         nodes = [inputs] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             inter_nodes = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 weights = weightss[self.edge2index[node_str]] | ||||
|                 inter_nodes.append( | ||||
|                     sum( | ||||
|                         layer(nodes[j]) * w | ||||
|                         for layer, w in zip(self.edges[node_str], weights) | ||||
|                     ) | ||||
|                 ) | ||||
|             nodes.append(sum(inter_nodes)) | ||||
|         return nodes[-1] | ||||
|  | ||||
|     # GDAS | ||||
|     def forward_gdas(self, inputs, hardwts, index): | ||||
|         nodes = [inputs] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             inter_nodes = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 weights = hardwts[self.edge2index[node_str]] | ||||
|                 argmaxs = index[self.edge2index[node_str]].item() | ||||
|                 weigsum = sum( | ||||
|                     weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] | ||||
|                     for _ie, edge in enumerate(self.edges[node_str]) | ||||
|                 ) | ||||
|                 inter_nodes.append(weigsum) | ||||
|             nodes.append(sum(inter_nodes)) | ||||
|         return nodes[-1] | ||||
|  | ||||
|     # GDAS Variant: https://github.com/D-X-Y/AutoDL-Projects/issues/119 | ||||
|     def forward_gdas_v1(self, inputs, hardwts, index): | ||||
|         nodes = [inputs] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             inter_nodes = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 weights = hardwts[self.edge2index[node_str]] | ||||
|                 argmaxs = index[self.edge2index[node_str]].item() | ||||
|                 weigsum = weights[argmaxs] * self.edges[node_str](nodes[j]) | ||||
|                 inter_nodes.append(weigsum) | ||||
|             nodes.append(sum(inter_nodes)) | ||||
|         return nodes[-1] | ||||
|  | ||||
|     # joint | ||||
|     def forward_joint(self, inputs, weightss): | ||||
|         nodes = [inputs] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             inter_nodes = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 weights = weightss[self.edge2index[node_str]] | ||||
|                 # aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel() | ||||
|                 aggregation = sum( | ||||
|                     layer(nodes[j]) * w | ||||
|                     for layer, w in zip(self.edges[node_str], weights) | ||||
|                 ) | ||||
|                 inter_nodes.append(aggregation) | ||||
|             nodes.append(sum(inter_nodes)) | ||||
|         return nodes[-1] | ||||
|  | ||||
|     # uniform random sampling per iteration, SETN | ||||
|     def forward_urs(self, inputs): | ||||
|         nodes = [inputs] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             while True:  # to avoid select zero for all ops | ||||
|                 sops, has_non_zero = [], False | ||||
|                 for j in range(i): | ||||
|                     node_str = "{:}<-{:}".format(i, j) | ||||
|                     candidates = self.edges[node_str] | ||||
|                     select_op = random.choice(candidates) | ||||
|                     sops.append(select_op) | ||||
|                     if not hasattr(select_op, "is_zero") or select_op.is_zero is False: | ||||
|                         has_non_zero = True | ||||
|                 if has_non_zero: | ||||
|                     break | ||||
|             inter_nodes = [] | ||||
|             for j, select_op in enumerate(sops): | ||||
|                 inter_nodes.append(select_op(nodes[j])) | ||||
|             nodes.append(sum(inter_nodes)) | ||||
|         return nodes[-1] | ||||
|  | ||||
|     # select the argmax | ||||
|     def forward_select(self, inputs, weightss): | ||||
|         nodes = [inputs] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             inter_nodes = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 weights = weightss[self.edge2index[node_str]] | ||||
|                 inter_nodes.append( | ||||
|                     self.edges[node_str][weights.argmax().item()](nodes[j]) | ||||
|                 ) | ||||
|                 # inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) ) | ||||
|             nodes.append(sum(inter_nodes)) | ||||
|         return nodes[-1] | ||||
|  | ||||
|     # forward with a specific structure | ||||
|     def forward_dynamic(self, inputs, structure): | ||||
|         nodes = [inputs] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             cur_op_node = structure.nodes[i - 1] | ||||
|             inter_nodes = [] | ||||
|             for op_name, j in cur_op_node: | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 op_index = self.op_names.index(op_name) | ||||
|                 inter_nodes.append(self.edges[node_str][op_index](nodes[j])) | ||||
|             nodes.append(sum(inter_nodes)) | ||||
|         return nodes[-1] | ||||
|  | ||||
|  | ||||
| # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 | ||||
|  | ||||
|  | ||||
| class MixedOp(nn.Module): | ||||
|     def __init__(self, space, C, stride, affine, track_running_stats): | ||||
|         super(MixedOp, self).__init__() | ||||
|         self._ops = nn.ModuleList() | ||||
|         for primitive in space: | ||||
|             op = OPS[primitive](C, C, stride, affine, track_running_stats) | ||||
|             self._ops.append(op) | ||||
|  | ||||
|     def forward_gdas(self, x, weights, index): | ||||
|         return self._ops[index](x) * weights[index] | ||||
|  | ||||
|     def forward_darts(self, x, weights): | ||||
|         return sum(w * op(x) for w, op in zip(weights, self._ops)) | ||||
|  | ||||
|  | ||||
| class NASNetSearchCell(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         space, | ||||
|         steps, | ||||
|         multiplier, | ||||
|         C_prev_prev, | ||||
|         C_prev, | ||||
|         C, | ||||
|         reduction, | ||||
|         reduction_prev, | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ): | ||||
|         super(NASNetSearchCell, self).__init__() | ||||
|         self.reduction = reduction | ||||
|         self.op_names = deepcopy(space) | ||||
|         if reduction_prev: | ||||
|             self.preprocess0 = OPS["skip_connect"]( | ||||
|                 C_prev_prev, C, 2, affine, track_running_stats | ||||
|             ) | ||||
|         else: | ||||
|             self.preprocess0 = OPS["nor_conv_1x1"]( | ||||
|                 C_prev_prev, C, 1, affine, track_running_stats | ||||
|             ) | ||||
|         self.preprocess1 = OPS["nor_conv_1x1"]( | ||||
|             C_prev, C, 1, affine, track_running_stats | ||||
|         ) | ||||
|         self._steps = steps | ||||
|         self._multiplier = multiplier | ||||
|  | ||||
|         self._ops = nn.ModuleList() | ||||
|         self.edges = nn.ModuleDict() | ||||
|         for i in range(self._steps): | ||||
|             for j in range(2 + i): | ||||
|                 node_str = "{:}<-{:}".format( | ||||
|                     i, j | ||||
|                 )  # indicate the edge from node-(j) to node-(i+2) | ||||
|                 stride = 2 if reduction and j < 2 else 1 | ||||
|                 op = MixedOp(space, C, stride, affine, track_running_stats) | ||||
|                 self.edges[node_str] = op | ||||
|         self.edge_keys = sorted(list(self.edges.keys())) | ||||
|         self.edge2index = {key: i for i, key in enumerate(self.edge_keys)} | ||||
|         self.num_edges = len(self.edges) | ||||
|  | ||||
|     @property | ||||
|     def multiplier(self): | ||||
|         return self._multiplier | ||||
|  | ||||
|     def forward_gdas(self, s0, s1, weightss, indexs): | ||||
|         s0 = self.preprocess0(s0) | ||||
|         s1 = self.preprocess1(s1) | ||||
|  | ||||
|         states = [s0, s1] | ||||
|         for i in range(self._steps): | ||||
|             clist = [] | ||||
|             for j, h in enumerate(states): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 op = self.edges[node_str] | ||||
|                 weights = weightss[self.edge2index[node_str]] | ||||
|                 index = indexs[self.edge2index[node_str]].item() | ||||
|                 clist.append(op.forward_gdas(h, weights, index)) | ||||
|             states.append(sum(clist)) | ||||
|  | ||||
|         return torch.cat(states[-self._multiplier :], dim=1) | ||||
|  | ||||
|     def forward_darts(self, s0, s1, weightss): | ||||
|         s0 = self.preprocess0(s0) | ||||
|         s1 = self.preprocess1(s1) | ||||
|  | ||||
|         states = [s0, s1] | ||||
|         for i in range(self._steps): | ||||
|             clist = [] | ||||
|             for j, h in enumerate(states): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 op = self.edges[node_str] | ||||
|                 weights = weightss[self.edge2index[node_str]] | ||||
|                 clist.append(op.forward_darts(h, weights)) | ||||
|             states.append(sum(clist)) | ||||
|  | ||||
|         return torch.cat(states[-self._multiplier :], dim=1) | ||||
| @@ -0,0 +1,122 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ######################################################## | ||||
| # DARTS: Differentiable Architecture Search, ICLR 2019 # | ||||
| ######################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .search_cells import NAS201SearchCell as SearchCell | ||||
| from .genotypes import Structure | ||||
|  | ||||
|  | ||||
| class TinyNetworkDarts(nn.Module): | ||||
|     def __init__( | ||||
|         self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats | ||||
|     ): | ||||
|         super(TinyNetworkDarts, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self.max_nodes = max_nodes | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) | ||||
|         ) | ||||
|  | ||||
|         layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|         C_prev, num_edge, edge2index = C, None, None | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||
|             else: | ||||
|                 cell = SearchCell( | ||||
|                     C_prev, | ||||
|                     C_curr, | ||||
|                     1, | ||||
|                     max_nodes, | ||||
|                     search_space, | ||||
|                     affine, | ||||
|                     track_running_stats, | ||||
|                 ) | ||||
|                 if num_edge is None: | ||||
|                     num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|                 else: | ||||
|                     assert ( | ||||
|                         num_edge == cell.num_edges and edge2index == cell.edge2index | ||||
|                     ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self.cells.append(cell) | ||||
|             C_prev = cell.out_dim | ||||
|         self.op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self.cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self.arch_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|  | ||||
|     def get_weights(self): | ||||
|         xlist = list(self.stem.parameters()) + list(self.cells.parameters()) | ||||
|         xlist += list(self.lastact.parameters()) + list( | ||||
|             self.global_pooling.parameters() | ||||
|         ) | ||||
|         xlist += list(self.classifier.parameters()) | ||||
|         return xlist | ||||
|  | ||||
|     def get_alphas(self): | ||||
|         return [self.arch_parameters] | ||||
|  | ||||
|     def show_alphas(self): | ||||
|         with torch.no_grad(): | ||||
|             return "arch-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self.arch_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def genotype(self): | ||||
|         genotypes = [] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 with torch.no_grad(): | ||||
|                     weights = self.arch_parameters[self.edge2index[node_str]] | ||||
|                     op_name = self.op_names[weights.argmax().item()] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         alphas = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|  | ||||
|         feature = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             if isinstance(cell, SearchCell): | ||||
|                 feature = cell(feature, alphas) | ||||
|             else: | ||||
|                 feature = cell(feature) | ||||
|  | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
| @@ -0,0 +1,178 @@ | ||||
| #################### | ||||
| # DARTS, ICLR 2019 # | ||||
| #################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from typing import List, Text, Dict | ||||
| from .search_cells import NASNetSearchCell as SearchCell | ||||
|  | ||||
|  | ||||
| # The macro structure is based on NASNet | ||||
| class NASNetworkDARTS(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         C: int, | ||||
|         N: int, | ||||
|         steps: int, | ||||
|         multiplier: int, | ||||
|         stem_multiplier: int, | ||||
|         num_classes: int, | ||||
|         search_space: List[Text], | ||||
|         affine: bool, | ||||
|         track_running_stats: bool, | ||||
|     ): | ||||
|         super(NASNetworkDARTS, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self._steps = steps | ||||
|         self._multiplier = multiplier | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C * stem_multiplier), | ||||
|         ) | ||||
|  | ||||
|         # config for each layer | ||||
|         layer_channels = ( | ||||
|             [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1) | ||||
|         ) | ||||
|         layer_reductions = ( | ||||
|             [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1) | ||||
|         ) | ||||
|  | ||||
|         num_edge, edge2index = None, None | ||||
|         C_prev_prev, C_prev, C_curr, reduction_prev = ( | ||||
|             C * stem_multiplier, | ||||
|             C * stem_multiplier, | ||||
|             C, | ||||
|             False, | ||||
|         ) | ||||
|  | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             cell = SearchCell( | ||||
|                 search_space, | ||||
|                 steps, | ||||
|                 multiplier, | ||||
|                 C_prev_prev, | ||||
|                 C_prev, | ||||
|                 C_curr, | ||||
|                 reduction, | ||||
|                 reduction_prev, | ||||
|                 affine, | ||||
|                 track_running_stats, | ||||
|             ) | ||||
|             if num_edge is None: | ||||
|                 num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|             else: | ||||
|                 assert ( | ||||
|                     num_edge == cell.num_edges and edge2index == cell.edge2index | ||||
|                 ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self.cells.append(cell) | ||||
|             C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction | ||||
|         self.op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self.cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self.arch_normal_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|         self.arch_reduce_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|  | ||||
|     def get_weights(self) -> List[torch.nn.Parameter]: | ||||
|         xlist = list(self.stem.parameters()) + list(self.cells.parameters()) | ||||
|         xlist += list(self.lastact.parameters()) + list( | ||||
|             self.global_pooling.parameters() | ||||
|         ) | ||||
|         xlist += list(self.classifier.parameters()) | ||||
|         return xlist | ||||
|  | ||||
|     def get_alphas(self) -> List[torch.nn.Parameter]: | ||||
|         return [self.arch_normal_parameters, self.arch_reduce_parameters] | ||||
|  | ||||
|     def show_alphas(self) -> Text: | ||||
|         with torch.no_grad(): | ||||
|             A = "arch-normal-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|             B = "arch-reduce-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|         return "{:}\n{:}".format(A, B) | ||||
|  | ||||
|     def get_message(self) -> Text: | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self) -> Text: | ||||
|         return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def genotype(self) -> Dict[Text, List]: | ||||
|         def _parse(weights): | ||||
|             gene = [] | ||||
|             for i in range(self._steps): | ||||
|                 edges = [] | ||||
|                 for j in range(2 + i): | ||||
|                     node_str = "{:}<-{:}".format(i, j) | ||||
|                     ws = weights[self.edge2index[node_str]] | ||||
|                     for k, op_name in enumerate(self.op_names): | ||||
|                         if op_name == "none": | ||||
|                             continue | ||||
|                         edges.append((op_name, j, ws[k])) | ||||
|                 # (TODO) xuanyidong: | ||||
|                 # Here the selected two edges might come from the same input node. | ||||
|                 # And this case could be a problem that two edges will collapse into a single one | ||||
|                 # due to our assumption -- at most one edge from an input node during evaluation. | ||||
|                 edges = sorted(edges, key=lambda x: -x[-1]) | ||||
|                 selected_edges = edges[:2] | ||||
|                 gene.append(tuple(selected_edges)) | ||||
|             return gene | ||||
|  | ||||
|         with torch.no_grad(): | ||||
|             gene_normal = _parse( | ||||
|                 torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy() | ||||
|             ) | ||||
|             gene_reduce = _parse( | ||||
|                 torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy() | ||||
|             ) | ||||
|         return { | ||||
|             "normal": gene_normal, | ||||
|             "normal_concat": list( | ||||
|                 range(2 + self._steps - self._multiplier, self._steps + 2) | ||||
|             ), | ||||
|             "reduce": gene_reduce, | ||||
|             "reduce_concat": list( | ||||
|                 range(2 + self._steps - self._multiplier, self._steps + 2) | ||||
|             ), | ||||
|         } | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         normal_w = nn.functional.softmax(self.arch_normal_parameters, dim=1) | ||||
|         reduce_w = nn.functional.softmax(self.arch_reduce_parameters, dim=1) | ||||
|  | ||||
|         s0 = s1 = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             if cell.reduction: | ||||
|                 ww = reduce_w | ||||
|             else: | ||||
|                 ww = normal_w | ||||
|             s0, s1 = s1, cell.forward_darts(s0, s1, ww) | ||||
|         out = self.lastact(s1) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
							
								
								
									
										114
									
								
								AutoDL-Projects/xautodl/models/cell_searchs/search_model_enas.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								AutoDL-Projects/xautodl/models/cell_searchs/search_model_enas.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,114 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ########################################################################## | ||||
| # Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 # | ||||
| ########################################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .search_cells import NAS201SearchCell as SearchCell | ||||
| from .genotypes import Structure | ||||
| from .search_model_enas_utils import Controller | ||||
|  | ||||
|  | ||||
| class TinyNetworkENAS(nn.Module): | ||||
|     def __init__( | ||||
|         self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats | ||||
|     ): | ||||
|         super(TinyNetworkENAS, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self.max_nodes = max_nodes | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) | ||||
|         ) | ||||
|  | ||||
|         layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|         C_prev, num_edge, edge2index = C, None, None | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||
|             else: | ||||
|                 cell = SearchCell( | ||||
|                     C_prev, | ||||
|                     C_curr, | ||||
|                     1, | ||||
|                     max_nodes, | ||||
|                     search_space, | ||||
|                     affine, | ||||
|                     track_running_stats, | ||||
|                 ) | ||||
|                 if num_edge is None: | ||||
|                     num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|                 else: | ||||
|                     assert ( | ||||
|                         num_edge == cell.num_edges and edge2index == cell.edge2index | ||||
|                     ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self.cells.append(cell) | ||||
|             C_prev = cell.out_dim | ||||
|         self.op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self.cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         # to maintain the sampled architecture | ||||
|         self.sampled_arch = None | ||||
|  | ||||
|     def update_arch(self, _arch): | ||||
|         if _arch is None: | ||||
|             self.sampled_arch = None | ||||
|         elif isinstance(_arch, Structure): | ||||
|             self.sampled_arch = _arch | ||||
|         elif isinstance(_arch, (list, tuple)): | ||||
|             genotypes = [] | ||||
|             for i in range(1, self.max_nodes): | ||||
|                 xlist = [] | ||||
|                 for j in range(i): | ||||
|                     node_str = "{:}<-{:}".format(i, j) | ||||
|                     op_index = _arch[self.edge2index[node_str]] | ||||
|                     op_name = self.op_names[op_index] | ||||
|                     xlist.append((op_name, j)) | ||||
|                 genotypes.append(tuple(xlist)) | ||||
|             self.sampled_arch = Structure(genotypes) | ||||
|         else: | ||||
|             raise ValueError("invalid type of input architecture : {:}".format(_arch)) | ||||
|         return self.sampled_arch | ||||
|  | ||||
|     def create_controller(self): | ||||
|         return Controller(len(self.edge2index), len(self.op_names)) | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         feature = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             if isinstance(cell, SearchCell): | ||||
|                 feature = cell.forward_dynamic(feature, self.sampled_arch) | ||||
|             else: | ||||
|                 feature = cell(feature) | ||||
|  | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
| @@ -0,0 +1,74 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ########################################################################## | ||||
| # Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 # | ||||
| ########################################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from torch.distributions.categorical import Categorical | ||||
|  | ||||
|  | ||||
| class Controller(nn.Module): | ||||
|     # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py | ||||
|     def __init__( | ||||
|         self, | ||||
|         num_edge, | ||||
|         num_ops, | ||||
|         lstm_size=32, | ||||
|         lstm_num_layers=2, | ||||
|         tanh_constant=2.5, | ||||
|         temperature=5.0, | ||||
|     ): | ||||
|         super(Controller, self).__init__() | ||||
|         # assign the attributes | ||||
|         self.num_edge = num_edge | ||||
|         self.num_ops = num_ops | ||||
|         self.lstm_size = lstm_size | ||||
|         self.lstm_N = lstm_num_layers | ||||
|         self.tanh_constant = tanh_constant | ||||
|         self.temperature = temperature | ||||
|         # create parameters | ||||
|         self.register_parameter( | ||||
|             "input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size)) | ||||
|         ) | ||||
|         self.w_lstm = nn.LSTM( | ||||
|             input_size=self.lstm_size, | ||||
|             hidden_size=self.lstm_size, | ||||
|             num_layers=self.lstm_N, | ||||
|         ) | ||||
|         self.w_embd = nn.Embedding(self.num_ops, self.lstm_size) | ||||
|         self.w_pred = nn.Linear(self.lstm_size, self.num_ops) | ||||
|  | ||||
|         nn.init.uniform_(self.input_vars, -0.1, 0.1) | ||||
|         nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1) | ||||
|         nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1) | ||||
|         nn.init.uniform_(self.w_embd.weight, -0.1, 0.1) | ||||
|         nn.init.uniform_(self.w_pred.weight, -0.1, 0.1) | ||||
|  | ||||
|     def forward(self): | ||||
|  | ||||
|         inputs, h0 = self.input_vars, None | ||||
|         log_probs, entropys, sampled_arch = [], [], [] | ||||
|         for iedge in range(self.num_edge): | ||||
|             outputs, h0 = self.w_lstm(inputs, h0) | ||||
|  | ||||
|             logits = self.w_pred(outputs) | ||||
|             logits = logits / self.temperature | ||||
|             logits = self.tanh_constant * torch.tanh(logits) | ||||
|             # distribution | ||||
|             op_distribution = Categorical(logits=logits) | ||||
|             op_index = op_distribution.sample() | ||||
|             sampled_arch.append(op_index.item()) | ||||
|  | ||||
|             op_log_prob = op_distribution.log_prob(op_index) | ||||
|             log_probs.append(op_log_prob.view(-1)) | ||||
|             op_entropy = op_distribution.entropy() | ||||
|             entropys.append(op_entropy.view(-1)) | ||||
|  | ||||
|             # obtain the input embedding for the next step | ||||
|             inputs = self.w_embd(op_index) | ||||
|         return ( | ||||
|             torch.sum(torch.cat(log_probs)), | ||||
|             torch.sum(torch.cat(entropys)), | ||||
|             sampled_arch, | ||||
|         ) | ||||
							
								
								
									
										142
									
								
								AutoDL-Projects/xautodl/models/cell_searchs/search_model_gdas.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								AutoDL-Projects/xautodl/models/cell_searchs/search_model_gdas.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,142 @@ | ||||
| ########################################################################### | ||||
| # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # | ||||
| ########################################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .search_cells import NAS201SearchCell as SearchCell | ||||
| from .genotypes import Structure | ||||
|  | ||||
|  | ||||
| class TinyNetworkGDAS(nn.Module): | ||||
|  | ||||
|     # def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True): | ||||
|     def __init__( | ||||
|         self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats | ||||
|     ): | ||||
|         super(TinyNetworkGDAS, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self.max_nodes = max_nodes | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) | ||||
|         ) | ||||
|  | ||||
|         layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|         C_prev, num_edge, edge2index = C, None, None | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||
|             else: | ||||
|                 cell = SearchCell( | ||||
|                     C_prev, | ||||
|                     C_curr, | ||||
|                     1, | ||||
|                     max_nodes, | ||||
|                     search_space, | ||||
|                     affine, | ||||
|                     track_running_stats, | ||||
|                 ) | ||||
|                 if num_edge is None: | ||||
|                     num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|                 else: | ||||
|                     assert ( | ||||
|                         num_edge == cell.num_edges and edge2index == cell.edge2index | ||||
|                     ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self.cells.append(cell) | ||||
|             C_prev = cell.out_dim | ||||
|         self.op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self.cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self.arch_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|         self.tau = 10 | ||||
|  | ||||
|     def get_weights(self): | ||||
|         xlist = list(self.stem.parameters()) + list(self.cells.parameters()) | ||||
|         xlist += list(self.lastact.parameters()) + list( | ||||
|             self.global_pooling.parameters() | ||||
|         ) | ||||
|         xlist += list(self.classifier.parameters()) | ||||
|         return xlist | ||||
|  | ||||
|     def set_tau(self, tau): | ||||
|         self.tau = tau | ||||
|  | ||||
|     def get_tau(self): | ||||
|         return self.tau | ||||
|  | ||||
|     def get_alphas(self): | ||||
|         return [self.arch_parameters] | ||||
|  | ||||
|     def show_alphas(self): | ||||
|         with torch.no_grad(): | ||||
|             return "arch-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self.arch_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def genotype(self): | ||||
|         genotypes = [] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 with torch.no_grad(): | ||||
|                     weights = self.arch_parameters[self.edge2index[node_str]] | ||||
|                     op_name = self.op_names[weights.argmax().item()] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         while True: | ||||
|             gumbels = -torch.empty_like(self.arch_parameters).exponential_().log() | ||||
|             logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau | ||||
|             probs = nn.functional.softmax(logits, dim=1) | ||||
|             index = probs.max(-1, keepdim=True)[1] | ||||
|             one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) | ||||
|             hardwts = one_h - probs.detach() + probs | ||||
|             if ( | ||||
|                 (torch.isinf(gumbels).any()) | ||||
|                 or (torch.isinf(probs).any()) | ||||
|                 or (torch.isnan(probs).any()) | ||||
|             ): | ||||
|                 continue | ||||
|             else: | ||||
|                 break | ||||
|  | ||||
|         feature = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             if isinstance(cell, SearchCell): | ||||
|                 feature = cell.forward_gdas(feature, hardwts, index) | ||||
|             else: | ||||
|                 feature = cell(feature) | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
| @@ -0,0 +1,200 @@ | ||||
| ########################################################################### | ||||
| # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # | ||||
| ########################################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
|  | ||||
| from .search_cells import NASNetSearchCell as SearchCell | ||||
| from ..cell_operations import RAW_OP_CLASSES | ||||
|  | ||||
|  | ||||
| # The macro structure is based on NASNet | ||||
| class NASNetworkGDAS_FRC(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         C, | ||||
|         N, | ||||
|         steps, | ||||
|         multiplier, | ||||
|         stem_multiplier, | ||||
|         num_classes, | ||||
|         search_space, | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ): | ||||
|         super(NASNetworkGDAS_FRC, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self._steps = steps | ||||
|         self._multiplier = multiplier | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C * stem_multiplier), | ||||
|         ) | ||||
|  | ||||
|         # config for each layer | ||||
|         layer_channels = ( | ||||
|             [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1) | ||||
|         ) | ||||
|         layer_reductions = ( | ||||
|             [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1) | ||||
|         ) | ||||
|  | ||||
|         num_edge, edge2index = None, None | ||||
|         C_prev_prev, C_prev, C_curr, reduction_prev = ( | ||||
|             C * stem_multiplier, | ||||
|             C * stem_multiplier, | ||||
|             C, | ||||
|             False, | ||||
|         ) | ||||
|  | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             if reduction: | ||||
|                 cell = RAW_OP_CLASSES["gdas_reduction"]( | ||||
|                     C_prev_prev, | ||||
|                     C_prev, | ||||
|                     C_curr, | ||||
|                     reduction_prev, | ||||
|                     affine, | ||||
|                     track_running_stats, | ||||
|                 ) | ||||
|             else: | ||||
|                 cell = SearchCell( | ||||
|                     search_space, | ||||
|                     steps, | ||||
|                     multiplier, | ||||
|                     C_prev_prev, | ||||
|                     C_prev, | ||||
|                     C_curr, | ||||
|                     reduction, | ||||
|                     reduction_prev, | ||||
|                     affine, | ||||
|                     track_running_stats, | ||||
|                 ) | ||||
|             if num_edge is None: | ||||
|                 num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|             else: | ||||
|                 assert ( | ||||
|                     reduction | ||||
|                     or num_edge == cell.num_edges | ||||
|                     and edge2index == cell.edge2index | ||||
|                 ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self.cells.append(cell) | ||||
|             C_prev_prev, C_prev, reduction_prev = ( | ||||
|                 C_prev, | ||||
|                 cell.multiplier * C_curr, | ||||
|                 reduction, | ||||
|             ) | ||||
|         self.op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self.cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self.arch_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|         self.tau = 10 | ||||
|  | ||||
|     def get_weights(self): | ||||
|         xlist = list(self.stem.parameters()) + list(self.cells.parameters()) | ||||
|         xlist += list(self.lastact.parameters()) + list( | ||||
|             self.global_pooling.parameters() | ||||
|         ) | ||||
|         xlist += list(self.classifier.parameters()) | ||||
|         return xlist | ||||
|  | ||||
|     def set_tau(self, tau): | ||||
|         self.tau = tau | ||||
|  | ||||
|     def get_tau(self): | ||||
|         return self.tau | ||||
|  | ||||
|     def get_alphas(self): | ||||
|         return [self.arch_parameters] | ||||
|  | ||||
|     def show_alphas(self): | ||||
|         with torch.no_grad(): | ||||
|             A = "arch-normal-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self.arch_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|         return "{:}".format(A) | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def genotype(self): | ||||
|         def _parse(weights): | ||||
|             gene = [] | ||||
|             for i in range(self._steps): | ||||
|                 edges = [] | ||||
|                 for j in range(2 + i): | ||||
|                     node_str = "{:}<-{:}".format(i, j) | ||||
|                     ws = weights[self.edge2index[node_str]] | ||||
|                     for k, op_name in enumerate(self.op_names): | ||||
|                         if op_name == "none": | ||||
|                             continue | ||||
|                         edges.append((op_name, j, ws[k])) | ||||
|                 edges = sorted(edges, key=lambda x: -x[-1]) | ||||
|                 selected_edges = edges[:2] | ||||
|                 gene.append(tuple(selected_edges)) | ||||
|             return gene | ||||
|  | ||||
|         with torch.no_grad(): | ||||
|             gene_normal = _parse( | ||||
|                 torch.softmax(self.arch_parameters, dim=-1).cpu().numpy() | ||||
|             ) | ||||
|         return { | ||||
|             "normal": gene_normal, | ||||
|             "normal_concat": list( | ||||
|                 range(2 + self._steps - self._multiplier, self._steps + 2) | ||||
|             ), | ||||
|         } | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         def get_gumbel_prob(xins): | ||||
|             while True: | ||||
|                 gumbels = -torch.empty_like(xins).exponential_().log() | ||||
|                 logits = (xins.log_softmax(dim=1) + gumbels) / self.tau | ||||
|                 probs = nn.functional.softmax(logits, dim=1) | ||||
|                 index = probs.max(-1, keepdim=True)[1] | ||||
|                 one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) | ||||
|                 hardwts = one_h - probs.detach() + probs | ||||
|                 if ( | ||||
|                     (torch.isinf(gumbels).any()) | ||||
|                     or (torch.isinf(probs).any()) | ||||
|                     or (torch.isnan(probs).any()) | ||||
|                 ): | ||||
|                     continue | ||||
|                 else: | ||||
|                     break | ||||
|             return hardwts, index | ||||
|  | ||||
|         hardwts, index = get_gumbel_prob(self.arch_parameters) | ||||
|  | ||||
|         s0 = s1 = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             if cell.reduction: | ||||
|                 s0, s1 = s1, cell(s0, s1) | ||||
|             else: | ||||
|                 s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index) | ||||
|         out = self.lastact(s1) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
| @@ -0,0 +1,197 @@ | ||||
| ########################################################################### | ||||
| # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # | ||||
| ########################################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from .search_cells import NASNetSearchCell as SearchCell | ||||
|  | ||||
|  | ||||
| # The macro structure is based on NASNet | ||||
| class NASNetworkGDAS(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         C, | ||||
|         N, | ||||
|         steps, | ||||
|         multiplier, | ||||
|         stem_multiplier, | ||||
|         num_classes, | ||||
|         search_space, | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ): | ||||
|         super(NASNetworkGDAS, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self._steps = steps | ||||
|         self._multiplier = multiplier | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C * stem_multiplier), | ||||
|         ) | ||||
|  | ||||
|         # config for each layer | ||||
|         layer_channels = ( | ||||
|             [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1) | ||||
|         ) | ||||
|         layer_reductions = ( | ||||
|             [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1) | ||||
|         ) | ||||
|  | ||||
|         num_edge, edge2index = None, None | ||||
|         C_prev_prev, C_prev, C_curr, reduction_prev = ( | ||||
|             C * stem_multiplier, | ||||
|             C * stem_multiplier, | ||||
|             C, | ||||
|             False, | ||||
|         ) | ||||
|  | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             cell = SearchCell( | ||||
|                 search_space, | ||||
|                 steps, | ||||
|                 multiplier, | ||||
|                 C_prev_prev, | ||||
|                 C_prev, | ||||
|                 C_curr, | ||||
|                 reduction, | ||||
|                 reduction_prev, | ||||
|                 affine, | ||||
|                 track_running_stats, | ||||
|             ) | ||||
|             if num_edge is None: | ||||
|                 num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|             else: | ||||
|                 assert ( | ||||
|                     num_edge == cell.num_edges and edge2index == cell.edge2index | ||||
|                 ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self.cells.append(cell) | ||||
|             C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction | ||||
|         self.op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self.cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self.arch_normal_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|         self.arch_reduce_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|         self.tau = 10 | ||||
|  | ||||
|     def get_weights(self): | ||||
|         xlist = list(self.stem.parameters()) + list(self.cells.parameters()) | ||||
|         xlist += list(self.lastact.parameters()) + list( | ||||
|             self.global_pooling.parameters() | ||||
|         ) | ||||
|         xlist += list(self.classifier.parameters()) | ||||
|         return xlist | ||||
|  | ||||
|     def set_tau(self, tau): | ||||
|         self.tau = tau | ||||
|  | ||||
|     def get_tau(self): | ||||
|         return self.tau | ||||
|  | ||||
|     def get_alphas(self): | ||||
|         return [self.arch_normal_parameters, self.arch_reduce_parameters] | ||||
|  | ||||
|     def show_alphas(self): | ||||
|         with torch.no_grad(): | ||||
|             A = "arch-normal-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|             B = "arch-reduce-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|         return "{:}\n{:}".format(A, B) | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def genotype(self): | ||||
|         def _parse(weights): | ||||
|             gene = [] | ||||
|             for i in range(self._steps): | ||||
|                 edges = [] | ||||
|                 for j in range(2 + i): | ||||
|                     node_str = "{:}<-{:}".format(i, j) | ||||
|                     ws = weights[self.edge2index[node_str]] | ||||
|                     for k, op_name in enumerate(self.op_names): | ||||
|                         if op_name == "none": | ||||
|                             continue | ||||
|                         edges.append((op_name, j, ws[k])) | ||||
|                 edges = sorted(edges, key=lambda x: -x[-1]) | ||||
|                 selected_edges = edges[:2] | ||||
|                 gene.append(tuple(selected_edges)) | ||||
|             return gene | ||||
|  | ||||
|         with torch.no_grad(): | ||||
|             gene_normal = _parse( | ||||
|                 torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy() | ||||
|             ) | ||||
|             gene_reduce = _parse( | ||||
|                 torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy() | ||||
|             ) | ||||
|         return { | ||||
|             "normal": gene_normal, | ||||
|             "normal_concat": list( | ||||
|                 range(2 + self._steps - self._multiplier, self._steps + 2) | ||||
|             ), | ||||
|             "reduce": gene_reduce, | ||||
|             "reduce_concat": list( | ||||
|                 range(2 + self._steps - self._multiplier, self._steps + 2) | ||||
|             ), | ||||
|         } | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         def get_gumbel_prob(xins): | ||||
|             while True: | ||||
|                 gumbels = -torch.empty_like(xins).exponential_().log() | ||||
|                 logits = (xins.log_softmax(dim=1) + gumbels) / self.tau | ||||
|                 probs = nn.functional.softmax(logits, dim=1) | ||||
|                 index = probs.max(-1, keepdim=True)[1] | ||||
|                 one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) | ||||
|                 hardwts = one_h - probs.detach() + probs | ||||
|                 if ( | ||||
|                     (torch.isinf(gumbels).any()) | ||||
|                     or (torch.isinf(probs).any()) | ||||
|                     or (torch.isnan(probs).any()) | ||||
|                 ): | ||||
|                     continue | ||||
|                 else: | ||||
|                     break | ||||
|             return hardwts, index | ||||
|  | ||||
|         normal_hardwts, normal_index = get_gumbel_prob(self.arch_normal_parameters) | ||||
|         reduce_hardwts, reduce_index = get_gumbel_prob(self.arch_reduce_parameters) | ||||
|  | ||||
|         s0 = s1 = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             if cell.reduction: | ||||
|                 hardwts, index = reduce_hardwts, reduce_index | ||||
|             else: | ||||
|                 hardwts, index = normal_hardwts, normal_index | ||||
|             s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index) | ||||
|         out = self.lastact(s1) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
| @@ -0,0 +1,102 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ############################################################################## | ||||
| # Random Search and Reproducibility for Neural Architecture Search, UAI 2019 # | ||||
| ############################################################################## | ||||
| import torch, random | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .search_cells import NAS201SearchCell as SearchCell | ||||
| from .genotypes import Structure | ||||
|  | ||||
|  | ||||
| class TinyNetworkRANDOM(nn.Module): | ||||
|     def __init__( | ||||
|         self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats | ||||
|     ): | ||||
|         super(TinyNetworkRANDOM, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self.max_nodes = max_nodes | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) | ||||
|         ) | ||||
|  | ||||
|         layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|         C_prev, num_edge, edge2index = C, None, None | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||
|             else: | ||||
|                 cell = SearchCell( | ||||
|                     C_prev, | ||||
|                     C_curr, | ||||
|                     1, | ||||
|                     max_nodes, | ||||
|                     search_space, | ||||
|                     affine, | ||||
|                     track_running_stats, | ||||
|                 ) | ||||
|                 if num_edge is None: | ||||
|                     num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|                 else: | ||||
|                     assert ( | ||||
|                         num_edge == cell.num_edges and edge2index == cell.edge2index | ||||
|                     ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self.cells.append(cell) | ||||
|             C_prev = cell.out_dim | ||||
|         self.op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self.cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self.arch_cache = None | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def random_genotype(self, set_cache): | ||||
|         genotypes = [] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 op_name = random.choice(self.op_names) | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         arch = Structure(genotypes) | ||||
|         if set_cache: | ||||
|             self.arch_cache = arch | ||||
|         return arch | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         feature = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             if isinstance(cell, SearchCell): | ||||
|                 feature = cell.forward_dynamic(feature, self.arch_cache) | ||||
|             else: | ||||
|                 feature = cell(feature) | ||||
|  | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|         return out, logits | ||||
							
								
								
									
										178
									
								
								AutoDL-Projects/xautodl/models/cell_searchs/search_model_setn.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								AutoDL-Projects/xautodl/models/cell_searchs/search_model_setn.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,178 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ###################################################################################### | ||||
| # One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 # | ||||
| ###################################################################################### | ||||
| import torch, random | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .search_cells import NAS201SearchCell as SearchCell | ||||
| from .genotypes import Structure | ||||
|  | ||||
|  | ||||
| class TinyNetworkSETN(nn.Module): | ||||
|     def __init__( | ||||
|         self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats | ||||
|     ): | ||||
|         super(TinyNetworkSETN, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self.max_nodes = max_nodes | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) | ||||
|         ) | ||||
|  | ||||
|         layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|         C_prev, num_edge, edge2index = C, None, None | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||
|             else: | ||||
|                 cell = SearchCell( | ||||
|                     C_prev, | ||||
|                     C_curr, | ||||
|                     1, | ||||
|                     max_nodes, | ||||
|                     search_space, | ||||
|                     affine, | ||||
|                     track_running_stats, | ||||
|                 ) | ||||
|                 if num_edge is None: | ||||
|                     num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|                 else: | ||||
|                     assert ( | ||||
|                         num_edge == cell.num_edges and edge2index == cell.edge2index | ||||
|                     ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self.cells.append(cell) | ||||
|             C_prev = cell.out_dim | ||||
|         self.op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self.cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self.arch_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|         self.mode = "urs" | ||||
|         self.dynamic_cell = None | ||||
|  | ||||
|     def set_cal_mode(self, mode, dynamic_cell=None): | ||||
|         assert mode in ["urs", "joint", "select", "dynamic"] | ||||
|         self.mode = mode | ||||
|         if mode == "dynamic": | ||||
|             self.dynamic_cell = deepcopy(dynamic_cell) | ||||
|         else: | ||||
|             self.dynamic_cell = None | ||||
|  | ||||
|     def get_cal_mode(self): | ||||
|         return self.mode | ||||
|  | ||||
|     def get_weights(self): | ||||
|         xlist = list(self.stem.parameters()) + list(self.cells.parameters()) | ||||
|         xlist += list(self.lastact.parameters()) + list( | ||||
|             self.global_pooling.parameters() | ||||
|         ) | ||||
|         xlist += list(self.classifier.parameters()) | ||||
|         return xlist | ||||
|  | ||||
|     def get_alphas(self): | ||||
|         return [self.arch_parameters] | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def genotype(self): | ||||
|         genotypes = [] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 with torch.no_grad(): | ||||
|                     weights = self.arch_parameters[self.edge2index[node_str]] | ||||
|                     op_name = self.op_names[weights.argmax().item()] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     def dync_genotype(self, use_random=False): | ||||
|         genotypes = [] | ||||
|         with torch.no_grad(): | ||||
|             alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|         for i in range(1, self.max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 if use_random: | ||||
|                     op_name = random.choice(self.op_names) | ||||
|                 else: | ||||
|                     weights = alphas_cpu[self.edge2index[node_str]] | ||||
|                     op_index = torch.multinomial(weights, 1).item() | ||||
|                     op_name = self.op_names[op_index] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     def get_log_prob(self, arch): | ||||
|         with torch.no_grad(): | ||||
|             logits = nn.functional.log_softmax(self.arch_parameters, dim=-1) | ||||
|         select_logits = [] | ||||
|         for i, node_info in enumerate(arch.nodes): | ||||
|             for op, xin in node_info: | ||||
|                 node_str = "{:}<-{:}".format(i + 1, xin) | ||||
|                 op_index = self.op_names.index(op) | ||||
|                 select_logits.append(logits[self.edge2index[node_str], op_index]) | ||||
|         return sum(select_logits).item() | ||||
|  | ||||
|     def return_topK(self, K): | ||||
|         archs = Structure.gen_all(self.op_names, self.max_nodes, False) | ||||
|         pairs = [(self.get_log_prob(arch), arch) for arch in archs] | ||||
|         if K < 0 or K >= len(archs): | ||||
|             K = len(archs) | ||||
|         sorted_pairs = sorted(pairs, key=lambda x: -x[0]) | ||||
|         return_pairs = [sorted_pairs[_][1] for _ in range(K)] | ||||
|         return return_pairs | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         alphas = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|         with torch.no_grad(): | ||||
|             alphas_cpu = alphas.detach().cpu() | ||||
|  | ||||
|         feature = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             if isinstance(cell, SearchCell): | ||||
|                 if self.mode == "urs": | ||||
|                     feature = cell.forward_urs(feature) | ||||
|                 elif self.mode == "select": | ||||
|                     feature = cell.forward_select(feature, alphas_cpu) | ||||
|                 elif self.mode == "joint": | ||||
|                     feature = cell.forward_joint(feature, alphas) | ||||
|                 elif self.mode == "dynamic": | ||||
|                     feature = cell.forward_dynamic(feature, self.dynamic_cell) | ||||
|                 else: | ||||
|                     raise ValueError("invalid mode={:}".format(self.mode)) | ||||
|             else: | ||||
|                 feature = cell(feature) | ||||
|  | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
| @@ -0,0 +1,205 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ###################################################################################### | ||||
| # One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 # | ||||
| ###################################################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from typing import List, Text, Dict | ||||
| from .search_cells import NASNetSearchCell as SearchCell | ||||
|  | ||||
|  | ||||
| # The macro structure is based on NASNet | ||||
| class NASNetworkSETN(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         C: int, | ||||
|         N: int, | ||||
|         steps: int, | ||||
|         multiplier: int, | ||||
|         stem_multiplier: int, | ||||
|         num_classes: int, | ||||
|         search_space: List[Text], | ||||
|         affine: bool, | ||||
|         track_running_stats: bool, | ||||
|     ): | ||||
|         super(NASNetworkSETN, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self._steps = steps | ||||
|         self._multiplier = multiplier | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C * stem_multiplier), | ||||
|         ) | ||||
|  | ||||
|         # config for each layer | ||||
|         layer_channels = ( | ||||
|             [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1) | ||||
|         ) | ||||
|         layer_reductions = ( | ||||
|             [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1) | ||||
|         ) | ||||
|  | ||||
|         num_edge, edge2index = None, None | ||||
|         C_prev_prev, C_prev, C_curr, reduction_prev = ( | ||||
|             C * stem_multiplier, | ||||
|             C * stem_multiplier, | ||||
|             C, | ||||
|             False, | ||||
|         ) | ||||
|  | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             cell = SearchCell( | ||||
|                 search_space, | ||||
|                 steps, | ||||
|                 multiplier, | ||||
|                 C_prev_prev, | ||||
|                 C_prev, | ||||
|                 C_curr, | ||||
|                 reduction, | ||||
|                 reduction_prev, | ||||
|                 affine, | ||||
|                 track_running_stats, | ||||
|             ) | ||||
|             if num_edge is None: | ||||
|                 num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|             else: | ||||
|                 assert ( | ||||
|                     num_edge == cell.num_edges and edge2index == cell.edge2index | ||||
|                 ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self.cells.append(cell) | ||||
|             C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction | ||||
|         self.op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self.cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self.arch_normal_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|         self.arch_reduce_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|         self.mode = "urs" | ||||
|         self.dynamic_cell = None | ||||
|  | ||||
|     def set_cal_mode(self, mode, dynamic_cell=None): | ||||
|         assert mode in ["urs", "joint", "select", "dynamic"] | ||||
|         self.mode = mode | ||||
|         if mode == "dynamic": | ||||
|             self.dynamic_cell = deepcopy(dynamic_cell) | ||||
|         else: | ||||
|             self.dynamic_cell = None | ||||
|  | ||||
|     def get_weights(self): | ||||
|         xlist = list(self.stem.parameters()) + list(self.cells.parameters()) | ||||
|         xlist += list(self.lastact.parameters()) + list( | ||||
|             self.global_pooling.parameters() | ||||
|         ) | ||||
|         xlist += list(self.classifier.parameters()) | ||||
|         return xlist | ||||
|  | ||||
|     def get_alphas(self): | ||||
|         return [self.arch_normal_parameters, self.arch_reduce_parameters] | ||||
|  | ||||
|     def show_alphas(self): | ||||
|         with torch.no_grad(): | ||||
|             A = "arch-normal-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|             B = "arch-reduce-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|         return "{:}\n{:}".format(A, B) | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def dync_genotype(self, use_random=False): | ||||
|         genotypes = [] | ||||
|         with torch.no_grad(): | ||||
|             alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|         for i in range(1, self.max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 if use_random: | ||||
|                     op_name = random.choice(self.op_names) | ||||
|                 else: | ||||
|                     weights = alphas_cpu[self.edge2index[node_str]] | ||||
|                     op_index = torch.multinomial(weights, 1).item() | ||||
|                     op_name = self.op_names[op_index] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     def genotype(self): | ||||
|         def _parse(weights): | ||||
|             gene = [] | ||||
|             for i in range(self._steps): | ||||
|                 edges = [] | ||||
|                 for j in range(2 + i): | ||||
|                     node_str = "{:}<-{:}".format(i, j) | ||||
|                     ws = weights[self.edge2index[node_str]] | ||||
|                     for k, op_name in enumerate(self.op_names): | ||||
|                         if op_name == "none": | ||||
|                             continue | ||||
|                         edges.append((op_name, j, ws[k])) | ||||
|                 edges = sorted(edges, key=lambda x: -x[-1]) | ||||
|                 selected_edges = edges[:2] | ||||
|                 gene.append(tuple(selected_edges)) | ||||
|             return gene | ||||
|  | ||||
|         with torch.no_grad(): | ||||
|             gene_normal = _parse( | ||||
|                 torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy() | ||||
|             ) | ||||
|             gene_reduce = _parse( | ||||
|                 torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy() | ||||
|             ) | ||||
|         return { | ||||
|             "normal": gene_normal, | ||||
|             "normal_concat": list( | ||||
|                 range(2 + self._steps - self._multiplier, self._steps + 2) | ||||
|             ), | ||||
|             "reduce": gene_reduce, | ||||
|             "reduce_concat": list( | ||||
|                 range(2 + self._steps - self._multiplier, self._steps + 2) | ||||
|             ), | ||||
|         } | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1) | ||||
|         reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1) | ||||
|  | ||||
|         s0 = s1 = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             # [TODO] | ||||
|             raise NotImplementedError | ||||
|             if cell.reduction: | ||||
|                 hardwts, index = reduce_hardwts, reduce_index | ||||
|             else: | ||||
|                 hardwts, index = normal_hardwts, normal_index | ||||
|             s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index) | ||||
|         out = self.lastact(s1) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
							
								
								
									
										74
									
								
								AutoDL-Projects/xautodl/models/clone_weights.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								AutoDL-Projects/xautodl/models/clone_weights.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,74 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
|  | ||||
| def copy_conv(module, init): | ||||
|     assert isinstance(module, nn.Conv2d), "invalid module : {:}".format(module) | ||||
|     assert isinstance(init, nn.Conv2d), "invalid module : {:}".format(init) | ||||
|     new_i, new_o = module.in_channels, module.out_channels | ||||
|     module.weight.copy_(init.weight.detach()[:new_o, :new_i]) | ||||
|     if module.bias is not None: | ||||
|         module.bias.copy_(init.bias.detach()[:new_o]) | ||||
|  | ||||
|  | ||||
| def copy_bn(module, init): | ||||
|     assert isinstance(module, nn.BatchNorm2d), "invalid module : {:}".format(module) | ||||
|     assert isinstance(init, nn.BatchNorm2d), "invalid module : {:}".format(init) | ||||
|     num_features = module.num_features | ||||
|     if module.weight is not None: | ||||
|         module.weight.copy_(init.weight.detach()[:num_features]) | ||||
|     if module.bias is not None: | ||||
|         module.bias.copy_(init.bias.detach()[:num_features]) | ||||
|     if module.running_mean is not None: | ||||
|         module.running_mean.copy_(init.running_mean.detach()[:num_features]) | ||||
|     if module.running_var is not None: | ||||
|         module.running_var.copy_(init.running_var.detach()[:num_features]) | ||||
|  | ||||
|  | ||||
| def copy_fc(module, init): | ||||
|     assert isinstance(module, nn.Linear), "invalid module : {:}".format(module) | ||||
|     assert isinstance(init, nn.Linear), "invalid module : {:}".format(init) | ||||
|     new_i, new_o = module.in_features, module.out_features | ||||
|     module.weight.copy_(init.weight.detach()[:new_o, :new_i]) | ||||
|     if module.bias is not None: | ||||
|         module.bias.copy_(init.bias.detach()[:new_o]) | ||||
|  | ||||
|  | ||||
| def copy_base(module, init): | ||||
|     assert type(module).__name__ in [ | ||||
|         "ConvBNReLU", | ||||
|         "Downsample", | ||||
|     ], "invalid module : {:}".format(module) | ||||
|     assert type(init).__name__ in [ | ||||
|         "ConvBNReLU", | ||||
|         "Downsample", | ||||
|     ], "invalid module : {:}".format(init) | ||||
|     if module.conv is not None: | ||||
|         copy_conv(module.conv, init.conv) | ||||
|     if module.bn is not None: | ||||
|         copy_bn(module.bn, init.bn) | ||||
|  | ||||
|  | ||||
| def copy_basic(module, init): | ||||
|     copy_base(module.conv_a, init.conv_a) | ||||
|     copy_base(module.conv_b, init.conv_b) | ||||
|     if module.downsample is not None: | ||||
|         if init.downsample is not None: | ||||
|             copy_base(module.downsample, init.downsample) | ||||
|         # else: | ||||
|         # import pdb; pdb.set_trace() | ||||
|  | ||||
|  | ||||
| def init_from_model(network, init_model): | ||||
|     with torch.no_grad(): | ||||
|         copy_fc(network.classifier, init_model.classifier) | ||||
|         for base, target in zip(init_model.layers, network.layers): | ||||
|             assert ( | ||||
|                 type(base).__name__ == type(target).__name__ | ||||
|             ), "invalid type : {:} vs {:}".format(base, target) | ||||
|             if type(base).__name__ == "ConvBNReLU": | ||||
|                 copy_base(target, base) | ||||
|             elif type(base).__name__ == "ResNetBasicblock": | ||||
|                 copy_basic(target, base) | ||||
|             else: | ||||
|                 raise ValueError("unknown type name : {:}".format(type(base).__name__)) | ||||
							
								
								
									
										16
									
								
								AutoDL-Projects/xautodl/models/initialization.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								AutoDL-Projects/xautodl/models/initialization.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
|  | ||||
| def initialize_resnet(m): | ||||
|     if isinstance(m, nn.Conv2d): | ||||
|         nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | ||||
|         if m.bias is not None: | ||||
|             nn.init.constant_(m.bias, 0) | ||||
|     elif isinstance(m, nn.BatchNorm2d): | ||||
|         nn.init.constant_(m.weight, 1) | ||||
|         if m.bias is not None: | ||||
|             nn.init.constant_(m.bias, 0) | ||||
|     elif isinstance(m, nn.Linear): | ||||
|         nn.init.normal_(m.weight, 0, 0.01) | ||||
|         nn.init.constant_(m.bias, 0) | ||||
							
								
								
									
										287
									
								
								AutoDL-Projects/xautodl/models/shape_infers/InferCifarResNet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										287
									
								
								AutoDL-Projects/xautodl/models/shape_infers/InferCifarResNet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,287 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from ..initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         if has_bn: | ||||
|             self.bn = nn.BatchNorm2d(nOut) | ||||
|         else: | ||||
|             self.bn = None | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.bn: | ||||
|             out = self.bn(conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|  | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     num_conv = 2 | ||||
|     expansion = 1 | ||||
|  | ||||
|     def __init__(self, iCs, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         assert isinstance(iCs, tuple) or isinstance( | ||||
|             iCs, list | ||||
|         ), "invalid type of iCs : {:}".format(iCs) | ||||
|         assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs) | ||||
|  | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             iCs[0], | ||||
|             iCs[1], | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         residual_in = iCs[0] | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[2], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[2] | ||||
|         elif iCs[0] != iCs[2]: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[2], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         # self.out_dim  = max(residual_in, iCs[2]) | ||||
|         self.out_dim = iCs[2] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + basicblock | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|     def __init__(self, iCs, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         assert isinstance(iCs, tuple) or isinstance( | ||||
|             iCs, list | ||||
|         ), "invalid type of iCs : {:}".format(iCs) | ||||
|         assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             iCs[1], | ||||
|             iCs[2], | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         residual_in = iCs[0] | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[3], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[3] | ||||
|         elif iCs[0] != iCs[3]: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[3], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[3] | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         # self.out_dim = max(residual_in, iCs[3]) | ||||
|         self.out_dim = iCs[3] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + bottleneck | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class InferCifarResNet(nn.Module): | ||||
|     def __init__( | ||||
|         self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual | ||||
|     ): | ||||
|         super(InferCifarResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|         assert len(xblocks) == 3, "invalid xblocks : {:}".format(xblocks) | ||||
|  | ||||
|         self.message = ( | ||||
|             "InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.xchannels = xchannels | ||||
|         self.layers = nn.ModuleList( | ||||
|             [ | ||||
|                 ConvBNReLU( | ||||
|                     xchannels[0], | ||||
|                     xchannels[1], | ||||
|                     3, | ||||
|                     1, | ||||
|                     1, | ||||
|                     False, | ||||
|                     has_avg=False, | ||||
|                     has_bn=True, | ||||
|                     has_relu=True, | ||||
|                 ) | ||||
|             ] | ||||
|         ) | ||||
|         last_channel_idx = 1 | ||||
|         for stage in range(3): | ||||
|             for iL in range(layer_blocks): | ||||
|                 num_conv = block.num_conv | ||||
|                 iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1] | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iCs, stride) | ||||
|                 last_channel_idx += num_conv | ||||
|                 self.xchannels[last_channel_idx] = module.out_dim | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iCs, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|                 if iL + 1 == xblocks[stage]:  # reach the maximum depth | ||||
|                     out_channel = module.out_dim | ||||
|                     for iiL in range(iL + 1, layer_blocks): | ||||
|                         last_channel_idx += num_conv | ||||
|                     self.xchannels[last_channel_idx] = module.out_dim | ||||
|                     break | ||||
|  | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(self.xchannels[-1], num_classes) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|         if zero_init_residual: | ||||
|             for m in self.modules(): | ||||
|                 if isinstance(m, ResNetBasicblock): | ||||
|                     nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|                 elif isinstance(m, ResNetBottleneck): | ||||
|                     nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
| @@ -0,0 +1,263 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from ..initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         if has_bn: | ||||
|             self.bn = nn.BatchNorm2d(nOut) | ||||
|         else: | ||||
|             self.bn = None | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.bn: | ||||
|             out = self.bn(conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|  | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     num_conv = 2 | ||||
|     expansion = 1 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|  | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             inplanes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + basicblock | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes * self.expansion, | ||||
|             1, | ||||
|             1, | ||||
|             0, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=False, | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes * self.expansion: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes * self.expansion | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + bottleneck | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class InferDepthCifarResNet(nn.Module): | ||||
|     def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual): | ||||
|         super(InferDepthCifarResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|         assert len(xblocks) == 3, "invalid xblocks : {:}".format(xblocks) | ||||
|  | ||||
|         self.message = ( | ||||
|             "InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.layers = nn.ModuleList( | ||||
|             [ | ||||
|                 ConvBNReLU( | ||||
|                     3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True | ||||
|                 ) | ||||
|             ] | ||||
|         ) | ||||
|         self.channels = [16] | ||||
|         for stage in range(3): | ||||
|             for iL in range(layer_blocks): | ||||
|                 iC = self.channels[-1] | ||||
|                 planes = 16 * (2 ** stage) | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iC, planes, stride) | ||||
|                 self.channels.append(module.out_dim) | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     planes, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|                 if iL + 1 == xblocks[stage]:  # reach the maximum depth | ||||
|                     break | ||||
|  | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(self.channels[-1], num_classes) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|         if zero_init_residual: | ||||
|             for m in self.modules(): | ||||
|                 if isinstance(m, ResNetBasicblock): | ||||
|                     nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|                 elif isinstance(m, ResNetBottleneck): | ||||
|                     nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
| @@ -0,0 +1,277 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from ..initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         if has_bn: | ||||
|             self.bn = nn.BatchNorm2d(nOut) | ||||
|         else: | ||||
|             self.bn = None | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.bn: | ||||
|             out = self.bn(conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|  | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     num_conv = 2 | ||||
|     expansion = 1 | ||||
|  | ||||
|     def __init__(self, iCs, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         assert isinstance(iCs, tuple) or isinstance( | ||||
|             iCs, list | ||||
|         ), "invalid type of iCs : {:}".format(iCs) | ||||
|         assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs) | ||||
|  | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             iCs[0], | ||||
|             iCs[1], | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         residual_in = iCs[0] | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[2], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[2] | ||||
|         elif iCs[0] != iCs[2]: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[2], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         # self.out_dim  = max(residual_in, iCs[2]) | ||||
|         self.out_dim = iCs[2] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + basicblock | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|     def __init__(self, iCs, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         assert isinstance(iCs, tuple) or isinstance( | ||||
|             iCs, list | ||||
|         ), "invalid type of iCs : {:}".format(iCs) | ||||
|         assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             iCs[1], | ||||
|             iCs[2], | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         residual_in = iCs[0] | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[3], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[3] | ||||
|         elif iCs[0] != iCs[3]: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[3], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[3] | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         # self.out_dim = max(residual_in, iCs[3]) | ||||
|         self.out_dim = iCs[3] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + bottleneck | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class InferWidthCifarResNet(nn.Module): | ||||
|     def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual): | ||||
|         super(InferWidthCifarResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|  | ||||
|         self.message = ( | ||||
|             "InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.xchannels = xchannels | ||||
|         self.layers = nn.ModuleList( | ||||
|             [ | ||||
|                 ConvBNReLU( | ||||
|                     xchannels[0], | ||||
|                     xchannels[1], | ||||
|                     3, | ||||
|                     1, | ||||
|                     1, | ||||
|                     False, | ||||
|                     has_avg=False, | ||||
|                     has_bn=True, | ||||
|                     has_relu=True, | ||||
|                 ) | ||||
|             ] | ||||
|         ) | ||||
|         last_channel_idx = 1 | ||||
|         for stage in range(3): | ||||
|             for iL in range(layer_blocks): | ||||
|                 num_conv = block.num_conv | ||||
|                 iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1] | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iCs, stride) | ||||
|                 last_channel_idx += num_conv | ||||
|                 self.xchannels[last_channel_idx] = module.out_dim | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iCs, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|  | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(self.xchannels[-1], num_classes) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|         if zero_init_residual: | ||||
|             for m in self.modules(): | ||||
|                 if isinstance(m, ResNetBasicblock): | ||||
|                     nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|                 elif isinstance(m, ResNetBottleneck): | ||||
|                     nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
| @@ -0,0 +1,324 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from ..initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|  | ||||
|     num_conv = 1 | ||||
|  | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         if has_bn: | ||||
|             self.bn = nn.BatchNorm2d(nOut) | ||||
|         else: | ||||
|             self.bn = None | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.bn: | ||||
|             out = self.bn(conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|  | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     num_conv = 2 | ||||
|     expansion = 1 | ||||
|  | ||||
|     def __init__(self, iCs, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         assert isinstance(iCs, tuple) or isinstance( | ||||
|             iCs, list | ||||
|         ), "invalid type of iCs : {:}".format(iCs) | ||||
|         assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs) | ||||
|  | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             iCs[0], | ||||
|             iCs[1], | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         residual_in = iCs[0] | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[2], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[2] | ||||
|         elif iCs[0] != iCs[2]: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[2], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         # self.out_dim  = max(residual_in, iCs[2]) | ||||
|         self.out_dim = iCs[2] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + basicblock | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|     def __init__(self, iCs, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         assert isinstance(iCs, tuple) or isinstance( | ||||
|             iCs, list | ||||
|         ), "invalid type of iCs : {:}".format(iCs) | ||||
|         assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             iCs[1], | ||||
|             iCs[2], | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         residual_in = iCs[0] | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[3], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[3] | ||||
|         elif iCs[0] != iCs[3]: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[3], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[3] | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         # self.out_dim = max(residual_in, iCs[3]) | ||||
|         self.out_dim = iCs[3] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + bottleneck | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class InferImagenetResNet(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         block_name, | ||||
|         layers, | ||||
|         xblocks, | ||||
|         xchannels, | ||||
|         deep_stem, | ||||
|         num_classes, | ||||
|         zero_init_residual, | ||||
|     ): | ||||
|         super(InferImagenetResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "BasicBlock": | ||||
|             block = ResNetBasicblock | ||||
|         elif block_name == "Bottleneck": | ||||
|             block = ResNetBottleneck | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|         assert len(xblocks) == len( | ||||
|             layers | ||||
|         ), "invalid layers : {:} vs xblocks : {:}".format(layers, xblocks) | ||||
|  | ||||
|         self.message = "InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}".format( | ||||
|             sum(layers) * block.num_conv, sum(xblocks) * block.num_conv, xblocks | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.xchannels = xchannels | ||||
|         if not deep_stem: | ||||
|             self.layers = nn.ModuleList( | ||||
|                 [ | ||||
|                     ConvBNReLU( | ||||
|                         xchannels[0], | ||||
|                         xchannels[1], | ||||
|                         7, | ||||
|                         2, | ||||
|                         3, | ||||
|                         False, | ||||
|                         has_avg=False, | ||||
|                         has_bn=True, | ||||
|                         has_relu=True, | ||||
|                     ) | ||||
|                 ] | ||||
|             ) | ||||
|             last_channel_idx = 1 | ||||
|         else: | ||||
|             self.layers = nn.ModuleList( | ||||
|                 [ | ||||
|                     ConvBNReLU( | ||||
|                         xchannels[0], | ||||
|                         xchannels[1], | ||||
|                         3, | ||||
|                         2, | ||||
|                         1, | ||||
|                         False, | ||||
|                         has_avg=False, | ||||
|                         has_bn=True, | ||||
|                         has_relu=True, | ||||
|                     ), | ||||
|                     ConvBNReLU( | ||||
|                         xchannels[1], | ||||
|                         xchannels[2], | ||||
|                         3, | ||||
|                         1, | ||||
|                         1, | ||||
|                         False, | ||||
|                         has_avg=False, | ||||
|                         has_bn=True, | ||||
|                         has_relu=True, | ||||
|                     ), | ||||
|                 ] | ||||
|             ) | ||||
|             last_channel_idx = 2 | ||||
|         self.layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) | ||||
|         for stage, layer_blocks in enumerate(layers): | ||||
|             for iL in range(layer_blocks): | ||||
|                 num_conv = block.num_conv | ||||
|                 iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1] | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iCs, stride) | ||||
|                 last_channel_idx += num_conv | ||||
|                 self.xchannels[last_channel_idx] = module.out_dim | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iCs, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|                 if iL + 1 == xblocks[stage]:  # reach the maximum depth | ||||
|                     out_channel = module.out_dim | ||||
|                     for iiL in range(iL + 1, layer_blocks): | ||||
|                         last_channel_idx += num_conv | ||||
|                     self.xchannels[last_channel_idx] = module.out_dim | ||||
|                     break | ||||
|         assert last_channel_idx + 1 == len(self.xchannels), "{:} vs {:}".format( | ||||
|             last_channel_idx, len(self.xchannels) | ||||
|         ) | ||||
|         self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||||
|         self.classifier = nn.Linear(self.xchannels[-1], num_classes) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|         if zero_init_residual: | ||||
|             for m in self.modules(): | ||||
|                 if isinstance(m, ResNetBasicblock): | ||||
|                     nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|                 elif isinstance(m, ResNetBottleneck): | ||||
|                     nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
							
								
								
									
										176
									
								
								AutoDL-Projects/xautodl/models/shape_infers/InferMobileNetV2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										176
									
								
								AutoDL-Projects/xautodl/models/shape_infers/InferMobileNetV2.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,176 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| # MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018 | ||||
| ##################################################### | ||||
| from torch import nn | ||||
|  | ||||
| from ..initialization import initialize_resnet | ||||
| from ..SharedUtils import parse_channel_info | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_planes, | ||||
|         out_planes, | ||||
|         kernel_size, | ||||
|         stride, | ||||
|         groups, | ||||
|         has_bn=True, | ||||
|         has_relu=True, | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         padding = (kernel_size - 1) // 2 | ||||
|         self.conv = nn.Conv2d( | ||||
|             in_planes, | ||||
|             out_planes, | ||||
|             kernel_size, | ||||
|             stride, | ||||
|             padding, | ||||
|             groups=groups, | ||||
|             bias=False, | ||||
|         ) | ||||
|         if has_bn: | ||||
|             self.bn = nn.BatchNorm2d(out_planes) | ||||
|         else: | ||||
|             self.bn = None | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU6(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|  | ||||
|     def forward(self, x): | ||||
|         out = self.conv(x) | ||||
|         if self.bn: | ||||
|             out = self.bn(out) | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class InvertedResidual(nn.Module): | ||||
|     def __init__(self, channels, stride, expand_ratio, additive): | ||||
|         super(InvertedResidual, self).__init__() | ||||
|         self.stride = stride | ||||
|         assert stride in [1, 2], "invalid stride : {:}".format(stride) | ||||
|         assert len(channels) in [2, 3], "invalid channels : {:}".format(channels) | ||||
|  | ||||
|         if len(channels) == 2: | ||||
|             layers = [] | ||||
|         else: | ||||
|             layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)] | ||||
|         layers.extend( | ||||
|             [ | ||||
|                 # dw | ||||
|                 ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]), | ||||
|                 # pw-linear | ||||
|                 ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False), | ||||
|             ] | ||||
|         ) | ||||
|         self.conv = nn.Sequential(*layers) | ||||
|         self.additive = additive | ||||
|         if self.additive and channels[0] != channels[-1]: | ||||
|             self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False) | ||||
|         else: | ||||
|             self.shortcut = None | ||||
|         self.out_dim = channels[-1] | ||||
|  | ||||
|     def forward(self, x): | ||||
|         out = self.conv(x) | ||||
|         # if self.additive: return additive_func(out, x) | ||||
|         if self.shortcut: | ||||
|             return out + self.shortcut(x) | ||||
|         else: | ||||
|             return out | ||||
|  | ||||
|  | ||||
| class InferMobileNetV2(nn.Module): | ||||
|     def __init__(self, num_classes, xchannels, xblocks, dropout): | ||||
|         super(InferMobileNetV2, self).__init__() | ||||
|         block = InvertedResidual | ||||
|         inverted_residual_setting = [ | ||||
|             # t, c,  n, s | ||||
|             [1, 16, 1, 1], | ||||
|             [6, 24, 2, 2], | ||||
|             [6, 32, 3, 2], | ||||
|             [6, 64, 4, 2], | ||||
|             [6, 96, 3, 1], | ||||
|             [6, 160, 3, 2], | ||||
|             [6, 320, 1, 1], | ||||
|         ] | ||||
|         assert len(inverted_residual_setting) == len( | ||||
|             xblocks | ||||
|         ), "invalid number of layers : {:} vs {:}".format( | ||||
|             len(inverted_residual_setting), len(xblocks) | ||||
|         ) | ||||
|         for block_num, ir_setting in zip(xblocks, inverted_residual_setting): | ||||
|             assert block_num <= ir_setting[2], "{:} vs {:}".format( | ||||
|                 block_num, ir_setting | ||||
|             ) | ||||
|         xchannels = parse_channel_info(xchannels) | ||||
|         # for i, chs in enumerate(xchannels): | ||||
|         #  if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs) | ||||
|         self.xchannels = xchannels | ||||
|         self.message = "InferMobileNetV2 : xblocks={:}".format(xblocks) | ||||
|         # building first layer | ||||
|         features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)] | ||||
|         last_channel_idx = 1 | ||||
|  | ||||
|         # building inverted residual blocks | ||||
|         for stage, (t, c, n, s) in enumerate(inverted_residual_setting): | ||||
|             for i in range(n): | ||||
|                 stride = s if i == 0 else 1 | ||||
|                 additv = True if i > 0 else False | ||||
|                 module = block(self.xchannels[last_channel_idx], stride, t, additv) | ||||
|                 features.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format( | ||||
|                     stage, | ||||
|                     i, | ||||
|                     n, | ||||
|                     len(features), | ||||
|                     self.xchannels[last_channel_idx], | ||||
|                     stride, | ||||
|                     t, | ||||
|                     c, | ||||
|                 ) | ||||
|                 last_channel_idx += 1 | ||||
|                 if i + 1 == xblocks[stage]: | ||||
|                     out_channel = module.out_dim | ||||
|                     for iiL in range(i + 1, n): | ||||
|                         last_channel_idx += 1 | ||||
|                     self.xchannels[last_channel_idx][0] = module.out_dim | ||||
|                     break | ||||
|         # building last several layers | ||||
|         features.append( | ||||
|             ConvBNReLU( | ||||
|                 self.xchannels[last_channel_idx][0], | ||||
|                 self.xchannels[last_channel_idx][1], | ||||
|                 1, | ||||
|                 1, | ||||
|                 1, | ||||
|             ) | ||||
|         ) | ||||
|         assert last_channel_idx + 2 == len(self.xchannels), "{:} vs {:}".format( | ||||
|             last_channel_idx, len(self.xchannels) | ||||
|         ) | ||||
|         # make it nn.Sequential | ||||
|         self.features = nn.Sequential(*features) | ||||
|  | ||||
|         # building classifier | ||||
|         self.classifier = nn.Sequential( | ||||
|             nn.Dropout(dropout), | ||||
|             nn.Linear(self.xchannels[last_channel_idx][1], num_classes), | ||||
|         ) | ||||
|  | ||||
|         # weight initialization | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         features = self.features(inputs) | ||||
|         vectors = features.mean([2, 3]) | ||||
|         predicts = self.classifier(vectors) | ||||
|         return features, predicts | ||||
| @@ -0,0 +1,65 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| from typing import List, Text, Any | ||||
| import torch.nn as nn | ||||
|  | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from ..cell_infers.cells import InferCell | ||||
|  | ||||
|  | ||||
| class DynamicShapeTinyNet(nn.Module): | ||||
|     def __init__(self, channels: List[int], genotype: Any, num_classes: int): | ||||
|         super(DynamicShapeTinyNet, self).__init__() | ||||
|         self._channels = channels | ||||
|         if len(channels) % 3 != 2: | ||||
|             raise ValueError("invalid number of layers : {:}".format(len(channels))) | ||||
|         self._num_stage = N = len(channels) // 3 | ||||
|  | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(channels[0]), | ||||
|         ) | ||||
|  | ||||
|         # layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|         c_prev = channels[0] | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(c_prev, c_curr, 2, True) | ||||
|             else: | ||||
|                 cell = InferCell(genotype, c_prev, c_curr, 1) | ||||
|             self.cells.append(cell) | ||||
|             c_prev = cell.out_dim | ||||
|         self._num_layer = len(self.cells) | ||||
|  | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(c_prev, num_classes) | ||||
|  | ||||
|     def get_message(self) -> Text: | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_channels}, N={_num_stage}, L={_num_layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         feature = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             feature = cell(feature) | ||||
|  | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
							
								
								
									
										9
									
								
								AutoDL-Projects/xautodl/models/shape_infers/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								AutoDL-Projects/xautodl/models/shape_infers/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| from .InferCifarResNet_width import InferWidthCifarResNet | ||||
| from .InferImagenetResNet import InferImagenetResNet | ||||
| from .InferCifarResNet_depth import InferDepthCifarResNet | ||||
| from .InferCifarResNet import InferCifarResNet | ||||
| from .InferMobileNetV2 import InferMobileNetV2 | ||||
| from .InferTinyCellNet import DynamicShapeTinyNet | ||||
| @@ -0,0 +1,5 @@ | ||||
| def parse_channel_info(xstring): | ||||
|     blocks = xstring.split(" ") | ||||
|     blocks = [x.split("-") for x in blocks] | ||||
|     blocks = [[int(_) for _ in x] for x in blocks] | ||||
|     return blocks | ||||
| @@ -0,0 +1,760 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import math, torch | ||||
| from collections import OrderedDict | ||||
| from bisect import bisect_right | ||||
| import torch.nn as nn | ||||
| from ..initialization import initialize_resnet | ||||
| from ..SharedUtils import additive_func | ||||
| from .SoftSelect import select2withP, ChannelWiseInter | ||||
| from .SoftSelect import linear_forward | ||||
| from .SoftSelect import get_width_choices | ||||
|  | ||||
|  | ||||
| def get_depth_choices(nDepth, return_num): | ||||
|     if nDepth == 2: | ||||
|         choices = (1, 2) | ||||
|     elif nDepth == 3: | ||||
|         choices = (1, 2, 3) | ||||
|     elif nDepth > 3: | ||||
|         choices = list(range(1, nDepth + 1, 2)) | ||||
|         if choices[-1] < nDepth: | ||||
|             choices.append(nDepth) | ||||
|     else: | ||||
|         raise ValueError("invalid nDepth : {:}".format(nDepth)) | ||||
|     if return_num: | ||||
|         return len(choices) | ||||
|     else: | ||||
|         return choices | ||||
|  | ||||
|  | ||||
| def conv_forward(inputs, conv, choices): | ||||
|     iC = conv.in_channels | ||||
|     fill_size = list(inputs.size()) | ||||
|     fill_size[1] = iC - fill_size[1] | ||||
|     filled = torch.zeros(fill_size, device=inputs.device) | ||||
|     xinputs = torch.cat((inputs, filled), dim=1) | ||||
|     outputs = conv(xinputs) | ||||
|     selecteds = [outputs[:, :oC] for oC in choices] | ||||
|     return selecteds | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     num_conv = 1 | ||||
|  | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         self.InShape = None | ||||
|         self.OutShape = None | ||||
|         self.choices = get_width_choices(nOut) | ||||
|         self.register_buffer("choices_tensor", torch.Tensor(self.choices)) | ||||
|  | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         # if has_bn  : self.bn  = nn.BatchNorm2d(nOut) | ||||
|         # else       : self.bn  = None | ||||
|         self.has_bn = has_bn | ||||
|         self.BNs = nn.ModuleList() | ||||
|         for i, _out in enumerate(self.choices): | ||||
|             self.BNs.append(nn.BatchNorm2d(_out)) | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|         self.in_dim = nIn | ||||
|         self.out_dim = nOut | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_flops(self, channels, check_range=True, divide=1): | ||||
|         iC, oC = channels | ||||
|         if check_range: | ||||
|             assert ( | ||||
|                 iC <= self.conv.in_channels and oC <= self.conv.out_channels | ||||
|             ), "{:} vs {:}  |  {:} vs {:}".format( | ||||
|                 iC, self.conv.in_channels, oC, self.conv.out_channels | ||||
|             ) | ||||
|         assert ( | ||||
|             isinstance(self.InShape, tuple) and len(self.InShape) == 2 | ||||
|         ), "invalid in-shape : {:}".format(self.InShape) | ||||
|         assert ( | ||||
|             isinstance(self.OutShape, tuple) and len(self.OutShape) == 2 | ||||
|         ), "invalid out-shape : {:}".format(self.OutShape) | ||||
|         # conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups | ||||
|         conv_per_position_flops = ( | ||||
|             self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups | ||||
|         ) | ||||
|         all_positions = self.OutShape[0] * self.OutShape[1] | ||||
|         flops = (conv_per_position_flops * all_positions / divide) * iC * oC | ||||
|         if self.conv.bias is not None: | ||||
|             flops += all_positions / divide | ||||
|         return flops | ||||
|  | ||||
|     def get_range(self): | ||||
|         return [self.choices] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, index, prob = tuple_inputs | ||||
|         index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) | ||||
|         probability = torch.squeeze(probability) | ||||
|         assert len(index) == 2, "invalid length : {:}".format(index) | ||||
|         # compute expected flop | ||||
|         # coordinates   = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability) | ||||
|         expected_outC = (self.choices_tensor * probability).sum() | ||||
|         expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6) | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         # convolutional layer | ||||
|         out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index]) | ||||
|         out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)] | ||||
|         # merge | ||||
|         out_channel = max([x.size(1) for x in out_bns]) | ||||
|         outA = ChannelWiseInter(out_bns[0], out_channel) | ||||
|         outB = ChannelWiseInter(out_bns[1], out_channel) | ||||
|         out = outA * prob[0] + outB * prob[1] | ||||
|         # out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1]) | ||||
|  | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|         return out, expected_outC, expected_flop | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.has_bn: | ||||
|             out = self.BNs[-1](conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|             self.OutShape = (out.size(-2), out.size(-1)) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     expansion = 1 | ||||
|     num_conv = 2 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             inplanes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_range(self): | ||||
|         return self.conv_a.get_range() + self.conv_b.get_range() | ||||
|  | ||||
|     def get_flops(self, channels): | ||||
|         assert len(channels) == 3, "invalid channels : {:}".format(channels) | ||||
|         flop_A = self.conv_a.get_flops([channels[0], channels[1]]) | ||||
|         flop_B = self.conv_b.get_flops([channels[1], channels[2]]) | ||||
|         if hasattr(self.downsample, "get_flops"): | ||||
|             flop_C = self.downsample.get_flops([channels[0], channels[-1]]) | ||||
|         else: | ||||
|             flop_C = 0 | ||||
|         if ( | ||||
|             channels[0] != channels[-1] and self.downsample is None | ||||
|         ):  # this short-cut will be added during the infer-train | ||||
|             flop_C = ( | ||||
|                 channels[0] | ||||
|                 * channels[-1] | ||||
|                 * self.conv_b.OutShape[0] | ||||
|                 * self.conv_b.OutShape[1] | ||||
|             ) | ||||
|         return flop_A + flop_B + flop_C | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, indexes, probs = tuple_inputs | ||||
|         assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2 | ||||
|         out_a, expected_inC_a, expected_flop_a = self.conv_a( | ||||
|             (inputs, expected_inC, probability[0], indexes[0], probs[0]) | ||||
|         ) | ||||
|         out_b, expected_inC_b, expected_flop_b = self.conv_b( | ||||
|             (out_a, expected_inC_a, probability[1], indexes[1], probs[1]) | ||||
|         ) | ||||
|         if self.downsample is not None: | ||||
|             residual, _, expected_flop_c = self.downsample( | ||||
|                 (inputs, expected_inC, probability[1], indexes[1], probs[1]) | ||||
|             ) | ||||
|         else: | ||||
|             residual, expected_flop_c = inputs, 0 | ||||
|         out = additive_func(residual, out_b) | ||||
|         return ( | ||||
|             nn.functional.relu(out, inplace=True), | ||||
|             expected_inC_b, | ||||
|             sum([expected_flop_a, expected_flop_b, expected_flop_c]), | ||||
|         ) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, basicblock) | ||||
|         return nn.functional.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes * self.expansion, | ||||
|             1, | ||||
|             1, | ||||
|             0, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=False, | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes * self.expansion: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes * self.expansion | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_range(self): | ||||
|         return ( | ||||
|             self.conv_1x1.get_range() | ||||
|             + self.conv_3x3.get_range() | ||||
|             + self.conv_1x4.get_range() | ||||
|         ) | ||||
|  | ||||
|     def get_flops(self, channels): | ||||
|         assert len(channels) == 4, "invalid channels : {:}".format(channels) | ||||
|         flop_A = self.conv_1x1.get_flops([channels[0], channels[1]]) | ||||
|         flop_B = self.conv_3x3.get_flops([channels[1], channels[2]]) | ||||
|         flop_C = self.conv_1x4.get_flops([channels[2], channels[3]]) | ||||
|         if hasattr(self.downsample, "get_flops"): | ||||
|             flop_D = self.downsample.get_flops([channels[0], channels[-1]]) | ||||
|         else: | ||||
|             flop_D = 0 | ||||
|         if ( | ||||
|             channels[0] != channels[-1] and self.downsample is None | ||||
|         ):  # this short-cut will be added during the infer-train | ||||
|             flop_D = ( | ||||
|                 channels[0] | ||||
|                 * channels[-1] | ||||
|                 * self.conv_1x4.OutShape[0] | ||||
|                 * self.conv_1x4.OutShape[1] | ||||
|             ) | ||||
|         return flop_A + flop_B + flop_C + flop_D | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, bottleneck) | ||||
|         return nn.functional.relu(out, inplace=True) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, indexes, probs = tuple_inputs | ||||
|         assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3 | ||||
|         out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( | ||||
|             (inputs, expected_inC, probability[0], indexes[0], probs[0]) | ||||
|         ) | ||||
|         out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( | ||||
|             (out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1]) | ||||
|         ) | ||||
|         out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( | ||||
|             (out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2]) | ||||
|         ) | ||||
|         if self.downsample is not None: | ||||
|             residual, _, expected_flop_c = self.downsample( | ||||
|                 (inputs, expected_inC, probability[2], indexes[2], probs[2]) | ||||
|             ) | ||||
|         else: | ||||
|             residual, expected_flop_c = inputs, 0 | ||||
|         out = additive_func(residual, out_1x4) | ||||
|         return ( | ||||
|             nn.functional.relu(out, inplace=True), | ||||
|             expected_inC_1x4, | ||||
|             sum( | ||||
|                 [ | ||||
|                     expected_flop_1x1, | ||||
|                     expected_flop_3x3, | ||||
|                     expected_flop_1x4, | ||||
|                     expected_flop_c, | ||||
|                 ] | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SearchShapeCifarResNet(nn.Module): | ||||
|     def __init__(self, block_name, depth, num_classes): | ||||
|         super(SearchShapeCifarResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|  | ||||
|         self.message = ( | ||||
|             "SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.channels = [16] | ||||
|         self.layers = nn.ModuleList( | ||||
|             [ | ||||
|                 ConvBNReLU( | ||||
|                     3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True | ||||
|                 ) | ||||
|             ] | ||||
|         ) | ||||
|         self.InShape = None | ||||
|         self.depth_info = OrderedDict() | ||||
|         self.depth_at_i = OrderedDict() | ||||
|         for stage in range(3): | ||||
|             cur_block_choices = get_depth_choices(layer_blocks, False) | ||||
|             assert ( | ||||
|                 cur_block_choices[-1] == layer_blocks | ||||
|             ), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks) | ||||
|             self.message += ( | ||||
|                 "\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format( | ||||
|                     stage, cur_block_choices, layer_blocks | ||||
|                 ) | ||||
|             ) | ||||
|             block_choices, xstart = [], len(self.layers) | ||||
|             for iL in range(layer_blocks): | ||||
|                 iC = self.channels[-1] | ||||
|                 planes = 16 * (2 ** stage) | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iC, planes, stride) | ||||
|                 self.channels.append(module.out_dim) | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iC, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|                 # added for depth | ||||
|                 layer_index = len(self.layers) - 1 | ||||
|                 if iL + 1 in cur_block_choices: | ||||
|                     block_choices.append(layer_index) | ||||
|                 if iL + 1 == layer_blocks: | ||||
|                     self.depth_info[layer_index] = { | ||||
|                         "choices": block_choices, | ||||
|                         "stage": stage, | ||||
|                         "xstart": xstart, | ||||
|                     } | ||||
|         self.depth_info_list = [] | ||||
|         for xend, info in self.depth_info.items(): | ||||
|             self.depth_info_list.append((xend, info)) | ||||
|             xstart, xstage = info["xstart"], info["stage"] | ||||
|             for ilayer in range(xstart, xend + 1): | ||||
|                 idx = bisect_right(info["choices"], ilayer - 1) | ||||
|                 self.depth_at_i[ilayer] = (xstage, idx) | ||||
|  | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(module.out_dim, num_classes) | ||||
|         self.InShape = None | ||||
|         self.tau = -1 | ||||
|         self.search_mode = "basic" | ||||
|         # assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) | ||||
|  | ||||
|         # parameters for width | ||||
|         self.Ranges = [] | ||||
|         self.layer2indexRange = [] | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             start_index = len(self.Ranges) | ||||
|             self.Ranges += layer.get_range() | ||||
|             self.layer2indexRange.append((start_index, len(self.Ranges))) | ||||
|         assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format( | ||||
|             len(self.Ranges) + 1, depth | ||||
|         ) | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "width_attentions", | ||||
|             nn.Parameter(torch.Tensor(len(self.Ranges), get_width_choices(None))), | ||||
|         ) | ||||
|         self.register_parameter( | ||||
|             "depth_attentions", | ||||
|             nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True))), | ||||
|         ) | ||||
|         nn.init.normal_(self.width_attentions, 0, 0.01) | ||||
|         nn.init.normal_(self.depth_attentions, 0, 0.01) | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     def arch_parameters(self, LR=None): | ||||
|         if LR is None: | ||||
|             return [self.width_attentions, self.depth_attentions] | ||||
|         else: | ||||
|             return [ | ||||
|                 {"params": self.width_attentions, "lr": LR}, | ||||
|                 {"params": self.depth_attentions, "lr": LR}, | ||||
|             ] | ||||
|  | ||||
|     def base_parameters(self): | ||||
|         return ( | ||||
|             list(self.layers.parameters()) | ||||
|             + list(self.avgpool.parameters()) | ||||
|             + list(self.classifier.parameters()) | ||||
|         ) | ||||
|  | ||||
|     def get_flop(self, mode, config_dict, extra_info): | ||||
|         if config_dict is not None: | ||||
|             config_dict = config_dict.copy() | ||||
|         # select channels | ||||
|         channels = [3] | ||||
|         for i, weight in enumerate(self.width_attentions): | ||||
|             if mode == "genotype": | ||||
|                 with torch.no_grad(): | ||||
|                     probe = nn.functional.softmax(weight, dim=0) | ||||
|                     C = self.Ranges[i][torch.argmax(probe).item()] | ||||
|             elif mode == "max": | ||||
|                 C = self.Ranges[i][-1] | ||||
|             elif mode == "fix": | ||||
|                 C = int(math.sqrt(extra_info) * self.Ranges[i][-1]) | ||||
|             elif mode == "random": | ||||
|                 assert isinstance(extra_info, float), "invalid extra_info : {:}".format( | ||||
|                     extra_info | ||||
|                 ) | ||||
|                 with torch.no_grad(): | ||||
|                     prob = nn.functional.softmax(weight, dim=0) | ||||
|                     approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1]) | ||||
|                     for j in range(prob.size(0)): | ||||
|                         prob[j] = 1 / ( | ||||
|                             abs(j - (approximate_C - self.Ranges[i][j])) + 0.2 | ||||
|                         ) | ||||
|                     C = self.Ranges[i][torch.multinomial(prob, 1, False).item()] | ||||
|             else: | ||||
|                 raise ValueError("invalid mode : {:}".format(mode)) | ||||
|             channels.append(C) | ||||
|         # select depth | ||||
|         if mode == "genotype": | ||||
|             with torch.no_grad(): | ||||
|                 depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) | ||||
|                 choices = torch.argmax(depth_probs, dim=1).cpu().tolist() | ||||
|         elif mode == "max" or mode == "fix": | ||||
|             choices = [depth_probs.size(1) - 1 for _ in range(depth_probs.size(0))] | ||||
|         elif mode == "random": | ||||
|             with torch.no_grad(): | ||||
|                 depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) | ||||
|                 choices = torch.multinomial(depth_probs, 1, False).cpu().tolist() | ||||
|         else: | ||||
|             raise ValueError("invalid mode : {:}".format(mode)) | ||||
|         selected_layers = [] | ||||
|         for choice, xvalue in zip(choices, self.depth_info_list): | ||||
|             xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1 | ||||
|             selected_layers.append(xtemp) | ||||
|         flop = 0 | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             s, e = self.layer2indexRange[i] | ||||
|             xchl = tuple(channels[s : e + 1]) | ||||
|             if i in self.depth_at_i: | ||||
|                 xstagei, xatti = self.depth_at_i[i] | ||||
|                 if xatti <= choices[xstagei]:  # leave this depth | ||||
|                     flop += layer.get_flops(xchl) | ||||
|                 else: | ||||
|                     flop += 0  # do not use this layer | ||||
|             else: | ||||
|                 flop += layer.get_flops(xchl) | ||||
|         # the last fc layer | ||||
|         flop += channels[-1] * self.classifier.out_features | ||||
|         if config_dict is None: | ||||
|             return flop / 1e6 | ||||
|         else: | ||||
|             config_dict["xchannels"] = channels | ||||
|             config_dict["xblocks"] = selected_layers | ||||
|             config_dict["super_type"] = "infer-shape" | ||||
|             config_dict["estimated_FLOP"] = flop / 1e6 | ||||
|             return flop / 1e6, config_dict | ||||
|  | ||||
|     def get_arch_info(self): | ||||
|         string = ( | ||||
|             "for depth and width, there are {:} + {:} attention probabilities.".format( | ||||
|                 len(self.depth_attentions), len(self.width_attentions) | ||||
|             ) | ||||
|         ) | ||||
|         string += "\n{:}".format(self.depth_info) | ||||
|         discrepancy = [] | ||||
|         with torch.no_grad(): | ||||
|             for i, att in enumerate(self.depth_attentions): | ||||
|                 prob = nn.functional.softmax(att, dim=0) | ||||
|                 prob = prob.cpu() | ||||
|                 selc = prob.argmax().item() | ||||
|                 prob = prob.tolist() | ||||
|                 prob = ["{:.3f}".format(x) for x in prob] | ||||
|                 xstring = "{:03d}/{:03d}-th : {:}".format( | ||||
|                     i, len(self.depth_attentions), " ".join(prob) | ||||
|                 ) | ||||
|                 logt = ["{:.4f}".format(x) for x in att.cpu().tolist()] | ||||
|                 xstring += "  ||  {:17s}".format(" ".join(logt)) | ||||
|                 prob = sorted([float(x) for x in prob]) | ||||
|                 disc = prob[-1] - prob[-2] | ||||
|                 xstring += "  || discrepancy={:.2f} || select={:}/{:}".format( | ||||
|                     disc, selc, len(prob) | ||||
|                 ) | ||||
|                 discrepancy.append(disc) | ||||
|                 string += "\n{:}".format(xstring) | ||||
|             string += "\n-----------------------------------------------" | ||||
|             for i, att in enumerate(self.width_attentions): | ||||
|                 prob = nn.functional.softmax(att, dim=0) | ||||
|                 prob = prob.cpu() | ||||
|                 selc = prob.argmax().item() | ||||
|                 prob = prob.tolist() | ||||
|                 prob = ["{:.3f}".format(x) for x in prob] | ||||
|                 xstring = "{:03d}/{:03d}-th : {:}".format( | ||||
|                     i, len(self.width_attentions), " ".join(prob) | ||||
|                 ) | ||||
|                 logt = ["{:.3f}".format(x) for x in att.cpu().tolist()] | ||||
|                 xstring += "  ||  {:52s}".format(" ".join(logt)) | ||||
|                 prob = sorted([float(x) for x in prob]) | ||||
|                 disc = prob[-1] - prob[-2] | ||||
|                 xstring += "  || dis={:.2f} || select={:}/{:}".format( | ||||
|                     disc, selc, len(prob) | ||||
|                 ) | ||||
|                 discrepancy.append(disc) | ||||
|                 string += "\n{:}".format(xstring) | ||||
|         return string, discrepancy | ||||
|  | ||||
|     def set_tau(self, tau_max, tau_min, epoch_ratio): | ||||
|         assert ( | ||||
|             epoch_ratio >= 0 and epoch_ratio <= 1 | ||||
|         ), "invalid epoch-ratio : {:}".format(epoch_ratio) | ||||
|         tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 | ||||
|         self.tau = tau | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, inputs): | ||||
|         flop_width_probs = nn.functional.softmax(self.width_attentions, dim=1) | ||||
|         flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) | ||||
|         flop_depth_probs = torch.flip( | ||||
|             torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1] | ||||
|         ) | ||||
|         selected_widths, selected_width_probs = select2withP( | ||||
|             self.width_attentions, self.tau | ||||
|         ) | ||||
|         selected_depth_probs = select2withP(self.depth_attentions, self.tau, True) | ||||
|         with torch.no_grad(): | ||||
|             selected_widths = selected_widths.cpu() | ||||
|  | ||||
|         x, last_channel_idx, expected_inC, flops = inputs, 0, 3, [] | ||||
|         feature_maps = [] | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             selected_w_index = selected_widths[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             selected_w_probs = selected_width_probs[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             layer_prob = flop_width_probs[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             x, expected_inC, expected_flop = layer( | ||||
|                 (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) | ||||
|             ) | ||||
|             feature_maps.append(x) | ||||
|             last_channel_idx += layer.num_conv | ||||
|             if i in self.depth_info:  # aggregate the information | ||||
|                 choices = self.depth_info[i]["choices"] | ||||
|                 xstagei = self.depth_info[i]["stage"] | ||||
|                 # print ('iL={:}, choices={:}, stage={:}, probs={:}'.format(i, choices, xstagei, selected_depth_probs[xstagei].cpu().tolist())) | ||||
|                 # for A, W in zip(choices, selected_depth_probs[xstagei]): | ||||
|                 #  print('Size = {:}, W = {:}'.format(feature_maps[A].size(), W)) | ||||
|                 possible_tensors = [] | ||||
|                 max_C = max(feature_maps[A].size(1) for A in choices) | ||||
|                 for tempi, A in enumerate(choices): | ||||
|                     xtensor = ChannelWiseInter(feature_maps[A], max_C) | ||||
|                     # drop_ratio = 1-(tempi+1.0)/len(choices) | ||||
|                     # xtensor = drop_path(xtensor, drop_ratio) | ||||
|                     possible_tensors.append(xtensor) | ||||
|                 weighted_sum = sum( | ||||
|                     xtensor * W | ||||
|                     for xtensor, W in zip( | ||||
|                         possible_tensors, selected_depth_probs[xstagei] | ||||
|                     ) | ||||
|                 ) | ||||
|                 x = weighted_sum | ||||
|  | ||||
|             if i in self.depth_at_i: | ||||
|                 xstagei, xatti = self.depth_at_i[i] | ||||
|                 x_expected_flop = flop_depth_probs[xstagei, xatti] * expected_flop | ||||
|             else: | ||||
|                 x_expected_flop = expected_flop | ||||
|             flops.append(x_expected_flop) | ||||
|         flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6)) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = linear_forward(features, self.classifier) | ||||
|         return logits, torch.stack([sum(flops)]) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
| @@ -0,0 +1,515 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import math, torch | ||||
| from collections import OrderedDict | ||||
| from bisect import bisect_right | ||||
| import torch.nn as nn | ||||
| from ..initialization import initialize_resnet | ||||
| from ..SharedUtils import additive_func | ||||
| from .SoftSelect import select2withP, ChannelWiseInter | ||||
| from .SoftSelect import linear_forward | ||||
| from .SoftSelect import get_width_choices | ||||
|  | ||||
|  | ||||
| def get_depth_choices(nDepth, return_num): | ||||
|     if nDepth == 2: | ||||
|         choices = (1, 2) | ||||
|     elif nDepth == 3: | ||||
|         choices = (1, 2, 3) | ||||
|     elif nDepth > 3: | ||||
|         choices = list(range(1, nDepth + 1, 2)) | ||||
|         if choices[-1] < nDepth: | ||||
|             choices.append(nDepth) | ||||
|     else: | ||||
|         raise ValueError("invalid nDepth : {:}".format(nDepth)) | ||||
|     if return_num: | ||||
|         return len(choices) | ||||
|     else: | ||||
|         return choices | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     num_conv = 1 | ||||
|  | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         self.InShape = None | ||||
|         self.OutShape = None | ||||
|         self.choices = get_width_choices(nOut) | ||||
|         self.register_buffer("choices_tensor", torch.Tensor(self.choices)) | ||||
|  | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         if has_bn: | ||||
|             self.bn = nn.BatchNorm2d(nOut) | ||||
|         else: | ||||
|             self.bn = None | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=False) | ||||
|         else: | ||||
|             self.relu = None | ||||
|         self.in_dim = nIn | ||||
|         self.out_dim = nOut | ||||
|  | ||||
|     def get_flops(self, divide=1): | ||||
|         iC, oC = self.in_dim, self.out_dim | ||||
|         assert ( | ||||
|             iC <= self.conv.in_channels and oC <= self.conv.out_channels | ||||
|         ), "{:} vs {:}  |  {:} vs {:}".format( | ||||
|             iC, self.conv.in_channels, oC, self.conv.out_channels | ||||
|         ) | ||||
|         assert ( | ||||
|             isinstance(self.InShape, tuple) and len(self.InShape) == 2 | ||||
|         ), "invalid in-shape : {:}".format(self.InShape) | ||||
|         assert ( | ||||
|             isinstance(self.OutShape, tuple) and len(self.OutShape) == 2 | ||||
|         ), "invalid out-shape : {:}".format(self.OutShape) | ||||
|         # conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups | ||||
|         conv_per_position_flops = ( | ||||
|             self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups | ||||
|         ) | ||||
|         all_positions = self.OutShape[0] * self.OutShape[1] | ||||
|         flops = (conv_per_position_flops * all_positions / divide) * iC * oC | ||||
|         if self.conv.bias is not None: | ||||
|             flops += all_positions / divide | ||||
|         return flops | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.bn: | ||||
|             out = self.bn(conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|             self.OutShape = (out.size(-2), out.size(-1)) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     expansion = 1 | ||||
|     num_conv = 2 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             inplanes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_flops(self, divide=1): | ||||
|         flop_A = self.conv_a.get_flops(divide) | ||||
|         flop_B = self.conv_b.get_flops(divide) | ||||
|         if hasattr(self.downsample, "get_flops"): | ||||
|             flop_C = self.downsample.get_flops(divide) | ||||
|         else: | ||||
|             flop_C = 0 | ||||
|         return flop_A + flop_B + flop_C | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, basicblock) | ||||
|         return nn.functional.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes * self.expansion, | ||||
|             1, | ||||
|             1, | ||||
|             0, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=False, | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes * self.expansion: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes * self.expansion | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_range(self): | ||||
|         return ( | ||||
|             self.conv_1x1.get_range() | ||||
|             + self.conv_3x3.get_range() | ||||
|             + self.conv_1x4.get_range() | ||||
|         ) | ||||
|  | ||||
|     def get_flops(self, divide): | ||||
|         flop_A = self.conv_1x1.get_flops(divide) | ||||
|         flop_B = self.conv_3x3.get_flops(divide) | ||||
|         flop_C = self.conv_1x4.get_flops(divide) | ||||
|         if hasattr(self.downsample, "get_flops"): | ||||
|             flop_D = self.downsample.get_flops(divide) | ||||
|         else: | ||||
|             flop_D = 0 | ||||
|         return flop_A + flop_B + flop_C + flop_D | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, bottleneck) | ||||
|         return nn.functional.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class SearchDepthCifarResNet(nn.Module): | ||||
|     def __init__(self, block_name, depth, num_classes): | ||||
|         super(SearchDepthCifarResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|  | ||||
|         self.message = ( | ||||
|             "SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.channels = [16] | ||||
|         self.layers = nn.ModuleList( | ||||
|             [ | ||||
|                 ConvBNReLU( | ||||
|                     3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True | ||||
|                 ) | ||||
|             ] | ||||
|         ) | ||||
|         self.InShape = None | ||||
|         self.depth_info = OrderedDict() | ||||
|         self.depth_at_i = OrderedDict() | ||||
|         for stage in range(3): | ||||
|             cur_block_choices = get_depth_choices(layer_blocks, False) | ||||
|             assert ( | ||||
|                 cur_block_choices[-1] == layer_blocks | ||||
|             ), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks) | ||||
|             self.message += ( | ||||
|                 "\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format( | ||||
|                     stage, cur_block_choices, layer_blocks | ||||
|                 ) | ||||
|             ) | ||||
|             block_choices, xstart = [], len(self.layers) | ||||
|             for iL in range(layer_blocks): | ||||
|                 iC = self.channels[-1] | ||||
|                 planes = 16 * (2 ** stage) | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iC, planes, stride) | ||||
|                 self.channels.append(module.out_dim) | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iC, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|                 # added for depth | ||||
|                 layer_index = len(self.layers) - 1 | ||||
|                 if iL + 1 in cur_block_choices: | ||||
|                     block_choices.append(layer_index) | ||||
|                 if iL + 1 == layer_blocks: | ||||
|                     self.depth_info[layer_index] = { | ||||
|                         "choices": block_choices, | ||||
|                         "stage": stage, | ||||
|                         "xstart": xstart, | ||||
|                     } | ||||
|         self.depth_info_list = [] | ||||
|         for xend, info in self.depth_info.items(): | ||||
|             self.depth_info_list.append((xend, info)) | ||||
|             xstart, xstage = info["xstart"], info["stage"] | ||||
|             for ilayer in range(xstart, xend + 1): | ||||
|                 idx = bisect_right(info["choices"], ilayer - 1) | ||||
|                 self.depth_at_i[ilayer] = (xstage, idx) | ||||
|  | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(module.out_dim, num_classes) | ||||
|         self.InShape = None | ||||
|         self.tau = -1 | ||||
|         self.search_mode = "basic" | ||||
|         # assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "depth_attentions", | ||||
|             nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True))), | ||||
|         ) | ||||
|         nn.init.normal_(self.depth_attentions, 0, 0.01) | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     def arch_parameters(self): | ||||
|         return [self.depth_attentions] | ||||
|  | ||||
|     def base_parameters(self): | ||||
|         return ( | ||||
|             list(self.layers.parameters()) | ||||
|             + list(self.avgpool.parameters()) | ||||
|             + list(self.classifier.parameters()) | ||||
|         ) | ||||
|  | ||||
|     def get_flop(self, mode, config_dict, extra_info): | ||||
|         if config_dict is not None: | ||||
|             config_dict = config_dict.copy() | ||||
|         # select depth | ||||
|         if mode == "genotype": | ||||
|             with torch.no_grad(): | ||||
|                 depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) | ||||
|                 choices = torch.argmax(depth_probs, dim=1).cpu().tolist() | ||||
|         elif mode == "max": | ||||
|             choices = [depth_probs.size(1) - 1 for _ in range(depth_probs.size(0))] | ||||
|         elif mode == "random": | ||||
|             with torch.no_grad(): | ||||
|                 depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) | ||||
|                 choices = torch.multinomial(depth_probs, 1, False).cpu().tolist() | ||||
|         else: | ||||
|             raise ValueError("invalid mode : {:}".format(mode)) | ||||
|         selected_layers = [] | ||||
|         for choice, xvalue in zip(choices, self.depth_info_list): | ||||
|             xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1 | ||||
|             selected_layers.append(xtemp) | ||||
|         flop = 0 | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             if i in self.depth_at_i: | ||||
|                 xstagei, xatti = self.depth_at_i[i] | ||||
|                 if xatti <= choices[xstagei]:  # leave this depth | ||||
|                     flop += layer.get_flops() | ||||
|                 else: | ||||
|                     flop += 0  # do not use this layer | ||||
|             else: | ||||
|                 flop += layer.get_flops() | ||||
|         # the last fc layer | ||||
|         flop += self.classifier.in_features * self.classifier.out_features | ||||
|         if config_dict is None: | ||||
|             return flop / 1e6 | ||||
|         else: | ||||
|             config_dict["xblocks"] = selected_layers | ||||
|             config_dict["super_type"] = "infer-depth" | ||||
|             config_dict["estimated_FLOP"] = flop / 1e6 | ||||
|             return flop / 1e6, config_dict | ||||
|  | ||||
|     def get_arch_info(self): | ||||
|         string = "for depth, there are {:} attention probabilities.".format( | ||||
|             len(self.depth_attentions) | ||||
|         ) | ||||
|         string += "\n{:}".format(self.depth_info) | ||||
|         discrepancy = [] | ||||
|         with torch.no_grad(): | ||||
|             for i, att in enumerate(self.depth_attentions): | ||||
|                 prob = nn.functional.softmax(att, dim=0) | ||||
|                 prob = prob.cpu() | ||||
|                 selc = prob.argmax().item() | ||||
|                 prob = prob.tolist() | ||||
|                 prob = ["{:.3f}".format(x) for x in prob] | ||||
|                 xstring = "{:03d}/{:03d}-th : {:}".format( | ||||
|                     i, len(self.depth_attentions), " ".join(prob) | ||||
|                 ) | ||||
|                 logt = ["{:.4f}".format(x) for x in att.cpu().tolist()] | ||||
|                 xstring += "  ||  {:17s}".format(" ".join(logt)) | ||||
|                 prob = sorted([float(x) for x in prob]) | ||||
|                 disc = prob[-1] - prob[-2] | ||||
|                 xstring += "  || discrepancy={:.2f} || select={:}/{:}".format( | ||||
|                     disc, selc, len(prob) | ||||
|                 ) | ||||
|                 discrepancy.append(disc) | ||||
|                 string += "\n{:}".format(xstring) | ||||
|         return string, discrepancy | ||||
|  | ||||
|     def set_tau(self, tau_max, tau_min, epoch_ratio): | ||||
|         assert ( | ||||
|             epoch_ratio >= 0 and epoch_ratio <= 1 | ||||
|         ), "invalid epoch-ratio : {:}".format(epoch_ratio) | ||||
|         tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 | ||||
|         self.tau = tau | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, inputs): | ||||
|         flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) | ||||
|         flop_depth_probs = torch.flip( | ||||
|             torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1] | ||||
|         ) | ||||
|         selected_depth_probs = select2withP(self.depth_attentions, self.tau, True) | ||||
|  | ||||
|         x, flops = inputs, [] | ||||
|         feature_maps = [] | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             layer_i = layer(x) | ||||
|             feature_maps.append(layer_i) | ||||
|             if i in self.depth_info:  # aggregate the information | ||||
|                 choices = self.depth_info[i]["choices"] | ||||
|                 xstagei = self.depth_info[i]["stage"] | ||||
|                 possible_tensors = [] | ||||
|                 for tempi, A in enumerate(choices): | ||||
|                     xtensor = feature_maps[A] | ||||
|                     possible_tensors.append(xtensor) | ||||
|                 weighted_sum = sum( | ||||
|                     xtensor * W | ||||
|                     for xtensor, W in zip( | ||||
|                         possible_tensors, selected_depth_probs[xstagei] | ||||
|                     ) | ||||
|                 ) | ||||
|                 x = weighted_sum | ||||
|             else: | ||||
|                 x = layer_i | ||||
|  | ||||
|             if i in self.depth_at_i: | ||||
|                 xstagei, xatti = self.depth_at_i[i] | ||||
|                 # print ('layer-{:03d}, stage={:}, att={:}, prob={:}, flop={:}'.format(i, xstagei, xatti, flop_depth_probs[xstagei, xatti].item(), layer.get_flops(1e6))) | ||||
|                 x_expected_flop = flop_depth_probs[xstagei, xatti] * layer.get_flops( | ||||
|                     1e6 | ||||
|                 ) | ||||
|             else: | ||||
|                 x_expected_flop = layer.get_flops(1e6) | ||||
|             flops.append(x_expected_flop) | ||||
|         flops.append( | ||||
|             (self.classifier.in_features * self.classifier.out_features * 1.0 / 1e6) | ||||
|         ) | ||||
|  | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = linear_forward(features, self.classifier) | ||||
|         return logits, torch.stack([sum(flops)]) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
| @@ -0,0 +1,619 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import math, torch | ||||
| import torch.nn as nn | ||||
| from ..initialization import initialize_resnet | ||||
| from ..SharedUtils import additive_func | ||||
| from .SoftSelect import select2withP, ChannelWiseInter | ||||
| from .SoftSelect import linear_forward | ||||
| from .SoftSelect import get_width_choices as get_choices | ||||
|  | ||||
|  | ||||
| def conv_forward(inputs, conv, choices): | ||||
|     iC = conv.in_channels | ||||
|     fill_size = list(inputs.size()) | ||||
|     fill_size[1] = iC - fill_size[1] | ||||
|     filled = torch.zeros(fill_size, device=inputs.device) | ||||
|     xinputs = torch.cat((inputs, filled), dim=1) | ||||
|     outputs = conv(xinputs) | ||||
|     selecteds = [outputs[:, :oC] for oC in choices] | ||||
|     return selecteds | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     num_conv = 1 | ||||
|  | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         self.InShape = None | ||||
|         self.OutShape = None | ||||
|         self.choices = get_choices(nOut) | ||||
|         self.register_buffer("choices_tensor", torch.Tensor(self.choices)) | ||||
|  | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         # if has_bn  : self.bn  = nn.BatchNorm2d(nOut) | ||||
|         # else       : self.bn  = None | ||||
|         self.has_bn = has_bn | ||||
|         self.BNs = nn.ModuleList() | ||||
|         for i, _out in enumerate(self.choices): | ||||
|             self.BNs.append(nn.BatchNorm2d(_out)) | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|         self.in_dim = nIn | ||||
|         self.out_dim = nOut | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_flops(self, channels, check_range=True, divide=1): | ||||
|         iC, oC = channels | ||||
|         if check_range: | ||||
|             assert ( | ||||
|                 iC <= self.conv.in_channels and oC <= self.conv.out_channels | ||||
|             ), "{:} vs {:}  |  {:} vs {:}".format( | ||||
|                 iC, self.conv.in_channels, oC, self.conv.out_channels | ||||
|             ) | ||||
|         assert ( | ||||
|             isinstance(self.InShape, tuple) and len(self.InShape) == 2 | ||||
|         ), "invalid in-shape : {:}".format(self.InShape) | ||||
|         assert ( | ||||
|             isinstance(self.OutShape, tuple) and len(self.OutShape) == 2 | ||||
|         ), "invalid out-shape : {:}".format(self.OutShape) | ||||
|         # conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups | ||||
|         conv_per_position_flops = ( | ||||
|             self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups | ||||
|         ) | ||||
|         all_positions = self.OutShape[0] * self.OutShape[1] | ||||
|         flops = (conv_per_position_flops * all_positions / divide) * iC * oC | ||||
|         if self.conv.bias is not None: | ||||
|             flops += all_positions / divide | ||||
|         return flops | ||||
|  | ||||
|     def get_range(self): | ||||
|         return [self.choices] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, index, prob = tuple_inputs | ||||
|         index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) | ||||
|         probability = torch.squeeze(probability) | ||||
|         assert len(index) == 2, "invalid length : {:}".format(index) | ||||
|         # compute expected flop | ||||
|         # coordinates   = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability) | ||||
|         expected_outC = (self.choices_tensor * probability).sum() | ||||
|         expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6) | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         # convolutional layer | ||||
|         out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index]) | ||||
|         out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)] | ||||
|         # merge | ||||
|         out_channel = max([x.size(1) for x in out_bns]) | ||||
|         outA = ChannelWiseInter(out_bns[0], out_channel) | ||||
|         outB = ChannelWiseInter(out_bns[1], out_channel) | ||||
|         out = outA * prob[0] + outB * prob[1] | ||||
|         # out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1]) | ||||
|  | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|         return out, expected_outC, expected_flop | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.has_bn: | ||||
|             out = self.BNs[-1](conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|             self.OutShape = (out.size(-2), out.size(-1)) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     expansion = 1 | ||||
|     num_conv = 2 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             inplanes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_range(self): | ||||
|         return self.conv_a.get_range() + self.conv_b.get_range() | ||||
|  | ||||
|     def get_flops(self, channels): | ||||
|         assert len(channels) == 3, "invalid channels : {:}".format(channels) | ||||
|         flop_A = self.conv_a.get_flops([channels[0], channels[1]]) | ||||
|         flop_B = self.conv_b.get_flops([channels[1], channels[2]]) | ||||
|         if hasattr(self.downsample, "get_flops"): | ||||
|             flop_C = self.downsample.get_flops([channels[0], channels[-1]]) | ||||
|         else: | ||||
|             flop_C = 0 | ||||
|         if ( | ||||
|             channels[0] != channels[-1] and self.downsample is None | ||||
|         ):  # this short-cut will be added during the infer-train | ||||
|             flop_C = ( | ||||
|                 channels[0] | ||||
|                 * channels[-1] | ||||
|                 * self.conv_b.OutShape[0] | ||||
|                 * self.conv_b.OutShape[1] | ||||
|             ) | ||||
|         return flop_A + flop_B + flop_C | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, indexes, probs = tuple_inputs | ||||
|         assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2 | ||||
|         out_a, expected_inC_a, expected_flop_a = self.conv_a( | ||||
|             (inputs, expected_inC, probability[0], indexes[0], probs[0]) | ||||
|         ) | ||||
|         out_b, expected_inC_b, expected_flop_b = self.conv_b( | ||||
|             (out_a, expected_inC_a, probability[1], indexes[1], probs[1]) | ||||
|         ) | ||||
|         if self.downsample is not None: | ||||
|             residual, _, expected_flop_c = self.downsample( | ||||
|                 (inputs, expected_inC, probability[1], indexes[1], probs[1]) | ||||
|             ) | ||||
|         else: | ||||
|             residual, expected_flop_c = inputs, 0 | ||||
|         out = additive_func(residual, out_b) | ||||
|         return ( | ||||
|             nn.functional.relu(out, inplace=True), | ||||
|             expected_inC_b, | ||||
|             sum([expected_flop_a, expected_flop_b, expected_flop_c]), | ||||
|         ) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, basicblock) | ||||
|         return nn.functional.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes * self.expansion, | ||||
|             1, | ||||
|             1, | ||||
|             0, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=False, | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes * self.expansion: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes * self.expansion | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_range(self): | ||||
|         return ( | ||||
|             self.conv_1x1.get_range() | ||||
|             + self.conv_3x3.get_range() | ||||
|             + self.conv_1x4.get_range() | ||||
|         ) | ||||
|  | ||||
|     def get_flops(self, channels): | ||||
|         assert len(channels) == 4, "invalid channels : {:}".format(channels) | ||||
|         flop_A = self.conv_1x1.get_flops([channels[0], channels[1]]) | ||||
|         flop_B = self.conv_3x3.get_flops([channels[1], channels[2]]) | ||||
|         flop_C = self.conv_1x4.get_flops([channels[2], channels[3]]) | ||||
|         if hasattr(self.downsample, "get_flops"): | ||||
|             flop_D = self.downsample.get_flops([channels[0], channels[-1]]) | ||||
|         else: | ||||
|             flop_D = 0 | ||||
|         if ( | ||||
|             channels[0] != channels[-1] and self.downsample is None | ||||
|         ):  # this short-cut will be added during the infer-train | ||||
|             flop_D = ( | ||||
|                 channels[0] | ||||
|                 * channels[-1] | ||||
|                 * self.conv_1x4.OutShape[0] | ||||
|                 * self.conv_1x4.OutShape[1] | ||||
|             ) | ||||
|         return flop_A + flop_B + flop_C + flop_D | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, bottleneck) | ||||
|         return nn.functional.relu(out, inplace=True) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, indexes, probs = tuple_inputs | ||||
|         assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3 | ||||
|         out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( | ||||
|             (inputs, expected_inC, probability[0], indexes[0], probs[0]) | ||||
|         ) | ||||
|         out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( | ||||
|             (out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1]) | ||||
|         ) | ||||
|         out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( | ||||
|             (out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2]) | ||||
|         ) | ||||
|         if self.downsample is not None: | ||||
|             residual, _, expected_flop_c = self.downsample( | ||||
|                 (inputs, expected_inC, probability[2], indexes[2], probs[2]) | ||||
|             ) | ||||
|         else: | ||||
|             residual, expected_flop_c = inputs, 0 | ||||
|         out = additive_func(residual, out_1x4) | ||||
|         return ( | ||||
|             nn.functional.relu(out, inplace=True), | ||||
|             expected_inC_1x4, | ||||
|             sum( | ||||
|                 [ | ||||
|                     expected_flop_1x1, | ||||
|                     expected_flop_3x3, | ||||
|                     expected_flop_1x4, | ||||
|                     expected_flop_c, | ||||
|                 ] | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SearchWidthCifarResNet(nn.Module): | ||||
|     def __init__(self, block_name, depth, num_classes): | ||||
|         super(SearchWidthCifarResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|  | ||||
|         self.message = ( | ||||
|             "SearchWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.channels = [16] | ||||
|         self.layers = nn.ModuleList( | ||||
|             [ | ||||
|                 ConvBNReLU( | ||||
|                     3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True | ||||
|                 ) | ||||
|             ] | ||||
|         ) | ||||
|         self.InShape = None | ||||
|         for stage in range(3): | ||||
|             for iL in range(layer_blocks): | ||||
|                 iC = self.channels[-1] | ||||
|                 planes = 16 * (2 ** stage) | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iC, planes, stride) | ||||
|                 self.channels.append(module.out_dim) | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iC, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|  | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(module.out_dim, num_classes) | ||||
|         self.InShape = None | ||||
|         self.tau = -1 | ||||
|         self.search_mode = "basic" | ||||
|         # assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) | ||||
|  | ||||
|         # parameters for width | ||||
|         self.Ranges = [] | ||||
|         self.layer2indexRange = [] | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             start_index = len(self.Ranges) | ||||
|             self.Ranges += layer.get_range() | ||||
|             self.layer2indexRange.append((start_index, len(self.Ranges))) | ||||
|         assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format( | ||||
|             len(self.Ranges) + 1, depth | ||||
|         ) | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "width_attentions", | ||||
|             nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))), | ||||
|         ) | ||||
|         nn.init.normal_(self.width_attentions, 0, 0.01) | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     def arch_parameters(self): | ||||
|         return [self.width_attentions] | ||||
|  | ||||
|     def base_parameters(self): | ||||
|         return ( | ||||
|             list(self.layers.parameters()) | ||||
|             + list(self.avgpool.parameters()) | ||||
|             + list(self.classifier.parameters()) | ||||
|         ) | ||||
|  | ||||
|     def get_flop(self, mode, config_dict, extra_info): | ||||
|         if config_dict is not None: | ||||
|             config_dict = config_dict.copy() | ||||
|         # weights = [F.softmax(x, dim=0) for x in self.width_attentions] | ||||
|         channels = [3] | ||||
|         for i, weight in enumerate(self.width_attentions): | ||||
|             if mode == "genotype": | ||||
|                 with torch.no_grad(): | ||||
|                     probe = nn.functional.softmax(weight, dim=0) | ||||
|                     C = self.Ranges[i][torch.argmax(probe).item()] | ||||
|             elif mode == "max": | ||||
|                 C = self.Ranges[i][-1] | ||||
|             elif mode == "fix": | ||||
|                 C = int(math.sqrt(extra_info) * self.Ranges[i][-1]) | ||||
|             elif mode == "random": | ||||
|                 assert isinstance(extra_info, float), "invalid extra_info : {:}".format( | ||||
|                     extra_info | ||||
|                 ) | ||||
|                 with torch.no_grad(): | ||||
|                     prob = nn.functional.softmax(weight, dim=0) | ||||
|                     approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1]) | ||||
|                     for j in range(prob.size(0)): | ||||
|                         prob[j] = 1 / ( | ||||
|                             abs(j - (approximate_C - self.Ranges[i][j])) + 0.2 | ||||
|                         ) | ||||
|                     C = self.Ranges[i][torch.multinomial(prob, 1, False).item()] | ||||
|             else: | ||||
|                 raise ValueError("invalid mode : {:}".format(mode)) | ||||
|             channels.append(C) | ||||
|         flop = 0 | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             s, e = self.layer2indexRange[i] | ||||
|             xchl = tuple(channels[s : e + 1]) | ||||
|             flop += layer.get_flops(xchl) | ||||
|         # the last fc layer | ||||
|         flop += channels[-1] * self.classifier.out_features | ||||
|         if config_dict is None: | ||||
|             return flop / 1e6 | ||||
|         else: | ||||
|             config_dict["xchannels"] = channels | ||||
|             config_dict["super_type"] = "infer-width" | ||||
|             config_dict["estimated_FLOP"] = flop / 1e6 | ||||
|             return flop / 1e6, config_dict | ||||
|  | ||||
|     def get_arch_info(self): | ||||
|         string = "for width, there are {:} attention probabilities.".format( | ||||
|             len(self.width_attentions) | ||||
|         ) | ||||
|         discrepancy = [] | ||||
|         with torch.no_grad(): | ||||
|             for i, att in enumerate(self.width_attentions): | ||||
|                 prob = nn.functional.softmax(att, dim=0) | ||||
|                 prob = prob.cpu() | ||||
|                 selc = prob.argmax().item() | ||||
|                 prob = prob.tolist() | ||||
|                 prob = ["{:.3f}".format(x) for x in prob] | ||||
|                 xstring = "{:03d}/{:03d}-th : {:}".format( | ||||
|                     i, len(self.width_attentions), " ".join(prob) | ||||
|                 ) | ||||
|                 logt = ["{:.3f}".format(x) for x in att.cpu().tolist()] | ||||
|                 xstring += "  ||  {:52s}".format(" ".join(logt)) | ||||
|                 prob = sorted([float(x) for x in prob]) | ||||
|                 disc = prob[-1] - prob[-2] | ||||
|                 xstring += "  || dis={:.2f} || select={:}/{:}".format( | ||||
|                     disc, selc, len(prob) | ||||
|                 ) | ||||
|                 discrepancy.append(disc) | ||||
|                 string += "\n{:}".format(xstring) | ||||
|         return string, discrepancy | ||||
|  | ||||
|     def set_tau(self, tau_max, tau_min, epoch_ratio): | ||||
|         assert ( | ||||
|             epoch_ratio >= 0 and epoch_ratio <= 1 | ||||
|         ), "invalid epoch-ratio : {:}".format(epoch_ratio) | ||||
|         tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 | ||||
|         self.tau = tau | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, inputs): | ||||
|         flop_probs = nn.functional.softmax(self.width_attentions, dim=1) | ||||
|         selected_widths, selected_probs = select2withP(self.width_attentions, self.tau) | ||||
|         with torch.no_grad(): | ||||
|             selected_widths = selected_widths.cpu() | ||||
|  | ||||
|         x, last_channel_idx, expected_inC, flops = inputs, 0, 3, [] | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             selected_w_index = selected_widths[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             selected_w_probs = selected_probs[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             layer_prob = flop_probs[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             x, expected_inC, expected_flop = layer( | ||||
|                 (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) | ||||
|             ) | ||||
|             last_channel_idx += layer.num_conv | ||||
|             flops.append(expected_flop) | ||||
|         flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6)) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = linear_forward(features, self.classifier) | ||||
|         return logits, torch.stack([sum(flops)]) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
| @@ -0,0 +1,766 @@ | ||||
| import math, torch | ||||
| from collections import OrderedDict | ||||
| from bisect import bisect_right | ||||
| import torch.nn as nn | ||||
| from ..initialization import initialize_resnet | ||||
| from ..SharedUtils import additive_func | ||||
| from .SoftSelect import select2withP, ChannelWiseInter | ||||
| from .SoftSelect import linear_forward | ||||
| from .SoftSelect import get_width_choices | ||||
|  | ||||
|  | ||||
| def get_depth_choices(layers): | ||||
|     min_depth = min(layers) | ||||
|     info = {"num": min_depth} | ||||
|     for i, depth in enumerate(layers): | ||||
|         choices = [] | ||||
|         for j in range(1, min_depth + 1): | ||||
|             choices.append(int(float(depth) * j / min_depth)) | ||||
|         info[i] = choices | ||||
|     return info | ||||
|  | ||||
|  | ||||
| def conv_forward(inputs, conv, choices): | ||||
|     iC = conv.in_channels | ||||
|     fill_size = list(inputs.size()) | ||||
|     fill_size[1] = iC - fill_size[1] | ||||
|     filled = torch.zeros(fill_size, device=inputs.device) | ||||
|     xinputs = torch.cat((inputs, filled), dim=1) | ||||
|     outputs = conv(xinputs) | ||||
|     selecteds = [outputs[:, :oC] for oC in choices] | ||||
|     return selecteds | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     num_conv = 1 | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         nIn, | ||||
|         nOut, | ||||
|         kernel, | ||||
|         stride, | ||||
|         padding, | ||||
|         bias, | ||||
|         has_avg, | ||||
|         has_bn, | ||||
|         has_relu, | ||||
|         last_max_pool=False, | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         self.InShape = None | ||||
|         self.OutShape = None | ||||
|         self.choices = get_width_choices(nOut) | ||||
|         self.register_buffer("choices_tensor", torch.Tensor(self.choices)) | ||||
|  | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         # if has_bn  : self.bn  = nn.BatchNorm2d(nOut) | ||||
|         # else       : self.bn  = None | ||||
|         self.has_bn = has_bn | ||||
|         self.BNs = nn.ModuleList() | ||||
|         for i, _out in enumerate(self.choices): | ||||
|             self.BNs.append(nn.BatchNorm2d(_out)) | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|  | ||||
|         if last_max_pool: | ||||
|             self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||||
|         else: | ||||
|             self.maxpool = None | ||||
|         self.in_dim = nIn | ||||
|         self.out_dim = nOut | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_flops(self, channels, check_range=True, divide=1): | ||||
|         iC, oC = channels | ||||
|         if check_range: | ||||
|             assert ( | ||||
|                 iC <= self.conv.in_channels and oC <= self.conv.out_channels | ||||
|             ), "{:} vs {:}  |  {:} vs {:}".format( | ||||
|                 iC, self.conv.in_channels, oC, self.conv.out_channels | ||||
|             ) | ||||
|         assert ( | ||||
|             isinstance(self.InShape, tuple) and len(self.InShape) == 2 | ||||
|         ), "invalid in-shape : {:}".format(self.InShape) | ||||
|         assert ( | ||||
|             isinstance(self.OutShape, tuple) and len(self.OutShape) == 2 | ||||
|         ), "invalid out-shape : {:}".format(self.OutShape) | ||||
|         # conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups | ||||
|         conv_per_position_flops = ( | ||||
|             self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups | ||||
|         ) | ||||
|         all_positions = self.OutShape[0] * self.OutShape[1] | ||||
|         flops = (conv_per_position_flops * all_positions / divide) * iC * oC | ||||
|         if self.conv.bias is not None: | ||||
|             flops += all_positions / divide | ||||
|         return flops | ||||
|  | ||||
|     def get_range(self): | ||||
|         return [self.choices] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, index, prob = tuple_inputs | ||||
|         index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) | ||||
|         probability = torch.squeeze(probability) | ||||
|         assert len(index) == 2, "invalid length : {:}".format(index) | ||||
|         # compute expected flop | ||||
|         # coordinates   = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability) | ||||
|         expected_outC = (self.choices_tensor * probability).sum() | ||||
|         expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6) | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         # convolutional layer | ||||
|         out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index]) | ||||
|         out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)] | ||||
|         # merge | ||||
|         out_channel = max([x.size(1) for x in out_bns]) | ||||
|         outA = ChannelWiseInter(out_bns[0], out_channel) | ||||
|         outB = ChannelWiseInter(out_bns[1], out_channel) | ||||
|         out = outA * prob[0] + outB * prob[1] | ||||
|         # out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1]) | ||||
|  | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         if self.maxpool: | ||||
|             out = self.maxpool(out) | ||||
|         return out, expected_outC, expected_flop | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.has_bn: | ||||
|             out = self.BNs[-1](conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|             self.OutShape = (out.size(-2), out.size(-1)) | ||||
|         if self.maxpool: | ||||
|             out = self.maxpool(out) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     expansion = 1 | ||||
|     num_conv = 2 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             inplanes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_range(self): | ||||
|         return self.conv_a.get_range() + self.conv_b.get_range() | ||||
|  | ||||
|     def get_flops(self, channels): | ||||
|         assert len(channels) == 3, "invalid channels : {:}".format(channels) | ||||
|         flop_A = self.conv_a.get_flops([channels[0], channels[1]]) | ||||
|         flop_B = self.conv_b.get_flops([channels[1], channels[2]]) | ||||
|         if hasattr(self.downsample, "get_flops"): | ||||
|             flop_C = self.downsample.get_flops([channels[0], channels[-1]]) | ||||
|         else: | ||||
|             flop_C = 0 | ||||
|         if ( | ||||
|             channels[0] != channels[-1] and self.downsample is None | ||||
|         ):  # this short-cut will be added during the infer-train | ||||
|             flop_C = ( | ||||
|                 channels[0] | ||||
|                 * channels[-1] | ||||
|                 * self.conv_b.OutShape[0] | ||||
|                 * self.conv_b.OutShape[1] | ||||
|             ) | ||||
|         return flop_A + flop_B + flop_C | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, indexes, probs = tuple_inputs | ||||
|         assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2 | ||||
|         # import pdb; pdb.set_trace() | ||||
|         out_a, expected_inC_a, expected_flop_a = self.conv_a( | ||||
|             (inputs, expected_inC, probability[0], indexes[0], probs[0]) | ||||
|         ) | ||||
|         out_b, expected_inC_b, expected_flop_b = self.conv_b( | ||||
|             (out_a, expected_inC_a, probability[1], indexes[1], probs[1]) | ||||
|         ) | ||||
|         if self.downsample is not None: | ||||
|             residual, _, expected_flop_c = self.downsample( | ||||
|                 (inputs, expected_inC, probability[1], indexes[1], probs[1]) | ||||
|             ) | ||||
|         else: | ||||
|             residual, expected_flop_c = inputs, 0 | ||||
|         out = additive_func(residual, out_b) | ||||
|         return ( | ||||
|             nn.functional.relu(out, inplace=True), | ||||
|             expected_inC_b, | ||||
|             sum([expected_flop_a, expected_flop_b, expected_flop_c]), | ||||
|         ) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, basicblock) | ||||
|         return nn.functional.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes * self.expansion, | ||||
|             1, | ||||
|             1, | ||||
|             0, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=False, | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes * self.expansion: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes * self.expansion | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_range(self): | ||||
|         return ( | ||||
|             self.conv_1x1.get_range() | ||||
|             + self.conv_3x3.get_range() | ||||
|             + self.conv_1x4.get_range() | ||||
|         ) | ||||
|  | ||||
|     def get_flops(self, channels): | ||||
|         assert len(channels) == 4, "invalid channels : {:}".format(channels) | ||||
|         flop_A = self.conv_1x1.get_flops([channels[0], channels[1]]) | ||||
|         flop_B = self.conv_3x3.get_flops([channels[1], channels[2]]) | ||||
|         flop_C = self.conv_1x4.get_flops([channels[2], channels[3]]) | ||||
|         if hasattr(self.downsample, "get_flops"): | ||||
|             flop_D = self.downsample.get_flops([channels[0], channels[-1]]) | ||||
|         else: | ||||
|             flop_D = 0 | ||||
|         if ( | ||||
|             channels[0] != channels[-1] and self.downsample is None | ||||
|         ):  # this short-cut will be added during the infer-train | ||||
|             flop_D = ( | ||||
|                 channels[0] | ||||
|                 * channels[-1] | ||||
|                 * self.conv_1x4.OutShape[0] | ||||
|                 * self.conv_1x4.OutShape[1] | ||||
|             ) | ||||
|         return flop_A + flop_B + flop_C + flop_D | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, bottleneck) | ||||
|         return nn.functional.relu(out, inplace=True) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, indexes, probs = tuple_inputs | ||||
|         assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3 | ||||
|         out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( | ||||
|             (inputs, expected_inC, probability[0], indexes[0], probs[0]) | ||||
|         ) | ||||
|         out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( | ||||
|             (out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1]) | ||||
|         ) | ||||
|         out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( | ||||
|             (out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2]) | ||||
|         ) | ||||
|         if self.downsample is not None: | ||||
|             residual, _, expected_flop_c = self.downsample( | ||||
|                 (inputs, expected_inC, probability[2], indexes[2], probs[2]) | ||||
|             ) | ||||
|         else: | ||||
|             residual, expected_flop_c = inputs, 0 | ||||
|         out = additive_func(residual, out_1x4) | ||||
|         return ( | ||||
|             nn.functional.relu(out, inplace=True), | ||||
|             expected_inC_1x4, | ||||
|             sum( | ||||
|                 [ | ||||
|                     expected_flop_1x1, | ||||
|                     expected_flop_3x3, | ||||
|                     expected_flop_1x4, | ||||
|                     expected_flop_c, | ||||
|                 ] | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SearchShapeImagenetResNet(nn.Module): | ||||
|     def __init__(self, block_name, layers, deep_stem, num_classes): | ||||
|         super(SearchShapeImagenetResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "BasicBlock": | ||||
|             block = ResNetBasicblock | ||||
|         elif block_name == "Bottleneck": | ||||
|             block = ResNetBottleneck | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|  | ||||
|         self.message = ( | ||||
|             "SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 sum(layers) * block.num_conv, layers | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         if not deep_stem: | ||||
|             self.layers = nn.ModuleList( | ||||
|                 [ | ||||
|                     ConvBNReLU( | ||||
|                         3, | ||||
|                         64, | ||||
|                         7, | ||||
|                         2, | ||||
|                         3, | ||||
|                         False, | ||||
|                         has_avg=False, | ||||
|                         has_bn=True, | ||||
|                         has_relu=True, | ||||
|                         last_max_pool=True, | ||||
|                     ) | ||||
|                 ] | ||||
|             ) | ||||
|             self.channels = [64] | ||||
|         else: | ||||
|             self.layers = nn.ModuleList( | ||||
|                 [ | ||||
|                     ConvBNReLU( | ||||
|                         3, 32, 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True | ||||
|                     ), | ||||
|                     ConvBNReLU( | ||||
|                         32, | ||||
|                         64, | ||||
|                         3, | ||||
|                         1, | ||||
|                         1, | ||||
|                         False, | ||||
|                         has_avg=False, | ||||
|                         has_bn=True, | ||||
|                         has_relu=True, | ||||
|                         last_max_pool=True, | ||||
|                     ), | ||||
|                 ] | ||||
|             ) | ||||
|             self.channels = [32, 64] | ||||
|  | ||||
|         meta_depth_info = get_depth_choices(layers) | ||||
|         self.InShape = None | ||||
|         self.depth_info = OrderedDict() | ||||
|         self.depth_at_i = OrderedDict() | ||||
|         for stage, layer_blocks in enumerate(layers): | ||||
|             cur_block_choices = meta_depth_info[stage] | ||||
|             assert ( | ||||
|                 cur_block_choices[-1] == layer_blocks | ||||
|             ), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks) | ||||
|             block_choices, xstart = [], len(self.layers) | ||||
|             for iL in range(layer_blocks): | ||||
|                 iC = self.channels[-1] | ||||
|                 planes = 64 * (2 ** stage) | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iC, planes, stride) | ||||
|                 self.channels.append(module.out_dim) | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iC, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|                 # added for depth | ||||
|                 layer_index = len(self.layers) - 1 | ||||
|                 if iL + 1 in cur_block_choices: | ||||
|                     block_choices.append(layer_index) | ||||
|                 if iL + 1 == layer_blocks: | ||||
|                     self.depth_info[layer_index] = { | ||||
|                         "choices": block_choices, | ||||
|                         "stage": stage, | ||||
|                         "xstart": xstart, | ||||
|                     } | ||||
|         self.depth_info_list = [] | ||||
|         for xend, info in self.depth_info.items(): | ||||
|             self.depth_info_list.append((xend, info)) | ||||
|             xstart, xstage = info["xstart"], info["stage"] | ||||
|             for ilayer in range(xstart, xend + 1): | ||||
|                 idx = bisect_right(info["choices"], ilayer - 1) | ||||
|                 self.depth_at_i[ilayer] = (xstage, idx) | ||||
|  | ||||
|         self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||||
|         self.classifier = nn.Linear(module.out_dim, num_classes) | ||||
|         self.InShape = None | ||||
|         self.tau = -1 | ||||
|         self.search_mode = "basic" | ||||
|         # assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) | ||||
|  | ||||
|         # parameters for width | ||||
|         self.Ranges = [] | ||||
|         self.layer2indexRange = [] | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             start_index = len(self.Ranges) | ||||
|             self.Ranges += layer.get_range() | ||||
|             self.layer2indexRange.append((start_index, len(self.Ranges))) | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "width_attentions", | ||||
|             nn.Parameter(torch.Tensor(len(self.Ranges), get_width_choices(None))), | ||||
|         ) | ||||
|         self.register_parameter( | ||||
|             "depth_attentions", | ||||
|             nn.Parameter(torch.Tensor(len(layers), meta_depth_info["num"])), | ||||
|         ) | ||||
|         nn.init.normal_(self.width_attentions, 0, 0.01) | ||||
|         nn.init.normal_(self.depth_attentions, 0, 0.01) | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     def arch_parameters(self, LR=None): | ||||
|         if LR is None: | ||||
|             return [self.width_attentions, self.depth_attentions] | ||||
|         else: | ||||
|             return [ | ||||
|                 {"params": self.width_attentions, "lr": LR}, | ||||
|                 {"params": self.depth_attentions, "lr": LR}, | ||||
|             ] | ||||
|  | ||||
|     def base_parameters(self): | ||||
|         return ( | ||||
|             list(self.layers.parameters()) | ||||
|             + list(self.avgpool.parameters()) | ||||
|             + list(self.classifier.parameters()) | ||||
|         ) | ||||
|  | ||||
|     def get_flop(self, mode, config_dict, extra_info): | ||||
|         if config_dict is not None: | ||||
|             config_dict = config_dict.copy() | ||||
|         # select channels | ||||
|         channels = [3] | ||||
|         for i, weight in enumerate(self.width_attentions): | ||||
|             if mode == "genotype": | ||||
|                 with torch.no_grad(): | ||||
|                     probe = nn.functional.softmax(weight, dim=0) | ||||
|                     C = self.Ranges[i][torch.argmax(probe).item()] | ||||
|             else: | ||||
|                 raise ValueError("invalid mode : {:}".format(mode)) | ||||
|             channels.append(C) | ||||
|         # select depth | ||||
|         if mode == "genotype": | ||||
|             with torch.no_grad(): | ||||
|                 depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) | ||||
|                 choices = torch.argmax(depth_probs, dim=1).cpu().tolist() | ||||
|         else: | ||||
|             raise ValueError("invalid mode : {:}".format(mode)) | ||||
|         selected_layers = [] | ||||
|         for choice, xvalue in zip(choices, self.depth_info_list): | ||||
|             xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1 | ||||
|             selected_layers.append(xtemp) | ||||
|         flop = 0 | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             s, e = self.layer2indexRange[i] | ||||
|             xchl = tuple(channels[s : e + 1]) | ||||
|             if i in self.depth_at_i: | ||||
|                 xstagei, xatti = self.depth_at_i[i] | ||||
|                 if xatti <= choices[xstagei]:  # leave this depth | ||||
|                     flop += layer.get_flops(xchl) | ||||
|                 else: | ||||
|                     flop += 0  # do not use this layer | ||||
|             else: | ||||
|                 flop += layer.get_flops(xchl) | ||||
|         # the last fc layer | ||||
|         flop += channels[-1] * self.classifier.out_features | ||||
|         if config_dict is None: | ||||
|             return flop / 1e6 | ||||
|         else: | ||||
|             config_dict["xchannels"] = channels | ||||
|             config_dict["xblocks"] = selected_layers | ||||
|             config_dict["super_type"] = "infer-shape" | ||||
|             config_dict["estimated_FLOP"] = flop / 1e6 | ||||
|             return flop / 1e6, config_dict | ||||
|  | ||||
|     def get_arch_info(self): | ||||
|         string = ( | ||||
|             "for depth and width, there are {:} + {:} attention probabilities.".format( | ||||
|                 len(self.depth_attentions), len(self.width_attentions) | ||||
|             ) | ||||
|         ) | ||||
|         string += "\n{:}".format(self.depth_info) | ||||
|         discrepancy = [] | ||||
|         with torch.no_grad(): | ||||
|             for i, att in enumerate(self.depth_attentions): | ||||
|                 prob = nn.functional.softmax(att, dim=0) | ||||
|                 prob = prob.cpu() | ||||
|                 selc = prob.argmax().item() | ||||
|                 prob = prob.tolist() | ||||
|                 prob = ["{:.3f}".format(x) for x in prob] | ||||
|                 xstring = "{:03d}/{:03d}-th : {:}".format( | ||||
|                     i, len(self.depth_attentions), " ".join(prob) | ||||
|                 ) | ||||
|                 logt = ["{:.4f}".format(x) for x in att.cpu().tolist()] | ||||
|                 xstring += "  ||  {:17s}".format(" ".join(logt)) | ||||
|                 prob = sorted([float(x) for x in prob]) | ||||
|                 disc = prob[-1] - prob[-2] | ||||
|                 xstring += "  || discrepancy={:.2f} || select={:}/{:}".format( | ||||
|                     disc, selc, len(prob) | ||||
|                 ) | ||||
|                 discrepancy.append(disc) | ||||
|                 string += "\n{:}".format(xstring) | ||||
|             string += "\n-----------------------------------------------" | ||||
|             for i, att in enumerate(self.width_attentions): | ||||
|                 prob = nn.functional.softmax(att, dim=0) | ||||
|                 prob = prob.cpu() | ||||
|                 selc = prob.argmax().item() | ||||
|                 prob = prob.tolist() | ||||
|                 prob = ["{:.3f}".format(x) for x in prob] | ||||
|                 xstring = "{:03d}/{:03d}-th : {:}".format( | ||||
|                     i, len(self.width_attentions), " ".join(prob) | ||||
|                 ) | ||||
|                 logt = ["{:.3f}".format(x) for x in att.cpu().tolist()] | ||||
|                 xstring += "  ||  {:52s}".format(" ".join(logt)) | ||||
|                 prob = sorted([float(x) for x in prob]) | ||||
|                 disc = prob[-1] - prob[-2] | ||||
|                 xstring += "  || dis={:.2f} || select={:}/{:}".format( | ||||
|                     disc, selc, len(prob) | ||||
|                 ) | ||||
|                 discrepancy.append(disc) | ||||
|                 string += "\n{:}".format(xstring) | ||||
|         return string, discrepancy | ||||
|  | ||||
|     def set_tau(self, tau_max, tau_min, epoch_ratio): | ||||
|         assert ( | ||||
|             epoch_ratio >= 0 and epoch_ratio <= 1 | ||||
|         ), "invalid epoch-ratio : {:}".format(epoch_ratio) | ||||
|         tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 | ||||
|         self.tau = tau | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, inputs): | ||||
|         flop_width_probs = nn.functional.softmax(self.width_attentions, dim=1) | ||||
|         flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) | ||||
|         flop_depth_probs = torch.flip( | ||||
|             torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1] | ||||
|         ) | ||||
|         selected_widths, selected_width_probs = select2withP( | ||||
|             self.width_attentions, self.tau | ||||
|         ) | ||||
|         selected_depth_probs = select2withP(self.depth_attentions, self.tau, True) | ||||
|         with torch.no_grad(): | ||||
|             selected_widths = selected_widths.cpu() | ||||
|  | ||||
|         x, last_channel_idx, expected_inC, flops = inputs, 0, 3, [] | ||||
|         feature_maps = [] | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             selected_w_index = selected_widths[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             selected_w_probs = selected_width_probs[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             layer_prob = flop_width_probs[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             x, expected_inC, expected_flop = layer( | ||||
|                 (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) | ||||
|             ) | ||||
|             feature_maps.append(x) | ||||
|             last_channel_idx += layer.num_conv | ||||
|             if i in self.depth_info:  # aggregate the information | ||||
|                 choices = self.depth_info[i]["choices"] | ||||
|                 xstagei = self.depth_info[i]["stage"] | ||||
|                 # print ('iL={:}, choices={:}, stage={:}, probs={:}'.format(i, choices, xstagei, selected_depth_probs[xstagei].cpu().tolist())) | ||||
|                 # for A, W in zip(choices, selected_depth_probs[xstagei]): | ||||
|                 #  print('Size = {:}, W = {:}'.format(feature_maps[A].size(), W)) | ||||
|                 possible_tensors = [] | ||||
|                 max_C = max(feature_maps[A].size(1) for A in choices) | ||||
|                 for tempi, A in enumerate(choices): | ||||
|                     xtensor = ChannelWiseInter(feature_maps[A], max_C) | ||||
|                     possible_tensors.append(xtensor) | ||||
|                 weighted_sum = sum( | ||||
|                     xtensor * W | ||||
|                     for xtensor, W in zip( | ||||
|                         possible_tensors, selected_depth_probs[xstagei] | ||||
|                     ) | ||||
|                 ) | ||||
|                 x = weighted_sum | ||||
|  | ||||
|             if i in self.depth_at_i: | ||||
|                 xstagei, xatti = self.depth_at_i[i] | ||||
|                 x_expected_flop = flop_depth_probs[xstagei, xatti] * expected_flop | ||||
|             else: | ||||
|                 x_expected_flop = expected_flop | ||||
|             flops.append(x_expected_flop) | ||||
|         flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6)) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = linear_forward(features, self.classifier) | ||||
|         return logits, torch.stack([sum(flops)]) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
| @@ -0,0 +1,466 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import math, torch | ||||
| import torch.nn as nn | ||||
| from ..initialization import initialize_resnet | ||||
| from ..SharedUtils import additive_func | ||||
| from .SoftSelect import select2withP, ChannelWiseInter | ||||
| from .SoftSelect import linear_forward | ||||
| from .SoftSelect import get_width_choices as get_choices | ||||
|  | ||||
|  | ||||
| def conv_forward(inputs, conv, choices): | ||||
|     iC = conv.in_channels | ||||
|     fill_size = list(inputs.size()) | ||||
|     fill_size[1] = iC - fill_size[1] | ||||
|     filled = torch.zeros(fill_size, device=inputs.device) | ||||
|     xinputs = torch.cat((inputs, filled), dim=1) | ||||
|     outputs = conv(xinputs) | ||||
|     selecteds = [outputs[:, :oC] for oC in choices] | ||||
|     return selecteds | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     num_conv = 1 | ||||
|  | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         self.InShape = None | ||||
|         self.OutShape = None | ||||
|         self.choices = get_choices(nOut) | ||||
|         self.register_buffer("choices_tensor", torch.Tensor(self.choices)) | ||||
|  | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         # if has_bn  : self.bn  = nn.BatchNorm2d(nOut) | ||||
|         # else       : self.bn  = None | ||||
|         self.has_bn = has_bn | ||||
|         self.BNs = nn.ModuleList() | ||||
|         for i, _out in enumerate(self.choices): | ||||
|             self.BNs.append(nn.BatchNorm2d(_out)) | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|         self.in_dim = nIn | ||||
|         self.out_dim = nOut | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_flops(self, channels, check_range=True, divide=1): | ||||
|         iC, oC = channels | ||||
|         if check_range: | ||||
|             assert ( | ||||
|                 iC <= self.conv.in_channels and oC <= self.conv.out_channels | ||||
|             ), "{:} vs {:}  |  {:} vs {:}".format( | ||||
|                 iC, self.conv.in_channels, oC, self.conv.out_channels | ||||
|             ) | ||||
|         assert ( | ||||
|             isinstance(self.InShape, tuple) and len(self.InShape) == 2 | ||||
|         ), "invalid in-shape : {:}".format(self.InShape) | ||||
|         assert ( | ||||
|             isinstance(self.OutShape, tuple) and len(self.OutShape) == 2 | ||||
|         ), "invalid out-shape : {:}".format(self.OutShape) | ||||
|         # conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups | ||||
|         conv_per_position_flops = ( | ||||
|             self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups | ||||
|         ) | ||||
|         all_positions = self.OutShape[0] * self.OutShape[1] | ||||
|         flops = (conv_per_position_flops * all_positions / divide) * iC * oC | ||||
|         if self.conv.bias is not None: | ||||
|             flops += all_positions / divide | ||||
|         return flops | ||||
|  | ||||
|     def get_range(self): | ||||
|         return [self.choices] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, index, prob = tuple_inputs | ||||
|         index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) | ||||
|         probability = torch.squeeze(probability) | ||||
|         assert len(index) == 2, "invalid length : {:}".format(index) | ||||
|         # compute expected flop | ||||
|         # coordinates   = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability) | ||||
|         expected_outC = (self.choices_tensor * probability).sum() | ||||
|         expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6) | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         # convolutional layer | ||||
|         out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index]) | ||||
|         out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)] | ||||
|         # merge | ||||
|         out_channel = max([x.size(1) for x in out_bns]) | ||||
|         outA = ChannelWiseInter(out_bns[0], out_channel) | ||||
|         outB = ChannelWiseInter(out_bns[1], out_channel) | ||||
|         out = outA * prob[0] + outB * prob[1] | ||||
|         # out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1]) | ||||
|  | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|         return out, expected_outC, expected_flop | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.has_bn: | ||||
|             out = self.BNs[-1](conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|             self.OutShape = (out.size(-2), out.size(-1)) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class SimBlock(nn.Module): | ||||
|     expansion = 1 | ||||
|     num_conv = 1 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(SimBlock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv = ConvBNReLU( | ||||
|             inplanes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_range(self): | ||||
|         return self.conv.get_range() | ||||
|  | ||||
|     def get_flops(self, channels): | ||||
|         assert len(channels) == 2, "invalid channels : {:}".format(channels) | ||||
|         flop_A = self.conv.get_flops([channels[0], channels[1]]) | ||||
|         if hasattr(self.downsample, "get_flops"): | ||||
|             flop_C = self.downsample.get_flops([channels[0], channels[-1]]) | ||||
|         else: | ||||
|             flop_C = 0 | ||||
|         if ( | ||||
|             channels[0] != channels[-1] and self.downsample is None | ||||
|         ):  # this short-cut will be added during the infer-train | ||||
|             flop_C = ( | ||||
|                 channels[0] | ||||
|                 * channels[-1] | ||||
|                 * self.conv.OutShape[0] | ||||
|                 * self.conv.OutShape[1] | ||||
|             ) | ||||
|         return flop_A + flop_C | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, indexes, probs = tuple_inputs | ||||
|         assert ( | ||||
|             indexes.size(0) == 1 and probs.size(0) == 1 and probability.size(0) == 1 | ||||
|         ), "invalid size : {:}, {:}, {:}".format( | ||||
|             indexes.size(), probs.size(), probability.size() | ||||
|         ) | ||||
|         out, expected_next_inC, expected_flop = self.conv( | ||||
|             (inputs, expected_inC, probability[0], indexes[0], probs[0]) | ||||
|         ) | ||||
|         if self.downsample is not None: | ||||
|             residual, _, expected_flop_c = self.downsample( | ||||
|                 (inputs, expected_inC, probability[-1], indexes[-1], probs[-1]) | ||||
|             ) | ||||
|         else: | ||||
|             residual, expected_flop_c = inputs, 0 | ||||
|         out = additive_func(residual, out) | ||||
|         return ( | ||||
|             nn.functional.relu(out, inplace=True), | ||||
|             expected_next_inC, | ||||
|             sum([expected_flop, expected_flop_c]), | ||||
|         ) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         basicblock = self.conv(inputs) | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, basicblock) | ||||
|         return nn.functional.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class SearchWidthSimResNet(nn.Module): | ||||
|     def __init__(self, depth, num_classes): | ||||
|         super(SearchWidthSimResNet, self).__init__() | ||||
|  | ||||
|         assert ( | ||||
|             depth - 2 | ||||
|         ) % 3 == 0, "depth should be one of 5, 8, 11, 14, ... instead of {:}".format( | ||||
|             depth | ||||
|         ) | ||||
|         layer_blocks = (depth - 2) // 3 | ||||
|         self.message = ( | ||||
|             "SearchWidthSimResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.channels = [16] | ||||
|         self.layers = nn.ModuleList( | ||||
|             [ | ||||
|                 ConvBNReLU( | ||||
|                     3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True | ||||
|                 ) | ||||
|             ] | ||||
|         ) | ||||
|         self.InShape = None | ||||
|         for stage in range(3): | ||||
|             for iL in range(layer_blocks): | ||||
|                 iC = self.channels[-1] | ||||
|                 planes = 16 * (2 ** stage) | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = SimBlock(iC, planes, stride) | ||||
|                 self.channels.append(module.out_dim) | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iC, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|  | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(module.out_dim, num_classes) | ||||
|         self.InShape = None | ||||
|         self.tau = -1 | ||||
|         self.search_mode = "basic" | ||||
|         # assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) | ||||
|  | ||||
|         # parameters for width | ||||
|         self.Ranges = [] | ||||
|         self.layer2indexRange = [] | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             start_index = len(self.Ranges) | ||||
|             self.Ranges += layer.get_range() | ||||
|             self.layer2indexRange.append((start_index, len(self.Ranges))) | ||||
|         assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format( | ||||
|             len(self.Ranges) + 1, depth | ||||
|         ) | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "width_attentions", | ||||
|             nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))), | ||||
|         ) | ||||
|         nn.init.normal_(self.width_attentions, 0, 0.01) | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     def arch_parameters(self): | ||||
|         return [self.width_attentions] | ||||
|  | ||||
|     def base_parameters(self): | ||||
|         return ( | ||||
|             list(self.layers.parameters()) | ||||
|             + list(self.avgpool.parameters()) | ||||
|             + list(self.classifier.parameters()) | ||||
|         ) | ||||
|  | ||||
|     def get_flop(self, mode, config_dict, extra_info): | ||||
|         if config_dict is not None: | ||||
|             config_dict = config_dict.copy() | ||||
|         # weights = [F.softmax(x, dim=0) for x in self.width_attentions] | ||||
|         channels = [3] | ||||
|         for i, weight in enumerate(self.width_attentions): | ||||
|             if mode == "genotype": | ||||
|                 with torch.no_grad(): | ||||
|                     probe = nn.functional.softmax(weight, dim=0) | ||||
|                     C = self.Ranges[i][torch.argmax(probe).item()] | ||||
|             elif mode == "max": | ||||
|                 C = self.Ranges[i][-1] | ||||
|             elif mode == "fix": | ||||
|                 C = int(math.sqrt(extra_info) * self.Ranges[i][-1]) | ||||
|             elif mode == "random": | ||||
|                 assert isinstance(extra_info, float), "invalid extra_info : {:}".format( | ||||
|                     extra_info | ||||
|                 ) | ||||
|                 with torch.no_grad(): | ||||
|                     prob = nn.functional.softmax(weight, dim=0) | ||||
|                     approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1]) | ||||
|                     for j in range(prob.size(0)): | ||||
|                         prob[j] = 1 / ( | ||||
|                             abs(j - (approximate_C - self.Ranges[i][j])) + 0.2 | ||||
|                         ) | ||||
|                     C = self.Ranges[i][torch.multinomial(prob, 1, False).item()] | ||||
|             else: | ||||
|                 raise ValueError("invalid mode : {:}".format(mode)) | ||||
|             channels.append(C) | ||||
|         flop = 0 | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             s, e = self.layer2indexRange[i] | ||||
|             xchl = tuple(channels[s : e + 1]) | ||||
|             flop += layer.get_flops(xchl) | ||||
|         # the last fc layer | ||||
|         flop += channels[-1] * self.classifier.out_features | ||||
|         if config_dict is None: | ||||
|             return flop / 1e6 | ||||
|         else: | ||||
|             config_dict["xchannels"] = channels | ||||
|             config_dict["super_type"] = "infer-width" | ||||
|             config_dict["estimated_FLOP"] = flop / 1e6 | ||||
|             return flop / 1e6, config_dict | ||||
|  | ||||
|     def get_arch_info(self): | ||||
|         string = "for width, there are {:} attention probabilities.".format( | ||||
|             len(self.width_attentions) | ||||
|         ) | ||||
|         discrepancy = [] | ||||
|         with torch.no_grad(): | ||||
|             for i, att in enumerate(self.width_attentions): | ||||
|                 prob = nn.functional.softmax(att, dim=0) | ||||
|                 prob = prob.cpu() | ||||
|                 selc = prob.argmax().item() | ||||
|                 prob = prob.tolist() | ||||
|                 prob = ["{:.3f}".format(x) for x in prob] | ||||
|                 xstring = "{:03d}/{:03d}-th : {:}".format( | ||||
|                     i, len(self.width_attentions), " ".join(prob) | ||||
|                 ) | ||||
|                 logt = ["{:.3f}".format(x) for x in att.cpu().tolist()] | ||||
|                 xstring += "  ||  {:52s}".format(" ".join(logt)) | ||||
|                 prob = sorted([float(x) for x in prob]) | ||||
|                 disc = prob[-1] - prob[-2] | ||||
|                 xstring += "  || dis={:.2f} || select={:}/{:}".format( | ||||
|                     disc, selc, len(prob) | ||||
|                 ) | ||||
|                 discrepancy.append(disc) | ||||
|                 string += "\n{:}".format(xstring) | ||||
|         return string, discrepancy | ||||
|  | ||||
|     def set_tau(self, tau_max, tau_min, epoch_ratio): | ||||
|         assert ( | ||||
|             epoch_ratio >= 0 and epoch_ratio <= 1 | ||||
|         ), "invalid epoch-ratio : {:}".format(epoch_ratio) | ||||
|         tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 | ||||
|         self.tau = tau | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, inputs): | ||||
|         flop_probs = nn.functional.softmax(self.width_attentions, dim=1) | ||||
|         selected_widths, selected_probs = select2withP(self.width_attentions, self.tau) | ||||
|         with torch.no_grad(): | ||||
|             selected_widths = selected_widths.cpu() | ||||
|  | ||||
|         x, last_channel_idx, expected_inC, flops = inputs, 0, 3, [] | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             selected_w_index = selected_widths[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             selected_w_probs = selected_probs[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             layer_prob = flop_probs[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             x, expected_inC, expected_flop = layer( | ||||
|                 (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) | ||||
|             ) | ||||
|             last_channel_idx += layer.num_conv | ||||
|             flops.append(expected_flop) | ||||
|         flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6)) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = linear_forward(features, self.classifier) | ||||
|         return logits, torch.stack([sum(flops)]) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
							
								
								
									
										128
									
								
								AutoDL-Projects/xautodl/models/shape_searchs/SoftSelect.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										128
									
								
								AutoDL-Projects/xautodl/models/shape_searchs/SoftSelect.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,128 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import math, torch | ||||
| import torch.nn as nn | ||||
|  | ||||
|  | ||||
| def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7): | ||||
|     if tau <= 0: | ||||
|         new_logits = logits | ||||
|         probs = nn.functional.softmax(new_logits, dim=1) | ||||
|     else: | ||||
|         while True:  # a trick to avoid the gumbels bug | ||||
|             gumbels = -torch.empty_like(logits).exponential_().log() | ||||
|             new_logits = (logits.log_softmax(dim=1) + gumbels) / tau | ||||
|             probs = nn.functional.softmax(new_logits, dim=1) | ||||
|             if ( | ||||
|                 (not torch.isinf(gumbels).any()) | ||||
|                 and (not torch.isinf(probs).any()) | ||||
|                 and (not torch.isnan(probs).any()) | ||||
|             ): | ||||
|                 break | ||||
|  | ||||
|     if just_prob: | ||||
|         return probs | ||||
|  | ||||
|     # with torch.no_grad(): # add eps for unexpected torch error | ||||
|     #  probs = nn.functional.softmax(new_logits, dim=1) | ||||
|     #  selected_index = torch.multinomial(probs + eps, 2, False) | ||||
|     with torch.no_grad():  # add eps for unexpected torch error | ||||
|         probs = probs.cpu() | ||||
|         selected_index = torch.multinomial(probs + eps, num, False).to(logits.device) | ||||
|     selected_logit = torch.gather(new_logits, 1, selected_index) | ||||
|     selcted_probs = nn.functional.softmax(selected_logit, dim=1) | ||||
|     return selected_index, selcted_probs | ||||
|  | ||||
|  | ||||
| def ChannelWiseInter(inputs, oC, mode="v2"): | ||||
|     if mode == "v1": | ||||
|         return ChannelWiseInterV1(inputs, oC) | ||||
|     elif mode == "v2": | ||||
|         return ChannelWiseInterV2(inputs, oC) | ||||
|     else: | ||||
|         raise ValueError("invalid mode : {:}".format(mode)) | ||||
|  | ||||
|  | ||||
| def ChannelWiseInterV1(inputs, oC): | ||||
|     assert inputs.dim() == 4, "invalid dimension : {:}".format(inputs.size()) | ||||
|  | ||||
|     def start_index(a, b, c): | ||||
|         return int(math.floor(float(a * c) / b)) | ||||
|  | ||||
|     def end_index(a, b, c): | ||||
|         return int(math.ceil(float((a + 1) * c) / b)) | ||||
|  | ||||
|     batch, iC, H, W = inputs.size() | ||||
|     outputs = torch.zeros((batch, oC, H, W), dtype=inputs.dtype, device=inputs.device) | ||||
|     if iC == oC: | ||||
|         return inputs | ||||
|     for ot in range(oC): | ||||
|         istartT, iendT = start_index(ot, oC, iC), end_index(ot, oC, iC) | ||||
|         values = inputs[:, istartT:iendT].mean(dim=1) | ||||
|         outputs[:, ot, :, :] = values | ||||
|     return outputs | ||||
|  | ||||
|  | ||||
| def ChannelWiseInterV2(inputs, oC): | ||||
|     assert inputs.dim() == 4, "invalid dimension : {:}".format(inputs.size()) | ||||
|     batch, C, H, W = inputs.size() | ||||
|     if C == oC: | ||||
|         return inputs | ||||
|     else: | ||||
|         return nn.functional.adaptive_avg_pool3d(inputs, (oC, H, W)) | ||||
|     # inputs_5D = inputs.view(batch, 1, C, H, W) | ||||
|     # otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'area', None) | ||||
|     # otputs    = otputs_5D.view(batch, oC, H, W) | ||||
|     # otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'trilinear', False) | ||||
|     # return otputs | ||||
|  | ||||
|  | ||||
| def linear_forward(inputs, linear): | ||||
|     if linear is None: | ||||
|         return inputs | ||||
|     iC = inputs.size(1) | ||||
|     weight = linear.weight[:, :iC] | ||||
|     if linear.bias is None: | ||||
|         bias = None | ||||
|     else: | ||||
|         bias = linear.bias | ||||
|     return nn.functional.linear(inputs, weight, bias) | ||||
|  | ||||
|  | ||||
| def get_width_choices(nOut): | ||||
|     xsrange = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] | ||||
|     if nOut is None: | ||||
|         return len(xsrange) | ||||
|     else: | ||||
|         Xs = [int(nOut * i) for i in xsrange] | ||||
|         # xs = [ int(nOut * i // 10) for i in range(2, 11)] | ||||
|         # Xs = [x for i, x in enumerate(xs) if i+1 == len(xs) or xs[i+1] > x+1] | ||||
|         Xs = sorted(list(set(Xs))) | ||||
|         return tuple(Xs) | ||||
|  | ||||
|  | ||||
| def get_depth_choices(nDepth): | ||||
|     if nDepth is None: | ||||
|         return 3 | ||||
|     else: | ||||
|         assert nDepth >= 3, "nDepth should be greater than 2 vs {:}".format(nDepth) | ||||
|         if nDepth == 1: | ||||
|             return (1, 1, 1) | ||||
|         elif nDepth == 2: | ||||
|             return (1, 1, 2) | ||||
|         elif nDepth >= 3: | ||||
|             return (nDepth // 3, nDepth * 2 // 3, nDepth) | ||||
|         else: | ||||
|             raise ValueError("invalid Depth : {:}".format(nDepth)) | ||||
|  | ||||
|  | ||||
| def drop_path(x, drop_prob): | ||||
|     if drop_prob > 0.0: | ||||
|         keep_prob = 1.0 - drop_prob | ||||
|         mask = x.new_zeros(x.size(0), 1, 1, 1) | ||||
|         mask = mask.bernoulli_(keep_prob) | ||||
|         x = x * (mask / keep_prob) | ||||
|         # x.div_(keep_prob) | ||||
|         # x.mul_(mask) | ||||
|     return x | ||||
							
								
								
									
										9
									
								
								AutoDL-Projects/xautodl/models/shape_searchs/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								AutoDL-Projects/xautodl/models/shape_searchs/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from .SearchCifarResNet_width import SearchWidthCifarResNet | ||||
| from .SearchCifarResNet_depth import SearchDepthCifarResNet | ||||
| from .SearchCifarResNet import SearchShapeCifarResNet | ||||
| from .SearchSimResNet_width import SearchWidthSimResNet | ||||
| from .SearchImagenetResNet import SearchShapeImagenetResNet | ||||
| from .generic_size_tiny_cell_model import GenericNAS301Model | ||||
| @@ -0,0 +1,209 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| # Here, we utilized three techniques to search for the number of channels: | ||||
| # - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019" | ||||
| # - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020" | ||||
| # - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020" | ||||
| from typing import List, Text, Any | ||||
| import random, torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from ..cell_infers.cells import InferCell | ||||
| from .SoftSelect import select2withP, ChannelWiseInter | ||||
|  | ||||
|  | ||||
| class GenericNAS301Model(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         candidate_Cs: List[int], | ||||
|         max_num_Cs: int, | ||||
|         genotype: Any, | ||||
|         num_classes: int, | ||||
|         affine: bool, | ||||
|         track_running_stats: bool, | ||||
|     ): | ||||
|         super(GenericNAS301Model, self).__init__() | ||||
|         self._max_num_Cs = max_num_Cs | ||||
|         self._candidate_Cs = candidate_Cs | ||||
|         if max_num_Cs % 3 != 2: | ||||
|             raise ValueError("invalid number of layers : {:}".format(max_num_Cs)) | ||||
|         self._num_stage = N = max_num_Cs // 3 | ||||
|         self._max_C = max(candidate_Cs) | ||||
|  | ||||
|         stem = nn.Sequential( | ||||
|             nn.Conv2d(3, self._max_C, kernel_size=3, padding=1, bias=not affine), | ||||
|             nn.BatchNorm2d( | ||||
|                 self._max_C, affine=affine, track_running_stats=track_running_stats | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|         c_prev = self._max_C | ||||
|         self._cells = nn.ModuleList() | ||||
|         self._cells.append(stem) | ||||
|         for index, reduction in enumerate(layer_reductions): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(c_prev, self._max_C, 2, True) | ||||
|             else: | ||||
|                 cell = InferCell( | ||||
|                     genotype, c_prev, self._max_C, 1, affine, track_running_stats | ||||
|                 ) | ||||
|             self._cells.append(cell) | ||||
|             c_prev = cell.out_dim | ||||
|         self._num_layer = len(self._cells) | ||||
|  | ||||
|         self.lastact = nn.Sequential( | ||||
|             nn.BatchNorm2d( | ||||
|                 c_prev, affine=affine, track_running_stats=track_running_stats | ||||
|             ), | ||||
|             nn.ReLU(inplace=True), | ||||
|         ) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(c_prev, num_classes) | ||||
|         # algorithm related | ||||
|         self.register_buffer("_tau", torch.zeros(1)) | ||||
|         self._algo = None | ||||
|         self._warmup_ratio = None | ||||
|  | ||||
|     def set_algo(self, algo: Text): | ||||
|         # used for searching | ||||
|         assert self._algo is None, "This functioin can only be called once." | ||||
|         assert algo in ["mask_gumbel", "mask_rl", "tas"], "invalid algo : {:}".format( | ||||
|             algo | ||||
|         ) | ||||
|         self._algo = algo | ||||
|         self._arch_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(self._max_num_Cs, len(self._candidate_Cs)) | ||||
|         ) | ||||
|         # if algo == 'mask_gumbel' or algo == 'mask_rl': | ||||
|         self.register_buffer( | ||||
|             "_masks", torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs)) | ||||
|         ) | ||||
|         for i in range(len(self._candidate_Cs)): | ||||
|             self._masks.data[i, : self._candidate_Cs[i]] = 1 | ||||
|  | ||||
|     @property | ||||
|     def tau(self): | ||||
|         return self._tau | ||||
|  | ||||
|     def set_tau(self, tau): | ||||
|         self._tau.data[:] = tau | ||||
|  | ||||
|     @property | ||||
|     def warmup_ratio(self): | ||||
|         return self._warmup_ratio | ||||
|  | ||||
|     def set_warmup_ratio(self, ratio: float): | ||||
|         self._warmup_ratio = ratio | ||||
|  | ||||
|     @property | ||||
|     def weights(self): | ||||
|         xlist = list(self._cells.parameters()) | ||||
|         xlist += list(self.lastact.parameters()) | ||||
|         xlist += list(self.global_pooling.parameters()) | ||||
|         xlist += list(self.classifier.parameters()) | ||||
|         return xlist | ||||
|  | ||||
|     @property | ||||
|     def alphas(self): | ||||
|         return [self._arch_parameters] | ||||
|  | ||||
|     def show_alphas(self): | ||||
|         with torch.no_grad(): | ||||
|             return "arch-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self._arch_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|  | ||||
|     @property | ||||
|     def random(self): | ||||
|         cs = [] | ||||
|         for i in range(self._max_num_Cs): | ||||
|             index = random.randint(0, len(self._candidate_Cs) - 1) | ||||
|             cs.append(str(self._candidate_Cs[index])) | ||||
|         return ":".join(cs) | ||||
|  | ||||
|     @property | ||||
|     def genotype(self): | ||||
|         cs = [] | ||||
|         for i in range(self._max_num_Cs): | ||||
|             with torch.no_grad(): | ||||
|                 index = self._arch_parameters[i].argmax().item() | ||||
|                 cs.append(str(self._candidate_Cs[index])) | ||||
|         return ":".join(cs) | ||||
|  | ||||
|     def get_message(self) -> Text: | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self._cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self._cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(candidates={_candidate_Cs}, num={_max_num_Cs}, N={_num_stage}, L={_num_layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         feature = inputs | ||||
|  | ||||
|         log_probs = [] | ||||
|         for i, cell in enumerate(self._cells): | ||||
|             feature = cell(feature) | ||||
|             # apply different searching algorithms | ||||
|             idx = max(0, i - 1) | ||||
|             if self._warmup_ratio is not None: | ||||
|                 if random.random() < self._warmup_ratio: | ||||
|                     mask = self._masks[-1] | ||||
|                 else: | ||||
|                     mask = self._masks[random.randint(0, len(self._masks) - 1)] | ||||
|                 feature = feature * mask.view(1, -1, 1, 1) | ||||
|             elif self._algo == "mask_gumbel": | ||||
|                 weights = nn.functional.gumbel_softmax( | ||||
|                     self._arch_parameters[idx : idx + 1], tau=self.tau, dim=-1 | ||||
|                 ) | ||||
|                 mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1) | ||||
|                 feature = feature * mask | ||||
|             elif self._algo == "tas": | ||||
|                 selected_cs, selected_probs = select2withP( | ||||
|                     self._arch_parameters[idx : idx + 1], self.tau, num=2 | ||||
|                 ) | ||||
|                 with torch.no_grad(): | ||||
|                     i1, i2 = selected_cs.cpu().view(-1).tolist() | ||||
|                 c1, c2 = self._candidate_Cs[i1], self._candidate_Cs[i2] | ||||
|                 out_channel = max(c1, c2) | ||||
|                 out1 = ChannelWiseInter(feature[:, :c1], out_channel) | ||||
|                 out2 = ChannelWiseInter(feature[:, :c2], out_channel) | ||||
|                 out = out1 * selected_probs[0, 0] + out2 * selected_probs[0, 1] | ||||
|                 if feature.shape[1] == out.shape[1]: | ||||
|                     feature = out | ||||
|                 else: | ||||
|                     miss = torch.zeros( | ||||
|                         feature.shape[0], | ||||
|                         feature.shape[1] - out.shape[1], | ||||
|                         feature.shape[2], | ||||
|                         feature.shape[3], | ||||
|                         device=feature.device, | ||||
|                     ) | ||||
|                     feature = torch.cat((out, miss), dim=1) | ||||
|             elif self._algo == "mask_rl": | ||||
|                 prob = nn.functional.softmax( | ||||
|                     self._arch_parameters[idx : idx + 1], dim=-1 | ||||
|                 ) | ||||
|                 dist = torch.distributions.Categorical(prob) | ||||
|                 action = dist.sample() | ||||
|                 log_probs.append(dist.log_prob(action)) | ||||
|                 mask = self._masks[action.item()].view(1, -1, 1, 1) | ||||
|                 feature = feature * mask | ||||
|             else: | ||||
|                 raise ValueError("invalid algorithm : {:}".format(self._algo)) | ||||
|  | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits, log_probs | ||||
							
								
								
									
										76
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/CifarNet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/CifarNet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,76 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from .construct_utils import drop_path | ||||
| from .head_utils      import CifarHEAD, AuxiliaryHeadCIFAR | ||||
| from .base_cells      import InferCell | ||||
|  | ||||
|  | ||||
| class NetworkCIFAR(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, N, stem_multiplier, auxiliary, genotype, num_classes): | ||||
|     super(NetworkCIFAR, self).__init__() | ||||
|     self._C               = C | ||||
|     self._layerN          = N | ||||
|     self._stem_multiplier = stem_multiplier | ||||
|  | ||||
|     C_curr = self._stem_multiplier * C | ||||
|     self.stem = CifarHEAD(C_curr) | ||||
|    | ||||
|     layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N     | ||||
|     layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|     block_indexs     = [0    ] * N + [-1  ] + [1    ] * N + [-1  ] + [2    ] * N | ||||
|     block2index      = {0:[], 1:[], 2:[]} | ||||
|  | ||||
|     C_prev_prev, C_prev, C_curr = C_curr, C_curr, C | ||||
|     reduction_prev, spatial, dims = False, 1, [] | ||||
|     self.auxiliary_index = None | ||||
|     self.auxiliary_head  = None | ||||
|     self.cells = nn.ModuleList() | ||||
|     for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): | ||||
|       cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) | ||||
|       reduction_prev = reduction | ||||
|       self.cells.append( cell ) | ||||
|       C_prev_prev, C_prev = C_prev, cell._multiplier*C_curr | ||||
|       if reduction and C_curr == C*4: | ||||
|         if auxiliary: | ||||
|           self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes) | ||||
|           self.auxiliary_index = index | ||||
|  | ||||
|       if reduction: spatial *= 2 | ||||
|       dims.append( (C_prev, spatial) ) | ||||
|        | ||||
|     self._Layer= len(self.cells) | ||||
|  | ||||
|  | ||||
|     self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|     self.classifier = nn.Linear(C_prev, num_classes) | ||||
|     self.drop_path_prob = -1 | ||||
|  | ||||
|   def update_drop_path(self, drop_path_prob): | ||||
|     self.drop_path_prob = drop_path_prob | ||||
|  | ||||
|   def auxiliary_param(self): | ||||
|     if self.auxiliary_head is None: return [] | ||||
|     else: return list( self.auxiliary_head.parameters() ) | ||||
|  | ||||
|   def get_message(self): | ||||
|     return self.extra_repr() | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(C={_C}, N={_layerN}, L={_Layer}, stem={_stem_multiplier}, drop-path={drop_path_prob})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     stem_feature, logits_aux = self.stem(inputs), None | ||||
|     cell_results = [stem_feature, stem_feature] | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob) | ||||
|       cell_results.append( cell_feature ) | ||||
|  | ||||
|       if self.auxiliary_index is not None and i == self.auxiliary_index and self.training: | ||||
|         logits_aux = self.auxiliary_head( cell_results[-1] ) | ||||
|     out = self.global_pooling( cell_results[-1] ) | ||||
|     out = out.view(out.size(0), -1) | ||||
|     logits = self.classifier(out) | ||||
|  | ||||
|     if logits_aux is None: return out, logits | ||||
|     else                 : return out, [logits, logits_aux] | ||||
							
								
								
									
										77
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/ImageNet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/ImageNet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,77 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from .construct_utils import drop_path | ||||
| from .base_cells import InferCell | ||||
| from .head_utils import ImageNetHEAD, AuxiliaryHeadImageNet | ||||
|  | ||||
|  | ||||
| class NetworkImageNet(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, N, auxiliary, genotype, num_classes): | ||||
|     super(NetworkImageNet, self).__init__() | ||||
|     self._C          = C | ||||
|     self._layerN     = N | ||||
|     layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4] * N | ||||
|     layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|     self.stem0 = nn.Sequential( | ||||
|       nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False), | ||||
|       nn.BatchNorm2d(C // 2), | ||||
|       nn.ReLU(inplace=True), | ||||
|       nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False), | ||||
|       nn.BatchNorm2d(C), | ||||
|     ) | ||||
|  | ||||
|     self.stem1 = nn.Sequential( | ||||
|       nn.ReLU(inplace=True), | ||||
|       nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False), | ||||
|       nn.BatchNorm2d(C), | ||||
|     ) | ||||
|  | ||||
|     C_prev_prev, C_prev, C_curr, reduction_prev = C, C, C, True | ||||
|  | ||||
|     self.cells = nn.ModuleList() | ||||
|     self.auxiliary_index = None | ||||
|     for i, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): | ||||
|       cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) | ||||
|       reduction_prev = reduction | ||||
|       self.cells += [cell] | ||||
|       C_prev_prev, C_prev = C_prev, cell._multiplier * C_curr | ||||
|       if reduction and C_curr == C*4: | ||||
|         C_to_auxiliary = C_prev | ||||
|         self.auxiliary_index = i | ||||
|    | ||||
|     self._NNN = len(self.cells) | ||||
|     if auxiliary: | ||||
|       self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes) | ||||
|     else: | ||||
|       self.auxiliary_head = None | ||||
|     self.global_pooling = nn.AvgPool2d(7) | ||||
|     self.classifier     = nn.Linear(C_prev, num_classes) | ||||
|     self.drop_path_prob = -1 | ||||
|  | ||||
|   def update_drop_path(self, drop_path_prob): | ||||
|     self.drop_path_prob = drop_path_prob | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(C={_C}, N=[{_layerN}, {_NNN}], aux-index={auxiliary_index}, drop-path={drop_path_prob})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def get_message(self): | ||||
|     return self.extra_repr() | ||||
|  | ||||
|   def auxiliary_param(self): | ||||
|     if self.auxiliary_head is None: return [] | ||||
|     else: return list( self.auxiliary_head.parameters() ) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     s0 = self.stem0(inputs) | ||||
|     s1 = self.stem1(s0) | ||||
|     logits_aux = None | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       s0, s1 = s1, cell(s0, s1, self.drop_path_prob) | ||||
|       if i == self.auxiliary_index and self.auxiliary_head and self.training: | ||||
|         logits_aux = self.auxiliary_head(s1) | ||||
|     out = self.global_pooling(s1) | ||||
|     logits = self.classifier(out.view(out.size(0), -1)) | ||||
|  | ||||
|     if logits_aux is None: return out, logits | ||||
|     else                 : return out, [logits, logits_aux] | ||||
							
								
								
									
										5
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Performance-Aware Template Network for One-Shot Neural Architecture Search | ||||
| from .CifarNet import NetworkCIFAR as CifarNet | ||||
| from .ImageNet import NetworkImageNet as ImageNet | ||||
| from .genotypes import Networks | ||||
| from .genotypes import build_genotype_from_dict | ||||
							
								
								
									
										173
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/base_cells.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/base_cells.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,173 @@ | ||||
| import math | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from .construct_utils import drop_path | ||||
| from ..operations import OPS, Identity, FactorizedReduce, ReLUConvBN | ||||
|  | ||||
|  | ||||
| class MixedOp(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, stride, PRIMITIVES): | ||||
|     super(MixedOp, self).__init__() | ||||
|     self._ops = nn.ModuleList() | ||||
|     self.name2idx = {} | ||||
|     for idx, primitive in enumerate(PRIMITIVES): | ||||
|       op = OPS[primitive](C, C, stride, False) | ||||
|       self._ops.append(op) | ||||
|       assert primitive not in self.name2idx, '{:} has already in'.format(primitive) | ||||
|       self.name2idx[primitive] = idx | ||||
|  | ||||
|   def forward(self, x, weights, op_name): | ||||
|     if op_name is None: | ||||
|       if weights is None: | ||||
|         return [op(x) for op in self._ops] | ||||
|       else: | ||||
|         return sum(w * op(x) for w, op in zip(weights, self._ops)) | ||||
|     else: | ||||
|       op_index = self.name2idx[op_name] | ||||
|       return self._ops[op_index](x) | ||||
|  | ||||
|  | ||||
|  | ||||
| class SearchCell(nn.Module): | ||||
|  | ||||
|   def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, PRIMITIVES, use_residual): | ||||
|     super(SearchCell, self).__init__() | ||||
|     self.reduction  = reduction | ||||
|     self.PRIMITIVES = deepcopy(PRIMITIVES) | ||||
|    | ||||
|     if reduction_prev: | ||||
|       self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine=False) | ||||
|     else: | ||||
|       self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False) | ||||
|     self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False) | ||||
|     self._steps        = steps | ||||
|     self._multiplier   = multiplier | ||||
|     self._use_residual = use_residual | ||||
|  | ||||
|     self._ops = nn.ModuleList() | ||||
|     for i in range(self._steps): | ||||
|       for j in range(2+i): | ||||
|         stride = 2 if reduction and j < 2 else 1 | ||||
|         op = MixedOp(C, stride, self.PRIMITIVES) | ||||
|         self._ops.append(op) | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(residual={_use_residual}, steps={_steps}, multiplier={_multiplier})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def forward(self, S0, S1, weights, connect, adjacency, drop_prob, modes): | ||||
|     if modes[0] is None: | ||||
|       if modes[1] == 'normal': | ||||
|         output = self.__forwardBoth(S0, S1, weights, connect, adjacency, drop_prob) | ||||
|       elif modes[1] == 'only_W': | ||||
|         output = self.__forwardOnlyW(S0, S1, drop_prob) | ||||
|     else: | ||||
|       test_genotype = modes[0] | ||||
|       if self.reduction: operations, concats = test_genotype.reduce, test_genotype.reduce_concat | ||||
|       else             : operations, concats = test_genotype.normal, test_genotype.normal_concat | ||||
|       s0, s1 = self.preprocess0(S0), self.preprocess1(S1) | ||||
|       states, offset = [s0, s1], 0 | ||||
|       assert self._steps == len(operations), '{:} vs. {:}'.format(self._steps, len(operations)) | ||||
|       for i, (opA, opB) in enumerate(operations): | ||||
|         A = self._ops[offset + opA[1]](states[opA[1]], None, opA[0]) | ||||
|         B = self._ops[offset + opB[1]](states[opB[1]], None, opB[0]) | ||||
|         state = A + B | ||||
|         offset += len(states) | ||||
|         states.append(state) | ||||
|       output = torch.cat([states[i] for i in concats], dim=1) | ||||
|     if self._use_residual and S1.size() == output.size(): | ||||
|       return S1 + output | ||||
|     else: return output | ||||
|    | ||||
|   def __forwardBoth(self, S0, S1, weights, connect, adjacency, drop_prob): | ||||
|     s0, s1 = self.preprocess0(S0), self.preprocess1(S1) | ||||
|     states, offset = [s0, s1], 0 | ||||
|     for i in range(self._steps): | ||||
|       clist = [] | ||||
|       for j, h in enumerate(states): | ||||
|         x = self._ops[offset+j](h, weights[offset+j], None) | ||||
|         if self.training and drop_prob > 0.: | ||||
|           x = drop_path(x, math.pow(drop_prob, 1./len(states))) | ||||
|         clist.append( x ) | ||||
|       connection = torch.mm(connect['{:}'.format(i)], adjacency[i]).squeeze(0) | ||||
|       state = sum(w * node for w, node in zip(connection, clist)) | ||||
|       offset += len(states) | ||||
|       states.append(state) | ||||
|     return torch.cat(states[-self._multiplier:], dim=1) | ||||
|  | ||||
|   def __forwardOnlyW(self, S0, S1, drop_prob): | ||||
|     s0, s1 = self.preprocess0(S0), self.preprocess1(S1) | ||||
|     states, offset = [s0, s1], 0 | ||||
|     for i in range(self._steps): | ||||
|       clist = [] | ||||
|       for j, h in enumerate(states): | ||||
|         xs = self._ops[offset+j](h, None, None) | ||||
|         clist += xs | ||||
|       if self.training and drop_prob > 0.: | ||||
|         xlist = [drop_path(x, math.pow(drop_prob, 1./len(states))) for x in clist] | ||||
|       else: xlist = clist | ||||
|       state = sum(xlist) * 2 / len(xlist) | ||||
|       offset += len(states) | ||||
|       states.append(state) | ||||
|     return torch.cat(states[-self._multiplier:], dim=1) | ||||
|  | ||||
|  | ||||
|  | ||||
| class InferCell(nn.Module): | ||||
|  | ||||
|   def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): | ||||
|     super(InferCell, self).__init__() | ||||
|     print(C_prev_prev, C_prev, C) | ||||
|  | ||||
|     if reduction_prev is None: | ||||
|       self.preprocess0 = Identity() | ||||
|     elif reduction_prev: | ||||
|       self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2) | ||||
|     else: | ||||
|       self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) | ||||
|     self.preprocess1   = ReLUConvBN(C_prev, C, 1, 1, 0) | ||||
|      | ||||
|     if reduction: step_ops, concat = genotype.reduce, genotype.reduce_concat | ||||
|     else        : step_ops, concat = genotype.normal, genotype.normal_concat | ||||
|     self._steps        = len(step_ops) | ||||
|     self._concat       = concat | ||||
|     self._multiplier   = len(concat) | ||||
|     self._ops          = nn.ModuleList() | ||||
|     self._indices      = [] | ||||
|     for operations in step_ops: | ||||
|       for name, index in operations: | ||||
|         stride = 2 if reduction and index < 2 else 1 | ||||
|         if reduction_prev is None and index == 0: | ||||
|           op = OPS[name](C_prev_prev, C, stride, True) | ||||
|         else: | ||||
|           op = OPS[name](C          , C, stride, True) | ||||
|         self._ops.append( op ) | ||||
|         self._indices.append( index ) | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(steps={_steps}, concat={_concat})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def forward(self, S0, S1, drop_prob): | ||||
|     s0 = self.preprocess0(S0) | ||||
|     s1 = self.preprocess1(S1) | ||||
|  | ||||
|     states = [s0, s1] | ||||
|     for i in range(self._steps): | ||||
|       h1 = states[self._indices[2*i]] | ||||
|       h2 = states[self._indices[2*i+1]] | ||||
|       op1 = self._ops[2*i] | ||||
|       op2 = self._ops[2*i+1] | ||||
|       h1 = op1(h1) | ||||
|       h2 = op2(h2) | ||||
|       if self.training and drop_prob > 0.: | ||||
|         if not isinstance(op1, Identity): | ||||
|           h1 = drop_path(h1, drop_prob) | ||||
|         if not isinstance(op2, Identity): | ||||
|           h2 = drop_path(h2, drop_prob) | ||||
|  | ||||
|       state = h1 + h2 | ||||
|       states += [state] | ||||
|     output = torch.cat([states[i] for i in self._concat], dim=1) | ||||
|     return output | ||||
| @@ -0,0 +1,60 @@ | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
|  | ||||
|  | ||||
| def drop_path(x, drop_prob): | ||||
|   if drop_prob > 0.: | ||||
|     keep_prob = 1. - drop_prob | ||||
|     mask = x.new_zeros(x.size(0), 1, 1, 1) | ||||
|     mask = mask.bernoulli_(keep_prob) | ||||
|     x = torch.div(x, keep_prob) | ||||
|     x.mul_(mask) | ||||
|   return x | ||||
|  | ||||
|  | ||||
| def return_alphas_str(basemodel): | ||||
|   if hasattr(basemodel, 'alphas_normal'): | ||||
|     string = 'normal [{:}] : \n-->>{:}'.format(basemodel.alphas_normal.size(), F.softmax(basemodel.alphas_normal, dim=-1) ) | ||||
|   else: string = '' | ||||
|   if hasattr(basemodel, 'alphas_reduce'): | ||||
|     string = string + '\nreduce : {:}'.format( F.softmax(basemodel.alphas_reduce, dim=-1) ) | ||||
|  | ||||
|   if hasattr(basemodel, 'get_adjacency'): | ||||
|     adjacency = basemodel.get_adjacency() | ||||
|     for i in range( len(adjacency) ): | ||||
|       weight = F.softmax( basemodel.connect_normal[str(i)], dim=-1 ) | ||||
|       adj = torch.mm(weight, adjacency[i]).view(-1) | ||||
|       adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()] | ||||
|       string = string + '\nnormal--{:}-->{:}'.format(i, ', '.join(adj)) | ||||
|     for i in range( len(adjacency) ): | ||||
|       weight = F.softmax( basemodel.connect_reduce[str(i)], dim=-1 ) | ||||
|       adj = torch.mm(weight, adjacency[i]).view(-1) | ||||
|       adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()] | ||||
|       string = string + '\nreduce--{:}-->{:}'.format(i, ', '.join(adj)) | ||||
|  | ||||
|   if hasattr(basemodel, 'alphas_connect'): | ||||
|     weight = F.softmax(basemodel.alphas_connect, dim=-1).cpu() | ||||
|     ZERO = ['{:.3f}'.format(x) for x in weight[:,0].tolist()] | ||||
|     IDEN = ['{:.3f}'.format(x) for x in weight[:,1].tolist()] | ||||
|     string = string + '\nconnect [{:}] : \n ->{:}\n ->{:}'.format( list(basemodel.alphas_connect.size()), ZERO, IDEN ) | ||||
|   else: | ||||
|     string = string + '\nconnect = None' | ||||
|    | ||||
|   if hasattr(basemodel, 'get_gcn_out'): | ||||
|     outputs = basemodel.get_gcn_out(True) | ||||
|     for i, output in enumerate(outputs): | ||||
|       string = string + '\nnormal:[{:}] : {:}'.format(i, F.softmax(output, dim=-1) ) | ||||
|  | ||||
|   return string | ||||
|  | ||||
|  | ||||
| def remove_duplicate_archs(all_archs): | ||||
|   archs = [] | ||||
|   str_archs = ['{:}'.format(x) for x in all_archs] | ||||
|   for i, arch_x in enumerate(str_archs): | ||||
|     choose = True | ||||
|     for j in range(i): | ||||
|       if arch_x == str_archs[j]: | ||||
|         choose = False; break | ||||
|     if choose: archs.append(all_archs[i]) | ||||
|   return archs | ||||
							
								
								
									
										182
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/genotypes.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										182
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/genotypes.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,182 @@ | ||||
| from collections import namedtuple | ||||
|  | ||||
| Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat connectN connects') | ||||
| #Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') | ||||
|  | ||||
| PRIMITIVES_small = [ | ||||
|     'max_pool_3x3', | ||||
|     'avg_pool_3x3', | ||||
|     'skip_connect', | ||||
|     'sep_conv_3x3', | ||||
|     'sep_conv_5x5', | ||||
|     'conv_3x1_1x3', | ||||
| ] | ||||
|  | ||||
| PRIMITIVES_large = [ | ||||
|     'max_pool_3x3', | ||||
|     'avg_pool_3x3', | ||||
|     'skip_connect', | ||||
|     'sep_conv_3x3', | ||||
|     'sep_conv_5x5', | ||||
|     'dil_conv_3x3', | ||||
|     'dil_conv_5x5', | ||||
|     'conv_3x1_1x3', | ||||
| ] | ||||
|  | ||||
| PRIMITIVES_huge = [ | ||||
|     'skip_connect', | ||||
|     'nor_conv_1x1', | ||||
|     'max_pool_3x3', | ||||
|     'avg_pool_3x3', | ||||
|     'nor_conv_3x3', | ||||
|     'sep_conv_3x3', | ||||
|     'dil_conv_3x3', | ||||
|     'conv_3x1_1x3', | ||||
|     'sep_conv_5x5', | ||||
|     'dil_conv_5x5', | ||||
|     'sep_conv_7x7', | ||||
|     'conv_7x1_1x7', | ||||
|     'att_squeeze', | ||||
| ] | ||||
|  | ||||
| PRIMITIVES = {'small': PRIMITIVES_small, | ||||
|               'large': PRIMITIVES_large, | ||||
|               'huge' : PRIMITIVES_huge} | ||||
|  | ||||
| NASNet = Genotype( | ||||
|   normal = [ | ||||
|     (('sep_conv_5x5', 1), ('sep_conv_3x3', 0)), | ||||
|     (('sep_conv_5x5', 0), ('sep_conv_3x3', 0)), | ||||
|     (('avg_pool_3x3', 1), ('skip_connect', 0)), | ||||
|     (('avg_pool_3x3', 0), ('avg_pool_3x3', 0)), | ||||
|     (('sep_conv_3x3', 1), ('skip_connect', 1)), | ||||
|   ], | ||||
|   normal_concat = [2, 3, 4, 5, 6], | ||||
|   reduce = [ | ||||
|     (('sep_conv_5x5', 1), ('sep_conv_7x7', 0)), | ||||
|     (('max_pool_3x3', 1), ('sep_conv_7x7', 0)), | ||||
|     (('avg_pool_3x3', 1), ('sep_conv_5x5', 0)), | ||||
|     (('skip_connect', 3), ('avg_pool_3x3', 2)), | ||||
|     (('sep_conv_3x3', 2), ('max_pool_3x3', 1)), | ||||
|   ], | ||||
|   reduce_concat = [4, 5, 6], | ||||
|   connectN=None, connects=None, | ||||
| ) | ||||
|  | ||||
| PNASNet = Genotype( | ||||
|   normal = [ | ||||
|     (('sep_conv_5x5', 0), ('max_pool_3x3', 0)), | ||||
|     (('sep_conv_7x7', 1), ('max_pool_3x3', 1)), | ||||
|     (('sep_conv_5x5', 1), ('sep_conv_3x3', 1)), | ||||
|     (('sep_conv_3x3', 4), ('max_pool_3x3', 1)), | ||||
|     (('sep_conv_3x3', 0), ('skip_connect', 1)), | ||||
|   ], | ||||
|   normal_concat = [2, 3, 4, 5, 6], | ||||
|   reduce = [ | ||||
|     (('sep_conv_5x5', 0), ('max_pool_3x3', 0)), | ||||
|     (('sep_conv_7x7', 1), ('max_pool_3x3', 1)), | ||||
|     (('sep_conv_5x5', 1), ('sep_conv_3x3', 1)), | ||||
|     (('sep_conv_3x3', 4), ('max_pool_3x3', 1)), | ||||
|     (('sep_conv_3x3', 0), ('skip_connect', 1)), | ||||
|   ], | ||||
|   reduce_concat = [2, 3, 4, 5, 6], | ||||
|   connectN=None, connects=None, | ||||
| ) | ||||
|  | ||||
|  | ||||
| DARTS_V1 = Genotype( | ||||
|   normal=[ | ||||
|     (('sep_conv_3x3', 1), ('sep_conv_3x3', 0)), # step 1 | ||||
|     (('skip_connect', 0), ('sep_conv_3x3', 1)), # step 2 | ||||
|     (('skip_connect', 0), ('sep_conv_3x3', 1)), # step 3 | ||||
|     (('sep_conv_3x3', 0), ('skip_connect', 2))  # step 4 | ||||
|   ], | ||||
|   normal_concat=[2, 3, 4, 5], | ||||
|   reduce=[ | ||||
|     (('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1 | ||||
|     (('skip_connect', 2), ('max_pool_3x3', 0)), # step 2 | ||||
|     (('max_pool_3x3', 0), ('skip_connect', 2)), # step 3 | ||||
|     (('skip_connect', 2), ('avg_pool_3x3', 0))  # step 4 | ||||
|   ], | ||||
|   reduce_concat=[2, 3, 4, 5], | ||||
|   connectN=None, connects=None, | ||||
| ) | ||||
|  | ||||
| # DARTS: Differentiable Architecture Search, ICLR 2019 | ||||
| DARTS_V2 = Genotype( | ||||
|   normal=[ | ||||
|     (('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 1 | ||||
|     (('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 2 | ||||
|     (('sep_conv_3x3', 1), ('skip_connect', 0)), # step 3 | ||||
|     (('skip_connect', 0), ('dil_conv_3x3', 2))  # step 4 | ||||
|   ], | ||||
|   normal_concat=[2, 3, 4, 5], | ||||
|   reduce=[ | ||||
|     (('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1 | ||||
|     (('skip_connect', 2), ('max_pool_3x3', 1)), # step 2 | ||||
|     (('max_pool_3x3', 0), ('skip_connect', 2)), # step 3 | ||||
|     (('skip_connect', 2), ('max_pool_3x3', 1))  # step 4 | ||||
|   ], | ||||
|   reduce_concat=[2, 3, 4, 5], | ||||
|   connectN=None, connects=None, | ||||
| ) | ||||
|  | ||||
|  | ||||
| # One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 | ||||
| SETN = Genotype( | ||||
|   normal=[ | ||||
|     (('skip_connect', 0), ('sep_conv_5x5', 1)), | ||||
|     (('sep_conv_5x5', 0), ('sep_conv_3x3', 1)), | ||||
|     (('sep_conv_5x5', 1), ('sep_conv_5x5', 3)), | ||||
|     (('max_pool_3x3', 1), ('conv_3x1_1x3', 4))], | ||||
|   normal_concat=[2, 3, 4, 5], | ||||
|   reduce=[ | ||||
|     (('sep_conv_3x3', 0), ('sep_conv_5x5', 1)), | ||||
|     (('avg_pool_3x3', 0), ('sep_conv_5x5', 1)), | ||||
|     (('avg_pool_3x3', 0), ('sep_conv_5x5', 1)), | ||||
|     (('avg_pool_3x3', 0), ('skip_connect', 1))], | ||||
|   reduce_concat=[2, 3, 4, 5], | ||||
|   connectN=None, connects=None | ||||
| ) | ||||
|  | ||||
|  | ||||
| # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 | ||||
| GDAS_V1 = Genotype( | ||||
|   normal=[ | ||||
|     (('skip_connect', 0), ('skip_connect', 1)), | ||||
|     (('skip_connect', 0), ('sep_conv_5x5', 2)), | ||||
|     (('sep_conv_3x3', 3), ('skip_connect', 0)), | ||||
|     (('sep_conv_5x5', 4), ('sep_conv_3x3', 3))], | ||||
|   normal_concat=[2, 3, 4, 5], | ||||
|   reduce=[ | ||||
|     (('sep_conv_5x5', 0), ('sep_conv_3x3', 1)),  | ||||
|     (('sep_conv_5x5', 2), ('sep_conv_5x5', 1)), | ||||
|     (('dil_conv_5x5', 2), ('sep_conv_3x3', 1)), | ||||
|     (('sep_conv_5x5', 0), ('sep_conv_5x5', 1))], | ||||
|   reduce_concat=[2, 3, 4, 5], | ||||
|   connectN=None, connects=None | ||||
| ) | ||||
|  | ||||
|  | ||||
|  | ||||
| Networks = {'DARTS_V1': DARTS_V1, | ||||
|             'DARTS_V2': DARTS_V2, | ||||
|             'DARTS'   : DARTS_V2, | ||||
|             'NASNet'  : NASNet, | ||||
|             'GDAS_V1' : GDAS_V1, | ||||
|             'PNASNet' : PNASNet, | ||||
|             'SETN'    : SETN, | ||||
|            } | ||||
|  | ||||
| # This function will return a Genotype from a dict. | ||||
| def build_genotype_from_dict(xdict): | ||||
|   def remove_value(nodes): | ||||
|     return [tuple([(x[0], x[1]) for x in node]) for node in nodes] | ||||
|   genotype = Genotype( | ||||
|       normal=remove_value(xdict['normal']), | ||||
|       normal_concat=xdict['normal_concat'], | ||||
|       reduce=remove_value(xdict['reduce']), | ||||
|       reduce_concat=xdict['reduce_concat'], | ||||
|       connectN=None, connects=None | ||||
|       ) | ||||
|   return genotype | ||||
							
								
								
									
										71
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/head_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/head_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,71 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
|  | ||||
| class ImageNetHEAD(nn.Sequential): | ||||
|     def __init__(self, C, stride=2): | ||||
|         super(ImageNetHEAD, self).__init__() | ||||
|         self.add_module( | ||||
|             "conv1", | ||||
|             nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False), | ||||
|         ) | ||||
|         self.add_module("bn1", nn.BatchNorm2d(C // 2)) | ||||
|         self.add_module("relu1", nn.ReLU(inplace=True)) | ||||
|         self.add_module( | ||||
|             "conv2", | ||||
|             nn.Conv2d(C // 2, C, kernel_size=3, stride=stride, padding=1, bias=False), | ||||
|         ) | ||||
|         self.add_module("bn2", nn.BatchNorm2d(C)) | ||||
|  | ||||
|  | ||||
| class CifarHEAD(nn.Sequential): | ||||
|     def __init__(self, C): | ||||
|         super(CifarHEAD, self).__init__() | ||||
|         self.add_module("conv", nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False)) | ||||
|         self.add_module("bn", nn.BatchNorm2d(C)) | ||||
|  | ||||
|  | ||||
| class AuxiliaryHeadCIFAR(nn.Module): | ||||
|     def __init__(self, C, num_classes): | ||||
|         """assuming input size 8x8""" | ||||
|         super(AuxiliaryHeadCIFAR, self).__init__() | ||||
|         self.features = nn.Sequential( | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.AvgPool2d( | ||||
|                 5, stride=3, padding=0, count_include_pad=False | ||||
|             ),  # image size = 2 x 2 | ||||
|             nn.Conv2d(C, 128, 1, bias=False), | ||||
|             nn.BatchNorm2d(128), | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.Conv2d(128, 768, 2, bias=False), | ||||
|             nn.BatchNorm2d(768), | ||||
|             nn.ReLU(inplace=True), | ||||
|         ) | ||||
|         self.classifier = nn.Linear(768, num_classes) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.features(x) | ||||
|         x = self.classifier(x.view(x.size(0), -1)) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class AuxiliaryHeadImageNet(nn.Module): | ||||
|     def __init__(self, C, num_classes): | ||||
|         """assuming input size 14x14""" | ||||
|         super(AuxiliaryHeadImageNet, self).__init__() | ||||
|         self.features = nn.Sequential( | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False), | ||||
|             nn.Conv2d(C, 128, 1, bias=False), | ||||
|             nn.BatchNorm2d(128), | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.Conv2d(128, 768, 2, bias=False), | ||||
|             nn.BatchNorm2d(768), | ||||
|             nn.ReLU(inplace=True), | ||||
|         ) | ||||
|         self.classifier = nn.Linear(768, num_classes) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.features(x) | ||||
|         x = self.classifier(x.view(x.size(0), -1)) | ||||
|         return x | ||||
							
								
								
									
										51
									
								
								AutoDL-Projects/xautodl/nas_infer_model/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								AutoDL-Projects/xautodl/nas_infer_model/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,51 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| # I write this package to make AutoDL-Projects to be compatible with the old GDAS projects. | ||||
| # Ideally, this package will be merged into lib/models/cell_infers in future. | ||||
| # Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019). | ||||
| ################################################## | ||||
|  | ||||
| import os, torch | ||||
|  | ||||
|  | ||||
| def obtain_nas_infer_model(config, extra_model_path=None): | ||||
|  | ||||
|     if config.arch == "dxys": | ||||
|         from .DXYs import CifarNet, ImageNet, Networks | ||||
|         from .DXYs import build_genotype_from_dict | ||||
|  | ||||
|         if config.genotype is None: | ||||
|             if extra_model_path is not None and not os.path.isfile(extra_model_path): | ||||
|                 raise ValueError( | ||||
|                     "When genotype in confiig is None, extra_model_path must be set as a path instead of {:}".format( | ||||
|                         extra_model_path | ||||
|                     ) | ||||
|                 ) | ||||
|             xdata = torch.load(extra_model_path) | ||||
|             current_epoch = xdata["epoch"] | ||||
|             genotype_dict = xdata["genotypes"][current_epoch - 1] | ||||
|             genotype = build_genotype_from_dict(genotype_dict) | ||||
|         else: | ||||
|             genotype = Networks[config.genotype] | ||||
|         if config.dataset == "cifar": | ||||
|             return CifarNet( | ||||
|                 config.ichannel, | ||||
|                 config.layers, | ||||
|                 config.stem_multi, | ||||
|                 config.auxiliary, | ||||
|                 genotype, | ||||
|                 config.class_num, | ||||
|             ) | ||||
|         elif config.dataset == "imagenet": | ||||
|             return ImageNet( | ||||
|                 config.ichannel, | ||||
|                 config.layers, | ||||
|                 config.auxiliary, | ||||
|                 genotype, | ||||
|                 config.class_num, | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("invalid dataset : {:}".format(config.dataset)) | ||||
|     else: | ||||
|         raise ValueError("invalid nas arch type : {:}".format(config.arch)) | ||||
							
								
								
									
										183
									
								
								AutoDL-Projects/xautodl/nas_infer_model/operations.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										183
									
								
								AutoDL-Projects/xautodl/nas_infer_model/operations.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,183 @@ | ||||
| ############################################################################################## | ||||
| # This code is copied and modified from Hanxiao Liu's work (https://github.com/quark0/darts) # | ||||
| ############################################################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| OPS = { | ||||
|   'none'         : lambda C_in, C_out, stride, affine: Zero(stride), | ||||
|   'avg_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg'), | ||||
|   'max_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'max'), | ||||
|   'nor_conv_7x7' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), affine), | ||||
|   'nor_conv_3x3' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), affine), | ||||
|   'nor_conv_1x1' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), affine), | ||||
|   'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine), | ||||
|   'sep_conv_3x3' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 3, stride, 1, affine=affine), | ||||
|   'sep_conv_5x5' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 5, stride, 2, affine=affine), | ||||
|   'sep_conv_7x7' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 7, stride, 3, affine=affine), | ||||
|   'dil_conv_3x3' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 3, stride, 2, 2, affine=affine), | ||||
|   'dil_conv_5x5' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 5, stride, 4, 2, affine=affine), | ||||
|   'conv_7x1_1x7' : lambda C_in, C_out, stride, affine: Conv717(C_in, C_out, stride, affine), | ||||
|   'conv_3x1_1x3' : lambda C_in, C_out, stride, affine: Conv313(C_in, C_out, stride, affine) | ||||
| } | ||||
|  | ||||
|  | ||||
| class POOLING(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, mode): | ||||
|     super(POOLING, self).__init__() | ||||
|     if C_in == C_out: | ||||
|       self.preprocess = None | ||||
|     else: | ||||
|       self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0) | ||||
|     if mode == 'avg'  : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) | ||||
|     elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     if self.preprocess is not None: | ||||
|       x = self.preprocess(inputs) | ||||
|     else: x = inputs | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class Conv313(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, affine): | ||||
|     super(Conv313, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in , C_out, (1,3), stride=(1, stride), padding=(0, 1), bias=False), | ||||
|       nn.Conv2d(C_out, C_out, (3,1), stride=(stride, 1), padding=(1, 0), bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine) | ||||
|     ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class Conv717(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, affine): | ||||
|     super(Conv717, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in , C_out, (1,7), stride=(1, stride), padding=(0, 3), bias=False), | ||||
|       nn.Conv2d(C_out, C_out, (7,1), stride=(stride, 1), padding=(3, 0), bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine) | ||||
|     ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class ReLUConvBN(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): | ||||
|     super(ReLUConvBN, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine) | ||||
|     ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class DilConv(nn.Module): | ||||
|      | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): | ||||
|     super(DilConv, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in, C_in,  kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), | ||||
|       nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine), | ||||
|       ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class SepConv(nn.Module): | ||||
|      | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): | ||||
|     super(SepConv, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), | ||||
|       nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), | ||||
|       nn.BatchNorm2d(C_in, affine=affine), | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=     1, padding=padding, groups=C_in, bias=False), | ||||
|       nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine), | ||||
|       ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class Identity(nn.Module): | ||||
|  | ||||
|   def __init__(self): | ||||
|     super(Identity, self).__init__() | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return x | ||||
|  | ||||
|  | ||||
| class Zero(nn.Module): | ||||
|  | ||||
|   def __init__(self, stride): | ||||
|     super(Zero, self).__init__() | ||||
|     self.stride = stride | ||||
|  | ||||
|   def forward(self, x): | ||||
|     if self.stride == 1: | ||||
|       return x.mul(0.) | ||||
|     return x[:,:,::self.stride,::self.stride].mul(0.) | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return 'stride={stride}'.format(**self.__dict__) | ||||
|  | ||||
|  | ||||
| class FactorizedReduce(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, affine=True): | ||||
|     super(FactorizedReduce, self).__init__() | ||||
|     self.stride = stride | ||||
|     self.C_in   = C_in   | ||||
|     self.C_out  = C_out   | ||||
|     self.relu   = nn.ReLU(inplace=False) | ||||
|     if stride == 2: | ||||
|       #assert C_out % 2 == 0, 'C_out : {:}'.format(C_out) | ||||
|       C_outs = [C_out // 2, C_out - C_out // 2] | ||||
|       self.convs = nn.ModuleList() | ||||
|       for i in range(2): | ||||
|         self.convs.append( nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) ) | ||||
|       self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) | ||||
|     elif stride == 4: | ||||
|       assert C_out % 4 == 0, 'C_out : {:}'.format(C_out) | ||||
|       self.convs = nn.ModuleList() | ||||
|       for i in range(4): | ||||
|         self.convs.append( nn.Conv2d(C_in, C_out // 4, 1, stride=stride, padding=0, bias=False) ) | ||||
|       self.pad = nn.ConstantPad2d((0, 3, 0, 3), 0) | ||||
|     else: | ||||
|       raise ValueError('Invalid stride : {:}'.format(stride)) | ||||
|      | ||||
|     self.bn = nn.BatchNorm2d(C_out, affine=affine) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     x = self.relu(x) | ||||
|     y = self.pad(x) | ||||
|     if self.stride == 2: | ||||
|       out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1) | ||||
|     else: | ||||
|       out = torch.cat([self.convs[0](x),            self.convs[1](y[:,:,1:-2,1:-2]), | ||||
|                        self.convs[2](y[:,:,2:-1,2:-1]), self.convs[3](y[:,:,3:,3:])], dim=1) | ||||
|     out = self.bn(out) | ||||
|     return out | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__) | ||||
							
								
								
									
										38
									
								
								AutoDL-Projects/xautodl/procedures/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								AutoDL-Projects/xautodl/procedures/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,38 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ###################################################################### | ||||
| # This folder is deprecated, which is re-organized in "xalgorithms". # | ||||
| ###################################################################### | ||||
| from .starts import prepare_seed | ||||
| from .starts import prepare_logger | ||||
| from .starts import get_machine_info | ||||
| from .starts import save_checkpoint | ||||
| from .starts import copy_checkpoint | ||||
| from .optimizers import get_optim_scheduler | ||||
| from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed | ||||
| from .funcs_nasbench import pure_evaluate as bench_pure_evaluate | ||||
| from .funcs_nasbench import get_nas_bench_loaders | ||||
|  | ||||
|  | ||||
| def get_procedures(procedure): | ||||
|     from .basic_main import basic_train, basic_valid | ||||
|     from .search_main import search_train, search_valid | ||||
|     from .search_main_v2 import search_train_v2 | ||||
|     from .simple_KD_main import simple_KD_train, simple_KD_valid | ||||
|  | ||||
|     train_funcs = { | ||||
|         "basic": basic_train, | ||||
|         "search": search_train, | ||||
|         "Simple-KD": simple_KD_train, | ||||
|         "search-v2": search_train_v2, | ||||
|     } | ||||
|     valid_funcs = { | ||||
|         "basic": basic_valid, | ||||
|         "search": search_valid, | ||||
|         "Simple-KD": simple_KD_valid, | ||||
|         "search-v2": search_valid, | ||||
|     } | ||||
|  | ||||
|     train_func = train_funcs[procedure] | ||||
|     valid_func = valid_funcs[procedure] | ||||
|     return train_func, valid_func | ||||
							
								
								
									
										99
									
								
								AutoDL-Projects/xautodl/procedures/advanced_main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								AutoDL-Projects/xautodl/procedures/advanced_main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,99 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 # | ||||
| ##################################################### | ||||
| # To be finished. | ||||
| # | ||||
| import os, sys, time, torch | ||||
| from typing import Optional, Text, Callable | ||||
|  | ||||
| # modules in AutoDL | ||||
| from xautodl.log_utils import AverageMeter, time_string | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def get_device(tensors): | ||||
|     if isinstance(tensors, (list, tuple)): | ||||
|         return get_device(tensors[0]) | ||||
|     elif isinstance(tensors, dict): | ||||
|         for key, value in tensors.items(): | ||||
|             return get_device(value) | ||||
|     else: | ||||
|         return tensors.device | ||||
|  | ||||
|  | ||||
| def basic_train_fn( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     optimizer, | ||||
|     metric, | ||||
|     logger, | ||||
| ): | ||||
|     results = procedure( | ||||
|         xloader, | ||||
|         network, | ||||
|         criterion, | ||||
|         optimizer, | ||||
|         metric, | ||||
|         "train", | ||||
|         logger, | ||||
|     ) | ||||
|     return results | ||||
|  | ||||
|  | ||||
| def basic_eval_fn(xloader, network, metric, logger): | ||||
|     with torch.no_grad(): | ||||
|         results = procedure( | ||||
|             xloader, | ||||
|             network, | ||||
|             None, | ||||
|             None, | ||||
|             metric, | ||||
|             "valid", | ||||
|             logger, | ||||
|         ) | ||||
|     return results | ||||
|  | ||||
|  | ||||
| def procedure( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     optimizer, | ||||
|     metric, | ||||
|     mode: Text, | ||||
|     logger_fn: Callable = None, | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     if mode.lower() == "train": | ||||
|         network.train() | ||||
|     elif mode.lower() == "valid": | ||||
|         network.eval() | ||||
|     else: | ||||
|         raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|  | ||||
|     end = time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|         # calculate prediction and loss | ||||
|  | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|  | ||||
|         outputs = network(inputs) | ||||
|         targets = targets.to(get_device(outputs)) | ||||
|  | ||||
|         if mode == "train": | ||||
|             loss = criterion(outputs, targets) | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         with torch.no_grad(): | ||||
|             results = metric(outputs, targets) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|     return metric.get_info() | ||||
							
								
								
									
										154
									
								
								AutoDL-Projects/xautodl/procedures/basic_main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								AutoDL-Projects/xautodl/procedures/basic_main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,154 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, time, torch | ||||
|  | ||||
| # modules in AutoDL | ||||
| from xautodl.log_utils import AverageMeter, time_string | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def basic_train( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     optimizer, | ||||
|     optim_config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     loss, acc1, acc5 = procedure( | ||||
|         xloader, | ||||
|         network, | ||||
|         criterion, | ||||
|         scheduler, | ||||
|         optimizer, | ||||
|         "train", | ||||
|         optim_config, | ||||
|         extra_info, | ||||
|         print_freq, | ||||
|         logger, | ||||
|     ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def basic_valid( | ||||
|     xloader, network, criterion, optim_config, extra_info, print_freq, logger | ||||
| ): | ||||
|     with torch.no_grad(): | ||||
|         loss, acc1, acc5 = procedure( | ||||
|             xloader, | ||||
|             network, | ||||
|             criterion, | ||||
|             None, | ||||
|             None, | ||||
|             "valid", | ||||
|             None, | ||||
|             extra_info, | ||||
|             print_freq, | ||||
|             logger, | ||||
|         ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def procedure( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     optimizer, | ||||
|     mode, | ||||
|     config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time, losses, top1, top5 = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|     ) | ||||
|     if mode == "train": | ||||
|         network.train() | ||||
|     elif mode == "valid": | ||||
|         network.eval() | ||||
|     else: | ||||
|         raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|  | ||||
|     # logger.log('[{:5s}] config ::  auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message())) | ||||
|     logger.log( | ||||
|         "[{:5s}] config ::  auxiliary={:}".format( | ||||
|             mode, config.auxiliary if hasattr(config, "auxiliary") else -1 | ||||
|         ) | ||||
|     ) | ||||
|     end = time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|         if mode == "train": | ||||
|             scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|         # calculate prediction and loss | ||||
|         targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|  | ||||
|         features, logits = network(inputs) | ||||
|         if isinstance(logits, list): | ||||
|             assert len(logits) == 2, "logits must has {:} items instead of {:}".format( | ||||
|                 2, len(logits) | ||||
|             ) | ||||
|             logits, logits_aux = logits | ||||
|         else: | ||||
|             logits, logits_aux = logits, None | ||||
|         loss = criterion(logits, targets) | ||||
|         if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0: | ||||
|             loss_aux = criterion(logits_aux, targets) | ||||
|             loss += config.auxiliary * loss_aux | ||||
|  | ||||
|         if mode == "train": | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|         losses.update(loss.item(), inputs.size(0)) | ||||
|         top1.update(prec1.item(), inputs.size(0)) | ||||
|         top5.update(prec5.item(), inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|  | ||||
|         if i % print_freq == 0 or (i + 1) == len(xloader): | ||||
|             Sstr = ( | ||||
|                 " {:5s} ".format(mode.upper()) | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) | ||||
|             ) | ||||
|             if scheduler is not None: | ||||
|                 Sstr += " {:}".format(scheduler.get_min_info()) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                 loss=losses, top1=top1, top5=top5 | ||||
|             ) | ||||
|             Istr = "Size={:}".format(list(inputs.size())) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr) | ||||
|  | ||||
|     logger.log( | ||||
|         " **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( | ||||
|             mode=mode.upper(), | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             loss=losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|     return losses.avg, top1.avg, top5.avg | ||||
							
								
								
									
										20
									
								
								AutoDL-Projects/xautodl/procedures/eval_funcs.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								AutoDL-Projects/xautodl/procedures/eval_funcs.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 # | ||||
| ##################################################### | ||||
| import abc | ||||
|  | ||||
|  | ||||
| def obtain_accuracy(output, target, topk=(1,)): | ||||
|     """Computes the precision@k for the specified values of k""" | ||||
|     maxk = max(topk) | ||||
|     batch_size = target.size(0) | ||||
|  | ||||
|     _, pred = output.topk(maxk, 1, True, True) | ||||
|     pred = pred.t() | ||||
|     correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||||
|  | ||||
|     res = [] | ||||
|     for k in topk: | ||||
|         correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) | ||||
|         res.append(correct_k.mul_(100.0 / batch_size)) | ||||
|     return res | ||||
							
								
								
									
										437
									
								
								AutoDL-Projects/xautodl/procedures/funcs_nasbench.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										437
									
								
								AutoDL-Projects/xautodl/procedures/funcs_nasbench.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,437 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||
| ##################################################### | ||||
| import os, time, copy, torch, pathlib | ||||
|  | ||||
| from xautodl import datasets | ||||
| from xautodl.config_utils import load_config | ||||
| from xautodl.procedures import prepare_seed, get_optim_scheduler | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import get_cell_based_tiny_net | ||||
| from xautodl.utils import get_model_infos | ||||
| from xautodl.procedures.eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| __all__ = ["evaluate_for_seed", "pure_evaluate", "get_nas_bench_loaders"] | ||||
|  | ||||
|  | ||||
| def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | ||||
|     data_time, batch_time, batch = AverageMeter(), AverageMeter(), None | ||||
|     losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     latencies, device = [], torch.cuda.current_device() | ||||
|     network.eval() | ||||
|     with torch.no_grad(): | ||||
|         end = time.time() | ||||
|         for i, (inputs, targets) in enumerate(xloader): | ||||
|             targets = targets.cuda(device=device, non_blocking=True) | ||||
|             inputs = inputs.cuda(device=device, non_blocking=True) | ||||
|             data_time.update(time.time() - end) | ||||
|             # forward | ||||
|             features, logits = network(inputs) | ||||
|             loss = criterion(logits, targets) | ||||
|             batch_time.update(time.time() - end) | ||||
|             if batch is None or batch == inputs.size(0): | ||||
|                 batch = inputs.size(0) | ||||
|                 latencies.append(batch_time.val - data_time.val) | ||||
|             # record loss and accuracy | ||||
|             prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|             losses.update(loss.item(), inputs.size(0)) | ||||
|             top1.update(prec1.item(), inputs.size(0)) | ||||
|             top5.update(prec5.item(), inputs.size(0)) | ||||
|             end = time.time() | ||||
|     if len(latencies) > 2: | ||||
|         latencies = latencies[1:] | ||||
|     return losses.avg, top1.avg, top5.avg, latencies | ||||
|  | ||||
|  | ||||
| def procedure(xloader, network, criterion, scheduler, optimizer, mode: str): | ||||
|     losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     if mode == "train": | ||||
|         network.train() | ||||
|     elif mode == "valid": | ||||
|         network.eval() | ||||
|     else: | ||||
|         raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|     device = torch.cuda.current_device() | ||||
|     data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|         if mode == "train": | ||||
|             scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|  | ||||
|         targets = targets.cuda(device=device, non_blocking=True) | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|         # forward | ||||
|         features, logits = network(inputs) | ||||
|         loss = criterion(logits, targets) | ||||
|         # backward | ||||
|         if mode == "train": | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|         # record loss and accuracy | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|         losses.update(loss.item(), inputs.size(0)) | ||||
|         top1.update(prec1.item(), inputs.size(0)) | ||||
|         top5.update(prec5.item(), inputs.size(0)) | ||||
|         # count time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|     return losses.avg, top1.avg, top5.avg, batch_time.sum | ||||
|  | ||||
|  | ||||
| def evaluate_for_seed( | ||||
|     arch_config, opt_config, train_loader, valid_loaders, seed: int, logger | ||||
| ): | ||||
|     """A modular function to train and evaluate a single network, using the given random seed and optimization config with the provided loaders.""" | ||||
|     prepare_seed(seed)  # random seed | ||||
|     net = get_cell_based_tiny_net(arch_config) | ||||
|     # net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) | ||||
|     flop, param = get_model_infos(net, opt_config.xshape) | ||||
|     logger.log("Network : {:}".format(net.get_message()), False) | ||||
|     logger.log( | ||||
|         "{:} Seed-------------------------- {:} --------------------------".format( | ||||
|             time_string(), seed | ||||
|         ) | ||||
|     ) | ||||
|     logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param)) | ||||
|     # train and valid | ||||
|     optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config) | ||||
|     default_device = torch.cuda.current_device() | ||||
|     network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda( | ||||
|         device=default_device | ||||
|     ) | ||||
|     criterion = criterion.cuda(device=default_device) | ||||
|     # start training | ||||
|     start_time, epoch_time, total_epoch = ( | ||||
|         time.time(), | ||||
|         AverageMeter(), | ||||
|         opt_config.epochs + opt_config.warmup, | ||||
|     ) | ||||
|     ( | ||||
|         train_losses, | ||||
|         train_acc1es, | ||||
|         train_acc5es, | ||||
|         valid_losses, | ||||
|         valid_acc1es, | ||||
|         valid_acc5es, | ||||
|     ) = ({}, {}, {}, {}, {}, {}) | ||||
|     train_times, valid_times, lrs = {}, {}, {} | ||||
|     for epoch in range(total_epoch): | ||||
|         scheduler.update(epoch, 0.0) | ||||
|         lr = min(scheduler.get_lr()) | ||||
|         train_loss, train_acc1, train_acc5, train_tm = procedure( | ||||
|             train_loader, network, criterion, scheduler, optimizer, "train" | ||||
|         ) | ||||
|         train_losses[epoch] = train_loss | ||||
|         train_acc1es[epoch] = train_acc1 | ||||
|         train_acc5es[epoch] = train_acc5 | ||||
|         train_times[epoch] = train_tm | ||||
|         lrs[epoch] = lr | ||||
|         with torch.no_grad(): | ||||
|             for key, xloder in valid_loaders.items(): | ||||
|                 valid_loss, valid_acc1, valid_acc5, valid_tm = procedure( | ||||
|                     xloder, network, criterion, None, None, "valid" | ||||
|                 ) | ||||
|                 valid_losses["{:}@{:}".format(key, epoch)] = valid_loss | ||||
|                 valid_acc1es["{:}@{:}".format(key, epoch)] = valid_acc1 | ||||
|                 valid_acc5es["{:}@{:}".format(key, epoch)] = valid_acc5 | ||||
|                 valid_times["{:}@{:}".format(key, epoch)] = valid_tm | ||||
|  | ||||
|         # measure elapsed time | ||||
|         epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}".format( | ||||
|                 time_string(), | ||||
|                 need_time, | ||||
|                 epoch, | ||||
|                 total_epoch, | ||||
|                 train_loss, | ||||
|                 train_acc1, | ||||
|                 train_acc5, | ||||
|                 valid_loss, | ||||
|                 valid_acc1, | ||||
|                 valid_acc5, | ||||
|                 lr, | ||||
|             ) | ||||
|         ) | ||||
|     info_seed = { | ||||
|         "flop": flop, | ||||
|         "param": param, | ||||
|         "arch_config": arch_config._asdict(), | ||||
|         "opt_config": opt_config._asdict(), | ||||
|         "total_epoch": total_epoch, | ||||
|         "train_losses": train_losses, | ||||
|         "train_acc1es": train_acc1es, | ||||
|         "train_acc5es": train_acc5es, | ||||
|         "train_times": train_times, | ||||
|         "valid_losses": valid_losses, | ||||
|         "valid_acc1es": valid_acc1es, | ||||
|         "valid_acc5es": valid_acc5es, | ||||
|         "valid_times": valid_times, | ||||
|         "learning_rates": lrs, | ||||
|         "net_state_dict": net.state_dict(), | ||||
|         "net_string": "{:}".format(net), | ||||
|         "finish-train": True, | ||||
|     } | ||||
|     return info_seed | ||||
|  | ||||
|  | ||||
| def get_nas_bench_loaders(workers): | ||||
|  | ||||
|     torch.set_num_threads(workers) | ||||
|  | ||||
|     root_dir = (pathlib.Path(__file__).parent / ".." / "..").resolve() | ||||
|     torch_dir = pathlib.Path(os.environ["TORCH_HOME"]) | ||||
|     # cifar | ||||
|     cifar_config_path = root_dir / "configs" / "nas-benchmark" / "CIFAR.config" | ||||
|     cifar_config = load_config(cifar_config_path, None, None) | ||||
|     get_datasets = datasets.get_datasets  # a function to return the dataset | ||||
|     break_line = "-" * 150 | ||||
|     print("{:} Create data-loader for all datasets".format(time_string())) | ||||
|     print(break_line) | ||||
|     TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets( | ||||
|         "cifar10", str(torch_dir / "cifar.python"), -1 | ||||
|     ) | ||||
|     print( | ||||
|         "original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes".format( | ||||
|             len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num | ||||
|         ) | ||||
|     ) | ||||
|     cifar10_splits = load_config( | ||||
|         root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None | ||||
|     ) | ||||
|     assert cifar10_splits.train[:10] == [ | ||||
|         0, | ||||
|         5, | ||||
|         7, | ||||
|         11, | ||||
|         13, | ||||
|         15, | ||||
|         16, | ||||
|         17, | ||||
|         20, | ||||
|         24, | ||||
|     ] and cifar10_splits.valid[:10] == [ | ||||
|         1, | ||||
|         2, | ||||
|         3, | ||||
|         4, | ||||
|         6, | ||||
|         8, | ||||
|         9, | ||||
|         10, | ||||
|         12, | ||||
|         14, | ||||
|     ] | ||||
|     temp_dataset = copy.deepcopy(TRAIN_CIFAR10) | ||||
|     temp_dataset.transform = VALID_CIFAR10.transform | ||||
|     # data loader | ||||
|     trainval_cifar10_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_CIFAR10, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         shuffle=True, | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     train_cifar10_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_CIFAR10, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     valid_cifar10_loader = torch.utils.data.DataLoader( | ||||
|         temp_dataset, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     test__cifar10_loader = torch.utils.data.DataLoader( | ||||
|         VALID_CIFAR10, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         shuffle=False, | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-10  : trval-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(trainval_cifar10_loader), cifar_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-10  : train-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(train_cifar10_loader), cifar_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-10  : valid-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(valid_cifar10_loader), cifar_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-10  : test--loader has {:3d} batch with {:} per batch".format( | ||||
|             len(test__cifar10_loader), cifar_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print(break_line) | ||||
|     # CIFAR-100 | ||||
|     TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets( | ||||
|         "cifar100", str(torch_dir / "cifar.python"), -1 | ||||
|     ) | ||||
|     print( | ||||
|         "original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes".format( | ||||
|             len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num | ||||
|         ) | ||||
|     ) | ||||
|     cifar100_splits = load_config( | ||||
|         root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None | ||||
|     ) | ||||
|     assert cifar100_splits.xvalid[:10] == [ | ||||
|         1, | ||||
|         3, | ||||
|         4, | ||||
|         5, | ||||
|         8, | ||||
|         10, | ||||
|         13, | ||||
|         14, | ||||
|         15, | ||||
|         16, | ||||
|     ] and cifar100_splits.xtest[:10] == [ | ||||
|         0, | ||||
|         2, | ||||
|         6, | ||||
|         7, | ||||
|         9, | ||||
|         11, | ||||
|         12, | ||||
|         17, | ||||
|         20, | ||||
|         24, | ||||
|     ] | ||||
|     train_cifar100_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_CIFAR100, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         shuffle=True, | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     valid_cifar100_loader = torch.utils.data.DataLoader( | ||||
|         VALID_CIFAR100, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     test__cifar100_loader = torch.utils.data.DataLoader( | ||||
|         VALID_CIFAR100, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-100  : train-loader has {:3d} batch".format(len(train_cifar100_loader)) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-100  : valid-loader has {:3d} batch".format(len(valid_cifar100_loader)) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-100  : test--loader has {:3d} batch".format(len(test__cifar100_loader)) | ||||
|     ) | ||||
|     print(break_line) | ||||
|  | ||||
|     imagenet16_config_path = "configs/nas-benchmark/ImageNet-16.config" | ||||
|     imagenet16_config = load_config(imagenet16_config_path, None, None) | ||||
|     TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets( | ||||
|         "ImageNet16-120", str(torch_dir / "cifar.python" / "ImageNet16"), -1 | ||||
|     ) | ||||
|     print( | ||||
|         "original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes".format( | ||||
|             len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num | ||||
|         ) | ||||
|     ) | ||||
|     imagenet_splits = load_config( | ||||
|         root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt", | ||||
|         None, | ||||
|         None, | ||||
|     ) | ||||
|     assert imagenet_splits.xvalid[:10] == [ | ||||
|         1, | ||||
|         2, | ||||
|         3, | ||||
|         6, | ||||
|         7, | ||||
|         8, | ||||
|         9, | ||||
|         12, | ||||
|         16, | ||||
|         18, | ||||
|     ] and imagenet_splits.xtest[:10] == [ | ||||
|         0, | ||||
|         4, | ||||
|         5, | ||||
|         10, | ||||
|         11, | ||||
|         13, | ||||
|         14, | ||||
|         15, | ||||
|         17, | ||||
|         20, | ||||
|     ] | ||||
|     train_imagenet_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_ImageNet16_120, | ||||
|         batch_size=imagenet16_config.batch_size, | ||||
|         shuffle=True, | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     valid_imagenet_loader = torch.utils.data.DataLoader( | ||||
|         VALID_ImageNet16_120, | ||||
|         batch_size=imagenet16_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     test__imagenet_loader = torch.utils.data.DataLoader( | ||||
|         VALID_ImageNet16_120, | ||||
|         batch_size=imagenet16_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     print( | ||||
|         "ImageNet-16-120  : train-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(train_imagenet_loader), imagenet16_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "ImageNet-16-120  : valid-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(valid_imagenet_loader), imagenet16_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "ImageNet-16-120  : test--loader has {:3d} batch with {:} per batch".format( | ||||
|             len(test__imagenet_loader), imagenet16_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     # 'cifar10', 'cifar100', 'ImageNet16-120' | ||||
|     loaders = { | ||||
|         "cifar10@trainval": trainval_cifar10_loader, | ||||
|         "cifar10@train": train_cifar10_loader, | ||||
|         "cifar10@valid": valid_cifar10_loader, | ||||
|         "cifar10@test": test__cifar10_loader, | ||||
|         "cifar100@train": train_cifar100_loader, | ||||
|         "cifar100@valid": valid_cifar100_loader, | ||||
|         "cifar100@test": test__cifar100_loader, | ||||
|         "ImageNet16-120@train": train_imagenet_loader, | ||||
|         "ImageNet16-120@valid": valid_imagenet_loader, | ||||
|         "ImageNet16-120@test": test__imagenet_loader, | ||||
|     } | ||||
|     return loaders | ||||
							
								
								
									
										166
									
								
								AutoDL-Projects/xautodl/procedures/metric_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										166
									
								
								AutoDL-Projects/xautodl/procedures/metric_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,166 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 # | ||||
| ##################################################### | ||||
| import abc | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
|  | ||||
| class AverageMeter(object): | ||||
|     """Computes and stores the average and current value""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.reset() | ||||
|  | ||||
|     def reset(self): | ||||
|         self.val = 0.0 | ||||
|         self.avg = 0.0 | ||||
|         self.sum = 0.0 | ||||
|         self.count = 0.0 | ||||
|  | ||||
|     def update(self, val, n=1): | ||||
|         self.val = val | ||||
|         self.sum += val * n | ||||
|         self.count += n | ||||
|         self.avg = self.sum / self.count | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(val={val}, avg={avg}, count={count})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Metric(abc.ABC): | ||||
|     """The default meta metric class.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.reset() | ||||
|  | ||||
|     def reset(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({inner})".format( | ||||
|             name=self.__class__.__name__, inner=self.inner_repr() | ||||
|         ) | ||||
|  | ||||
|     def inner_repr(self): | ||||
|         return "" | ||||
|  | ||||
|  | ||||
| class ComposeMetric(Metric): | ||||
|     """The composed metric class.""" | ||||
|  | ||||
|     def __init__(self, *metric_list): | ||||
|         self.reset() | ||||
|         for metric in metric_list: | ||||
|             self.append(metric) | ||||
|  | ||||
|     def reset(self): | ||||
|         self._metric_list = [] | ||||
|  | ||||
|     def append(self, metric): | ||||
|         if not isinstance(metric, Metric): | ||||
|             raise ValueError( | ||||
|                 "The input metric is not correct: {:}".format(type(metric)) | ||||
|             ) | ||||
|         self._metric_list.append(metric) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._metric_list) | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         results = list() | ||||
|         for metric in self._metric_list: | ||||
|             results.append(metric(predictions, targets)) | ||||
|         return results | ||||
|  | ||||
|     def get_info(self): | ||||
|         results = dict() | ||||
|         for metric in self._metric_list: | ||||
|             for key, value in metric.get_info().items(): | ||||
|                 results[key] = value | ||||
|         return results | ||||
|  | ||||
|     def inner_repr(self): | ||||
|         xlist = [] | ||||
|         for metric in self._metric_list: | ||||
|             xlist.append(str(metric)) | ||||
|         return ",".join(xlist) | ||||
|  | ||||
|  | ||||
| class MSEMetric(Metric): | ||||
|     """The metric for mse.""" | ||||
|  | ||||
|     def __init__(self, ignore_batch): | ||||
|         super(MSEMetric, self).__init__() | ||||
|         self._ignore_batch = ignore_batch | ||||
|  | ||||
|     def reset(self): | ||||
|         self._mse = AverageMeter() | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): | ||||
|             loss = torch.nn.functional.mse_loss(predictions.data, targets.data).item() | ||||
|             if self._ignore_batch: | ||||
|                 self._mse.update(loss, 1) | ||||
|             else: | ||||
|                 self._mse.update(loss, predictions.shape[0]) | ||||
|             return loss | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         return {"mse": self._mse.avg, "score": self._mse.avg} | ||||
|  | ||||
|  | ||||
| class Top1AccMetric(Metric): | ||||
|     """The metric for the top-1 accuracy.""" | ||||
|  | ||||
|     def __init__(self, ignore_batch): | ||||
|         super(Top1AccMetric, self).__init__() | ||||
|         self._ignore_batch = ignore_batch | ||||
|  | ||||
|     def reset(self): | ||||
|         self._accuracy = AverageMeter() | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): | ||||
|             max_prob_indexes = torch.argmax(predictions, dim=-1) | ||||
|             corrects = torch.eq(max_prob_indexes, targets) | ||||
|             accuracy = corrects.float().mean().float() | ||||
|             if self._ignore_batch: | ||||
|                 self._accuracy.update(accuracy, 1) | ||||
|             else:  # [TODO] for 3-d tensor | ||||
|                 self._accuracy.update(accuracy, predictions.shape[0]) | ||||
|             return accuracy | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         return {"accuracy": self._accuracy.avg, "score": self._accuracy.avg * 100} | ||||
|  | ||||
|  | ||||
| class SaveMetric(Metric): | ||||
|     """The metric for mse.""" | ||||
|  | ||||
|     def reset(self): | ||||
|         self._predicts = [] | ||||
|  | ||||
|     def __call__(self, predictions, targets=None): | ||||
|         if isinstance(predictions, torch.Tensor): | ||||
|             predicts = predictions.cpu().numpy() | ||||
|             self._predicts.append(predicts) | ||||
|             return predicts | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         all_predicts = np.concatenate(self._predicts) | ||||
|         return {"predictions": all_predicts} | ||||
							
								
								
									
										263
									
								
								AutoDL-Projects/xautodl/procedures/optimizers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										263
									
								
								AutoDL-Projects/xautodl/procedures/optimizers.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,263 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import math, torch | ||||
| import torch.nn as nn | ||||
| from bisect import bisect_right | ||||
| from torch.optim import Optimizer | ||||
|  | ||||
|  | ||||
| class _LRScheduler(object): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs): | ||||
|         if not isinstance(optimizer, Optimizer): | ||||
|             raise TypeError("{:} is not an Optimizer".format(type(optimizer).__name__)) | ||||
|         self.optimizer = optimizer | ||||
|         for group in optimizer.param_groups: | ||||
|             group.setdefault("initial_lr", group["lr"]) | ||||
|         self.base_lrs = list( | ||||
|             map(lambda group: group["initial_lr"], optimizer.param_groups) | ||||
|         ) | ||||
|         self.max_epochs = epochs | ||||
|         self.warmup_epochs = warmup_epochs | ||||
|         self.current_epoch = 0 | ||||
|         self.current_iter = 0 | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "" | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) + ", {:})".format( | ||||
|             self.extra_repr() | ||||
|         ) | ||||
|  | ||||
|     def state_dict(self): | ||||
|         return { | ||||
|             key: value for key, value in self.__dict__.items() if key != "optimizer" | ||||
|         } | ||||
|  | ||||
|     def load_state_dict(self, state_dict): | ||||
|         self.__dict__.update(state_dict) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def get_min_info(self): | ||||
|         lrs = self.get_lr() | ||||
|         return "#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#".format( | ||||
|             min(lrs), max(lrs), self.current_epoch, self.current_iter | ||||
|         ) | ||||
|  | ||||
|     def get_min_lr(self): | ||||
|         return min(self.get_lr()) | ||||
|  | ||||
|     def update(self, cur_epoch, cur_iter): | ||||
|         if cur_epoch is not None: | ||||
|             assert ( | ||||
|                 isinstance(cur_epoch, int) and cur_epoch >= 0 | ||||
|             ), "invalid cur-epoch : {:}".format(cur_epoch) | ||||
|             self.current_epoch = cur_epoch | ||||
|         if cur_iter is not None: | ||||
|             assert ( | ||||
|                 isinstance(cur_iter, float) and cur_iter >= 0 | ||||
|             ), "invalid cur-iter : {:}".format(cur_iter) | ||||
|             self.current_iter = cur_iter | ||||
|         for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): | ||||
|             param_group["lr"] = lr | ||||
|  | ||||
|  | ||||
| class CosineAnnealingLR(_LRScheduler): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min): | ||||
|         self.T_max = T_max | ||||
|         self.eta_min = eta_min | ||||
|         super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "type={:}, T-max={:}, eta-min={:}".format( | ||||
|             "cosine", self.T_max, self.eta_min | ||||
|         ) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         lrs = [] | ||||
|         for base_lr in self.base_lrs: | ||||
|             if ( | ||||
|                 self.current_epoch >= self.warmup_epochs | ||||
|                 and self.current_epoch < self.max_epochs | ||||
|             ): | ||||
|                 last_epoch = self.current_epoch - self.warmup_epochs | ||||
|                 # if last_epoch < self.T_max: | ||||
|                 # if last_epoch < self.max_epochs: | ||||
|                 lr = ( | ||||
|                     self.eta_min | ||||
|                     + (base_lr - self.eta_min) | ||||
|                     * (1 + math.cos(math.pi * last_epoch / self.T_max)) | ||||
|                     / 2 | ||||
|                 ) | ||||
|                 # else: | ||||
|                 #  lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2 | ||||
|             elif self.current_epoch >= self.max_epochs: | ||||
|                 lr = self.eta_min | ||||
|             else: | ||||
|                 lr = ( | ||||
|                     self.current_epoch / self.warmup_epochs | ||||
|                     + self.current_iter / self.warmup_epochs | ||||
|                 ) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
|  | ||||
| class MultiStepLR(_LRScheduler): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas): | ||||
|         assert len(milestones) == len(gammas), "invalid {:} vs {:}".format( | ||||
|             len(milestones), len(gammas) | ||||
|         ) | ||||
|         self.milestones = milestones | ||||
|         self.gammas = gammas | ||||
|         super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "type={:}, milestones={:}, gammas={:}, base-lrs={:}".format( | ||||
|             "multistep", self.milestones, self.gammas, self.base_lrs | ||||
|         ) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         lrs = [] | ||||
|         for base_lr in self.base_lrs: | ||||
|             if self.current_epoch >= self.warmup_epochs: | ||||
|                 last_epoch = self.current_epoch - self.warmup_epochs | ||||
|                 idx = bisect_right(self.milestones, last_epoch) | ||||
|                 lr = base_lr | ||||
|                 for x in self.gammas[:idx]: | ||||
|                     lr *= x | ||||
|             else: | ||||
|                 lr = ( | ||||
|                     self.current_epoch / self.warmup_epochs | ||||
|                     + self.current_iter / self.warmup_epochs | ||||
|                 ) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
|  | ||||
| class ExponentialLR(_LRScheduler): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs, gamma): | ||||
|         self.gamma = gamma | ||||
|         super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "type={:}, gamma={:}, base-lrs={:}".format( | ||||
|             "exponential", self.gamma, self.base_lrs | ||||
|         ) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         lrs = [] | ||||
|         for base_lr in self.base_lrs: | ||||
|             if self.current_epoch >= self.warmup_epochs: | ||||
|                 last_epoch = self.current_epoch - self.warmup_epochs | ||||
|                 assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch) | ||||
|                 lr = base_lr * (self.gamma**last_epoch) | ||||
|             else: | ||||
|                 lr = ( | ||||
|                     self.current_epoch / self.warmup_epochs | ||||
|                     + self.current_iter / self.warmup_epochs | ||||
|                 ) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
|  | ||||
| class LinearLR(_LRScheduler): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR): | ||||
|         self.max_LR = max_LR | ||||
|         self.min_LR = min_LR | ||||
|         super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "type={:}, max_LR={:}, min_LR={:}, base-lrs={:}".format( | ||||
|             "LinearLR", self.max_LR, self.min_LR, self.base_lrs | ||||
|         ) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         lrs = [] | ||||
|         for base_lr in self.base_lrs: | ||||
|             if self.current_epoch >= self.warmup_epochs: | ||||
|                 last_epoch = self.current_epoch - self.warmup_epochs | ||||
|                 assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch) | ||||
|                 ratio = ( | ||||
|                     (self.max_LR - self.min_LR) | ||||
|                     * last_epoch | ||||
|                     / self.max_epochs | ||||
|                     / self.max_LR | ||||
|                 ) | ||||
|                 lr = base_lr * (1 - ratio) | ||||
|             else: | ||||
|                 lr = ( | ||||
|                     self.current_epoch / self.warmup_epochs | ||||
|                     + self.current_iter / self.warmup_epochs | ||||
|                 ) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
|  | ||||
| class CrossEntropyLabelSmooth(nn.Module): | ||||
|     def __init__(self, num_classes, epsilon): | ||||
|         super(CrossEntropyLabelSmooth, self).__init__() | ||||
|         self.num_classes = num_classes | ||||
|         self.epsilon = epsilon | ||||
|         self.logsoftmax = nn.LogSoftmax(dim=1) | ||||
|  | ||||
|     def forward(self, inputs, targets): | ||||
|         log_probs = self.logsoftmax(inputs) | ||||
|         targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) | ||||
|         targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes | ||||
|         loss = (-targets * log_probs).mean(0).sum() | ||||
|         return loss | ||||
|  | ||||
|  | ||||
| def get_optim_scheduler(parameters, config): | ||||
|     assert ( | ||||
|         hasattr(config, "optim") | ||||
|         and hasattr(config, "scheduler") | ||||
|         and hasattr(config, "criterion") | ||||
|     ), "config must have optim / scheduler / criterion keys instead of {:}".format( | ||||
|         config | ||||
|     ) | ||||
|     if config.optim == "SGD": | ||||
|         optim = torch.optim.SGD( | ||||
|             parameters, | ||||
|             config.LR, | ||||
|             momentum=config.momentum, | ||||
|             weight_decay=config.decay, | ||||
|             nesterov=config.nesterov, | ||||
|         ) | ||||
|     elif config.optim == "RMSprop": | ||||
|         optim = torch.optim.RMSprop( | ||||
|             parameters, config.LR, momentum=config.momentum, weight_decay=config.decay | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("invalid optim : {:}".format(config.optim)) | ||||
|  | ||||
|     if config.scheduler == "cos": | ||||
|         T_max = getattr(config, "T_max", config.epochs) | ||||
|         scheduler = CosineAnnealingLR( | ||||
|             optim, config.warmup, config.epochs, T_max, config.eta_min | ||||
|         ) | ||||
|     elif config.scheduler == "multistep": | ||||
|         scheduler = MultiStepLR( | ||||
|             optim, config.warmup, config.epochs, config.milestones, config.gammas | ||||
|         ) | ||||
|     elif config.scheduler == "exponential": | ||||
|         scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma) | ||||
|     elif config.scheduler == "linear": | ||||
|         scheduler = LinearLR( | ||||
|             optim, config.warmup, config.epochs, config.LR, config.LR_min | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("invalid scheduler : {:}".format(config.scheduler)) | ||||
|  | ||||
|     if config.criterion == "Softmax": | ||||
|         criterion = torch.nn.CrossEntropyLoss() | ||||
|     elif config.criterion == "SmoothSoftmax": | ||||
|         criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth) | ||||
|     else: | ||||
|         raise ValueError("invalid criterion : {:}".format(config.criterion)) | ||||
|     return optim, scheduler, criterion | ||||
							
								
								
									
										150
									
								
								AutoDL-Projects/xautodl/procedures/q_exps.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										150
									
								
								AutoDL-Projects/xautodl/procedures/q_exps.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,150 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||
| ##################################################### | ||||
|  | ||||
| import inspect | ||||
| import os | ||||
| import pprint | ||||
| import logging | ||||
| from copy import deepcopy | ||||
|  | ||||
| import qlib | ||||
| from qlib.utils import init_instance_by_config | ||||
| from qlib.workflow import R | ||||
| from qlib.utils import flatten_dict | ||||
| from qlib.log import get_module_logger | ||||
|  | ||||
|  | ||||
| def set_log_basic_config(filename=None, format=None, level=None): | ||||
|     """ | ||||
|     Set the basic configuration for the logging system. | ||||
|     See details at https://docs.python.org/3/library/logging.html#logging.basicConfig | ||||
|     :param filename: str or None | ||||
|         The path to save the logs. | ||||
|     :param format: the logging format | ||||
|     :param level: int | ||||
|     :return: Logger | ||||
|         Logger object. | ||||
|     """ | ||||
|     from qlib.config import C | ||||
|  | ||||
|     if level is None: | ||||
|         level = C.logging_level | ||||
|  | ||||
|     if format is None: | ||||
|         format = C.logging_config["formatters"]["logger_format"]["format"] | ||||
|  | ||||
|     # Remove all handlers associated with the root logger object. | ||||
|     for handler in logging.root.handlers[:]: | ||||
|         logging.root.removeHandler(handler) | ||||
|     logging.basicConfig(filename=filename, format=format, level=level) | ||||
|  | ||||
|  | ||||
| def update_gpu(config, gpu): | ||||
|     config = deepcopy(config) | ||||
|     if "task" in config and "model" in config["task"]: | ||||
|         if "GPU" in config["task"]["model"]: | ||||
|             config["task"]["model"]["GPU"] = gpu | ||||
|         elif ( | ||||
|             "kwargs" in config["task"]["model"] | ||||
|             and "GPU" in config["task"]["model"]["kwargs"] | ||||
|         ): | ||||
|             config["task"]["model"]["kwargs"]["GPU"] = gpu | ||||
|     elif "model" in config: | ||||
|         if "GPU" in config["model"]: | ||||
|             config["model"]["GPU"] = gpu | ||||
|         elif "kwargs" in config["model"] and "GPU" in config["model"]["kwargs"]: | ||||
|             config["model"]["kwargs"]["GPU"] = gpu | ||||
|     elif "kwargs" in config and "GPU" in config["kwargs"]: | ||||
|         config["kwargs"]["GPU"] = gpu | ||||
|     elif "GPU" in config: | ||||
|         config["GPU"] = gpu | ||||
|     return config | ||||
|  | ||||
|  | ||||
| def update_market(config, market): | ||||
|     config = deepcopy(config.copy()) | ||||
|     config["market"] = market | ||||
|     config["data_handler_config"]["instruments"] = market | ||||
|     return config | ||||
|  | ||||
|  | ||||
| def run_exp( | ||||
|     task_config, | ||||
|     dataset, | ||||
|     experiment_name, | ||||
|     recorder_name, | ||||
|     uri, | ||||
|     model_obj_name="model.pkl", | ||||
| ): | ||||
|  | ||||
|     model = init_instance_by_config(task_config["model"]) | ||||
|     model_fit_kwargs = dict(dataset=dataset) | ||||
|  | ||||
|     # Let's start the experiment. | ||||
|     with R.start( | ||||
|         experiment_name=experiment_name, | ||||
|         recorder_name=recorder_name, | ||||
|         uri=uri, | ||||
|         resume=True, | ||||
|     ): | ||||
|         # Setup log | ||||
|         recorder_root_dir = R.get_recorder().get_local_dir() | ||||
|         log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name)) | ||||
|  | ||||
|         set_log_basic_config(log_file) | ||||
|         logger = get_module_logger("q.run_exp") | ||||
|         logger.info("task_config::\n{:}".format(pprint.pformat(task_config, indent=2))) | ||||
|         logger.info("[{:}] - [{:}]: {:}".format(experiment_name, recorder_name, uri)) | ||||
|         logger.info("dataset={:}".format(dataset)) | ||||
|  | ||||
|         # Train model | ||||
|         try: | ||||
|             if hasattr(model, "to"):  # Recoverable model | ||||
|                 ori_device = model.device | ||||
|                 model = R.load_object(model_obj_name) | ||||
|                 model.to(ori_device) | ||||
|             else: | ||||
|                 model = R.load_object(model_obj_name) | ||||
|             logger.info("[Find existing object from {:}]".format(model_obj_name)) | ||||
|         except OSError: | ||||
|             R.log_params(**flatten_dict(update_gpu(task_config, None))) | ||||
|             if "save_path" in inspect.getfullargspec(model.fit).args: | ||||
|                 model_fit_kwargs["save_path"] = os.path.join( | ||||
|                     recorder_root_dir, "model.ckp" | ||||
|                 ) | ||||
|             elif "save_dir" in inspect.getfullargspec(model.fit).args: | ||||
|                 model_fit_kwargs["save_dir"] = os.path.join( | ||||
|                     recorder_root_dir, "model-ckps" | ||||
|                 ) | ||||
|             model.fit(**model_fit_kwargs) | ||||
|             # remove model to CPU for saving | ||||
|             if hasattr(model, "to"): | ||||
|                 old_device = model.device | ||||
|                 model.to("cpu") | ||||
|                 R.save_objects(**{model_obj_name: model}) | ||||
|                 model.to(old_device) | ||||
|             else: | ||||
|                 R.save_objects(**{model_obj_name: model}) | ||||
|         except Exception as e: | ||||
|             raise ValueError("Something wrong: {:}".format(e)) | ||||
|         # Get the recorder | ||||
|         recorder = R.get_recorder() | ||||
|  | ||||
|         # Generate records: prediction, backtest, and analysis | ||||
|         for record in task_config["record"]: | ||||
|             record = deepcopy(record) | ||||
|             if record["class"] == "MultiSegRecord": | ||||
|                 record["kwargs"] = dict(model=model, dataset=dataset, recorder=recorder) | ||||
|                 sr = init_instance_by_config(record) | ||||
|                 sr.generate(**record["generate_kwargs"]) | ||||
|             elif record["class"] == "SignalRecord": | ||||
|                 srconf = {"model": model, "dataset": dataset, "recorder": recorder} | ||||
|                 record["kwargs"].update(srconf) | ||||
|                 sr = init_instance_by_config(record) | ||||
|                 sr.generate() | ||||
|             else: | ||||
|                 rconf = {"recorder": recorder} | ||||
|                 record["kwargs"].update(rconf) | ||||
|                 ar = init_instance_by_config(record) | ||||
|                 ar.generate() | ||||
							
								
								
									
										199
									
								
								AutoDL-Projects/xautodl/procedures/search_main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										199
									
								
								AutoDL-Projects/xautodl/procedures/search_main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,199 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, time, torch | ||||
|  | ||||
| from xautodl.log_utils import AverageMeter, time_string | ||||
| from xautodl.models import change_key | ||||
|  | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant): | ||||
|     expected_flop = torch.mean(expected_flop) | ||||
|  | ||||
|     if flop_cur < flop_need - flop_tolerant:  # Too Small FLOP | ||||
|         loss = -torch.log(expected_flop) | ||||
|     # elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP | ||||
|     elif flop_cur > flop_need:  # Too Large FLOP | ||||
|         loss = torch.log(expected_flop) | ||||
|     else:  # Required FLOP | ||||
|         loss = None | ||||
|     if loss is None: | ||||
|         return 0, 0 | ||||
|     else: | ||||
|         return loss, loss.item() | ||||
|  | ||||
|  | ||||
| def search_train( | ||||
|     search_loader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     base_optimizer, | ||||
|     arch_optimizer, | ||||
|     optim_config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     base_losses, arch_losses, top1, top5 = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|     ) | ||||
|     arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter() | ||||
|     epoch_str, flop_need, flop_weight, flop_tolerant = ( | ||||
|         extra_info["epoch-str"], | ||||
|         extra_info["FLOP-exp"], | ||||
|         extra_info["FLOP-weight"], | ||||
|         extra_info["FLOP-tolerant"], | ||||
|     ) | ||||
|  | ||||
|     network.train() | ||||
|     logger.log( | ||||
|         "[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format( | ||||
|             epoch_str, flop_need, flop_weight | ||||
|         ) | ||||
|     ) | ||||
|     end = time.time() | ||||
|     network.apply(change_key("search_mode", "search")) | ||||
|     for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( | ||||
|         search_loader | ||||
|     ): | ||||
|         scheduler.update(None, 1.0 * step / len(search_loader)) | ||||
|         # calculate prediction and loss | ||||
|         base_targets = base_targets.cuda(non_blocking=True) | ||||
|         arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|  | ||||
|         # update the weights | ||||
|         base_optimizer.zero_grad() | ||||
|         logits, expected_flop = network(base_inputs) | ||||
|         # network.apply( change_key('search_mode', 'basic') ) | ||||
|         # features, logits = network(base_inputs) | ||||
|         base_loss = criterion(logits, base_targets) | ||||
|         base_loss.backward() | ||||
|         base_optimizer.step() | ||||
|         # record | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) | ||||
|         base_losses.update(base_loss.item(), base_inputs.size(0)) | ||||
|         top1.update(prec1.item(), base_inputs.size(0)) | ||||
|         top5.update(prec5.item(), base_inputs.size(0)) | ||||
|  | ||||
|         # update the architecture | ||||
|         arch_optimizer.zero_grad() | ||||
|         logits, expected_flop = network(arch_inputs) | ||||
|         flop_cur = network.module.get_flop("genotype", None, None) | ||||
|         flop_loss, flop_loss_scale = get_flop_loss( | ||||
|             expected_flop, flop_cur, flop_need, flop_tolerant | ||||
|         ) | ||||
|         acls_loss = criterion(logits, arch_targets) | ||||
|         arch_loss = acls_loss + flop_loss * flop_weight | ||||
|         arch_loss.backward() | ||||
|         arch_optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|         arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0)) | ||||
|         arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|         if step % print_freq == 0 or (step + 1) == len(search_loader): | ||||
|             Sstr = ( | ||||
|                 "**TRAIN** " | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader)) | ||||
|             ) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Lstr = "Base-Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                 loss=base_losses, top1=top1, top5=top5 | ||||
|             ) | ||||
|             Vstr = "Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})".format( | ||||
|                 aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses | ||||
|             ) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Vstr) | ||||
|             # Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size())) | ||||
|             # logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr) | ||||
|             # print(network.module.get_arch_info()) | ||||
|             # print(network.module.width_attentions[0]) | ||||
|             # print(network.module.width_attentions[1]) | ||||
|  | ||||
|     logger.log( | ||||
|         " **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}".format( | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             baseloss=base_losses.avg, | ||||
|             archloss=arch_losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|     return base_losses.avg, arch_losses.avg, top1.avg, top5.avg | ||||
|  | ||||
|  | ||||
| def search_valid(xloader, network, criterion, extra_info, print_freq, logger): | ||||
|     data_time, batch_time, losses, top1, top5 = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|     ) | ||||
|  | ||||
|     network.eval() | ||||
|     network.apply(change_key("search_mode", "search")) | ||||
|     end = time.time() | ||||
|     # logger.log('Starting evaluating {:}'.format(epoch_info)) | ||||
|     with torch.no_grad(): | ||||
|         for i, (inputs, targets) in enumerate(xloader): | ||||
|             # measure data loading time | ||||
|             data_time.update(time.time() - end) | ||||
|             # calculate prediction and loss | ||||
|             targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|             logits, expected_flop = network(inputs) | ||||
|             loss = criterion(logits, targets) | ||||
|             # record | ||||
|             prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|             losses.update(loss.item(), inputs.size(0)) | ||||
|             top1.update(prec1.item(), inputs.size(0)) | ||||
|             top5.update(prec5.item(), inputs.size(0)) | ||||
|  | ||||
|             # measure elapsed time | ||||
|             batch_time.update(time.time() - end) | ||||
|             end = time.time() | ||||
|  | ||||
|             if i % print_freq == 0 or (i + 1) == len(xloader): | ||||
|                 Sstr = ( | ||||
|                     "**VALID** " | ||||
|                     + time_string() | ||||
|                     + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) | ||||
|                 ) | ||||
|                 Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                     batch_time=batch_time, data_time=data_time | ||||
|                 ) | ||||
|                 Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                     loss=losses, top1=top1, top5=top5 | ||||
|                 ) | ||||
|                 Istr = "Size={:}".format(list(inputs.size())) | ||||
|                 logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr) | ||||
|  | ||||
|     logger.log( | ||||
|         " **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             loss=losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     return losses.avg, top1.avg, top5.avg | ||||
							
								
								
									
										139
									
								
								AutoDL-Projects/xautodl/procedures/search_main_v2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								AutoDL-Projects/xautodl/procedures/search_main_v2.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,139 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, time, torch | ||||
|  | ||||
| # modules in AutoDL | ||||
| from xautodl.log_utils import AverageMeter, time_string | ||||
| from xautodl.models import change_key | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant): | ||||
|     expected_flop = torch.mean(expected_flop) | ||||
|  | ||||
|     if flop_cur < flop_need - flop_tolerant:  # Too Small FLOP | ||||
|         loss = -torch.log(expected_flop) | ||||
|     # elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP | ||||
|     elif flop_cur > flop_need:  # Too Large FLOP | ||||
|         loss = torch.log(expected_flop) | ||||
|     else:  # Required FLOP | ||||
|         loss = None | ||||
|     if loss is None: | ||||
|         return 0, 0 | ||||
|     else: | ||||
|         return loss, loss.item() | ||||
|  | ||||
|  | ||||
| def search_train_v2( | ||||
|     search_loader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     base_optimizer, | ||||
|     arch_optimizer, | ||||
|     optim_config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     base_losses, arch_losses, top1, top5 = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|     ) | ||||
|     arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter() | ||||
|     epoch_str, flop_need, flop_weight, flop_tolerant = ( | ||||
|         extra_info["epoch-str"], | ||||
|         extra_info["FLOP-exp"], | ||||
|         extra_info["FLOP-weight"], | ||||
|         extra_info["FLOP-tolerant"], | ||||
|     ) | ||||
|  | ||||
|     network.train() | ||||
|     logger.log( | ||||
|         "[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format( | ||||
|             epoch_str, flop_need, flop_weight | ||||
|         ) | ||||
|     ) | ||||
|     end = time.time() | ||||
|     network.apply(change_key("search_mode", "search")) | ||||
|     for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( | ||||
|         search_loader | ||||
|     ): | ||||
|         scheduler.update(None, 1.0 * step / len(search_loader)) | ||||
|         # calculate prediction and loss | ||||
|         base_targets = base_targets.cuda(non_blocking=True) | ||||
|         arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|  | ||||
|         # update the weights | ||||
|         base_optimizer.zero_grad() | ||||
|         logits, expected_flop = network(base_inputs) | ||||
|         base_loss = criterion(logits, base_targets) | ||||
|         base_loss.backward() | ||||
|         base_optimizer.step() | ||||
|         # record | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) | ||||
|         base_losses.update(base_loss.item(), base_inputs.size(0)) | ||||
|         top1.update(prec1.item(), base_inputs.size(0)) | ||||
|         top5.update(prec5.item(), base_inputs.size(0)) | ||||
|  | ||||
|         # update the architecture | ||||
|         arch_optimizer.zero_grad() | ||||
|         logits, expected_flop = network(arch_inputs) | ||||
|         flop_cur = network.module.get_flop("genotype", None, None) | ||||
|         flop_loss, flop_loss_scale = get_flop_loss( | ||||
|             expected_flop, flop_cur, flop_need, flop_tolerant | ||||
|         ) | ||||
|         acls_loss = criterion(logits, arch_targets) | ||||
|         arch_loss = acls_loss + flop_loss * flop_weight | ||||
|         arch_loss.backward() | ||||
|         arch_optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|         arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0)) | ||||
|         arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|         if step % print_freq == 0 or (step + 1) == len(search_loader): | ||||
|             Sstr = ( | ||||
|                 "**TRAIN** " | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader)) | ||||
|             ) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Lstr = "Base-Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                 loss=base_losses, top1=top1, top5=top5 | ||||
|             ) | ||||
|             Vstr = "Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})".format( | ||||
|                 aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses | ||||
|             ) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Vstr) | ||||
|             # num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0 | ||||
|             # logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' GPU={:.2f}MB'.format(num_bytes/1e6)) | ||||
|             # Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size())) | ||||
|             # logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr) | ||||
|             # print(network.module.get_arch_info()) | ||||
|             # print(network.module.width_attentions[0]) | ||||
|             # print(network.module.width_attentions[1]) | ||||
|  | ||||
|     logger.log( | ||||
|         " **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}".format( | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             baseloss=base_losses.avg, | ||||
|             archloss=arch_losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|     return base_losses.avg, arch_losses.avg, top1.avg, top5.avg | ||||
							
								
								
									
										204
									
								
								AutoDL-Projects/xautodl/procedures/simple_KD_main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										204
									
								
								AutoDL-Projects/xautodl/procedures/simple_KD_main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,204 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import os, sys, time, torch | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| # modules in AutoDL | ||||
| from xautodl.log_utils import AverageMeter, time_string | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def simple_KD_train( | ||||
|     xloader, | ||||
|     teacher, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     optimizer, | ||||
|     optim_config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     loss, acc1, acc5 = procedure( | ||||
|         xloader, | ||||
|         teacher, | ||||
|         network, | ||||
|         criterion, | ||||
|         scheduler, | ||||
|         optimizer, | ||||
|         "train", | ||||
|         optim_config, | ||||
|         extra_info, | ||||
|         print_freq, | ||||
|         logger, | ||||
|     ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def simple_KD_valid( | ||||
|     xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger | ||||
| ): | ||||
|     with torch.no_grad(): | ||||
|         loss, acc1, acc5 = procedure( | ||||
|             xloader, | ||||
|             teacher, | ||||
|             network, | ||||
|             criterion, | ||||
|             None, | ||||
|             None, | ||||
|             "valid", | ||||
|             optim_config, | ||||
|             extra_info, | ||||
|             print_freq, | ||||
|             logger, | ||||
|         ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def loss_KD_fn( | ||||
|     criterion, | ||||
|     student_logits, | ||||
|     teacher_logits, | ||||
|     studentFeatures, | ||||
|     teacherFeatures, | ||||
|     targets, | ||||
|     alpha, | ||||
|     temperature, | ||||
| ): | ||||
|     basic_loss = criterion(student_logits, targets) * (1.0 - alpha) | ||||
|     log_student = F.log_softmax(student_logits / temperature, dim=1) | ||||
|     sof_teacher = F.softmax(teacher_logits / temperature, dim=1) | ||||
|     KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * ( | ||||
|         alpha * temperature * temperature | ||||
|     ) | ||||
|     return basic_loss + KD_loss | ||||
|  | ||||
|  | ||||
| def procedure( | ||||
|     xloader, | ||||
|     teacher, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     optimizer, | ||||
|     mode, | ||||
|     config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time, losses, top1, top5 = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|     ) | ||||
|     Ttop1, Ttop5 = AverageMeter(), AverageMeter() | ||||
|     if mode == "train": | ||||
|         network.train() | ||||
|     elif mode == "valid": | ||||
|         network.eval() | ||||
|     else: | ||||
|         raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|     teacher.eval() | ||||
|  | ||||
|     logger.log( | ||||
|         "[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]".format( | ||||
|             mode, | ||||
|             config.auxiliary if hasattr(config, "auxiliary") else -1, | ||||
|             config.KD_alpha, | ||||
|             config.KD_temperature, | ||||
|         ) | ||||
|     ) | ||||
|     end = time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|         if mode == "train": | ||||
|             scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|         # calculate prediction and loss | ||||
|         targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|  | ||||
|         student_f, logits = network(inputs) | ||||
|         if isinstance(logits, list): | ||||
|             assert len(logits) == 2, "logits must has {:} items instead of {:}".format( | ||||
|                 2, len(logits) | ||||
|             ) | ||||
|             logits, logits_aux = logits | ||||
|         else: | ||||
|             logits, logits_aux = logits, None | ||||
|         with torch.no_grad(): | ||||
|             teacher_f, teacher_logits = teacher(inputs) | ||||
|  | ||||
|         loss = loss_KD_fn( | ||||
|             criterion, | ||||
|             logits, | ||||
|             teacher_logits, | ||||
|             student_f, | ||||
|             teacher_f, | ||||
|             targets, | ||||
|             config.KD_alpha, | ||||
|             config.KD_temperature, | ||||
|         ) | ||||
|         if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0: | ||||
|             loss_aux = criterion(logits_aux, targets) | ||||
|             loss += config.auxiliary * loss_aux | ||||
|  | ||||
|         if mode == "train": | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         sprec1, sprec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|         losses.update(loss.item(), inputs.size(0)) | ||||
|         top1.update(sprec1.item(), inputs.size(0)) | ||||
|         top5.update(sprec5.item(), inputs.size(0)) | ||||
|         # teacher | ||||
|         tprec1, tprec5 = obtain_accuracy(teacher_logits.data, targets.data, topk=(1, 5)) | ||||
|         Ttop1.update(tprec1.item(), inputs.size(0)) | ||||
|         Ttop5.update(tprec5.item(), inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|  | ||||
|         if i % print_freq == 0 or (i + 1) == len(xloader): | ||||
|             Sstr = ( | ||||
|                 " {:5s} ".format(mode.upper()) | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) | ||||
|             ) | ||||
|             if scheduler is not None: | ||||
|                 Sstr += " {:}".format(scheduler.get_min_info()) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                 loss=losses, top1=top1, top5=top5 | ||||
|             ) | ||||
|             Lstr += " Teacher : acc@1={:.2f}, acc@5={:.2f}".format(Ttop1.avg, Ttop5.avg) | ||||
|             Istr = "Size={:}".format(list(inputs.size())) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr) | ||||
|  | ||||
|     logger.log( | ||||
|         " **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}".format( | ||||
|             mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg | ||||
|         ) | ||||
|     ) | ||||
|     logger.log( | ||||
|         " **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( | ||||
|             mode=mode.upper(), | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             loss=losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|     return losses.avg, top1.avg, top5.avg | ||||
							
								
								
									
										79
									
								
								AutoDL-Projects/xautodl/procedures/starts.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								AutoDL-Projects/xautodl/procedures/starts.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,79 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, torch, random, PIL, copy, numpy as np | ||||
| from os import path as osp | ||||
| from shutil import copyfile | ||||
|  | ||||
|  | ||||
| def prepare_seed(rand_seed): | ||||
|     random.seed(rand_seed) | ||||
|     np.random.seed(rand_seed) | ||||
|     torch.manual_seed(rand_seed) | ||||
|     torch.cuda.manual_seed(rand_seed) | ||||
|     torch.cuda.manual_seed_all(rand_seed) | ||||
|  | ||||
|  | ||||
| def prepare_logger(xargs): | ||||
|     args = copy.deepcopy(xargs) | ||||
|     from xautodl.log_utils import Logger | ||||
|  | ||||
|     logger = Logger(args.save_dir, args.rand_seed) | ||||
|     logger.log("Main Function with logger : {:}".format(logger)) | ||||
|     logger.log("Arguments : -------------------------------") | ||||
|     for name, value in args._get_kwargs(): | ||||
|         logger.log("{:16} : {:}".format(name, value)) | ||||
|     logger.log("Python  Version  : {:}".format(sys.version.replace("\n", " "))) | ||||
|     logger.log("Pillow  Version  : {:}".format(PIL.__version__)) | ||||
|     logger.log("PyTorch Version  : {:}".format(torch.__version__)) | ||||
|     logger.log("cuDNN   Version  : {:}".format(torch.backends.cudnn.version())) | ||||
|     logger.log("CUDA available   : {:}".format(torch.cuda.is_available())) | ||||
|     logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count())) | ||||
|     logger.log( | ||||
|         "CUDA_VISIBLE_DEVICES : {:}".format( | ||||
|             os.environ["CUDA_VISIBLE_DEVICES"] | ||||
|             if "CUDA_VISIBLE_DEVICES" in os.environ | ||||
|             else "None" | ||||
|         ) | ||||
|     ) | ||||
|     return logger | ||||
|  | ||||
|  | ||||
| def get_machine_info(): | ||||
|     info = "Python  Version  : {:}".format(sys.version.replace("\n", " ")) | ||||
|     info += "\nPillow  Version  : {:}".format(PIL.__version__) | ||||
|     info += "\nPyTorch Version  : {:}".format(torch.__version__) | ||||
|     info += "\ncuDNN   Version  : {:}".format(torch.backends.cudnn.version()) | ||||
|     info += "\nCUDA available   : {:}".format(torch.cuda.is_available()) | ||||
|     info += "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count()) | ||||
|     if "CUDA_VISIBLE_DEVICES" in os.environ: | ||||
|         info += "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ["CUDA_VISIBLE_DEVICES"]) | ||||
|     else: | ||||
|         info += "\nDoes not set CUDA_VISIBLE_DEVICES" | ||||
|     return info | ||||
|  | ||||
|  | ||||
| def save_checkpoint(state, filename, logger): | ||||
|     if osp.isfile(filename): | ||||
|         if hasattr(logger, "log"): | ||||
|             logger.log( | ||||
|                 "Find {:} exist, delete is at first before saving".format(filename) | ||||
|             ) | ||||
|         os.remove(filename) | ||||
|     torch.save(state, filename) | ||||
|     assert osp.isfile( | ||||
|         filename | ||||
|     ), "save filename : {:} failed, which is not found.".format(filename) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log("save checkpoint into {:}".format(filename)) | ||||
|     return filename | ||||
|  | ||||
|  | ||||
| def copy_checkpoint(src, dst, logger): | ||||
|     if osp.isfile(dst): | ||||
|         if hasattr(logger, "log"): | ||||
|             logger.log("Find {:} exist, delete is at first before saving".format(dst)) | ||||
|         os.remove(dst) | ||||
|     copyfile(src, dst) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log("copy the file from {:} into {:}".format(src, dst)) | ||||
							
								
								
									
										17
									
								
								AutoDL-Projects/xautodl/spaces/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								AutoDL-Projects/xautodl/spaces/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 # | ||||
| ##################################################### | ||||
| # Define complex searc space for AutoDL             # | ||||
| ##################################################### | ||||
|  | ||||
| from .basic_space import Categorical | ||||
| from .basic_space import Continuous | ||||
| from .basic_space import Integer | ||||
| from .basic_space import Space | ||||
| from .basic_space import VirtualNode | ||||
| from .basic_op import has_categorical | ||||
| from .basic_op import has_continuous | ||||
| from .basic_op import is_determined | ||||
| from .basic_op import get_determined_value | ||||
| from .basic_op import get_min | ||||
| from .basic_op import get_max | ||||
							
								
								
									
										74
									
								
								AutoDL-Projects/xautodl/spaces/basic_op.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								AutoDL-Projects/xautodl/spaces/basic_op.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,74 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| from .basic_space import Space | ||||
| from .basic_space import VirtualNode | ||||
| from .basic_space import Integer | ||||
| from .basic_space import Continuous | ||||
| from .basic_space import Categorical | ||||
| from .basic_space import _EPS | ||||
|  | ||||
|  | ||||
| def has_categorical(space_or_value, x): | ||||
|     if isinstance(space_or_value, Space): | ||||
|         return space_or_value.has(x) | ||||
|     else: | ||||
|         return space_or_value == x | ||||
|  | ||||
|  | ||||
| def has_continuous(space_or_value, x): | ||||
|     if isinstance(space_or_value, Space): | ||||
|         return space_or_value.has(x) | ||||
|     else: | ||||
|         return abs(space_or_value - x) <= _EPS | ||||
|  | ||||
|  | ||||
| def is_determined(space_or_value): | ||||
|     if isinstance(space_or_value, Space): | ||||
|         return space_or_value.determined | ||||
|     else: | ||||
|         return True | ||||
|  | ||||
|  | ||||
| def get_determined_value(space_or_value): | ||||
|     if not is_determined(space_or_value): | ||||
|         raise ValueError("This input is not determined: {:}".format(space_or_value)) | ||||
|     if isinstance(space_or_value, Space): | ||||
|         if isinstance(space_or_value, Continuous): | ||||
|             return space_or_value.lower | ||||
|         elif isinstance(space_or_value, Categorical): | ||||
|             return get_determined_value(space_or_value[0]) | ||||
|         else:  # VirtualNode | ||||
|             return space_or_value.value | ||||
|     else: | ||||
|         return space_or_value | ||||
|  | ||||
|  | ||||
| def get_max(space_or_value): | ||||
|     if isinstance(space_or_value, Integer): | ||||
|         return max(space_or_value.candidates) | ||||
|     elif isinstance(space_or_value, Continuous): | ||||
|         return space_or_value.upper | ||||
|     elif isinstance(space_or_value, Categorical): | ||||
|         values = [] | ||||
|         for index in range(len(space_or_value)): | ||||
|             max_value = get_max(space_or_value[index]) | ||||
|             values.append(max_value) | ||||
|         return max(values) | ||||
|     else: | ||||
|         return space_or_value | ||||
|  | ||||
|  | ||||
| def get_min(space_or_value): | ||||
|     if isinstance(space_or_value, Integer): | ||||
|         return min(space_or_value.candidates) | ||||
|     elif isinstance(space_or_value, Continuous): | ||||
|         return space_or_value.lower | ||||
|     elif isinstance(space_or_value, Categorical): | ||||
|         values = [] | ||||
|         for index in range(len(space_or_value)): | ||||
|             min_value = get_min(space_or_value[index]) | ||||
|             values.append(min_value) | ||||
|         return min(values) | ||||
|     else: | ||||
|         return space_or_value | ||||
							
								
								
									
										434
									
								
								AutoDL-Projects/xautodl/spaces/basic_space.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										434
									
								
								AutoDL-Projects/xautodl/spaces/basic_space.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,434 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
|  | ||||
| import abc | ||||
| import math | ||||
| import copy | ||||
| import random | ||||
| import numpy as np | ||||
| from collections import OrderedDict | ||||
|  | ||||
| from typing import Optional, Text | ||||
|  | ||||
|  | ||||
| __all__ = ["_EPS", "Space", "Categorical", "Integer", "Continuous"] | ||||
|  | ||||
| _EPS = 1e-9 | ||||
|  | ||||
|  | ||||
| class Space(metaclass=abc.ABCMeta): | ||||
|     """Basic search space describing the set of possible candidate values for hyperparameter. | ||||
|     All search space must inherit from this basic class. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self): | ||||
|         # used to avoid duplicate sample | ||||
|         self._last_sample = None | ||||
|         self._last_abstract = None | ||||
|  | ||||
|     @abc.abstractproperty | ||||
|     def xrepr(self, depth=0) -> Text: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self) -> Text: | ||||
|         return self.xrepr() | ||||
|  | ||||
|     @abc.abstractproperty | ||||
|     def abstract(self, reuse_last=False) -> "Space": | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def random(self, recursion=True, reuse_last=False): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def clean_last_sample(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def clean_last_abstract(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def clean_last(self): | ||||
|         self.clean_last_sample() | ||||
|         self.clean_last_abstract() | ||||
|  | ||||
|     @abc.abstractproperty | ||||
|     def determined(self) -> bool: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def has(self, x) -> bool: | ||||
|         """Check whether x is in this search space.""" | ||||
|         assert not isinstance( | ||||
|             x, Space | ||||
|         ), "The input value itself can not be a search space." | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def __eq__(self, other): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def copy(self) -> "Space": | ||||
|         return copy.deepcopy(self) | ||||
|  | ||||
|  | ||||
| class VirtualNode(Space): | ||||
|     """For a nested search space, we represent it as a tree structure. | ||||
|  | ||||
|     For example, | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, id=None, value=None): | ||||
|         super(VirtualNode, self).__init__() | ||||
|         self._id = id | ||||
|         self._value = value | ||||
|         self._attributes = OrderedDict() | ||||
|  | ||||
|     @property | ||||
|     def value(self): | ||||
|         return self._value | ||||
|  | ||||
|     def append(self, key, value): | ||||
|         if not isinstance(key, str): | ||||
|             raise TypeError( | ||||
|                 "Only accept string as a key instead of {:}".format(type(key)) | ||||
|             ) | ||||
|         if not isinstance(value, Space): | ||||
|             raise ValueError("Invalid type of value: {:}".format(type(value))) | ||||
|         # if value.determined: | ||||
|         #    raise ValueError("Can not attach a determined value: {:}".format(value)) | ||||
|         self._attributes[key] = value | ||||
|  | ||||
|     def xrepr(self, depth=0) -> Text: | ||||
|         strs = [self.__class__.__name__ + "(value={:}".format(self._value)] | ||||
|         for key, value in self._attributes.items(): | ||||
|             strs.append(key + " = " + value.xrepr(depth + 1)) | ||||
|         strs.append(")") | ||||
|         if len(strs) == 2: | ||||
|             return "".join(strs) | ||||
|         else: | ||||
|             space = "  " | ||||
|             xstrs = ( | ||||
|                 [strs[0]] | ||||
|                 + [space * (depth + 1) + x for x in strs[1:-1]] | ||||
|                 + [space * depth + strs[-1]] | ||||
|             ) | ||||
|             return ",\n".join(xstrs) | ||||
|  | ||||
|     def abstract(self, reuse_last=False) -> Space: | ||||
|         if reuse_last and self._last_abstract is not None: | ||||
|             return self._last_abstract | ||||
|         node = VirtualNode(id(self)) | ||||
|         for key, value in self._attributes.items(): | ||||
|             if not value.determined: | ||||
|                 node.append(value.abstract(reuse_last)) | ||||
|         self._last_abstract = node | ||||
|         return self._last_abstract | ||||
|  | ||||
|     def random(self, recursion=True, reuse_last=False): | ||||
|         if reuse_last and self._last_sample is not None: | ||||
|             return self._last_sample | ||||
|         node = VirtualNode(None, self._value) | ||||
|         for key, value in self._attributes.items(): | ||||
|             node.append(key, value.random(recursion, reuse_last)) | ||||
|         self._last_sample = node  # record the last sample | ||||
|         return node | ||||
|  | ||||
|     def clean_last_sample(self): | ||||
|         self._last_sample = None | ||||
|         for key, value in self._attributes.items(): | ||||
|             value.clean_last_sample() | ||||
|  | ||||
|     def clean_last_abstract(self): | ||||
|         self._last_abstract = None | ||||
|         for key, value in self._attributes.items(): | ||||
|             value.clean_last_abstract() | ||||
|  | ||||
|     def has(self, x) -> bool: | ||||
|         for key, value in self._attributes.items(): | ||||
|             if value.has(x): | ||||
|                 return True | ||||
|         return False | ||||
|  | ||||
|     def __contains__(self, key): | ||||
|         return key in self._attributes | ||||
|  | ||||
|     def __getitem__(self, key): | ||||
|         return self._attributes[key] | ||||
|  | ||||
|     @property | ||||
|     def determined(self) -> bool: | ||||
|         for key, value in self._attributes.items(): | ||||
|             if not value.determined: | ||||
|                 return False | ||||
|         return True | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if not isinstance(other, VirtualNode): | ||||
|             return False | ||||
|         for key, value in self._attributes.items(): | ||||
|             if not key in other: | ||||
|                 return False | ||||
|             if value != other[key]: | ||||
|                 return False | ||||
|         return True | ||||
|  | ||||
|  | ||||
| class Categorical(Space): | ||||
|     """A space contains the categorical values. | ||||
|     It can be a nested space, which means that the candidate in this space can also be a search space. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, *data, default: Optional[int] = None): | ||||
|         super(Categorical, self).__init__() | ||||
|         self._candidates = [*data] | ||||
|         self._default = default | ||||
|         assert self._default is None or 0 <= self._default < len( | ||||
|             self._candidates | ||||
|         ), "default >= {:}".format(len(self._candidates)) | ||||
|         assert len(self) > 0, "Please provide at least one candidate" | ||||
|  | ||||
|     @property | ||||
|     def candidates(self): | ||||
|         return self._candidates | ||||
|  | ||||
|     @property | ||||
|     def default(self): | ||||
|         return self._default | ||||
|  | ||||
|     @property | ||||
|     def determined(self): | ||||
|         if len(self) == 1: | ||||
|             return ( | ||||
|                 not isinstance(self._candidates[0], Space) | ||||
|                 or self._candidates[0].determined | ||||
|             ) | ||||
|         else: | ||||
|             return False | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         return self._candidates[index] | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._candidates) | ||||
|  | ||||
|     def clean_last_sample(self): | ||||
|         self._last_sample = None | ||||
|         for candidate in self._candidates: | ||||
|             if isinstance(candidate, Space): | ||||
|                 candidate.clean_last_sample() | ||||
|  | ||||
|     def clean_last_abstract(self): | ||||
|         self._last_abstract = None | ||||
|         for candidate in self._candidates: | ||||
|             if isinstance(candidate, Space): | ||||
|                 candidate.clean_last_abstract() | ||||
|  | ||||
|     def abstract(self, reuse_last=False) -> Space: | ||||
|         if reuse_last and self._last_abstract is not None: | ||||
|             return self._last_abstract | ||||
|         if self.determined: | ||||
|             result = VirtualNode(id(self), self) | ||||
|         else: | ||||
|             # [TO-IMPROVE] | ||||
|             data = [] | ||||
|             for candidate in self.candidates: | ||||
|                 if isinstance(candidate, Space): | ||||
|                     data.append(candidate.abstract()) | ||||
|                 else: | ||||
|                     data.append(VirtualNode(id(candidate), candidate)) | ||||
|             result = Categorical(*data, default=self._default) | ||||
|         self._last_abstract = result | ||||
|         return self._last_abstract | ||||
|  | ||||
|     def random(self, recursion=True, reuse_last=False): | ||||
|         if reuse_last and self._last_sample is not None: | ||||
|             return self._last_sample | ||||
|         sample = random.choice(self._candidates) | ||||
|         if recursion and isinstance(sample, Space): | ||||
|             sample = sample.random(recursion, reuse_last) | ||||
|         if isinstance(sample, VirtualNode): | ||||
|             sample = sample.copy() | ||||
|         else: | ||||
|             sample = VirtualNode(None, sample) | ||||
|         self._last_sample = sample | ||||
|         return self._last_sample | ||||
|  | ||||
|     def xrepr(self, depth=0): | ||||
|         del depth | ||||
|         xrepr = "{name:}(candidates={cs:}, default_index={default:})".format( | ||||
|             name=self.__class__.__name__, cs=self._candidates, default=self._default | ||||
|         ) | ||||
|         return xrepr | ||||
|  | ||||
|     def has(self, x): | ||||
|         super().has(x) | ||||
|         for candidate in self._candidates: | ||||
|             if isinstance(candidate, Space) and candidate.has(x): | ||||
|                 return True | ||||
|             elif candidate == x: | ||||
|                 return True | ||||
|         return False | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if not isinstance(other, Categorical): | ||||
|             return False | ||||
|         if len(self) != len(other): | ||||
|             return False | ||||
|         if self.default != other.default: | ||||
|             return False | ||||
|         for index in range(len(self)): | ||||
|             if self.__getitem__(index) != other[index]: | ||||
|                 return False | ||||
|         return True | ||||
|  | ||||
|  | ||||
| class Integer(Categorical): | ||||
|     """A space contains the integer values.""" | ||||
|  | ||||
|     def __init__(self, lower: int, upper: int, default: Optional[int] = None): | ||||
|         if not isinstance(lower, int) or not isinstance(upper, int): | ||||
|             raise ValueError( | ||||
|                 "The lower [{:}] and uppwer [{:}] must be int.".format(lower, upper) | ||||
|             ) | ||||
|         data = list(range(lower, upper + 1)) | ||||
|         self._raw_lower = lower | ||||
|         self._raw_upper = upper | ||||
|         self._raw_default = default | ||||
|         if default is not None and (default < lower or default > upper): | ||||
|             raise ValueError("The default value [{:}] is out of range.".format(default)) | ||||
|             default = data.index(default) | ||||
|         super(Integer, self).__init__(*data, default=default) | ||||
|  | ||||
|     def xrepr(self, depth=0): | ||||
|         del depth | ||||
|         xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             lower=self._raw_lower, | ||||
|             upper=self._raw_upper, | ||||
|             default=self._raw_default, | ||||
|         ) | ||||
|         return xrepr | ||||
|  | ||||
|  | ||||
| np_float_types = (np.float16, np.float32, np.float64) | ||||
| np_int_types = ( | ||||
|     np.uint8, | ||||
|     np.int8, | ||||
|     np.uint16, | ||||
|     np.int16, | ||||
|     np.uint32, | ||||
|     np.int32, | ||||
|     np.uint64, | ||||
|     np.int64, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class Continuous(Space): | ||||
|     """A space contains the continuous values.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         lower: float, | ||||
|         upper: float, | ||||
|         default: Optional[float] = None, | ||||
|         log: bool = False, | ||||
|         eps: float = _EPS, | ||||
|     ): | ||||
|         super(Continuous, self).__init__() | ||||
|         self._lower = lower | ||||
|         self._upper = upper | ||||
|         self._default = default | ||||
|         self._log_scale = log | ||||
|         self._eps = eps | ||||
|  | ||||
|     @property | ||||
|     def lower(self): | ||||
|         return self._lower | ||||
|  | ||||
|     @property | ||||
|     def upper(self): | ||||
|         return self._upper | ||||
|  | ||||
|     @property | ||||
|     def default(self): | ||||
|         return self._default | ||||
|  | ||||
|     @property | ||||
|     def use_log(self): | ||||
|         return self._log_scale | ||||
|  | ||||
|     @property | ||||
|     def eps(self): | ||||
|         return self._eps | ||||
|  | ||||
|     def abstract(self, reuse_last=False) -> Space: | ||||
|         if reuse_last and self._last_abstract is not None: | ||||
|             return self._last_abstract | ||||
|         self._last_abstract = self.copy() | ||||
|         return self._last_abstract | ||||
|  | ||||
|     def random(self, recursion=True, reuse_last=False): | ||||
|         del recursion | ||||
|         if reuse_last and self._last_sample is not None: | ||||
|             return self._last_sample | ||||
|         if self._log_scale: | ||||
|             sample = random.uniform(math.log(self._lower), math.log(self._upper)) | ||||
|             sample = math.exp(sample) | ||||
|         else: | ||||
|             sample = random.uniform(self._lower, self._upper) | ||||
|         self._last_sample = VirtualNode(None, sample) | ||||
|         return self._last_sample | ||||
|  | ||||
|     def xrepr(self, depth=0): | ||||
|         del depth | ||||
|         xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             lower=self._lower, | ||||
|             upper=self._upper, | ||||
|             default=self._default, | ||||
|             log=self._log_scale, | ||||
|         ) | ||||
|         return xrepr | ||||
|  | ||||
|     def convert(self, x): | ||||
|         if isinstance(x, np_float_types) and x.size == 1: | ||||
|             return float(x), True | ||||
|         elif isinstance(x, np_int_types) and x.size == 1: | ||||
|             return float(x), True | ||||
|         elif isinstance(x, int): | ||||
|             return float(x), True | ||||
|         elif isinstance(x, float): | ||||
|             return float(x), True | ||||
|         else: | ||||
|             return None, False | ||||
|  | ||||
|     def has(self, x): | ||||
|         super().has(x) | ||||
|         converted_x, success = self.convert(x) | ||||
|         return success and self.lower <= converted_x <= self.upper | ||||
|  | ||||
|     @property | ||||
|     def determined(self): | ||||
|         return abs(self.lower - self.upper) <= self._eps | ||||
|  | ||||
|     def clean_last_sample(self): | ||||
|         self._last_sample = None | ||||
|  | ||||
|     def clean_last_abstract(self): | ||||
|         self._last_abstract = None | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if not isinstance(other, Continuous): | ||||
|             return False | ||||
|         if self is other: | ||||
|             return True | ||||
|         else: | ||||
|             return ( | ||||
|                 self.lower == other.lower | ||||
|                 and self.upper == other.upper | ||||
|                 and self.default == other.default | ||||
|                 and self.use_log == other.use_log | ||||
|                 and self.eps == other.eps | ||||
|             ) | ||||
							
								
								
									
										4
									
								
								AutoDL-Projects/xautodl/trade_models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								AutoDL-Projects/xautodl/trade_models/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| from .transformers import get_transformer | ||||
							
								
								
									
										102
									
								
								AutoDL-Projects/xautodl/trade_models/naive_v1_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										102
									
								
								AutoDL-Projects/xautodl/trade_models/naive_v1_model.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,102 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021 # | ||||
| ################################################## | ||||
| # Use noise as prediction                        # | ||||
| ################################################## | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
|  | ||||
| import random | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
|  | ||||
| from qlib.log import get_module_logger | ||||
|  | ||||
| from qlib.model.base import Model | ||||
| from qlib.data.dataset import DatasetH | ||||
| from qlib.data.dataset.handler import DataHandlerLP | ||||
|  | ||||
|  | ||||
| class NAIVE_V1(Model): | ||||
|     """NAIVE Version 1 Quant Model""" | ||||
|  | ||||
|     def __init__(self, d_feat=6, seed=None, **kwargs): | ||||
|         # Set logger. | ||||
|         self.logger = get_module_logger("NAIVE") | ||||
|         self.logger.info("NAIVE 1st version: random noise ...") | ||||
|  | ||||
|         # set hyper-parameters. | ||||
|         self.d_feat = d_feat | ||||
|         self.seed = seed | ||||
|  | ||||
|         self.logger.info( | ||||
|             "NAIVE-V1 parameters setting: d_feat={:}, seed={:}".format( | ||||
|                 self.d_feat, self.seed | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         if self.seed is not None: | ||||
|             random.seed(self.seed) | ||||
|             np.random.seed(self.seed) | ||||
|         self._mean = None | ||||
|         self._std = None | ||||
|         self.fitted = False | ||||
|  | ||||
|     def process_data(self, features): | ||||
|         features = features.reshape(len(features), self.d_feat, -1) | ||||
|         features = features.transpose((0, 2, 1)) | ||||
|         return features[:, :59, 0] | ||||
|  | ||||
|     def mse(self, preds, labels): | ||||
|         masks = ~np.isnan(labels) | ||||
|         masked_preds = preds[masks] | ||||
|         masked_labels = labels[masks] | ||||
|         return np.square(masked_preds - masked_labels).mean() | ||||
|  | ||||
|     def model(self, x): | ||||
|         num = len(x) | ||||
|         return np.random.normal(loc=self._mean, scale=self._std, size=num).astype( | ||||
|             x.dtype | ||||
|         ) | ||||
|  | ||||
|     def fit(self, dataset: DatasetH): | ||||
|         def _prepare_dataset(df_data): | ||||
|             features = df_data["feature"].values | ||||
|             features = self.process_data(features) | ||||
|             labels = df_data["label"].values.squeeze() | ||||
|             return dict(features=features, labels=labels) | ||||
|  | ||||
|         df_train, df_valid, df_test = dataset.prepare( | ||||
|             ["train", "valid", "test"], | ||||
|             col_set=["feature", "label"], | ||||
|             data_key=DataHandlerLP.DK_L, | ||||
|         ) | ||||
|         train_dataset, valid_dataset, test_dataset = ( | ||||
|             _prepare_dataset(df_train), | ||||
|             _prepare_dataset(df_valid), | ||||
|             _prepare_dataset(df_test), | ||||
|         ) | ||||
|         # df_train['feature']['CLOSE1'].values | ||||
|         # train_dataset['features'][:, -1] | ||||
|         masks = ~np.isnan(train_dataset["labels"]) | ||||
|         self._mean, self._std = np.mean(train_dataset["labels"][masks]), np.std( | ||||
|             train_dataset["labels"][masks] | ||||
|         ) | ||||
|         train_mse_loss = self.mse( | ||||
|             self.model(train_dataset["features"]), train_dataset["labels"] | ||||
|         ) | ||||
|         valid_mse_loss = self.mse( | ||||
|             self.model(valid_dataset["features"]), valid_dataset["labels"] | ||||
|         ) | ||||
|         self.logger.info("Training MSE loss: {:}".format(train_mse_loss)) | ||||
|         self.logger.info("Validation MSE loss: {:}".format(valid_mse_loss)) | ||||
|         self.fitted = True | ||||
|  | ||||
|     def predict(self, dataset): | ||||
|         if not self.fitted: | ||||
|             raise ValueError("The model is not fitted yet!") | ||||
|         x_test = dataset.prepare("test", col_set="feature") | ||||
|         index = x_test.index | ||||
|  | ||||
|         preds = self.model(self.process_data(x_test.values)) | ||||
|         return pd.Series(preds, index=index) | ||||
							
								
								
									
										103
									
								
								AutoDL-Projects/xautodl/trade_models/naive_v2_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								AutoDL-Projects/xautodl/trade_models/naive_v2_model.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,103 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021 # | ||||
| ################################################## | ||||
| # A Simple Model that reused the prices of last day | ||||
| ################################################## | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
|  | ||||
| import random | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
|  | ||||
| from qlib.log import get_module_logger | ||||
|  | ||||
| from qlib.model.base import Model | ||||
| from qlib.data.dataset import DatasetH | ||||
| from qlib.data.dataset.handler import DataHandlerLP | ||||
|  | ||||
|  | ||||
| class NAIVE_V2(Model): | ||||
|     """NAIVE Version 2 Quant Model""" | ||||
|  | ||||
|     def __init__(self, d_feat=6, seed=None, **kwargs): | ||||
|         # Set logger. | ||||
|         self.logger = get_module_logger("NAIVE") | ||||
|         self.logger.info("NAIVE version...") | ||||
|  | ||||
|         # set hyper-parameters. | ||||
|         self.d_feat = d_feat | ||||
|         self.seed = seed | ||||
|  | ||||
|         self.logger.info( | ||||
|             "NAIVE parameters setting: d_feat={:}, seed={:}".format( | ||||
|                 self.d_feat, self.seed | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         if self.seed is not None: | ||||
|             random.seed(self.seed) | ||||
|             np.random.seed(self.seed) | ||||
|  | ||||
|         self.fitted = False | ||||
|  | ||||
|     def process_data(self, features): | ||||
|         features = features.reshape(len(features), self.d_feat, -1) | ||||
|         features = features.transpose((0, 2, 1)) | ||||
|         return features[:, :59, 0] | ||||
|  | ||||
|     def mse(self, preds, labels): | ||||
|         masks = ~np.isnan(labels) | ||||
|         masked_preds = preds[masks] | ||||
|         masked_labels = labels[masks] | ||||
|         return np.square(masked_preds - masked_labels).mean() | ||||
|  | ||||
|     def model(self, x): | ||||
|         x = 1 / x - 1 | ||||
|         masks = ~np.isnan(x) | ||||
|         results = [] | ||||
|         for rowd, rowm in zip(x, masks): | ||||
|             temp = rowd[rowm] | ||||
|             if rowm.any(): | ||||
|                 results.append(float(rowd[rowm][-1])) | ||||
|             else: | ||||
|                 results.append(0) | ||||
|         return np.array(results, dtype=x.dtype) | ||||
|  | ||||
|     def fit(self, dataset: DatasetH): | ||||
|         def _prepare_dataset(df_data): | ||||
|             features = df_data["feature"].values | ||||
|             features = self.process_data(features) | ||||
|             labels = df_data["label"].values.squeeze() | ||||
|             return dict(features=features, labels=labels) | ||||
|  | ||||
|         df_train, df_valid, df_test = dataset.prepare( | ||||
|             ["train", "valid", "test"], | ||||
|             col_set=["feature", "label"], | ||||
|             data_key=DataHandlerLP.DK_L, | ||||
|         ) | ||||
|         train_dataset, valid_dataset, test_dataset = ( | ||||
|             _prepare_dataset(df_train), | ||||
|             _prepare_dataset(df_valid), | ||||
|             _prepare_dataset(df_test), | ||||
|         ) | ||||
|         # df_train['feature']['CLOSE1'].values | ||||
|         # train_dataset['features'][:, -1] | ||||
|         train_mse_loss = self.mse( | ||||
|             self.model(train_dataset["features"]), train_dataset["labels"] | ||||
|         ) | ||||
|         valid_mse_loss = self.mse( | ||||
|             self.model(valid_dataset["features"]), valid_dataset["labels"] | ||||
|         ) | ||||
|         self.logger.info("Training MSE loss: {:}".format(train_mse_loss)) | ||||
|         self.logger.info("Validation MSE loss: {:}".format(valid_mse_loss)) | ||||
|         self.fitted = True | ||||
|  | ||||
|     def predict(self, dataset): | ||||
|         if not self.fitted: | ||||
|             raise ValueError("The model is not fitted yet!") | ||||
|         x_test = dataset.prepare("test", col_set="feature") | ||||
|         index = x_test.index | ||||
|  | ||||
|         preds = self.model(self.process_data(x_test.values)) | ||||
|         return pd.Series(preds, index=index) | ||||
							
								
								
									
										358
									
								
								AutoDL-Projects/xautodl/trade_models/quant_transformer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										358
									
								
								AutoDL-Projects/xautodl/trade_models/quant_transformer.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,358 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021 # | ||||
| ################################################## | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
|  | ||||
| import os, math, random | ||||
| from collections import OrderedDict | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
| from typing import Text, Union | ||||
| import copy | ||||
| from functools import partial | ||||
| from typing import Optional, Text | ||||
|  | ||||
| from qlib.utils import get_or_create_path | ||||
| from qlib.log import get_module_logger | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| import torch.optim as optim | ||||
| import torch.utils.data as th_data | ||||
|  | ||||
| from xautodl.xmisc import AverageMeter | ||||
| from xautodl.xmisc import count_parameters | ||||
|  | ||||
| from xautodl.xlayers import super_core | ||||
| from .transformers import DEFAULT_NET_CONFIG | ||||
| from .transformers import get_transformer | ||||
|  | ||||
|  | ||||
| from qlib.model.base import Model | ||||
| from qlib.data.dataset import DatasetH | ||||
| from qlib.data.dataset.handler import DataHandlerLP | ||||
|  | ||||
|  | ||||
| DEFAULT_OPT_CONFIG = dict( | ||||
|     epochs=200, | ||||
|     lr=0.001, | ||||
|     batch_size=2000, | ||||
|     early_stop=20, | ||||
|     loss="mse", | ||||
|     optimizer="adam", | ||||
|     num_workers=4, | ||||
| ) | ||||
|  | ||||
|  | ||||
| def train_or_test_epoch( | ||||
|     xloader, model, loss_fn, metric_fn, is_train, optimizer, device | ||||
| ): | ||||
|     if is_train: | ||||
|         model.train() | ||||
|     else: | ||||
|         model.eval() | ||||
|     score_meter, loss_meter = AverageMeter(), AverageMeter() | ||||
|     for ibatch, (feats, labels) in enumerate(xloader): | ||||
|         feats, labels = feats.to(device), labels.to(device) | ||||
|         # forward the network | ||||
|         preds = model(feats) | ||||
|         loss = loss_fn(preds, labels) | ||||
|         with torch.no_grad(): | ||||
|             score = metric_fn(preds, labels) | ||||
|             loss_meter.update(loss.item(), feats.size(0)) | ||||
|             score_meter.update(score.item(), feats.size(0)) | ||||
|         # optimize the network | ||||
|         if is_train and optimizer is not None: | ||||
|             optimizer.zero_grad() | ||||
|             loss.backward() | ||||
|             torch.nn.utils.clip_grad_value_(model.parameters(), 3.0) | ||||
|             optimizer.step() | ||||
|     return loss_meter.avg, score_meter.avg | ||||
|  | ||||
|  | ||||
| class QuantTransformer(Model): | ||||
|     """Transformer-based Quant Model""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, net_config=None, opt_config=None, metric="", GPU=0, seed=None, **kwargs | ||||
|     ): | ||||
|         # Set logger. | ||||
|         self.logger = get_module_logger("QuantTransformer") | ||||
|         self.logger.info("QuantTransformer PyTorch version...") | ||||
|  | ||||
|         # set hyper-parameters. | ||||
|         self.net_config = net_config or DEFAULT_NET_CONFIG | ||||
|         self.opt_config = opt_config or DEFAULT_OPT_CONFIG | ||||
|         self.metric = metric | ||||
|         self.device = torch.device( | ||||
|             "cuda:{:}".format(GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" | ||||
|         ) | ||||
|         self.seed = seed | ||||
|  | ||||
|         self.logger.info( | ||||
|             "Transformer parameters setting:" | ||||
|             "\nnet_config : {:}" | ||||
|             "\nopt_config : {:}" | ||||
|             "\nmetric     : {:}" | ||||
|             "\ndevice     : {:}" | ||||
|             "\nseed       : {:}".format( | ||||
|                 self.net_config, | ||||
|                 self.opt_config, | ||||
|                 self.metric, | ||||
|                 self.device, | ||||
|                 self.seed, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         if self.seed is not None: | ||||
|             random.seed(self.seed) | ||||
|             np.random.seed(self.seed) | ||||
|             torch.manual_seed(self.seed) | ||||
|             if self.use_gpu: | ||||
|                 torch.cuda.manual_seed(self.seed) | ||||
|                 torch.cuda.manual_seed_all(self.seed) | ||||
|  | ||||
|         self.model = get_transformer(self.net_config) | ||||
|         self.model.set_super_run_type(super_core.SuperRunMode.FullModel) | ||||
|         self.logger.info("model: {:}".format(self.model)) | ||||
|         self.logger.info("model size: {:.3f} MB".format(count_parameters(self.model))) | ||||
|  | ||||
|         if self.opt_config["optimizer"] == "adam": | ||||
|             self.train_optimizer = optim.Adam( | ||||
|                 self.model.parameters(), lr=self.opt_config["lr"] | ||||
|             ) | ||||
|         elif self.opt_config["optimizer"] == "adam": | ||||
|             self.train_optimizer = optim.SGD( | ||||
|                 self.model.parameters(), lr=self.opt_config["lr"] | ||||
|             ) | ||||
|         else: | ||||
|             raise NotImplementedError( | ||||
|                 "optimizer {:} is not supported!".format(optimizer) | ||||
|             ) | ||||
|  | ||||
|         self.fitted = False | ||||
|         self.model.to(self.device) | ||||
|  | ||||
|     @property | ||||
|     def use_gpu(self): | ||||
|         return self.device != torch.device("cpu") | ||||
|  | ||||
|     def to(self, device): | ||||
|         if device is None: | ||||
|             device = "cpu" | ||||
|         self.device = device | ||||
|         self.model.to(self.device) | ||||
|         # move the optimizer | ||||
|         for param in self.train_optimizer.state.values(): | ||||
|             # Not sure there are any global tensors in the state dict | ||||
|             if isinstance(param, torch.Tensor): | ||||
|                 param.data = param.data.to(device) | ||||
|                 if param._grad is not None: | ||||
|                     param._grad.data = param._grad.data.to(device) | ||||
|             elif isinstance(param, dict): | ||||
|                 for subparam in param.values(): | ||||
|                     if isinstance(subparam, torch.Tensor): | ||||
|                         subparam.data = subparam.data.to(device) | ||||
|                         if subparam._grad is not None: | ||||
|                             subparam._grad.data = subparam._grad.data.to(device) | ||||
|  | ||||
|     def loss_fn(self, pred, label): | ||||
|         mask = ~torch.isnan(label) | ||||
|         if self.opt_config["loss"] == "mse": | ||||
|             return F.mse_loss(pred[mask], label[mask]) | ||||
|         else: | ||||
|             raise ValueError("unknown loss `{:}`".format(self.loss)) | ||||
|  | ||||
|     def metric_fn(self, pred, label): | ||||
|         # the metric score : higher is better | ||||
|         if self.metric == "" or self.metric == "loss": | ||||
|             return -self.loss_fn(pred, label) | ||||
|         else: | ||||
|             raise ValueError("unknown metric `{:}`".format(self.metric)) | ||||
|  | ||||
|     def fit( | ||||
|         self, | ||||
|         dataset: DatasetH, | ||||
|         save_dir: Optional[Text] = None, | ||||
|     ): | ||||
|         def _prepare_dataset(df_data): | ||||
|             return th_data.TensorDataset( | ||||
|                 torch.from_numpy(df_data["feature"].values).float(), | ||||
|                 torch.from_numpy(df_data["label"].values).squeeze().float(), | ||||
|             ) | ||||
|  | ||||
|         def _prepare_loader(dataset, shuffle): | ||||
|             return th_data.DataLoader( | ||||
|                 dataset, | ||||
|                 batch_size=self.opt_config["batch_size"], | ||||
|                 drop_last=False, | ||||
|                 pin_memory=True, | ||||
|                 num_workers=self.opt_config["num_workers"], | ||||
|                 shuffle=shuffle, | ||||
|             ) | ||||
|  | ||||
|         df_train, df_valid, df_test = dataset.prepare( | ||||
|             ["train", "valid", "test"], | ||||
|             col_set=["feature", "label"], | ||||
|             data_key=DataHandlerLP.DK_L, | ||||
|         ) | ||||
|         train_dataset, valid_dataset, test_dataset = ( | ||||
|             _prepare_dataset(df_train), | ||||
|             _prepare_dataset(df_valid), | ||||
|             _prepare_dataset(df_test), | ||||
|         ) | ||||
|         train_loader, valid_loader, test_loader = ( | ||||
|             _prepare_loader(train_dataset, True), | ||||
|             _prepare_loader(valid_dataset, False), | ||||
|             _prepare_loader(test_dataset, False), | ||||
|         ) | ||||
|  | ||||
|         save_dir = get_or_create_path(save_dir, return_dir=True) | ||||
|         self.logger.info( | ||||
|             "Fit procedure for [{:}] with save path={:}".format( | ||||
|                 self.__class__.__name__, save_dir | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         def _internal_test(ckp_epoch=None, results_dict=None): | ||||
|             with torch.no_grad(): | ||||
|                 shared_kwards = { | ||||
|                     "model": self.model, | ||||
|                     "loss_fn": self.loss_fn, | ||||
|                     "metric_fn": self.metric_fn, | ||||
|                     "is_train": False, | ||||
|                     "optimizer": None, | ||||
|                     "device": self.device, | ||||
|                 } | ||||
|                 train_loss, train_score = train_or_test_epoch( | ||||
|                     train_loader, **shared_kwards | ||||
|                 ) | ||||
|                 valid_loss, valid_score = train_or_test_epoch( | ||||
|                     valid_loader, **shared_kwards | ||||
|                 ) | ||||
|                 test_loss, test_score = train_or_test_epoch( | ||||
|                     test_loader, **shared_kwards | ||||
|                 ) | ||||
|                 xstr = ( | ||||
|                     "train-score={:.6f}, valid-score={:.6f}, test-score={:.6f}".format( | ||||
|                         train_score, valid_score, test_score | ||||
|                     ) | ||||
|                 ) | ||||
|                 if ckp_epoch is not None and isinstance(results_dict, dict): | ||||
|                     results_dict["train"][ckp_epoch] = train_score | ||||
|                     results_dict["valid"][ckp_epoch] = valid_score | ||||
|                     results_dict["test"][ckp_epoch] = test_score | ||||
|                 return dict(train=train_score, valid=valid_score, test=test_score), xstr | ||||
|  | ||||
|         # Pre-fetch the potential checkpoints | ||||
|         ckp_path = os.path.join(save_dir, "{:}.pth".format(self.__class__.__name__)) | ||||
|         if os.path.exists(ckp_path): | ||||
|             ckp_data = torch.load(ckp_path, map_location=self.device) | ||||
|             stop_steps, best_score, best_epoch = ( | ||||
|                 ckp_data["stop_steps"], | ||||
|                 ckp_data["best_score"], | ||||
|                 ckp_data["best_epoch"], | ||||
|             ) | ||||
|             start_epoch, best_param = ckp_data["start_epoch"], ckp_data["best_param"] | ||||
|             results_dict = ckp_data["results_dict"] | ||||
|             self.model.load_state_dict(ckp_data["net_state_dict"]) | ||||
|             self.train_optimizer.load_state_dict(ckp_data["opt_state_dict"]) | ||||
|             self.logger.info("Resume from existing checkpoint: {:}".format(ckp_path)) | ||||
|         else: | ||||
|             stop_steps, best_score, best_epoch = 0, -np.inf, -1 | ||||
|             start_epoch, best_param = 0, None | ||||
|             results_dict = dict( | ||||
|                 train=OrderedDict(), valid=OrderedDict(), test=OrderedDict() | ||||
|             ) | ||||
|             _, eval_str = _internal_test(-1, results_dict) | ||||
|             self.logger.info( | ||||
|                 "Training from scratch, metrics@start: {:}".format(eval_str) | ||||
|             ) | ||||
|  | ||||
|         for iepoch in range(start_epoch, self.opt_config["epochs"]): | ||||
|             self.logger.info( | ||||
|                 "Epoch={:03d}/{:03d} ::==>> Best valid @{:03d} ({:.6f})".format( | ||||
|                     iepoch, self.opt_config["epochs"], best_epoch, best_score | ||||
|                 ) | ||||
|             ) | ||||
|             train_loss, train_score = train_or_test_epoch( | ||||
|                 train_loader, | ||||
|                 self.model, | ||||
|                 self.loss_fn, | ||||
|                 self.metric_fn, | ||||
|                 True, | ||||
|                 self.train_optimizer, | ||||
|                 self.device, | ||||
|             ) | ||||
|             self.logger.info( | ||||
|                 "Training :: loss={:.6f}, score={:.6f}".format(train_loss, train_score) | ||||
|             ) | ||||
|  | ||||
|             current_eval_scores, eval_str = _internal_test(iepoch, results_dict) | ||||
|             self.logger.info("Evaluating :: {:}".format(eval_str)) | ||||
|  | ||||
|             if current_eval_scores["valid"] > best_score: | ||||
|                 stop_steps, best_epoch, best_score = ( | ||||
|                     0, | ||||
|                     iepoch, | ||||
|                     current_eval_scores["valid"], | ||||
|                 ) | ||||
|                 best_param = copy.deepcopy(self.model.state_dict()) | ||||
|             else: | ||||
|                 stop_steps += 1 | ||||
|                 if stop_steps >= self.opt_config["early_stop"]: | ||||
|                     self.logger.info( | ||||
|                         "early stop at {:}-th epoch, where the best is @{:}".format( | ||||
|                             iepoch, best_epoch | ||||
|                         ) | ||||
|                     ) | ||||
|                     break | ||||
|             save_info = dict( | ||||
|                 net_config=self.net_config, | ||||
|                 opt_config=self.opt_config, | ||||
|                 net_state_dict=self.model.state_dict(), | ||||
|                 opt_state_dict=self.train_optimizer.state_dict(), | ||||
|                 best_param=best_param, | ||||
|                 stop_steps=stop_steps, | ||||
|                 best_score=best_score, | ||||
|                 best_epoch=best_epoch, | ||||
|                 results_dict=results_dict, | ||||
|                 start_epoch=iepoch + 1, | ||||
|             ) | ||||
|             torch.save(save_info, ckp_path) | ||||
|         self.logger.info( | ||||
|             "The best score: {:.6f} @ {:02d}-th epoch".format(best_score, best_epoch) | ||||
|         ) | ||||
|         self.model.load_state_dict(best_param) | ||||
|         _, eval_str = _internal_test("final", results_dict) | ||||
|         self.logger.info("Reload the best parameter :: {:}".format(eval_str)) | ||||
|  | ||||
|         if self.use_gpu: | ||||
|             with torch.cuda.device(self.device): | ||||
|                 torch.cuda.empty_cache() | ||||
|         self.fitted = True | ||||
|  | ||||
|     def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): | ||||
|         if not self.fitted: | ||||
|             raise ValueError("The model is not fitted yet!") | ||||
|         x_test = dataset.prepare( | ||||
|             segment, col_set="feature", data_key=DataHandlerLP.DK_I | ||||
|         ) | ||||
|         index = x_test.index | ||||
|  | ||||
|         with torch.no_grad(): | ||||
|             self.model.eval() | ||||
|             x_values = x_test.values | ||||
|             sample_num, batch_size = x_values.shape[0], self.opt_config["batch_size"] | ||||
|             preds = [] | ||||
|             for begin in range(sample_num)[::batch_size]: | ||||
|                 if sample_num - begin < batch_size: | ||||
|                     end = sample_num | ||||
|                 else: | ||||
|                     end = begin + batch_size | ||||
|                 x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device) | ||||
|                 with torch.no_grad(): | ||||
|                     pred = self.model(x_batch).detach().cpu().numpy() | ||||
|                 preds.append(pred) | ||||
|         return pd.Series(np.concatenate(preds), index=index) | ||||
							
								
								
									
										199
									
								
								AutoDL-Projects/xautodl/trade_models/transformers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										199
									
								
								AutoDL-Projects/xautodl/trade_models/transformers.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,199 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
|  | ||||
| import math | ||||
| from functools import partial | ||||
| from typing import Optional, Text, List | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from xautodl import spaces | ||||
| from xautodl.xlayers import weight_init | ||||
| from xautodl.xlayers import super_core | ||||
|  | ||||
|  | ||||
| __all__ = ["DefaultSearchSpace", "DEFAULT_NET_CONFIG", "get_transformer"] | ||||
|  | ||||
|  | ||||
| def _get_mul_specs(candidates, num): | ||||
|     results = [] | ||||
|     for i in range(num): | ||||
|         results.append(spaces.Categorical(*candidates)) | ||||
|     return results | ||||
|  | ||||
|  | ||||
| def _get_list_mul(num, multipler): | ||||
|     results = [] | ||||
|     for i in range(1, num + 1): | ||||
|         results.append(i * multipler) | ||||
|     return results | ||||
|  | ||||
|  | ||||
| def _assert_types(x, expected_types): | ||||
|     if not isinstance(x, expected_types): | ||||
|         raise TypeError( | ||||
|             "The type [{:}] is expected to be {:}.".format(type(x), expected_types) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| DEFAULT_NET_CONFIG = None | ||||
| _default_max_depth = 6 | ||||
| DefaultSearchSpace = dict( | ||||
|     d_feat=6, | ||||
|     embed_dim=32, | ||||
|     # embed_dim=spaces.Categorical(*_get_list_mul(8, 16)), | ||||
|     num_heads=[4] * _default_max_depth, | ||||
|     mlp_hidden_multipliers=[4] * _default_max_depth, | ||||
|     qkv_bias=True, | ||||
|     pos_drop=0.0, | ||||
|     other_drop=0.0, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class SuperTransformer(super_core.SuperModule): | ||||
|     """The super model for transformer.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         d_feat: int = 6, | ||||
|         embed_dim: List[super_core.IntSpaceType] = DefaultSearchSpace["embed_dim"], | ||||
|         num_heads: List[super_core.IntSpaceType] = DefaultSearchSpace["num_heads"], | ||||
|         mlp_hidden_multipliers: List[super_core.IntSpaceType] = DefaultSearchSpace[ | ||||
|             "mlp_hidden_multipliers" | ||||
|         ], | ||||
|         qkv_bias: bool = DefaultSearchSpace["qkv_bias"], | ||||
|         pos_drop: float = DefaultSearchSpace["pos_drop"], | ||||
|         other_drop: float = DefaultSearchSpace["other_drop"], | ||||
|         max_seq_len: int = 65, | ||||
|     ): | ||||
|         super(SuperTransformer, self).__init__() | ||||
|         self._embed_dim = embed_dim | ||||
|         self._num_heads = num_heads | ||||
|         self._mlp_hidden_multipliers = mlp_hidden_multipliers | ||||
|  | ||||
|         # the stem part | ||||
|         self.input_embed = super_core.SuperAlphaEBDv1(d_feat, embed_dim) | ||||
|         self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) | ||||
|         self.pos_embed = super_core.SuperPositionalEncoder( | ||||
|             d_model=embed_dim, max_seq_len=max_seq_len, dropout=pos_drop | ||||
|         ) | ||||
|         # build the transformer encode layers -->> check params | ||||
|         _assert_types(num_heads, (tuple, list)) | ||||
|         _assert_types(mlp_hidden_multipliers, (tuple, list)) | ||||
|         assert len(num_heads) == len(mlp_hidden_multipliers), "{:} vs {:}".format( | ||||
|             len(num_heads), len(mlp_hidden_multipliers) | ||||
|         ) | ||||
|         # build the transformer encode layers -->> backbone | ||||
|         layers = [] | ||||
|         for num_head, mlp_hidden_multiplier in zip(num_heads, mlp_hidden_multipliers): | ||||
|             layer = super_core.SuperTransformerEncoderLayer( | ||||
|                 embed_dim, | ||||
|                 num_head, | ||||
|                 qkv_bias, | ||||
|                 mlp_hidden_multiplier, | ||||
|                 other_drop, | ||||
|             ) | ||||
|             layers.append(layer) | ||||
|         self.backbone = super_core.SuperSequential(*layers) | ||||
|  | ||||
|         # the regression head | ||||
|         self.head = super_core.SuperSequential( | ||||
|             super_core.SuperLayerNorm1D(embed_dim), super_core.SuperLinear(embed_dim, 1) | ||||
|         ) | ||||
|         weight_init.trunc_normal_(self.cls_token, std=0.02) | ||||
|         self.apply(self._init_weights) | ||||
|  | ||||
|     @property | ||||
|     def embed_dim(self): | ||||
|         return spaces.get_max(self._embed_dim) | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         root_node = spaces.VirtualNode(id(self)) | ||||
|         if not spaces.is_determined(self._embed_dim): | ||||
|             root_node.append("_embed_dim", self._embed_dim.abstract(reuse_last=True)) | ||||
|         xdict = dict( | ||||
|             input_embed=self.input_embed.abstract_search_space, | ||||
|             pos_embed=self.pos_embed.abstract_search_space, | ||||
|             backbone=self.backbone.abstract_search_space, | ||||
|             head=self.head.abstract_search_space, | ||||
|         ) | ||||
|         for key, space in xdict.items(): | ||||
|             if not spaces.is_determined(space): | ||||
|                 root_node.append(key, space) | ||||
|         return root_node | ||||
|  | ||||
|     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||
|         super(SuperTransformer, self).apply_candidate(abstract_child) | ||||
|         xkeys = ("input_embed", "pos_embed", "backbone", "head") | ||||
|         for key in xkeys: | ||||
|             if key in abstract_child: | ||||
|                 getattr(self, key).apply_candidate(abstract_child[key]) | ||||
|  | ||||
|     def _init_weights(self, m): | ||||
|         if isinstance(m, nn.Linear): | ||||
|             weight_init.trunc_normal_(m.weight, std=0.02) | ||||
|             if isinstance(m, nn.Linear) and m.bias is not None: | ||||
|                 nn.init.constant_(m.bias, 0) | ||||
|         elif isinstance(m, super_core.SuperLinear): | ||||
|             weight_init.trunc_normal_(m._super_weight, std=0.02) | ||||
|             if m._super_bias is not None: | ||||
|                 nn.init.constant_(m._super_bias, 0) | ||||
|         elif isinstance(m, super_core.SuperLayerNorm1D): | ||||
|             nn.init.constant_(m.weight, 1.0) | ||||
|             nn.init.constant_(m.bias, 0) | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         batch, flatten_size = input.shape | ||||
|         feats = self.input_embed(input)  # batch * 60 * 64 | ||||
|         if not spaces.is_determined(self._embed_dim): | ||||
|             embed_dim = self.abstract_child["_embed_dim"].value | ||||
|         else: | ||||
|             embed_dim = spaces.get_determined_value(self._embed_dim) | ||||
|         cls_tokens = self.cls_token.expand(batch, -1, -1) | ||||
|         cls_tokens = F.interpolate( | ||||
|             cls_tokens, size=(embed_dim), mode="linear", align_corners=True | ||||
|         ) | ||||
|         feats_w_ct = torch.cat((cls_tokens, feats), dim=1) | ||||
|         feats_w_tp = self.pos_embed(feats_w_ct) | ||||
|         xfeats = self.backbone(feats_w_tp) | ||||
|         xfeats = xfeats[:, 0, :]  # use the feature for the first token | ||||
|         predicts = self.head(xfeats).squeeze(-1) | ||||
|         return predicts | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         batch, flatten_size = input.shape | ||||
|         feats = self.input_embed(input)  # batch * 60 * 64 | ||||
|         cls_tokens = self.cls_token.expand(batch, -1, -1) | ||||
|         feats_w_ct = torch.cat((cls_tokens, feats), dim=1) | ||||
|         feats_w_tp = self.pos_embed(feats_w_ct) | ||||
|         xfeats = self.backbone(feats_w_tp) | ||||
|         xfeats = xfeats[:, 0, :]  # use the feature for the first token | ||||
|         predicts = self.head(xfeats).squeeze(-1) | ||||
|         return predicts | ||||
|  | ||||
|  | ||||
| def get_transformer(config): | ||||
|     if config is None: | ||||
|         return SuperTransformer(6) | ||||
|     if not isinstance(config, dict): | ||||
|         raise ValueError("Invalid Configuration: {:}".format(config)) | ||||
|     name = config.get("name", "basic") | ||||
|     if name == "basic": | ||||
|         model = SuperTransformer( | ||||
|             d_feat=config.get("d_feat"), | ||||
|             embed_dim=config.get("embed_dim"), | ||||
|             num_heads=config.get("num_heads"), | ||||
|             mlp_hidden_multipliers=config.get("mlp_hidden_multipliers"), | ||||
|             qkv_bias=config.get("qkv_bias"), | ||||
|             pos_drop=config.get("pos_drop"), | ||||
|             other_drop=config.get("other_drop"), | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("Unknown model name: {:}".format(name)) | ||||
|     return model | ||||
							
								
								
									
										14
									
								
								AutoDL-Projects/xautodl/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								AutoDL-Projects/xautodl/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| # This directory contains some ad-hoc functions, classes, etc. | ||||
| # It will be re-formulated in the future. | ||||
| ##################################################### | ||||
| from .evaluation_utils import obtain_accuracy | ||||
| from .gpu_manager import GPUManager | ||||
| from .flop_benchmark import get_model_infos, count_parameters, count_parameters_in_MB | ||||
| from .affine_utils import normalize_points, denormalize_points | ||||
| from .affine_utils import identity2affine, solve2theta, affine2image | ||||
| from .hash_utils import get_md5_file | ||||
| from .str_utils import split_str2indexes | ||||
| from .str_utils import show_mean_var | ||||
							
								
								
									
										159
									
								
								AutoDL-Projects/xautodl/utils/affine_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								AutoDL-Projects/xautodl/utils/affine_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,159 @@ | ||||
| # functions for affine transformation | ||||
| import math | ||||
| import torch | ||||
| import numpy as np | ||||
| import torch.nn.functional as F | ||||
|  | ||||
|  | ||||
| def identity2affine(full=False): | ||||
|     if not full: | ||||
|         parameters = torch.zeros((2, 3)) | ||||
|         parameters[0, 0] = parameters[1, 1] = 1 | ||||
|     else: | ||||
|         parameters = torch.zeros((3, 3)) | ||||
|         parameters[0, 0] = parameters[1, 1] = parameters[2, 2] = 1 | ||||
|     return parameters | ||||
|  | ||||
|  | ||||
| def normalize_L(x, L): | ||||
|     return -1.0 + 2.0 * x / (L - 1) | ||||
|  | ||||
|  | ||||
| def denormalize_L(x, L): | ||||
|     return (x + 1.0) / 2.0 * (L - 1) | ||||
|  | ||||
|  | ||||
| def crop2affine(crop_box, W, H): | ||||
|     assert len(crop_box) == 4, "Invalid crop-box : {:}".format(crop_box) | ||||
|     parameters = torch.zeros(3, 3) | ||||
|     x1, y1 = normalize_L(crop_box[0], W), normalize_L(crop_box[1], H) | ||||
|     x2, y2 = normalize_L(crop_box[2], W), normalize_L(crop_box[3], H) | ||||
|     parameters[0, 0] = (x2 - x1) / 2 | ||||
|     parameters[0, 2] = (x2 + x1) / 2 | ||||
|  | ||||
|     parameters[1, 1] = (y2 - y1) / 2 | ||||
|     parameters[1, 2] = (y2 + y1) / 2 | ||||
|     parameters[2, 2] = 1 | ||||
|     return parameters | ||||
|  | ||||
|  | ||||
| def scale2affine(scalex, scaley): | ||||
|     parameters = torch.zeros(3, 3) | ||||
|     parameters[0, 0] = scalex | ||||
|     parameters[1, 1] = scaley | ||||
|     parameters[2, 2] = 1 | ||||
|     return parameters | ||||
|  | ||||
|  | ||||
| def offset2affine(offx, offy): | ||||
|     parameters = torch.zeros(3, 3) | ||||
|     parameters[0, 0] = parameters[1, 1] = parameters[2, 2] = 1 | ||||
|     parameters[0, 2] = offx | ||||
|     parameters[1, 2] = offy | ||||
|     return parameters | ||||
|  | ||||
|  | ||||
| def horizontalmirror2affine(): | ||||
|     parameters = torch.zeros(3, 3) | ||||
|     parameters[0, 0] = -1 | ||||
|     parameters[1, 1] = parameters[2, 2] = 1 | ||||
|     return parameters | ||||
|  | ||||
|  | ||||
| # clockwise rotate image = counterclockwise rotate the rectangle | ||||
| # degree is between [0, 360] | ||||
| def rotate2affine(degree): | ||||
|     assert degree >= 0 and degree <= 360, "Invalid degree : {:}".format(degree) | ||||
|     degree = degree / 180 * math.pi | ||||
|     parameters = torch.zeros(3, 3) | ||||
|     parameters[0, 0] = math.cos(-degree) | ||||
|     parameters[0, 1] = -math.sin(-degree) | ||||
|     parameters[1, 0] = math.sin(-degree) | ||||
|     parameters[1, 1] = math.cos(-degree) | ||||
|     parameters[2, 2] = 1 | ||||
|     return parameters | ||||
|  | ||||
|  | ||||
| # shape is a tuple [H, W] | ||||
| def normalize_points(shape, points): | ||||
|     assert (isinstance(shape, tuple) or isinstance(shape, list)) and len( | ||||
|         shape | ||||
|     ) == 2, "invalid shape : {:}".format(shape) | ||||
|     assert isinstance(points, torch.Tensor) and ( | ||||
|         points.shape[0] == 2 | ||||
|     ), "points are wrong : {:}".format(points.shape) | ||||
|     (H, W), points = shape, points.clone() | ||||
|     points[0, :] = normalize_L(points[0, :], W) | ||||
|     points[1, :] = normalize_L(points[1, :], H) | ||||
|     return points | ||||
|  | ||||
|  | ||||
| # shape is a tuple [H, W] | ||||
| def normalize_points_batch(shape, points): | ||||
|     assert (isinstance(shape, tuple) or isinstance(shape, list)) and len( | ||||
|         shape | ||||
|     ) == 2, "invalid shape : {:}".format(shape) | ||||
|     assert isinstance(points, torch.Tensor) and ( | ||||
|         points.size(-1) == 2 | ||||
|     ), "points are wrong : {:}".format(points.shape) | ||||
|     (H, W), points = shape, points.clone() | ||||
|     x = normalize_L(points[..., 0], W) | ||||
|     y = normalize_L(points[..., 1], H) | ||||
|     return torch.stack((x, y), dim=-1) | ||||
|  | ||||
|  | ||||
| # shape is a tuple [H, W] | ||||
| def denormalize_points(shape, points): | ||||
|     assert (isinstance(shape, tuple) or isinstance(shape, list)) and len( | ||||
|         shape | ||||
|     ) == 2, "invalid shape : {:}".format(shape) | ||||
|     assert isinstance(points, torch.Tensor) and ( | ||||
|         points.shape[0] == 2 | ||||
|     ), "points are wrong : {:}".format(points.shape) | ||||
|     (H, W), points = shape, points.clone() | ||||
|     points[0, :] = denormalize_L(points[0, :], W) | ||||
|     points[1, :] = denormalize_L(points[1, :], H) | ||||
|     return points | ||||
|  | ||||
|  | ||||
| # shape is a tuple [H, W] | ||||
| def denormalize_points_batch(shape, points): | ||||
|     assert (isinstance(shape, tuple) or isinstance(shape, list)) and len( | ||||
|         shape | ||||
|     ) == 2, "invalid shape : {:}".format(shape) | ||||
|     assert isinstance(points, torch.Tensor) and ( | ||||
|         points.shape[-1] == 2 | ||||
|     ), "points are wrong : {:}".format(points.shape) | ||||
|     (H, W), points = shape, points.clone() | ||||
|     x = denormalize_L(points[..., 0], W) | ||||
|     y = denormalize_L(points[..., 1], H) | ||||
|     return torch.stack((x, y), dim=-1) | ||||
|  | ||||
|  | ||||
| # make target * theta = source | ||||
| def solve2theta(source, target): | ||||
|     source, target = source.clone(), target.clone() | ||||
|     oks = source[2, :] == 1 | ||||
|     assert torch.sum(oks).item() >= 3, "valid points : {:} is short".format(oks) | ||||
|     if target.size(0) == 2: | ||||
|         target = torch.cat((target, oks.unsqueeze(0).float()), dim=0) | ||||
|     source, target = source[:, oks], target[:, oks] | ||||
|     source, target = source.transpose(1, 0), target.transpose(1, 0) | ||||
|     assert source.size(1) == target.size(1) == 3 | ||||
|     # X, residual, rank, s = np.linalg.lstsq(target.numpy(), source.numpy()) | ||||
|     # theta = torch.Tensor(X.T[:2, :]) | ||||
|     X_, qr = torch.gels(source, target) | ||||
|     theta = X_[:3, :2].transpose(1, 0) | ||||
|     return theta | ||||
|  | ||||
|  | ||||
| # shape = [H,W] | ||||
| def affine2image(image, theta, shape): | ||||
|     C, H, W = image.size() | ||||
|     theta = theta[:2, :].unsqueeze(0) | ||||
|     grid_size = torch.Size([1, C, shape[0], shape[1]]) | ||||
|     grid = F.affine_grid(theta, grid_size) | ||||
|     affI = F.grid_sample( | ||||
|         image.unsqueeze(0), grid, mode="bilinear", padding_mode="border" | ||||
|     ) | ||||
|     return affI.squeeze(0) | ||||
							
								
								
									
										17
									
								
								AutoDL-Projects/xautodl/utils/evaluation_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								AutoDL-Projects/xautodl/utils/evaluation_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | ||||
| import torch | ||||
|  | ||||
|  | ||||
| def obtain_accuracy(output, target, topk=(1,)): | ||||
|     """Computes the precision@k for the specified values of k""" | ||||
|     maxk = max(topk) | ||||
|     batch_size = target.size(0) | ||||
|  | ||||
|     _, pred = output.topk(maxk, 1, True, True) | ||||
|     pred = pred.t() | ||||
|     correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||||
|  | ||||
|     res = [] | ||||
|     for k in topk: | ||||
|         correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||||
|         res.append(correct_k.mul_(100.0 / batch_size)) | ||||
|     return res | ||||
							
								
								
									
										227
									
								
								AutoDL-Projects/xautodl/utils/flop_benchmark.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										227
									
								
								AutoDL-Projects/xautodl/utils/flop_benchmark.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,227 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 # | ||||
| ##################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| def count_parameters_in_MB(model): | ||||
|     return count_parameters(model, "mb", deprecated=True) | ||||
|  | ||||
|  | ||||
| def count_parameters(model_or_parameters, unit="mb", deprecated=False): | ||||
|     if isinstance(model_or_parameters, nn.Module): | ||||
|         counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters()) | ||||
|     elif isinstance(model_or_parameters, nn.Parameter): | ||||
|         counts = model_or_parameters.numel() | ||||
|     elif isinstance(model_or_parameters, (list, tuple)): | ||||
|         counts = sum( | ||||
|             count_parameters(x, None, deprecated) for x in model_or_parameters | ||||
|         ) | ||||
|     else: | ||||
|         counts = sum(np.prod(v.size()) for v in model_or_parameters) | ||||
|     if not isinstance(unit, str) and unit is not None: | ||||
|         raise ValueError("Unknow type of unit: {:}".format(unit)) | ||||
|     elif unit is None: | ||||
|         counts = counts | ||||
|     elif unit.lower() == "kb" or unit.lower() == "k": | ||||
|         counts /= 1e3 if deprecated else 2 ** 10  # changed from 1e3 to 2^10 | ||||
|     elif unit.lower() == "mb" or unit.lower() == "m": | ||||
|         counts /= 1e6 if deprecated else 2 ** 20  # changed from 1e6 to 2^20 | ||||
|     elif unit.lower() == "gb" or unit.lower() == "g": | ||||
|         counts /= 1e9 if deprecated else 2 ** 30  # changed from 1e9 to 2^30 | ||||
|     else: | ||||
|         raise ValueError("Unknow unit: {:}".format(unit)) | ||||
|     return counts | ||||
|  | ||||
|  | ||||
| def get_model_infos(model, shape): | ||||
|     # model = copy.deepcopy( model ) | ||||
|  | ||||
|     model = add_flops_counting_methods(model) | ||||
|     # model = model.cuda() | ||||
|     model.eval() | ||||
|  | ||||
|     # cache_inputs = torch.zeros(*shape).cuda() | ||||
|     # cache_inputs = torch.zeros(*shape) | ||||
|     cache_inputs = torch.rand(*shape) | ||||
|     if next(model.parameters()).is_cuda: | ||||
|         cache_inputs = cache_inputs.cuda() | ||||
|     # print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log) | ||||
|     with torch.no_grad(): | ||||
|         _____ = model(cache_inputs) | ||||
|     FLOPs = compute_average_flops_cost(model) / 1e6 | ||||
|     Param = count_parameters_in_MB(model) | ||||
|  | ||||
|     if hasattr(model, "auxiliary_param"): | ||||
|         aux_params = count_parameters_in_MB(model.auxiliary_param()) | ||||
|         print("The auxiliary params of this model is : {:}".format(aux_params)) | ||||
|         print( | ||||
|             "We remove the auxiliary params from the total params ({:}) when counting".format( | ||||
|                 Param | ||||
|             ) | ||||
|         ) | ||||
|         Param = Param - aux_params | ||||
|  | ||||
|     # print_log('FLOPs : {:} MB'.format(FLOPs), log) | ||||
|     torch.cuda.empty_cache() | ||||
|     model.apply(remove_hook_function) | ||||
|     return FLOPs, Param | ||||
|  | ||||
|  | ||||
| # ---- Public functions | ||||
| def add_flops_counting_methods(model): | ||||
|     model.__batch_counter__ = 0 | ||||
|     add_batch_counter_hook_function(model) | ||||
|     model.apply(add_flops_counter_variable_or_reset) | ||||
|     model.apply(add_flops_counter_hook_function) | ||||
|     return model | ||||
|  | ||||
|  | ||||
| def compute_average_flops_cost(model): | ||||
|     """ | ||||
|     A method that will be available after add_flops_counting_methods() is called on a desired net object. | ||||
|     Returns current mean flops consumption per image. | ||||
|     """ | ||||
|     batches_count = model.__batch_counter__ | ||||
|     flops_sum = 0 | ||||
|     # or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \ | ||||
|     for module in model.modules(): | ||||
|         if ( | ||||
|             isinstance(module, torch.nn.Conv2d) | ||||
|             or isinstance(module, torch.nn.Linear) | ||||
|             or isinstance(module, torch.nn.Conv1d) | ||||
|             or hasattr(module, "calculate_flop_self") | ||||
|         ): | ||||
|             flops_sum += module.__flops__ | ||||
|     return flops_sum / batches_count | ||||
|  | ||||
|  | ||||
| # ---- Internal functions | ||||
| def pool_flops_counter_hook(pool_module, inputs, output): | ||||
|     batch_size = inputs[0].size(0) | ||||
|     kernel_size = pool_module.kernel_size | ||||
|     out_C, output_height, output_width = output.shape[1:] | ||||
|     assert out_C == inputs[0].size(1), "{:} vs. {:}".format(out_C, inputs[0].size()) | ||||
|  | ||||
|     overall_flops = ( | ||||
|         batch_size * out_C * output_height * output_width * kernel_size * kernel_size | ||||
|     ) | ||||
|     pool_module.__flops__ += overall_flops | ||||
|  | ||||
|  | ||||
| def self_calculate_flops_counter_hook(self_module, inputs, output): | ||||
|     overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape) | ||||
|     self_module.__flops__ += overall_flops | ||||
|  | ||||
|  | ||||
| def fc_flops_counter_hook(fc_module, inputs, output): | ||||
|     batch_size = inputs[0].size(0) | ||||
|     xin, xout = fc_module.in_features, fc_module.out_features | ||||
|     assert xin == inputs[0].size(1) and xout == output.size(1), "IO=({:}, {:})".format( | ||||
|         xin, xout | ||||
|     ) | ||||
|     overall_flops = batch_size * xin * xout | ||||
|     if fc_module.bias is not None: | ||||
|         overall_flops += batch_size * xout | ||||
|     fc_module.__flops__ += overall_flops | ||||
|  | ||||
|  | ||||
| def conv1d_flops_counter_hook(conv_module, inputs, outputs): | ||||
|     batch_size = inputs[0].size(0) | ||||
|     outL = outputs.shape[-1] | ||||
|     [kernel] = conv_module.kernel_size | ||||
|     in_channels = conv_module.in_channels | ||||
|     out_channels = conv_module.out_channels | ||||
|     groups = conv_module.groups | ||||
|     conv_per_position_flops = kernel * in_channels * out_channels / groups | ||||
|  | ||||
|     active_elements_count = batch_size * outL | ||||
|     overall_flops = conv_per_position_flops * active_elements_count | ||||
|  | ||||
|     if conv_module.bias is not None: | ||||
|         overall_flops += out_channels * active_elements_count | ||||
|     conv_module.__flops__ += overall_flops | ||||
|  | ||||
|  | ||||
| def conv2d_flops_counter_hook(conv_module, inputs, output): | ||||
|     batch_size = inputs[0].size(0) | ||||
|     output_height, output_width = output.shape[2:] | ||||
|  | ||||
|     kernel_height, kernel_width = conv_module.kernel_size | ||||
|     in_channels = conv_module.in_channels | ||||
|     out_channels = conv_module.out_channels | ||||
|     groups = conv_module.groups | ||||
|     conv_per_position_flops = ( | ||||
|         kernel_height * kernel_width * in_channels * out_channels / groups | ||||
|     ) | ||||
|  | ||||
|     active_elements_count = batch_size * output_height * output_width | ||||
|     overall_flops = conv_per_position_flops * active_elements_count | ||||
|  | ||||
|     if conv_module.bias is not None: | ||||
|         overall_flops += out_channels * active_elements_count | ||||
|     conv_module.__flops__ += overall_flops | ||||
|  | ||||
|  | ||||
| def batch_counter_hook(module, inputs, output): | ||||
|     # Can have multiple inputs, getting the first one | ||||
|     inputs = inputs[0] | ||||
|     batch_size = inputs.shape[0] | ||||
|     module.__batch_counter__ += batch_size | ||||
|  | ||||
|  | ||||
| def add_batch_counter_hook_function(module): | ||||
|     if not hasattr(module, "__batch_counter_handle__"): | ||||
|         handle = module.register_forward_hook(batch_counter_hook) | ||||
|         module.__batch_counter_handle__ = handle | ||||
|  | ||||
|  | ||||
| def add_flops_counter_variable_or_reset(module): | ||||
|     if ( | ||||
|         isinstance(module, torch.nn.Conv2d) | ||||
|         or isinstance(module, torch.nn.Linear) | ||||
|         or isinstance(module, torch.nn.Conv1d) | ||||
|         or isinstance(module, torch.nn.AvgPool2d) | ||||
|         or isinstance(module, torch.nn.MaxPool2d) | ||||
|         or hasattr(module, "calculate_flop_self") | ||||
|     ): | ||||
|         module.__flops__ = 0 | ||||
|  | ||||
|  | ||||
| def add_flops_counter_hook_function(module): | ||||
|     if isinstance(module, torch.nn.Conv2d): | ||||
|         if not hasattr(module, "__flops_handle__"): | ||||
|             handle = module.register_forward_hook(conv2d_flops_counter_hook) | ||||
|             module.__flops_handle__ = handle | ||||
|     elif isinstance(module, torch.nn.Conv1d): | ||||
|         if not hasattr(module, "__flops_handle__"): | ||||
|             handle = module.register_forward_hook(conv1d_flops_counter_hook) | ||||
|             module.__flops_handle__ = handle | ||||
|     elif isinstance(module, torch.nn.Linear): | ||||
|         if not hasattr(module, "__flops_handle__"): | ||||
|             handle = module.register_forward_hook(fc_flops_counter_hook) | ||||
|             module.__flops_handle__ = handle | ||||
|     elif isinstance(module, torch.nn.AvgPool2d) or isinstance( | ||||
|         module, torch.nn.MaxPool2d | ||||
|     ): | ||||
|         if not hasattr(module, "__flops_handle__"): | ||||
|             handle = module.register_forward_hook(pool_flops_counter_hook) | ||||
|             module.__flops_handle__ = handle | ||||
|     elif hasattr(module, "calculate_flop_self"):  # self-defined module | ||||
|         if not hasattr(module, "__flops_handle__"): | ||||
|             handle = module.register_forward_hook(self_calculate_flops_counter_hook) | ||||
|             module.__flops_handle__ = handle | ||||
|  | ||||
|  | ||||
| def remove_hook_function(module): | ||||
|     hookers = ["__batch_counter_handle__", "__flops_handle__"] | ||||
|     for hooker in hookers: | ||||
|         if hasattr(module, hooker): | ||||
|             handle = getattr(module, hooker) | ||||
|             handle.remove() | ||||
|     keys = ["__flops__", "__batch_counter__", "__flops__"] + hookers | ||||
|     for ckey in keys: | ||||
|         if hasattr(module, ckey): | ||||
|             delattr(module, ckey) | ||||
							
								
								
									
										86
									
								
								AutoDL-Projects/xautodl/utils/gpu_manager.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								AutoDL-Projects/xautodl/utils/gpu_manager.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,86 @@ | ||||
| import os | ||||
|  | ||||
|  | ||||
| class GPUManager: | ||||
|     queries = ( | ||||
|         "index", | ||||
|         "gpu_name", | ||||
|         "memory.free", | ||||
|         "memory.used", | ||||
|         "memory.total", | ||||
|         "power.draw", | ||||
|         "power.limit", | ||||
|     ) | ||||
|  | ||||
|     def __init__(self): | ||||
|         all_gpus = self.query_gpu(False) | ||||
|  | ||||
|     def get_info(self, ctype): | ||||
|         cmd = "nvidia-smi --query-gpu={} --format=csv,noheader".format(ctype) | ||||
|         lines = os.popen(cmd).readlines() | ||||
|         lines = [line.strip("\n") for line in lines] | ||||
|         return lines | ||||
|  | ||||
|     def query_gpu(self, show=True): | ||||
|         num_gpus = len(self.get_info("index")) | ||||
|         all_gpus = [{} for i in range(num_gpus)] | ||||
|         for query in self.queries: | ||||
|             infos = self.get_info(query) | ||||
|             for idx, info in enumerate(infos): | ||||
|                 all_gpus[idx][query] = info | ||||
|  | ||||
|         if "CUDA_VISIBLE_DEVICES" in os.environ: | ||||
|             CUDA_VISIBLE_DEVICES = os.environ["CUDA_VISIBLE_DEVICES"].split(",") | ||||
|             selected_gpus = [] | ||||
|             for idx, CUDA_VISIBLE_DEVICE in enumerate(CUDA_VISIBLE_DEVICES): | ||||
|                 find = False | ||||
|                 for gpu in all_gpus: | ||||
|                     if gpu["index"] == CUDA_VISIBLE_DEVICE: | ||||
|                         assert not find, "Duplicate cuda device index : {}".format( | ||||
|                             CUDA_VISIBLE_DEVICE | ||||
|                         ) | ||||
|                         find = True | ||||
|                         selected_gpus.append(gpu.copy()) | ||||
|                         selected_gpus[-1]["index"] = "{}".format(idx) | ||||
|                 assert find, "Does not find the device : {}".format(CUDA_VISIBLE_DEVICE) | ||||
|             all_gpus = selected_gpus | ||||
|  | ||||
|         if show: | ||||
|             allstrings = "" | ||||
|             for gpu in all_gpus: | ||||
|                 string = "| " | ||||
|                 for query in self.queries: | ||||
|                     if query.find("memory") == 0: | ||||
|                         xinfo = "{:>9}".format(gpu[query]) | ||||
|                     else: | ||||
|                         xinfo = gpu[query] | ||||
|                     string = string + query + " : " + xinfo + " | " | ||||
|                 allstrings = allstrings + string + "\n" | ||||
|             return allstrings | ||||
|         else: | ||||
|             return all_gpus | ||||
|  | ||||
|     def select_by_memory(self, numbers=1): | ||||
|         all_gpus = self.query_gpu(False) | ||||
|         assert numbers <= len(all_gpus), "Require {} gpus more than you have".format( | ||||
|             numbers | ||||
|         ) | ||||
|         alls = [] | ||||
|         for idx, gpu in enumerate(all_gpus): | ||||
|             free_memory = gpu["memory.free"] | ||||
|             free_memory = free_memory.split(" ")[0] | ||||
|             free_memory = int(free_memory) | ||||
|             index = gpu["index"] | ||||
|             alls.append((free_memory, index)) | ||||
|         alls.sort(reverse=True) | ||||
|         alls = [int(alls[i][1]) for i in range(numbers)] | ||||
|         return sorted(alls) | ||||
|  | ||||
|  | ||||
| """ | ||||
| if __name__ == '__main__': | ||||
|   manager = GPUManager() | ||||
|   manager.query_gpu(True) | ||||
|   indexes = manager.select_by_memory(3) | ||||
|   print (indexes) | ||||
| """ | ||||
							
								
								
									
										17
									
								
								AutoDL-Projects/xautodl/utils/hash_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								AutoDL-Projects/xautodl/utils/hash_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | ||||
| import os | ||||
| import hashlib | ||||
|  | ||||
|  | ||||
| def get_md5_file(file_path, post_truncated=5): | ||||
|     md5_hash = hashlib.md5() | ||||
|     if os.path.exists(file_path): | ||||
|         xfile = open(file_path, "rb") | ||||
|         content = xfile.read() | ||||
|         md5_hash.update(content) | ||||
|         digest = md5_hash.hexdigest() | ||||
|     else: | ||||
|         raise ValueError("[get_md5_file] {:} does not exist".format(file_path)) | ||||
|     if post_truncated is None: | ||||
|         return digest | ||||
|     else: | ||||
|         return digest[-post_truncated:] | ||||
							
								
								
									
										76
									
								
								AutoDL-Projects/xautodl/utils/nas_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								AutoDL-Projects/xautodl/utils/nas_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,76 @@ | ||||
| # This file is for experimental usage | ||||
| import torch, random | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch.nn as nn | ||||
|  | ||||
| # modules in AutoDL | ||||
| from models import CellStructure | ||||
| from log_utils import time_string | ||||
|  | ||||
|  | ||||
| def evaluate_one_shot(model, xloader, api, cal_mode, seed=111): | ||||
|     print( | ||||
|         "This is an old version of codes to use NAS-Bench-API, and should be modified to align with the new version. Please contact me for more details if you use this function." | ||||
|     ) | ||||
|     weights = deepcopy(model.state_dict()) | ||||
|     model.train(cal_mode) | ||||
|     with torch.no_grad(): | ||||
|         logits = nn.functional.log_softmax(model.arch_parameters, dim=-1) | ||||
|         archs = CellStructure.gen_all(model.op_names, model.max_nodes, False) | ||||
|         probs, accuracies, gt_accs_10_valid, gt_accs_10_test = [], [], [], [] | ||||
|         loader_iter = iter(xloader) | ||||
|         random.seed(seed) | ||||
|         random.shuffle(archs) | ||||
|         for idx, arch in enumerate(archs): | ||||
|             arch_index = api.query_index_by_arch(arch) | ||||
|             metrics = api.get_more_info(arch_index, "cifar10-valid", None, False, False) | ||||
|             gt_accs_10_valid.append(metrics["valid-accuracy"]) | ||||
|             metrics = api.get_more_info(arch_index, "cifar10", None, False, False) | ||||
|             gt_accs_10_test.append(metrics["test-accuracy"]) | ||||
|             select_logits = [] | ||||
|             for i, node_info in enumerate(arch.nodes): | ||||
|                 for op, xin in node_info: | ||||
|                     node_str = "{:}<-{:}".format(i + 1, xin) | ||||
|                     op_index = model.op_names.index(op) | ||||
|                     select_logits.append(logits[model.edge2index[node_str], op_index]) | ||||
|             cur_prob = sum(select_logits).item() | ||||
|             probs.append(cur_prob) | ||||
|         cor_prob_valid = np.corrcoef(probs, gt_accs_10_valid)[0, 1] | ||||
|         cor_prob_test = np.corrcoef(probs, gt_accs_10_test)[0, 1] | ||||
|         print( | ||||
|             "{:} correlation for probabilities : {:.6f} on CIFAR-10 validation and {:.6f} on CIFAR-10 test".format( | ||||
|                 time_string(), cor_prob_valid, cor_prob_test | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         for idx, arch in enumerate(archs): | ||||
|             model.set_cal_mode("dynamic", arch) | ||||
|             try: | ||||
|                 inputs, targets = next(loader_iter) | ||||
|             except: | ||||
|                 loader_iter = iter(xloader) | ||||
|                 inputs, targets = next(loader_iter) | ||||
|             _, logits = model(inputs.cuda()) | ||||
|             _, preds = torch.max(logits, dim=-1) | ||||
|             correct = (preds == targets.cuda()).float() | ||||
|             accuracies.append(correct.mean().item()) | ||||
|             if idx != 0 and (idx % 500 == 0 or idx + 1 == len(archs)): | ||||
|                 cor_accs_valid = np.corrcoef(accuracies, gt_accs_10_valid[: idx + 1])[ | ||||
|                     0, 1 | ||||
|                 ] | ||||
|                 cor_accs_test = np.corrcoef(accuracies, gt_accs_10_test[: idx + 1])[ | ||||
|                     0, 1 | ||||
|                 ] | ||||
|                 print( | ||||
|                     "{:} {:05d}/{:05d} mode={:5s}, correlation : accs={:.5f} for CIFAR-10 valid, {:.5f} for CIFAR-10 test.".format( | ||||
|                         time_string(), | ||||
|                         idx, | ||||
|                         len(archs), | ||||
|                         "Train" if cal_mode else "Eval", | ||||
|                         cor_accs_valid, | ||||
|                         cor_accs_test, | ||||
|                     ) | ||||
|                 ) | ||||
|     model.load_state_dict(weights) | ||||
|     return archs, probs, accuracies | ||||
							
								
								
									
										129
									
								
								AutoDL-Projects/xautodl/utils/qlib_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										129
									
								
								AutoDL-Projects/xautodl/utils/qlib_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,129 @@ | ||||
| import os | ||||
| import numpy as np | ||||
| from typing import List, Text | ||||
| from collections import defaultdict, OrderedDict | ||||
|  | ||||
|  | ||||
| class QResult: | ||||
|     """A class to maintain the results of a qlib experiment.""" | ||||
|  | ||||
|     def __init__(self, name): | ||||
|         self._result = defaultdict(list) | ||||
|         self._name = name | ||||
|         self._recorder_paths = [] | ||||
|         self._date2ICs = [] | ||||
|  | ||||
|     def append(self, key, value): | ||||
|         self._result[key].append(value) | ||||
|  | ||||
|     def append_path(self, xpath): | ||||
|         self._recorder_paths.append(xpath) | ||||
|  | ||||
|     def append_date2ICs(self, date2IC): | ||||
|         if self._date2ICs:  # not empty | ||||
|             keys = sorted(list(date2IC.keys())) | ||||
|             pre_keys = sorted(list(self._date2ICs[0].keys())) | ||||
|             assert len(keys) == len(pre_keys) | ||||
|             for i, (x, y) in enumerate(zip(keys, pre_keys)): | ||||
|                 assert x == y, "[{:}] {:} vs {:}".format(i, x, y) | ||||
|         self._date2ICs.append(date2IC) | ||||
|  | ||||
|     def find_all_dates(self): | ||||
|         dates = self._date2ICs[-1].keys() | ||||
|         return sorted(list(dates)) | ||||
|  | ||||
|     def get_IC_by_date(self, date, scale=1.0): | ||||
|         values = [] | ||||
|         for date2IC in self._date2ICs: | ||||
|             values.append(date2IC[date] * scale) | ||||
|         return float(np.mean(values)), float(np.std(values)) | ||||
|  | ||||
|     @property | ||||
|     def name(self): | ||||
|         return self._name | ||||
|  | ||||
|     @property | ||||
|     def paths(self): | ||||
|         return self._recorder_paths | ||||
|  | ||||
|     @property | ||||
|     def result(self): | ||||
|         return self._result | ||||
|  | ||||
|     @property | ||||
|     def keys(self): | ||||
|         return list(self._result.keys()) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._result) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({xname}, {num} metrics)".format( | ||||
|             name=self.__class__.__name__, xname=self.name, num=len(self.result) | ||||
|         ) | ||||
|  | ||||
|     def __getitem__(self, key): | ||||
|         if key not in self._result: | ||||
|             raise ValueError( | ||||
|                 "Invalid key {:}, please use one of {:}".format(key, self.keys) | ||||
|             ) | ||||
|         values = self._result[key] | ||||
|         return float(np.mean(values)) | ||||
|  | ||||
|     def update(self, metrics, filter_keys=None): | ||||
|         for key, value in metrics.items(): | ||||
|             if filter_keys is not None and key in filter_keys: | ||||
|                 key = filter_keys[key] | ||||
|             elif filter_keys is not None: | ||||
|                 continue | ||||
|             self.append(key, value) | ||||
|  | ||||
|     @staticmethod | ||||
|     def full_str(xstr, space): | ||||
|         xformat = "{:" + str(space) + "s}" | ||||
|         return xformat.format(str(xstr)) | ||||
|  | ||||
|     @staticmethod | ||||
|     def merge_dict(dict_list): | ||||
|         new_dict = dict() | ||||
|         for xkey in dict_list[0].keys(): | ||||
|             values = [x for xdict in dict_list for x in xdict[xkey]] | ||||
|             new_dict[xkey] = values | ||||
|         return new_dict | ||||
|  | ||||
|     def info( | ||||
|         self, | ||||
|         keys: List[Text], | ||||
|         separate: Text = "& ", | ||||
|         space: int = 20, | ||||
|         verbose: bool = True, | ||||
|         version: str = "v1", | ||||
|     ): | ||||
|         avaliable_keys = [] | ||||
|         for key in keys: | ||||
|             if key not in self.result: | ||||
|                 print("There are invalid key [{:}].".format(key)) | ||||
|             else: | ||||
|                 avaliable_keys.append(key) | ||||
|         head_str = separate.join([self.full_str(x, space) for x in avaliable_keys]) | ||||
|         values = [] | ||||
|         for key in avaliable_keys: | ||||
|             if "IR" in key: | ||||
|                 current_values = [x * 100 for x in self._result[key]] | ||||
|             else: | ||||
|                 current_values = self._result[key] | ||||
|             mean = np.mean(current_values) | ||||
|             std = np.std(current_values) | ||||
|             if version == "v0": | ||||
|                 values.append("{:.2f} $\pm$ {:.2f}".format(mean, std)) | ||||
|             elif version == "v1": | ||||
|                 values.append( | ||||
|                     "{:.2f}".format(mean) + " \\subs{" + "{:.2f}".format(std) + "}" | ||||
|                 ) | ||||
|             else: | ||||
|                 raise ValueError("Unknown version") | ||||
|         value_str = separate.join([self.full_str(x, space) for x in values]) | ||||
|         if verbose: | ||||
|             print(head_str) | ||||
|             print(value_str) | ||||
|         return head_str, value_str | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user