add autodl
This commit is contained in:
		
							
								
								
									
										20
									
								
								AutoDL-Projects/xautodl/config_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								AutoDL-Projects/xautodl/config_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| # general config related functions | ||||
| from .config_utils import load_config, dict2config, configure2str | ||||
|  | ||||
| # the args setting for different experiments | ||||
| from .basic_args import obtain_basic_args | ||||
| from .attention_args import obtain_attention_args | ||||
| from .random_baseline import obtain_RandomSearch_args | ||||
| from .cls_kd_args import obtain_cls_kd_args | ||||
| from .cls_init_args import obtain_cls_init_args | ||||
| from .search_single_args import obtain_search_single_args | ||||
| from .search_args import obtain_search_args | ||||
|  | ||||
| # for network pruning | ||||
| from .pruning_args import obtain_pruning_args | ||||
|  | ||||
| # utils for args | ||||
| from .args_utils import arg_str2bool | ||||
							
								
								
									
										12
									
								
								AutoDL-Projects/xautodl/config_utils/args_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								AutoDL-Projects/xautodl/config_utils/args_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,12 @@ | ||||
| import argparse | ||||
|  | ||||
|  | ||||
| def arg_str2bool(v): | ||||
|     if isinstance(v, bool): | ||||
|         return v | ||||
|     elif v.lower() in ("yes", "true", "t", "y", "1"): | ||||
|         return True | ||||
|     elif v.lower() in ("no", "false", "f", "n", "0"): | ||||
|         return False | ||||
|     else: | ||||
|         raise argparse.ArgumentTypeError("Boolean value expected.") | ||||
							
								
								
									
										32
									
								
								AutoDL-Projects/xautodl/config_utils/attention_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								AutoDL-Projects/xautodl/config_utils/attention_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_attention_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument("--att_channel", type=int, help=".") | ||||
|     parser.add_argument("--att_spatial", type=str, help=".") | ||||
|     parser.add_argument("--att_active", type=str, help=".") | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     return args | ||||
							
								
								
									
										44
									
								
								AutoDL-Projects/xautodl/config_utils/basic_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								AutoDL-Projects/xautodl/config_utils/basic_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ################################################## | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_basic_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument( | ||||
|         "--model_source", | ||||
|         type=str, | ||||
|         default="normal", | ||||
|         help="The source of model defination.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--extra_model_path", | ||||
|         type=str, | ||||
|         default=None, | ||||
|         help="The extra model ckp file (help to indicate the searched architecture).", | ||||
|     ) | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     return args | ||||
							
								
								
									
										32
									
								
								AutoDL-Projects/xautodl/config_utils/cls_init_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								AutoDL-Projects/xautodl/config_utils/cls_init_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_cls_init_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument( | ||||
|         "--init_checkpoint", type=str, help="The checkpoint path to the initial model." | ||||
|     ) | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     return args | ||||
							
								
								
									
										43
									
								
								AutoDL-Projects/xautodl/config_utils/cls_kd_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								AutoDL-Projects/xautodl/config_utils/cls_kd_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,43 @@ | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_cls_kd_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument( | ||||
|         "--KD_checkpoint", | ||||
|         type=str, | ||||
|         help="The teacher checkpoint in knowledge distillation.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--KD_alpha", type=float, help="The alpha parameter in knowledge distillation." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--KD_temperature", | ||||
|         type=float, | ||||
|         help="The temperature parameter in knowledge distillation.", | ||||
|     ) | ||||
|     # parser.add_argument('--KD_feature',       type=float,                 help='Knowledge distillation at the feature level.') | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     return args | ||||
							
								
								
									
										135
									
								
								AutoDL-Projects/xautodl/config_utils/config_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								AutoDL-Projects/xautodl/config_utils/config_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,135 @@ | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
