add naswot
This commit is contained in:
		
							
								
								
									
										13
									
								
								graph_dit/naswot/config_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								graph_dit/naswot/config_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from .configure_utils    import load_config, dict2config, configure2str | ||||
| from .basic_args         import obtain_basic_args | ||||
| from .attention_args     import obtain_attention_args | ||||
| from .random_baseline    import obtain_RandomSearch_args | ||||
| from .cls_kd_args        import obtain_cls_kd_args | ||||
| from .cls_init_args      import obtain_cls_init_args | ||||
| from .search_single_args import obtain_search_single_args | ||||
| from .search_args        import obtain_search_args | ||||
| # for network pruning | ||||
| from .pruning_args       import obtain_pruning_args | ||||
							
								
								
									
										22
									
								
								graph_dit/naswot/config_utils/attention_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								graph_dit/naswot/config_utils/attention_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
| def obtain_attention_args(): | ||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--resume'      ,     type=str,                   help='Resume path.') | ||||
|   parser.add_argument('--init_model'  ,     type=str,                   help='The initialization model path.') | ||||
|   parser.add_argument('--model_config',     type=str,                   help='The path to the model configuration') | ||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') | ||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') | ||||
|   parser.add_argument('--att_channel' ,     type=int,                   help='.') | ||||
|   parser.add_argument('--att_spatial' ,     type=str,                   help='.') | ||||
|   parser.add_argument('--att_active'  ,     type=str,                   help='.') | ||||
|   add_shared_args( parser ) | ||||
|   # Optimization options | ||||
|   parser.add_argument('--batch_size',       type=int,   default=2,      help='Batch size for training.') | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   if args.rand_seed is None or args.rand_seed < 0: | ||||
|     args.rand_seed = random.randint(1, 100000) | ||||
|   assert args.save_dir is not None, 'save-path argument can not be None' | ||||
|   return args | ||||
							
								
								
									
										24
									
								
								graph_dit/naswot/config_utils/basic_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								graph_dit/naswot/config_utils/basic_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ################################################## | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
| def obtain_basic_args(): | ||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--resume'      ,     type=str,                   help='Resume path.') | ||||
|   parser.add_argument('--init_model'  ,     type=str,                   help='The initialization model path.') | ||||
|   parser.add_argument('--model_config',     type=str,                   help='The path to the model configuration') | ||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') | ||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') | ||||
|   parser.add_argument('--model_source',     type=str,  default='normal',help='The source of model defination.') | ||||
|   parser.add_argument('--extra_model_path', type=str,  default=None,    help='The extra model ckp file (help to indicate the searched architecture).') | ||||
|   add_shared_args( parser ) | ||||
|   # Optimization options | ||||
|   parser.add_argument('--batch_size',       type=int,  default=2,       help='Batch size for training.') | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   if args.rand_seed is None or args.rand_seed < 0: | ||||
|     args.rand_seed = random.randint(1, 100000) | ||||
|   assert args.save_dir is not None, 'save-path argument can not be None' | ||||
|   return args | ||||
							
								
								
									
										4
									
								
								graph_dit/naswot/config_utils/cifar-split.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								graph_dit/naswot/config_utils/cifar-split.txt
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										20
									
								
								graph_dit/naswot/config_utils/cls_init_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								graph_dit/naswot/config_utils/cls_init_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
| def obtain_cls_init_args(): | ||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--resume'      ,     type=str,                   help='Resume path.') | ||||
|   parser.add_argument('--init_model'  ,     type=str,                   help='The initialization model path.') | ||||
|   parser.add_argument('--model_config',     type=str,                   help='The path to the model configuration') | ||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') | ||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') | ||||
|   parser.add_argument('--init_checkpoint',  type=str,                   help='The checkpoint path to the initial model.') | ||||
|   add_shared_args( parser ) | ||||
|   # Optimization options | ||||
|   parser.add_argument('--batch_size',       type=int,   default=2,      help='Batch size for training.') | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   if args.rand_seed is None or args.rand_seed < 0: | ||||
|     args.rand_seed = random.randint(1, 100000) | ||||
|   assert args.save_dir is not None, 'save-path argument can not be None' | ||||
|   return args | ||||
							
								
								
									
										23
									
								
								graph_dit/naswot/config_utils/cls_kd_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								graph_dit/naswot/config_utils/cls_kd_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,23 @@ | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
