clarify restrictions
This commit is contained in:
		| @@ -105,20 +105,13 @@ def main(xargs): | |||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|   if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': |   #config_path = 'configs/nas-benchmark/algos/DARTS.config' | ||||||
|  |   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||||
|  |   if xargs.dataset == 'cifar10': | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||||
|     cifar_split = load_config(split_Fpath, None, None) |     cifar_split = load_config(split_Fpath, None, None) | ||||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid |     train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set | ||||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) |     logger.log('Load split file from {:}'.format(split_Fpath))      # they are two disjoint groups in the original CIFAR-10 training set | ||||||
|   elif xargs.dataset.startswith('ImageNet16'): |  | ||||||
|     split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset) |  | ||||||
|     imagenet16_split = load_config(split_Fpath, None, None) |  | ||||||
|     train_split, valid_split = imagenet16_split.train, imagenet16_split.valid |  | ||||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) |  | ||||||
|   else: |  | ||||||
|     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) |  | ||||||
|   config_path = 'configs/nas-benchmark/algos/DARTS.config' |  | ||||||
|   config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) |  | ||||||
|     # To split data |     # To split data | ||||||
|     train_data_v2 = deepcopy(train_data) |     train_data_v2 = deepcopy(train_data) | ||||||
|     train_data_v2.transform = valid_data.transform |     train_data_v2.transform = valid_data.transform | ||||||
| @@ -126,7 +119,13 @@ def main(xargs): | |||||||
|     search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) |     search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||||
|     # data loader |     # data loader | ||||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) |     search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) | ||||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) |     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||||
|  |   elif xargs.dataset == 'cifar100': | ||||||
|  |     raise ValueError('not support yet : {:}'.format(xargs.dataset)) | ||||||
|  |   elif xargs.dataset.startswith('ImageNet16'): | ||||||
|  |     raise ValueError('not support yet : {:}'.format(xargs.dataset)) | ||||||
|  |   else: | ||||||
|  |     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) | ||||||
|   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) |   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) | ||||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) |   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|  |  | ||||||
| @@ -231,6 +230,7 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--data_path',          type=str,   help='Path to dataset') |   parser.add_argument('--data_path',          type=str,   help='Path to dataset') | ||||||
|   parser.add_argument('--dataset',            type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') |   parser.add_argument('--dataset',            type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') | ||||||
|   # channels and number-of-cells |   # channels and number-of-cells | ||||||
|  |   parser.add_argument('--config_path',        type=str,   help='The config paths.') | ||||||
|   parser.add_argument('--search_space_name',  type=str,   help='The search space name.') |   parser.add_argument('--search_space_name',  type=str,   help='The search space name.') | ||||||
|   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') |   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') | ||||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') |   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||||
|   | |||||||
| @@ -184,6 +184,7 @@ def main(xargs): | |||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |   train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|  |   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' | ||||||
|   if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': |   if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||||
|     cifar_split = load_config(split_Fpath, None, None) |     cifar_split = load_config(split_Fpath, None, None) | ||||||
|   | |||||||
| @@ -81,6 +81,7 @@ def main(xargs): | |||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   train_data, _, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |   train_data, _, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|  |   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' | ||||||
|   if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': |   if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||||
|     cifar_split = load_config(split_Fpath, None, None) |     cifar_split = load_config(split_Fpath, None, None) | ||||||
|   | |||||||
| @@ -135,6 +135,7 @@ def main(xargs): | |||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|  |   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' | ||||||
|   if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': |   if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||||
|     cifar_split = load_config(split_Fpath, None, None) |     cifar_split = load_config(split_Fpath, None, None) | ||||||
|   | |||||||
| @@ -5,7 +5,6 @@ import os, sys, torch | |||||||
| import os.path as osp | import os.path as osp | ||||||
| import numpy as np | import numpy as np | ||||||
| import torchvision.datasets as dset | import torchvision.datasets as dset | ||||||
| import torch.backends.cudnn as cudnn |  | ||||||
| import torchvision.transforms as transforms | import torchvision.transforms as transforms | ||||||
| from PIL import Image | from PIL import Image | ||||||
| from .DownsampledImageNet import ImageNet16 | from .DownsampledImageNet import ImageNet16 | ||||||
|   | |||||||
| @@ -33,6 +33,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V1.py \ | |||||||
| 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} --data_path ${data_path} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
|  | 	--config_path configs/nas-benchmark/algos/DARTS.config \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||||
| 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user