90 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			90 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| ##################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | |
| ##################################################
 | |
| import os, sys, time, glob, random, argparse
 | |
| import numpy as np
 | |
| from copy import deepcopy
 | |
| import torch
 | |
| import torch.nn as nn
 | |
| import torch.nn.functional as F
 | |
| import torchvision.datasets as dset
 | |
| import torch.backends.cudnn as cudnn
 | |
| import torchvision.transforms as transforms
 | |
| from pathlib import Path
 | |
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
 | |
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
 | |
| from utils import AverageMeter, time_string, convert_secs2time
 | |
| from utils import print_log, obtain_accuracy
 | |
| from utils import Cutout, count_parameters_in_MB
 | |
| from nas import model_types as models
 | |
| from train_utils import main_procedure
 | |
| from train_utils_imagenet import main_procedure_imagenet
 | |
| from scheduler import load_config
 | |
| 
 | |
| 
 | |
| parser = argparse.ArgumentParser("Train-CNN")
 | |
| parser.add_argument('--data_path',         type=str,   help='Path to dataset')
 | |
| parser.add_argument('--dataset',           type=str,   choices=['imagenet', 'cifar10', 'cifar100'], help='Choose between Cifar10/100 and ImageNet.')
 | |
| parser.add_argument('--arch',              type=str,   choices=models.keys(), help='the searched model.')
 | |
| # 
 | |
| parser.add_argument('--grad_clip',      type=float, help='gradient clipping')
 | |
| parser.add_argument('--model_config',   type=str  , help='the model configuration')
 | |
| parser.add_argument('--init_channels',  type=int  , help='the initial number of channels')
 | |
| parser.add_argument('--layers',         type=int  , help='the number of layers.')
 | |
| 
 | |
| # log
 | |
| parser.add_argument('--workers',       type=int, default=2, help='number of data loading workers (default: 2)')
 | |
| parser.add_argument('--save_path',     type=str, help='Folder to save checkpoints and log.')
 | |
| parser.add_argument('--print_freq',    type=int, help='print frequency (default: 200)')
 | |
| parser.add_argument('--manualSeed',    type=int, help='manual seed')
 | |
| args = parser.parse_args()
 | |
| 
 | |
| if 'CUDA_VISIBLE_DEVICES' not in os.environ: print('Can not find CUDA_VISIBLE_DEVICES in os.environ')
 | |
| else                                       : print('Find CUDA_VISIBLE_DEVICES={:}'.format(os.environ['CUDA_VISIBLE_DEVICES']))
 | |
| 
 | |
| assert torch.cuda.is_available(), 'torch.cuda is not available'
 | |
| 
 | |
| 
 | |
| if args.manualSeed is None or args.manualSeed < 0:
 | |
|   args.manualSeed = random.randint(1, 10000)
 | |
| random.seed(args.manualSeed)
 | |
| cudnn.benchmark = True
 | |
| cudnn.enabled   = True
 | |
| torch.manual_seed(args.manualSeed)
 | |
| torch.cuda.manual_seed_all(args.manualSeed)
 | |
| 
 | |
| 
 | |
| def main():
 | |
| 
 | |
|   # Init logger
 | |
|   #args.save_path = os.path.join(args.save_path, 'seed-{:}'.format(args.manualSeed))
 | |
|   if not os.path.isdir(args.save_path):
 | |
|     os.makedirs(args.save_path)
 | |
|   log = open(os.path.join(args.save_path, 'seed-{:}-log.txt'.format(args.manualSeed)), 'w')
 | |
|   print_log('Save Path      : {:}'.format(args.save_path), log)
 | |
|   state = {k: v for k, v in args._get_kwargs()}
 | |
|   print_log(state, log)
 | |
|   print_log("Random Seed    : {:}".format(args.manualSeed), log)
 | |
|   print_log("Python version : {:}".format(sys.version.replace('\n', ' ')), log)
 | |
|   print_log("Torch  version : {:}".format(torch.__version__), log)
 | |
|   print_log("CUDA   version : {:}".format(torch.version.cuda), log)
 | |
|   print_log("cuDNN  version : {:}".format(cudnn.version()), log)
 | |
|   print_log("Num of GPUs    : {:}".format(torch.cuda.device_count()), log)
 | |
|   args.dataset = args.dataset.lower()
 | |
| 
 | |
|   config = load_config(args.model_config)
 | |
|   genotype = models[args.arch]
 | |
|   print_log('configuration : {:}'.format(config), log)
 | |
|   print_log('genotype      : {:}'.format(genotype), log)
 | |
|   # clear GPU cache
 | |
|   torch.cuda.empty_cache()
 | |
|   if args.dataset == 'imagenet':
 | |
|     main_procedure_imagenet(config, args.data_path, args, genotype, args.init_channels, args.layers, None, log)
 | |
|   else:
 | |
|     main_procedure(config, args.dataset, args.data_path, args, genotype, args.init_channels, args.layers, None, log)
 | |
|   log.close()
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|   main() 
 |