90 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			90 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | ################################################## | ||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||
|  | ################################################## | ||
|  | import os, sys, time, glob, random, argparse | ||
|  | import numpy as np | ||
|  | from copy import deepcopy | ||
|  | import torch | ||
|  | import torch.nn as nn | ||
|  | import torch.nn.functional as F | ||
|  | import torchvision.datasets as dset | ||
|  | import torch.backends.cudnn as cudnn | ||
|  | import torchvision.transforms as transforms | ||
|  | from pathlib import Path | ||
|  | lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||
|  | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||
|  | from utils import AverageMeter, time_string, convert_secs2time | ||
|  | from utils import print_log, obtain_accuracy | ||
|  | from utils import Cutout, count_parameters_in_MB | ||
|  | from nas import model_types as models | ||
|  | from train_utils import main_procedure | ||
|  | from train_utils_imagenet import main_procedure_imagenet | ||
|  | from scheduler import load_config | ||
|  | 
 | ||
|  | 
 | ||
|  | parser = argparse.ArgumentParser("Train-CNN") | ||
|  | parser.add_argument('--data_path',         type=str,   help='Path to dataset') | ||
|  | parser.add_argument('--dataset',           type=str,   choices=['imagenet', 'cifar10', 'cifar100'], help='Choose between Cifar10/100 and ImageNet.') | ||
|  | parser.add_argument('--arch',              type=str,   choices=models.keys(), help='the searched model.') | ||
|  | #  | ||
|  | parser.add_argument('--grad_clip',      type=float, help='gradient clipping') | ||
|  | parser.add_argument('--model_config',   type=str  , help='the model configuration') | ||
|  | parser.add_argument('--init_channels',  type=int  , help='the initial number of channels') | ||
|  | parser.add_argument('--layers',         type=int  , help='the number of layers.') | ||
|  | 
 | ||
|  | # log | ||
|  | parser.add_argument('--workers',       type=int, default=2, help='number of data loading workers (default: 2)') | ||
|  | parser.add_argument('--save_path',     type=str, help='Folder to save checkpoints and log.') | ||
|  | parser.add_argument('--print_freq',    type=int, help='print frequency (default: 200)') | ||
|  | parser.add_argument('--manualSeed',    type=int, help='manual seed') | ||
|  | args = parser.parse_args() | ||
|  | 
 | ||
|  | if 'CUDA_VISIBLE_DEVICES' not in os.environ: print('Can not find CUDA_VISIBLE_DEVICES in os.environ') | ||
|  | else                                       : print('Find CUDA_VISIBLE_DEVICES={:}'.format(os.environ['CUDA_VISIBLE_DEVICES'])) | ||
|  | 
 | ||
|  | assert torch.cuda.is_available(), 'torch.cuda is not available' | ||
|  | 
 | ||
|  | 
 | ||
|  | if args.manualSeed is None or args.manualSeed < 0: | ||
|  |   args.manualSeed = random.randint(1, 10000) | ||
|  | random.seed(args.manualSeed) | ||
|  | cudnn.benchmark = True | ||
|  | cudnn.enabled   = True | ||
|  | torch.manual_seed(args.manualSeed) | ||
|  | torch.cuda.manual_seed_all(args.manualSeed) | ||
|  | 
 | ||
|  | 
 | ||
|  | def main(): | ||
|  | 
 | ||
|  |   # Init logger | ||
|  |   #args.save_path = os.path.join(args.save_path, 'seed-{:}'.format(args.manualSeed)) | ||
|  |   if not os.path.isdir(args.save_path): | ||
|  |     os.makedirs(args.save_path) | ||
|  |   log = open(os.path.join(args.save_path, 'seed-{:}-log.txt'.format(args.manualSeed)), 'w') | ||
|  |   print_log('Save Path      : {:}'.format(args.save_path), log) | ||
|  |   state = {k: v for k, v in args._get_kwargs()} | ||
|  |   print_log(state, log) | ||
|  |   print_log("Random Seed    : {:}".format(args.manualSeed), log) | ||
|  |   print_log("Python version : {:}".format(sys.version.replace('\n', ' ')), log) | ||
|  |   print_log("Torch  version : {:}".format(torch.__version__), log) | ||
|  |   print_log("CUDA   version : {:}".format(torch.version.cuda), log) | ||
|  |   print_log("cuDNN  version : {:}".format(cudnn.version()), log) | ||
|  |   print_log("Num of GPUs    : {:}".format(torch.cuda.device_count()), log) | ||
|  |   args.dataset = args.dataset.lower() | ||
|  | 
 | ||
|  |   config = load_config(args.model_config) | ||
|  |   genotype = models[args.arch] | ||
|  |   print_log('configuration : {:}'.format(config), log) | ||
|  |   print_log('genotype      : {:}'.format(genotype), log) | ||
|  |   # clear GPU cache | ||
|  |   torch.cuda.empty_cache() | ||
|  |   if args.dataset == 'imagenet': | ||
|  |     main_procedure_imagenet(config, args.data_path, args, genotype, args.init_channels, args.layers, None, log) | ||
|  |   else: | ||
|  |     main_procedure(config, args.dataset, args.data_path, args, genotype, args.init_channels, args.layers, None, log) | ||
|  |   log.close() | ||
|  | 
 | ||
|  | 
 | ||
|  | if __name__ == '__main__': | ||
|  |   main()  |