54 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			54 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | ################################################## | ||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||
|  | ################################################## | ||
|  | # For evaluating the learned model | ||
|  | import os, sys, time, glob, random, argparse | ||
|  | import numpy as np | ||
|  | from copy import deepcopy | ||
|  | import torch | ||
|  | import torch.nn as nn | ||
|  | import torch.nn.functional as F | ||
|  | import torchvision.datasets as dset | ||
|  | import torch.backends.cudnn as cudnn | ||
|  | import torchvision.transforms as transforms | ||
|  | from pathlib import Path | ||
|  | lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||
|  | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||
|  | from utils import AverageMeter, time_string, convert_secs2time | ||
|  | from utils import print_log, obtain_accuracy | ||
|  | from utils import Cutout, count_parameters_in_MB | ||
|  | from nas import model_types as models | ||
|  | from train_utils import main_procedure | ||
|  | from train_utils_imagenet import main_procedure_imagenet | ||
|  | from scheduler import load_config | ||
|  | 
 | ||
|  | 
 | ||
|  | parser = argparse.ArgumentParser("Evaluate-CNN") | ||
|  | parser.add_argument('--data_path',         type=str,   help='Path to dataset.') | ||
|  | parser.add_argument('--checkpoint',        type=str,   help='Choose between Cifar10/100 and ImageNet.') | ||
|  | args = parser.parse_args() | ||
|  | 
 | ||
|  | assert torch.cuda.is_available(), 'torch.cuda is not available' | ||
|  | 
 | ||
|  | 
 | ||
|  | def main(): | ||
|  | 
 | ||
|  |   assert os.path.isdir( args.data_path ), 'invalid data-path : {:}'.format(args.data_path) | ||
|  |   assert os.path.isfile( args.checkpoint ), 'invalid checkpoint : {:}'.format(args.checkpoint) | ||
|  | 
 | ||
|  |   checkpoint = torch.load( args.checkpoint ) | ||
|  |   xargs      = checkpoint['args'] | ||
|  |   config     = load_config(xargs.model_config) | ||
|  |   genotype   = models[xargs.arch] | ||
|  | 
 | ||
|  |   # clear GPU cache | ||
|  |   torch.cuda.empty_cache() | ||
|  |   if xargs.dataset == 'imagenet': | ||
|  |     main_procedure_imagenet(config, args.data_path, xargs, genotype, xargs.init_channels, xargs.layers, checkpoint['state_dict'], None) | ||
|  |   else: | ||
|  |     main_procedure(config, xargs.dataset, args.data_path, xargs, genotype, xargs.init_channels, xargs.layers, checkpoint['state_dict'], None) | ||
|  | 
 | ||
|  | 
 | ||
|  | if __name__ == '__main__': | ||
|  |   main()  |