| def obtain_cls_kd_args(): | ||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--resume'      ,     type=str,                   help='Resume path.') | ||||
|   parser.add_argument('--init_model'  ,     type=str,                   help='The initialization model path.') | ||||
|   parser.add_argument('--model_config',     type=str,                   help='The path to the model configuration') | ||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') | ||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') | ||||
|   parser.add_argument('--KD_checkpoint',    type=str,                   help='The teacher checkpoint in knowledge distillation.') | ||||
|   parser.add_argument('--KD_alpha'    ,     type=float,                 help='The alpha parameter in knowledge distillation.') | ||||
|   parser.add_argument('--KD_temperature',   type=float,                 help='The temperature parameter in knowledge distillation.') | ||||
|   #parser.add_argument('--KD_feature',       type=float,                 help='Knowledge distillation at the feature level.') | ||||
|   add_shared_args( parser ) | ||||
|   # Optimization options | ||||
|   parser.add_argument('--batch_size',       type=int,   default=2,      help='Batch size for training.') | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   if args.rand_seed is None or args.rand_seed < 0: | ||||
|     args.rand_seed = random.randint(1, 100000) | ||||
|   assert args.save_dir is not None, 'save-path argument can not be None' | ||||
|   return args | ||||
							
								
								
									
										106
									
								
								graph_dit/naswot/config_utils/configure_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								graph_dit/naswot/config_utils/configure_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,106 @@ | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