| # | ||||
| import os, json | ||||
| from os import path as osp | ||||
| from pathlib import Path | ||||
| from collections import namedtuple | ||||
|  | ||||
| support_types = ("str", "int", "bool", "float", "none") | ||||
|  | ||||
|  | ||||
| def convert_param(original_lists): | ||||
|     assert isinstance(original_lists, list), "The type is not right : {:}".format( | ||||
|         original_lists | ||||
|     ) | ||||
|     ctype, value = original_lists[0], original_lists[1] | ||||
|     assert ctype in support_types, "Ctype={:}, support={:}".format(ctype, support_types) | ||||
|     is_list = isinstance(value, list) | ||||
|     if not is_list: | ||||
|         value = [value] | ||||
|     outs = [] | ||||
|     for x in value: | ||||
|         if ctype == "int": | ||||
|             x = int(x) | ||||
|         elif ctype == "str": | ||||
|             x = str(x) | ||||
|         elif ctype == "bool": | ||||
|             x = bool(int(x)) | ||||
|         elif ctype == "float": | ||||
|             x = float(x) | ||||
|         elif ctype == "none": | ||||
|             if x.lower() != "none": | ||||
|                 raise ValueError( | ||||
|                     "For the none type, the value must be none instead of {:}".format(x) | ||||
|                 ) | ||||
|             x = None | ||||
|         else: | ||||
|             raise TypeError("Does not know this type : {:}".format(ctype)) | ||||
|         outs.append(x) | ||||
|     if not is_list: | ||||
|         outs = outs[0] | ||||
|     return outs | ||||
|  | ||||
|  | ||||
| def load_config(path, extra, logger): | ||||
|     path = str(path) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log(path) | ||||
|     assert os.path.exists(path), "Can not find {:}".format(path) | ||||
|     # Reading data back | ||||
|     with open(path, "r") as f: | ||||
|         data = json.load(f) | ||||
|     content = {k: convert_param(v) for k, v in data.items()} | ||||
|     assert extra is None or isinstance( | ||||
|         extra, dict | ||||
|     ), "invalid type of extra : {:}".format(extra) | ||||
|     if isinstance(extra, dict): | ||||
|         content = {**content, **extra} | ||||
|     Arguments = namedtuple("Configure", " ".join(content.keys())) | ||||
|     content = Arguments(**content) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log("{:}".format(content)) | ||||
|     return content | ||||
|  | ||||
|  | ||||
| def configure2str(config, xpath=None): | ||||
|     if not isinstance(config, dict): | ||||
|         config = config._asdict() | ||||
|  | ||||
|     def cstring(x): | ||||
|         return '"{:}"'.format(x) | ||||
|  | ||||
|     def gtype(x): | ||||
|         if isinstance(x, list): | ||||
|             x = x[0] | ||||
|         if isinstance(x, str): | ||||
|             return "str" | ||||
|         elif isinstance(x, bool): | ||||
|             return "bool" | ||||
|         elif isinstance(x, int): | ||||
|             return "int" | ||||
|         elif isinstance(x, float): | ||||
|             return "float" | ||||
|         elif x is None: | ||||
|             return "none" | ||||
|         else: | ||||
|             raise ValueError("invalid : {:}".format(x)) | ||||
|  | ||||
|     def cvalue(x, xtype): | ||||
|         if isinstance(x, list): | ||||
|             is_list = True | ||||
|         else: | ||||
|             is_list, x = False, [x] | ||||
|         temps = [] | ||||
|         for temp in x: | ||||
|             if xtype == "bool": | ||||
|                 temp = cstring(int(temp)) | ||||
|             elif xtype == "none": | ||||
|                 temp = cstring("None") | ||||
|             else: | ||||
|                 temp = cstring(temp) | ||||
|             temps.append(temp) | ||||
|         if is_list: | ||||
|             return "[{:}]".format(", ".join(temps)) | ||||
|         else: | ||||
|             return temps[0] | ||||
|  | ||||
|     xstrings = [] | ||||
|     for key, value in config.items(): | ||||
|         xtype = gtype(value) | ||||
|         string = "  {:20s} : [{:8s}, {:}]".format( | ||||
|             cstring(key), cstring(xtype), cvalue(value, xtype) | ||||
|         ) | ||||
|         xstrings.append(string) | ||||
|     Fstring = "{\n" + ",\n".join(xstrings) + "\n}" | ||||
|     if xpath is not None: | ||||
|         parent = Path(xpath).resolve().parent | ||||
|         parent.mkdir(parents=True, exist_ok=True) | ||||
|         if osp.isfile(xpath): | ||||
|             os.remove(xpath) | ||||
|         with open(xpath, "w") as text_file: | ||||
|             text_file.write("{:}".format(Fstring)) | ||||
|     return Fstring | ||||
|  | ||||
|  | ||||
| def dict2config(xdict, logger): | ||||
|     assert isinstance(xdict, dict), "invalid type : {:}".format(type(xdict)) | ||||
|     Arguments = namedtuple("Configure", " ".join(xdict.keys())) | ||||
|     content = Arguments(**xdict) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log("{:}".format(content)) | ||||
|     return content | ||||
							
								
								
									
										48
									
								
								AutoDL-Projects/xautodl/config_utils/pruning_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								AutoDL-Projects/xautodl/config_utils/pruning_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | ||||
| import os, sys, time, random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_pruning_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument( | ||||
|         "--keep_ratio", | ||||
|         type=float, | ||||
|         help="The left channel ratio compared to the original network.", | ||||
|     ) | ||||
|     parser.add_argument("--model_version", type=str, help="The network version.") | ||||
|     parser.add_argument( | ||||
|         "--KD_alpha", type=float, help="The alpha parameter in knowledge distillation." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--KD_temperature", | ||||
|         type=float, | ||||
|         help="The temperature parameter in knowledge distillation.", | ||||
|     ) | ||||
|     parser.add_argument("--Regular_W_feat", type=float, help="The .") | ||||
|     parser.add_argument("--Regular_W_conv", type=float, help="The .") | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     assert ( | ||||
|         args.keep_ratio > 0 and args.keep_ratio <= 1 | ||||
|     ), "invalid keep ratio : {:}".format(args.keep_ratio) | ||||
|     return args | ||||
							
								
								
									
										44
									
								
								AutoDL-Projects/xautodl/config_utils/random_baseline.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								AutoDL-Projects/xautodl/config_utils/random_baseline.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | ||||
