|  |  |  | @@ -105,28 +105,27 @@ def main(xargs): | 
		
	
		
			
				|  |  |  |  |   logger = prepare_logger(args) | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | 
		
	
		
			
				|  |  |  |  |   if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': | 
		
	
		
			
				|  |  |  |  |   #config_path = 'configs/nas-benchmark/algos/DARTS.config' | 
		
	
		
			
				|  |  |  |  |   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) | 
		
	
		
			
				|  |  |  |  |   if xargs.dataset == 'cifar10': | 
		
	
		
			
				|  |  |  |  |     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | 
		
	
		
			
				|  |  |  |  |     cifar_split = load_config(split_Fpath, None, None) | 
		
	
		
			
				|  |  |  |  |     train_split, valid_split = cifar_split.train, cifar_split.valid | 
		
	
		
			
				|  |  |  |  |     logger.log('Load split file from {:}'.format(split_Fpath)) | 
		
	
		
			
				|  |  |  |  |     train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set | 
		
	
		
			
				|  |  |  |  |     logger.log('Load split file from {:}'.format(split_Fpath))      # they are two disjoint groups in the original CIFAR-10 training set | 
		
	
		
			
				|  |  |  |  |     # To split data | 
		
	
		
			
				|  |  |  |  |     train_data_v2 = deepcopy(train_data) | 
		
	
		
			
				|  |  |  |  |     train_data_v2.transform = valid_data.transform | 
		
	
		
			
				|  |  |  |  |     valid_data    = train_data_v2 | 
		
	
		
			
				|  |  |  |  |     search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | 
		
	
		
			
				|  |  |  |  |     # data loader | 
		
	
		
			
				|  |  |  |  |     search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) | 
		
	
		
			
				|  |  |  |  |     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | 
		
	
		
			
				|  |  |  |  |   elif xargs.dataset == 'cifar100': | 
		
	
		
			
				|  |  |  |  |     raise ValueError('not support yet : {:}'.format(xargs.dataset)) | 
		
	
		
			
				|  |  |  |  |   elif xargs.dataset.startswith('ImageNet16'): | 
		
	
		
			
				|  |  |  |  |     split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset) | 
		
	
		
			
				|  |  |  |  |     imagenet16_split = load_config(split_Fpath, None, None) | 
		
	
		
			
				|  |  |  |  |     train_split, valid_split = imagenet16_split.train, imagenet16_split.valid | 
		
	
		
			
				|  |  |  |  |     logger.log('Load split file from {:}'.format(split_Fpath)) | 
		
	
		
			
				|  |  |  |  |     raise ValueError('not support yet : {:}'.format(xargs.dataset)) | 
		
	
		
			
				|  |  |  |  |   else: | 
		
	
		
			
				|  |  |  |  |     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) | 
		
	
		
			
				|  |  |  |  |   config_path = 'configs/nas-benchmark/algos/DARTS.config' | 
		
	
		
			
				|  |  |  |  |   config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | 
		
	
		
			
				|  |  |  |  |   # To split data | 
		
	
		
			
				|  |  |  |  |   train_data_v2 = deepcopy(train_data) | 
		
	
		
			
				|  |  |  |  |   train_data_v2.transform = valid_data.transform | 
		
	
		
			
				|  |  |  |  |   valid_data    = train_data_v2 | 
		
	
		
			
				|  |  |  |  |   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | 
		
	
		
			
				|  |  |  |  |   # data loader | 
		
	
		
			
				|  |  |  |  |   search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) | 
		
	
		
			
				|  |  |  |  |   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | 
		
	
		
			
				|  |  |  |  |   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) | 
		
	
		
			
				|  |  |  |  |   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | 
		
	
		
			
				|  |  |  |  |  | 
		
	
	
		
			
				
					
					|  |  |  | @@ -231,6 +230,7 @@ if __name__ == '__main__': | 
		
	
		
			
				|  |  |  |  |   parser.add_argument('--data_path',          type=str,   help='Path to dataset') | 
		
	
		
			
				|  |  |  |  |   parser.add_argument('--dataset',            type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') | 
		
	
		
			
				|  |  |  |  |   # channels and number-of-cells | 
		
	
		
			
				|  |  |  |  |   parser.add_argument('--config_path',        type=str,   help='The config paths.') | 
		
	
		
			
				|  |  |  |  |   parser.add_argument('--search_space_name',  type=str,   help='The search space name.') | 
		
	
		
			
				|  |  |  |  |   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') | 
		
	
		
			
				|  |  |  |  |   parser.add_argument('--channel',            type=int,   help='The number of channels.') | 
		
	
	
		
			
				
					
					|  |  |  |   |