| # | ||||
| import os, json | ||||
| from os import path as osp | ||||
| from pathlib import Path | ||||
| from collections import namedtuple | ||||
|  | ||||
| support_types = ('str', 'int', 'bool', 'float', 'none') | ||||
|  | ||||
|  | ||||
| def convert_param(original_lists): | ||||
|   assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists) | ||||
|   ctype, value = original_lists[0], original_lists[1] | ||||
|   assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types) | ||||
|   is_list = isinstance(value, list) | ||||
|   if not is_list: value = [value] | ||||
|   outs = [] | ||||
|   for x in value: | ||||
|     if ctype == 'int': | ||||
|       x = int(x) | ||||
|     elif ctype == 'str': | ||||
|       x = str(x) | ||||
|     elif ctype == 'bool': | ||||
|       x = bool(int(x)) | ||||
|     elif ctype == 'float': | ||||
|       x = float(x) | ||||
|     elif ctype == 'none': | ||||
|       if x.lower() != 'none': | ||||
|         raise ValueError('For the none type, the value must be none instead of {:}'.format(x)) | ||||
|       x = None | ||||
|     else: | ||||
|       raise TypeError('Does not know this type : {:}'.format(ctype)) | ||||
|     outs.append(x) | ||||
|   if not is_list: outs = outs[0] | ||||
|   return outs | ||||
|  | ||||
|  | ||||
| def load_config(path, extra, logger): | ||||
|   path = str(path) | ||||
|   if hasattr(logger, 'log'): logger.log(path) | ||||
|   assert os.path.exists(path), 'Can not find {:}'.format(path) | ||||
|   # Reading data back | ||||
|   with open(path, 'r') as f: | ||||
|     data = json.load(f) | ||||
|   content = { k: convert_param(v) for k,v in data.items()} | ||||
|   assert extra is None or isinstance(extra, dict), 'invalid type of extra : {:}'.format(extra) | ||||
|   if isinstance(extra, dict): content = {**content, **extra} | ||||
|   Arguments = namedtuple('Configure', ' '.join(content.keys())) | ||||
|   content   = Arguments(**content) | ||||
|   if hasattr(logger, 'log'): logger.log('{:}'.format(content)) | ||||
|   return content | ||||
|  | ||||
|  | ||||
| def configure2str(config, xpath=None): | ||||
|   if not isinstance(config, dict): | ||||
|     config = config._asdict() | ||||
|   def cstring(x): | ||||
|     return "\"{:}\"".format(x) | ||||
|   def gtype(x): | ||||
|     if isinstance(x, list): x = x[0] | ||||
|     if isinstance(x, str)  : return 'str' | ||||
|     elif isinstance(x, bool) : return 'bool' | ||||
|     elif isinstance(x, int): return 'int' | ||||
|     elif isinstance(x, float): return 'float' | ||||
|     elif x is None           : return 'none' | ||||
|     else: raise ValueError('invalid : {:}'.format(x)) | ||||
|   def cvalue(x, xtype): | ||||
|     if isinstance(x, list): is_list = True | ||||
|     else: | ||||
|       is_list, x = False, [x] | ||||
|     temps = [] | ||||
|     for temp in x: | ||||
|       if xtype == 'bool'  : temp = cstring(int(temp)) | ||||
|       elif xtype == 'none': temp = cstring('None') | ||||
|       else                : temp = cstring(temp) | ||||
|       temps.append( temp ) | ||||
|     if is_list: | ||||
|       return "[{:}]".format( ', '.join( temps ) ) | ||||
|     else: | ||||
|       return temps[0] | ||||
|  | ||||
|   xstrings = [] | ||||
|   for key, value in config.items(): | ||||
|     xtype  = gtype(value) | ||||
|     string = '  {:20s} : [{:8s}, {:}]'.format(cstring(key), cstring(xtype), cvalue(value, xtype)) | ||||
|     xstrings.append(string) | ||||
|   Fstring = '{\n' + ',\n'.join(xstrings) + '\n}' | ||||
|   if xpath is not None: | ||||
|     parent = Path(xpath).resolve().parent | ||||
|     parent.mkdir(parents=True, exist_ok=True) | ||||
|     if osp.isfile(xpath): os.remove(xpath) | ||||
|     with open(xpath, "w") as text_file: | ||||
|       text_file.write('{:}'.format(Fstring)) | ||||
|   return Fstring | ||||
|  | ||||
|  | ||||
| def dict2config(xdict, logger): | ||||
|   assert isinstance(xdict, dict), 'invalid type : {:}'.format( type(xdict) ) | ||||
|   Arguments = namedtuple('Configure', ' '.join(xdict.keys())) | ||||
|   content   = Arguments(**xdict) | ||||
|   if hasattr(logger, 'log'): logger.log('{:}'.format(content)) | ||||
|   return content | ||||
							
								
								
									
										26
									
								
								graph_dit/naswot/config_utils/pruning_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								graph_dit/naswot/config_utils/pruning_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| import os, sys, time, random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
| def obtain_pruning_args(): | ||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--resume'      ,     type=str,                   help='Resume path.') | ||||
|   parser.add_argument('--init_model'  ,     type=str,                   help='The initialization model path.') | ||||
|   parser.add_argument('--model_config',     type=str,                   help='The path to the model configuration') | ||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') | ||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') | ||||
|   parser.add_argument('--keep_ratio'  ,     type=float,                 help='The left channel ratio compared to the original network.') | ||||
|   parser.add_argument('--model_version',    type=str,                   help='The network version.') | ||||
|   parser.add_argument('--KD_alpha'    ,     type=float,                 help='The alpha parameter in knowledge distillation.') | ||||
|   parser.add_argument('--KD_temperature',   type=float,                 help='The temperature parameter in knowledge distillation.') | ||||
|   parser.add_argument('--Regular_W_feat',   type=float,                 help='The .') | ||||
|   parser.add_argument('--Regular_W_conv',   type=float,                 help='The .') | ||||
|   add_shared_args( parser ) | ||||
|   # Optimization options | ||||
|   parser.add_argument('--batch_size',       type=int,  default=2,       help='Batch size for training.') | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   if args.rand_seed is None or args.rand_seed < 0: | ||||
|     args.rand_seed = random.randint(1, 100000) | ||||
|   assert args.save_dir is not None, 'save-path argument can not be None' | ||||
|   assert args.keep_ratio > 0 and args.keep_ratio <= 1, 'invalid keep ratio : {:}'.format(args.keep_ratio) | ||||
|   return args | ||||
							
								
								
									
										24
									
								
								graph_dit/naswot/config_utils/random_baseline.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								graph_dit/naswot/config_utils/random_baseline.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | ||||