| import os, sys, time, random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_RandomSearch_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--expect_flop", type=float, help="The expected flop keep ratio." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_nums", | ||||
|         type=int, | ||||
|         help="The maximum number of running random arch generating..", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--random_mode", | ||||
|         type=str, | ||||
|         choices=["random", "fix"], | ||||
|         help="The path to the optimizer configuration", | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     # assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max) | ||||
|     return args | ||||
							
								
								
									
										53
									
								
								AutoDL-Projects/xautodl/config_utils/search_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								AutoDL-Projects/xautodl/config_utils/search_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | ||||
| import os, sys, time, random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_search_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--split_path", type=str, help="The split file path.") | ||||
|     # parser.add_argument('--arch_para_pure',   type=int,                   help='The architecture-parameter pure or not.') | ||||
|     parser.add_argument( | ||||
|         "--gumbel_tau_max", type=float, help="The maximum tau for Gumbel." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--gumbel_tau_min", type=float, help="The minimum tau for Gumbel." | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.") | ||||
|     parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.") | ||||
|     parser.add_argument( | ||||
|         "--FLOP_tolerant", type=float, help="The tolerant range for FLOP." | ||||
|     ) | ||||
|     # ablation studies | ||||
|     parser.add_argument( | ||||
|         "--ablation_num_select", | ||||
|         type=int, | ||||
|         help="The number of randomly selected channels.", | ||||
|     ) | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None | ||||
|     assert ( | ||||
|         args.FLOP_tolerant is not None and args.FLOP_tolerant > 0 | ||||
|     ), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant) | ||||
|     # assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) | ||||
|     # args.arch_para_pure = bool(args.arch_para_pure) | ||||
|     return args | ||||
							
								
								
									
										48
									
								
								AutoDL-Projects/xautodl/config_utils/search_single_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								AutoDL-Projects/xautodl/config_utils/search_single_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | ||||
| import os, sys, time, random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_search_single_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--split_path", type=str, help="The split file path.") | ||||
|     parser.add_argument("--search_shape", type=str, help="The shape to be searched.") | ||||
|     # parser.add_argument('--arch_para_pure',   type=int,                   help='The architecture-parameter pure or not.') | ||||
|     parser.add_argument( | ||||
|         "--gumbel_tau_max", type=float, help="The maximum tau for Gumbel." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--gumbel_tau_min", type=float, help="The minimum tau for Gumbel." | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.") | ||||
|     parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.") | ||||
|     parser.add_argument( | ||||
|         "--FLOP_tolerant", type=float, help="The tolerant range for FLOP." | ||||
|     ) | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None | ||||
|     assert ( | ||||
|         args.FLOP_tolerant is not None and args.FLOP_tolerant > 0 | ||||
|     ), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant) | ||||
|     # assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) | ||||
|     # args.arch_para_pure = bool(args.arch_para_pure) | ||||
|     return args | ||||
							
								
								
									
										39
									
								
								AutoDL-Projects/xautodl/config_utils/share_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								AutoDL-Projects/xautodl/config_utils/share_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,39 @@ | ||||
| import os, sys, time, random, argparse | ||||
|  | ||||
|  | ||||
| def add_shared_args(parser): | ||||
|     # Data Generation | ||||
|     parser.add_argument("--dataset", type=str, help="The dataset name.") | ||||
|     parser.add_argument("--data_path", type=str, help="The dataset name.") | ||||
|     parser.add_argument( | ||||
|         "--cutout_length", type=int, help="The cutout length, negative means not use." | ||||
|     ) | ||||
|     # Printing | ||||
|     parser.add_argument( | ||||
|         "--print_freq", type=int, default=100, help="print frequency (default: 200)" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--print_freq_eval", | ||||
|         type=int, | ||||
|         default=100, | ||||
|         help="print frequency (default: 200)", | ||||
|     ) | ||||
|     # Checkpoints | ||||
|     parser.add_argument( | ||||
|         "--eval_frequency", | ||||
|         type=int, | ||||
|         default=1, | ||||
|         help="evaluation frequency (default: 200)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, help="Folder to save checkpoints and log." | ||||
|     ) | ||||
|     # Acceleration | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=8, | ||||
|         help="number of data loading workers (default: 8)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
		Reference in New Issue
	
	Block a user