| import os, sys, time, random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_RandomSearch_args(): | ||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--resume'      ,     type=str,                   help='Resume path.') | ||||
|   parser.add_argument('--init_model'  ,     type=str,                   help='The initialization model path.') | ||||
|   parser.add_argument('--expect_flop',      type=float,                 help='The expected flop keep ratio.') | ||||
|   parser.add_argument('--arch_nums'   ,     type=int,                   help='The maximum number of running random arch generating..') | ||||
|   parser.add_argument('--model_config',     type=str,                   help='The path to the model configuration') | ||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') | ||||
|   parser.add_argument('--random_mode', type=str, choices=['random', 'fix'], help='The path to the optimizer configuration') | ||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') | ||||
|   add_shared_args( parser ) | ||||
|   # Optimization options | ||||
|   parser.add_argument('--batch_size',       type=int,   default=2,      help='Batch size for training.') | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   if args.rand_seed is None or args.rand_seed < 0: | ||||
|     args.rand_seed = random.randint(1, 100000) | ||||
|   assert args.save_dir is not None, 'save-path argument can not be None' | ||||
|   #assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max) | ||||
|   return args | ||||
							
								
								
									
										32
									
								
								graph_dit/naswot/config_utils/search_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								graph_dit/naswot/config_utils/search_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | ||||
| import os, sys, time, random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_search_args(): | ||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--resume'        ,   type=str,                   help='Resume path.') | ||||
|   parser.add_argument('--model_config'  ,   type=str,                   help='The path to the model configuration') | ||||
|   parser.add_argument('--optim_config'  ,   type=str,                   help='The path to the optimizer configuration') | ||||
|   parser.add_argument('--split_path'    ,   type=str,                   help='The split file path.') | ||||
|   #parser.add_argument('--arch_para_pure',   type=int,                   help='The architecture-parameter pure or not.') | ||||
|   parser.add_argument('--gumbel_tau_max',   type=float,                 help='The maximum tau for Gumbel.') | ||||
|   parser.add_argument('--gumbel_tau_min',   type=float,                 help='The minimum tau for Gumbel.') | ||||
|   parser.add_argument('--procedure'     ,   type=str,                   help='The procedure basic prefix.') | ||||
|   parser.add_argument('--FLOP_ratio'    ,   type=float,                 help='The expected FLOP ratio.') | ||||
|   parser.add_argument('--FLOP_weight'   ,   type=float,                 help='The loss weight for FLOP.') | ||||
|   parser.add_argument('--FLOP_tolerant' ,   type=float,                 help='The tolerant range for FLOP.') | ||||
|   # ablation studies | ||||
|   parser.add_argument('--ablation_num_select', type=int,                help='The number of randomly selected channels.') | ||||
|   add_shared_args( parser ) | ||||
|   # Optimization options | ||||
|   parser.add_argument('--batch_size'    ,   type=int,   default=2,      help='Batch size for training.') | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   if args.rand_seed is None or args.rand_seed < 0: | ||||
|     args.rand_seed = random.randint(1, 100000) | ||||
|   assert args.save_dir is not None, 'save-path argument can not be None' | ||||
|   assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None | ||||
|   assert args.FLOP_tolerant is not None and args.FLOP_tolerant > 0, 'invalid FLOP_tolerant : {:}'.format(FLOP_tolerant) | ||||
|   #assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) | ||||
|   #args.arch_para_pure = bool(args.arch_para_pure) | ||||
|   return args | ||||
							
								
								
									
										31
									
								
								graph_dit/naswot/config_utils/search_single_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								graph_dit/naswot/config_utils/search_single_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | ||||
| import os, sys, time, random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_search_single_args(): | ||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--resume'        ,   type=str,                   help='Resume path.') | ||||
|   parser.add_argument('--model_config'  ,   type=str,                   help='The path to the model configuration') | ||||
|   parser.add_argument('--optim_config'  ,   type=str,                   help='The path to the optimizer configuration') | ||||
|   parser.add_argument('--split_path'    ,   type=str,                   help='The split file path.') | ||||
|   parser.add_argument('--search_shape'  ,   type=str,                   help='The shape to be searched.') | ||||
|   #parser.add_argument('--arch_para_pure',   type=int,                   help='The architecture-parameter pure or not.') | ||||
|   parser.add_argument('--gumbel_tau_max',   type=float,                 help='The maximum tau for Gumbel.') | ||||
|   parser.add_argument('--gumbel_tau_min',   type=float,                 help='The minimum tau for Gumbel.') | ||||
|   parser.add_argument('--procedure'     ,   type=str,                   help='The procedure basic prefix.') | ||||
|   parser.add_argument('--FLOP_ratio'    ,   type=float,                 help='The expected FLOP ratio.') | ||||
|   parser.add_argument('--FLOP_weight'   ,   type=float,                 help='The loss weight for FLOP.') | ||||
|   parser.add_argument('--FLOP_tolerant' ,   type=float,                 help='The tolerant range for FLOP.') | ||||
|   add_shared_args( parser ) | ||||
|   # Optimization options | ||||
|   parser.add_argument('--batch_size'    ,   type=int,   default=2,      help='Batch size for training.') | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   if args.rand_seed is None or args.rand_seed < 0: | ||||
|     args.rand_seed = random.randint(1, 100000) | ||||
|   assert args.save_dir is not None, 'save-path argument can not be None' | ||||
|   assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None | ||||
|   assert args.FLOP_tolerant is not None and args.FLOP_tolerant > 0, 'invalid FLOP_tolerant : {:}'.format(FLOP_tolerant) | ||||
|   #assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) | ||||
|   #args.arch_para_pure = bool(args.arch_para_pure) | ||||
|   return args | ||||
							
								
								
									
										17
									
								
								graph_dit/naswot/config_utils/share_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								graph_dit/naswot/config_utils/share_args.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | ||||
| import os, sys, time, random, argparse | ||||
|  | ||||
| def add_shared_args( parser ): | ||||
|   # Data Generation | ||||
|   parser.add_argument('--dataset',          type=str,                   help='The dataset name.') | ||||
|   parser.add_argument('--data_path',        type=str,                   help='The dataset name.') | ||||
|   parser.add_argument('--cutout_length',    type=int,                   help='The cutout length, negative means not use.') | ||||
|   # Printing | ||||
|   parser.add_argument('--print_freq',       type=int,   default=100,    help='print frequency (default: 200)') | ||||
|   parser.add_argument('--print_freq_eval',  type=int,   default=100,    help='print frequency (default: 200)') | ||||
|   # Checkpoints | ||||
|   parser.add_argument('--eval_frequency',   type=int,   default=1,      help='evaluation frequency (default: 200)') | ||||
|   parser.add_argument('--save_dir',         type=str,                   help='Folder to save checkpoints and log.') | ||||
|   # Acceleration | ||||
|   parser.add_argument('--workers',          type=int,   default=8,      help='number of data loading workers (default: 8)') | ||||
|   # Random Seed | ||||
|   parser.add_argument('--rand_seed',        type=int,   default=-1,     help='manual seed') | ||||
		Reference in New Issue
	
	Block a user