Compare commits
	
		
			11 Commits
		
	
	
		
			5bf036a763
			...
			4612cd198b
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 4612cd198b | ||
| 889bd1974c | |||
| af0e7786b6 | |||
| c6d53f08ae | |||
| ef2608bb42 | |||
| 50ff507a15 | |||
| 03d7d04d41 | |||
| bb33ca9a68 | |||
|  | f46486e21b | ||
|  | 5908a1edef | ||
|  | ed34024a88 | 
| @@ -61,13 +61,13 @@ At this moment, this project provides the following algorithms and scripts to ru | ||||
|     <tr> <!-- (6-th row) --> | ||||
|     <td align="center" valign="middle"> NATS-Bench </td> | ||||
|     <td align="center" valign="middle"> <a href="https://xuanyidong.com/assets/projects/NATS-Bench"> NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size</a> </td> | ||||
|     <td align="center" valign="middle"> <a href="https://github.com/D-X-Y/NATS-Bench">NATS-Bench.md</a> </td> | ||||
|     <td align="center" valign="middle"> <a href="https://github.com/D-X-Y/NATS-Bench/blob/main/README.md">NATS-Bench.md</a> </td> | ||||
|     </tr> | ||||
|     <tr> <!-- (7-th row) --> | ||||
|     <td align="center" valign="middle"> ... </td> | ||||
|     <td align="center" valign="middle"> ENAS / REA / REINFORCE / BOHB </td> | ||||
|     <td align="center" valign="middle"> Please check the original papers </td> | ||||
|     <td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/main/docs/NAS-Bench-201.md">NAS-Bench-201.md</a>  <a href="https://github.com/D-X-Y/NATS-Bench">NATS-Bench.md</a> </td> | ||||
|     <td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/main/docs/NAS-Bench-201.md">NAS-Bench-201.md</a>  <a href="https://github.com/D-X-Y/NATS-Bench/blob/main/README.md">NATS-Bench.md</a> </td> | ||||
|     </tr> | ||||
|     <tr> <!-- (start second block) --> | ||||
|     <td rowspan="1" align="center" valign="middle" halign="middle"> HPO </td> | ||||
|   | ||||
| @@ -29,7 +29,7 @@ You can simply type `pip install nas-bench-201` to install our api. Please see s | ||||
| You can move it to anywhere you want and send its path to our API for initialization. | ||||
| - [2020.02.25] APIv1.0/FILEv1.0: [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial. | ||||
| - [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [ | ||||
| NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights. | ||||
| NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the trained weights. | ||||
| - [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi). | ||||
| - [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions | ||||
| - [2020.03.16] APIv1.3/FILEv1.1: [`NAS-Bench-201-v1_1-096897.pth`](https://drive.google.com/open?id=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_) (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable. | ||||
|   | ||||
| @@ -27,7 +27,7 @@ You can simply type `pip install nas-bench-201` to install our api. Please see s | ||||
| You can move it to anywhere you want and send its path to our API for initialization. | ||||
| - [2020.02.25] APIv1.0/FILEv1.0: [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial. | ||||
| - [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [ | ||||
| NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights. | ||||
| NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the trained weights. | ||||
| - [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi). | ||||
| - [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions | ||||
| - [2020.03.16] APIv1.3/FILEv1.1: [`NAS-Bench-201-v1_1-096897.pth`](https://drive.google.com/open?id=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_) (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable. | ||||
|   | ||||
| @@ -3,7 +3,7 @@ | ||||
| </p> | ||||
|  | ||||
| --------- | ||||
| [](LICENSE.md) | ||||
| [](../LICENSE.md) | ||||
|  | ||||
| 自动深度学习库 (AutoDL-Projects) 是一个开源的,轻量级的,功能强大的项目。 | ||||
| 该项目实现了多种网络结构搜索(NAS)和超参数优化(HPO)算法。 | ||||
| @@ -142,8 +142,8 @@ | ||||
|  | ||||
| # 其他 | ||||
|  | ||||
| 如果你想要给这份代码库做贡献,请看[CONTRIBUTING.md](.github/CONTRIBUTING.md)。 | ||||
| 此外,使用规范请参考[CODE-OF-CONDUCT.md](.github/CODE-OF-CONDUCT.md)。 | ||||
| 如果你想要给这份代码库做贡献,请看[CONTRIBUTING.md](../.github/CONTRIBUTING.md)。 | ||||
| 此外,使用规范请参考[CODE-OF-CONDUCT.md](../.github/CODE-OF-CONDUCT.md)。 | ||||
|  | ||||
| # 许可证 | ||||
| The entire codebase is under [MIT license](LICENSE.md) | ||||
| The entire codebase is under [MIT license](../LICENSE.md) | ||||
|   | ||||
| @@ -2,11 +2,11 @@ | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||
| ##################################################### | ||||
| import time, torch | ||||
| from procedures import prepare_seed, get_optim_scheduler | ||||
| from utils import get_model_infos, obtain_accuracy | ||||
| from config_utils import dict2config | ||||
| from log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from models import get_cell_based_tiny_net | ||||
| from xautodl.procedures import prepare_seed, get_optim_scheduler | ||||
| from xautodl.utils import get_model_infos, obtain_accuracy | ||||
| from xautodl.config_utils import dict2config | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import get_cell_based_tiny_net | ||||
|  | ||||
|  | ||||
| __all__ = ["evaluate_for_seed", "pure_evaluate"] | ||||
|   | ||||
| @@ -16,10 +16,96 @@ from xautodl.procedures import get_machine_info | ||||
| from xautodl.datasets import get_datasets | ||||
| from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import CellStructure, CellArchitectures, get_search_spaces | ||||
| from xautodl.functions import evaluate_for_seed | ||||
| from functions import evaluate_for_seed | ||||
|  | ||||
| from torchvision import datasets, transforms | ||||
|  | ||||
| # NASBENCH201_CONFIG_PATH = os.path.join( os.getcwd(), 'main_exp', 'transfer_nag') | ||||
|  | ||||
| NASBENCH201_CONFIG_PATH = '/lustre/hpe/ws11/ws11.1/ws/xmuhanma-nbdit/autodl-projects/configs/nas-benchmark' | ||||
|  | ||||
|  | ||||
| def evaluate_all_datasets( | ||||
| def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, | ||||
|                           arch_config, workers, logger): | ||||
|     machine_info, arch_config = get_machine_info(), deepcopy(arch_config) | ||||
|     all_infos = {'info': machine_info} | ||||
|     all_dataset_keys = [] | ||||
|     # look all the datasets | ||||
|     for dataset, xpath, split in zip(datasets, xpaths, splits): | ||||
|         # train valid data | ||||
|         task = None | ||||
|         train_data, valid_data, xshape, class_num = get_datasets( | ||||
|             dataset, xpath, -1, task) | ||||
|  | ||||
|         # load the configuration | ||||
|         if dataset in ['mnist', 'svhn', 'aircraft', 'oxford']: | ||||
|             if use_less: | ||||
|                 # config_path = os.path.join( | ||||
|                 #     NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/LESS.config') | ||||
|                 config_path = os.path.join( | ||||
|                     NASBENCH201_CONFIG_PATH, 'LESS.config') | ||||
|             else: | ||||
|                 # config_path = os.path.join( | ||||
|                 #     NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{}.config'.format(dataset)) | ||||
|                 config_path = os.path.join( | ||||
|                     NASBENCH201_CONFIG_PATH, '{}.config'.format(dataset)) | ||||
|  | ||||
|  | ||||
|             p = os.path.join( | ||||
|                 NASBENCH201_CONFIG_PATH, '{:}-split.txt'.format(dataset)) | ||||
|             if not os.path.exists(p): | ||||
|                 import json | ||||
|                 label_list = list(range(len(train_data))) | ||||
|                 random.shuffle(label_list) | ||||
|                 strlist = [str(label_list[i]) for i in range(len(label_list))] | ||||
|                 splited = {'train': ["int", strlist[:len(train_data) // 2]], | ||||
|                            'valid': ["int", strlist[len(train_data) // 2:]]} | ||||
|                 with open(p, 'w') as f: | ||||
|                     f.write(json.dumps(splited)) | ||||
|             split_info = load_config(os.path.join( | ||||
|                 NASBENCH201_CONFIG_PATH, '{:}-split.txt'.format(dataset)), None, None) | ||||
|         else: | ||||
|             raise ValueError('invalid dataset : {:}'.format(dataset)) | ||||
|  | ||||
|         config = load_config( | ||||
|             config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|         # data loader | ||||
|         train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, | ||||
|                                                    shuffle=True, num_workers=workers, pin_memory=True) | ||||
|         valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, | ||||
|                                                    shuffle=False, num_workers=workers, pin_memory=True) | ||||
|         splits = load_config(os.path.join( | ||||
|             NASBENCH201_CONFIG_PATH, '{}-test-split.txt'.format(dataset)), None, None) | ||||
|         ValLoaders = {'ori-test': valid_loader, | ||||
|                       'x-valid': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, | ||||
|                                                              sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                                                                  splits.xvalid), | ||||
|                                                              num_workers=workers, pin_memory=True), | ||||
|                       'x-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, | ||||
|                                                             sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                                                                 splits.xtest), | ||||
|                                                             num_workers=workers, pin_memory=True) | ||||
|                       } | ||||
|         dataset_key = '{:}'.format(dataset) | ||||
|         if bool(split): | ||||
|             dataset_key = dataset_key + '-valid' | ||||
|         logger.log( | ||||
|             'Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'. | ||||
|             format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size)) | ||||
|         logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format( | ||||
|             dataset_key, config)) | ||||
|         for key, value in ValLoaders.items(): | ||||
|             logger.log( | ||||
|                 'Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value))) | ||||
|  | ||||
|         results = evaluate_for_seed( | ||||
|             arch_config, config, arch, train_loader, ValLoaders, seed, logger) | ||||
|         all_infos[dataset_key] = results | ||||
|         all_dataset_keys.append(dataset_key) | ||||
|     all_infos['all_dataset_keys'] = all_dataset_keys | ||||
|     return all_infos | ||||
|  | ||||
| def evaluate_all_datasets1( | ||||
|     arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger | ||||
| ): | ||||
|     machine_info, arch_config = get_machine_info(), deepcopy(arch_config) | ||||
| @@ -46,47 +132,117 @@ def evaluate_all_datasets( | ||||
|             split_info = load_config( | ||||
|                 "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None | ||||
|             ) | ||||
|         elif dataset.startswith("aircraft"): | ||||
|             if use_less: | ||||
|                 config_path = "configs/nas-benchmark/LESS.config" | ||||
|             else: | ||||
|                 config_path = "configs/nas-benchmark/aircraft.config" | ||||
|             split_info = load_config( | ||||
|                 "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None | ||||
|             ) | ||||
|         elif dataset.startswith("oxford"): | ||||
|             if use_less: | ||||
|                 config_path = "configs/nas-benchmark/LESS.config" | ||||
|             else: | ||||
|                 config_path = "configs/nas-benchmark/oxford.config" | ||||
|             split_info = load_config( | ||||
|                 "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("invalid dataset : {:}".format(dataset)) | ||||
|         config = load_config( | ||||
|             config_path, {"class_num": class_num, "xshape": xshape}, logger | ||||
|         ) | ||||
|         # check whether use splited validation set | ||||
|         # if dataset == 'aircraft': | ||||
|         #     split = True | ||||
|         if bool(split): | ||||
|             assert dataset == "cifar10" | ||||
|             ValLoaders = { | ||||
|                 "ori-test": torch.utils.data.DataLoader( | ||||
|                     valid_data, | ||||
|             if dataset == "cifar10" or dataset == "cifar100": | ||||
|                 assert dataset == "cifar10" | ||||
|                 ValLoaders = { | ||||
|                     "ori-test": torch.utils.data.DataLoader( | ||||
|                         valid_data, | ||||
|                         batch_size=config.batch_size, | ||||
|                         shuffle=False, | ||||
|                         num_workers=workers, | ||||
|                         pin_memory=True, | ||||
|                     ) | ||||
|                 } | ||||
|                 assert len(train_data) == len(split_info.train) + len( | ||||
|                     split_info.valid | ||||
|                 ), "invalid length : {:} vs {:} + {:}".format( | ||||
|                     len(train_data), len(split_info.train), len(split_info.valid) | ||||
|                 ) | ||||
|                 train_data_v2 = deepcopy(train_data) | ||||
|                 train_data_v2.transform = valid_data.transform | ||||
|                 valid_data = train_data_v2 | ||||
|                 # data loader | ||||
|                 train_loader = torch.utils.data.DataLoader( | ||||
|                     train_data, | ||||
|                     batch_size=config.batch_size, | ||||
|                     shuffle=False, | ||||
|                     sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), | ||||
|                     num_workers=workers, | ||||
|                     pin_memory=True, | ||||
|                 ) | ||||
|             } | ||||
|             assert len(train_data) == len(split_info.train) + len( | ||||
|                 split_info.valid | ||||
|             ), "invalid length : {:} vs {:} + {:}".format( | ||||
|                 len(train_data), len(split_info.train), len(split_info.valid) | ||||
|             ) | ||||
|             train_data_v2 = deepcopy(train_data) | ||||
|             train_data_v2.transform = valid_data.transform | ||||
|             valid_data = train_data_v2 | ||||
|             # data loader | ||||
|             train_loader = torch.utils.data.DataLoader( | ||||
|                 train_data, | ||||
|                 batch_size=config.batch_size, | ||||
|                 sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), | ||||
|                 num_workers=workers, | ||||
|                 pin_memory=True, | ||||
|             ) | ||||
|             valid_loader = torch.utils.data.DataLoader( | ||||
|                 valid_data, | ||||
|                 batch_size=config.batch_size, | ||||
|                 sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), | ||||
|                 num_workers=workers, | ||||
|                 pin_memory=True, | ||||
|             ) | ||||
|             ValLoaders["x-valid"] = valid_loader | ||||
|                 valid_loader = torch.utils.data.DataLoader( | ||||
|                     valid_data, | ||||
|                     batch_size=config.batch_size, | ||||
|                     sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), | ||||
|                     num_workers=workers, | ||||
|                     pin_memory=True, | ||||
|                 ) | ||||
|                 ValLoaders["x-valid"] = valid_loader | ||||
|             elif dataset == "aircraft": | ||||
|                 ValLoaders = { | ||||
|                     "ori-test": torch.utils.data.DataLoader( | ||||
|                         valid_data, | ||||
|                         batch_size=config.batch_size, | ||||
|                         shuffle=False, | ||||
|                         num_workers=workers, | ||||
|                         pin_memory=True, | ||||
|                     ) | ||||
|                 } | ||||
|                 train_data_v2 = deepcopy(train_data) | ||||
|                 train_data_v2.transform = valid_data.transform | ||||
|                 valid_data = train_data_v2 | ||||
|                 # 使用 DataLoader | ||||
|                 train_loader = torch.utils.data.DataLoader( | ||||
|                     train_data,  | ||||
|                     batch_size=config.batch_size,  | ||||
|                     sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), | ||||
|                     num_workers=workers, | ||||
|                     pin_memory=True) | ||||
|                 valid_loader = torch.utils.data.DataLoader( | ||||
|                     valid_data, | ||||
|                     batch_size=config.batch_size,  | ||||
|                     sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), | ||||
|                     num_workers=workers, | ||||
|                     pin_memory=True) | ||||
|             elif dataset == "oxford": | ||||
|                 ValLoaders = { | ||||
|                     "ori-test": torch.utils.data.DataLoader( | ||||
|                         valid_data, | ||||
|                         batch_size=config.batch_size, | ||||
|                         shuffle=False, | ||||
|                         num_workers=workers, | ||||
|                         pin_memory=True | ||||
|                     ) | ||||
|                 } | ||||
|                 # train_data_v2 = deepcopy(train_data) | ||||
|                 # train_data_v2.transform = valid_data.transform | ||||
|                 train_loader = torch.utils.data.DataLoader( | ||||
|                     train_data,  | ||||
|                     batch_size=config.batch_size,  | ||||
|                     sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), | ||||
|                     num_workers=workers, | ||||
|                     pin_memory=True) | ||||
|                 valid_loader = torch.utils.data.DataLoader( | ||||
|                     valid_data, | ||||
|                     batch_size=config.batch_size,  | ||||
|                     sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), | ||||
|                     num_workers=workers, | ||||
|                     pin_memory=True) | ||||
|  | ||||
|         else: | ||||
|             # data loader | ||||
|             train_loader = torch.utils.data.DataLoader( | ||||
| @@ -103,7 +259,7 @@ def evaluate_all_datasets( | ||||
|                 num_workers=workers, | ||||
|                 pin_memory=True, | ||||
|             ) | ||||
|             if dataset == "cifar10": | ||||
|             if dataset == "cifar10" or dataset == "aircraft" or dataset == "oxford": | ||||
|                 ValLoaders = {"ori-test": valid_loader} | ||||
|             elif dataset == "cifar100": | ||||
|                 cifar100_splits = load_config( | ||||
|   | ||||
| @@ -28,16 +28,41 @@ else | ||||
|   mode=cover | ||||
| fi | ||||
|  | ||||
| # OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \ | ||||
| # 	--mode ${mode} --save_dir ${save_dir} --max_node 4 \ | ||||
| # 	--use_less ${use_less} \ | ||||
| # 	--datasets cifar10 cifar10 cifar100 ImageNet16-120 \ | ||||
| # 	--splits   1       0       0        0 \ | ||||
| # 	--xpaths $TORCH_HOME/cifar.python \ | ||||
| # 		 $TORCH_HOME/cifar.python \ | ||||
| # 		 $TORCH_HOME/cifar.python \ | ||||
| # 		 $TORCH_HOME/cifar.python/ImageNet16 \ | ||||
| # 	--channel 16 --num_cells 5 \ | ||||
| # 	--workers 4 \ | ||||
| # 	--srange ${xstart} ${xend} --arch_index ${arch_index} \ | ||||
| # 	--seeds ${all_seeds} | ||||
|  | ||||
| OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \ | ||||
| 	--mode ${mode} --save_dir ${save_dir} --max_node 4 \ | ||||
| 	--use_less ${use_less} \ | ||||
| 	--datasets cifar10 cifar10 cifar100 ImageNet16-120 \ | ||||
| 	--splits   1       0       0        0 \ | ||||
| 	--xpaths $TORCH_HOME/cifar.python \ | ||||
| 		 $TORCH_HOME/cifar.python \ | ||||
| 		 $TORCH_HOME/cifar.python \ | ||||
| 		 $TORCH_HOME/cifar.python/ImageNet16 \ | ||||
| 	--channel 16 --num_cells 5 \ | ||||
| 	--datasets aircraft \ | ||||
| 	--xpaths /lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/ \ | ||||
| 	--channel 16 \ | ||||
| 	--splits 1 \ | ||||
| 	--num_cells 5 \ | ||||
| 	--workers 4 \ | ||||
| 	--srange ${xstart} ${xend} --arch_index ${arch_index} \ | ||||
| 	--seeds ${all_seeds} | ||||
|  | ||||
| # OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \ | ||||
| # 	--mode ${mode} --save_dir ${save_dir} --max_node 4 \ | ||||
| # 	--use_less ${use_less} \ | ||||
| # 	--datasets oxford\ | ||||
| # 	--xpaths /lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/ \ | ||||
| # 	--channel 16 \ | ||||
| # 	--splits 1 \ | ||||
| # 	--num_cells 5 \ | ||||
| # 	--workers 4 \ | ||||
| # 	--srange ${xstart} ${xend} --arch_index ${arch_index} \ | ||||
| # 	--seeds ${all_seeds} | ||||
|  | ||||
|   | ||||
							
								
								
									
										104336
									
								
								test.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104336
									
								
								test.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										616
									
								
								test_network.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										616
									
								
								test_network.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,616 @@ | ||||
| from nas_201_api import NASBench201API as API | ||||
| import os | ||||
|  | ||||
| import os, sys, time, torch, random, argparse | ||||
| from PIL import ImageFile | ||||
|  | ||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| from xautodl.config_utils import load_config | ||||
| from xautodl.procedures import save_checkpoint, copy_checkpoint | ||||
| from xautodl.procedures import get_machine_info | ||||
| from xautodl.datasets import get_datasets | ||||
| from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import CellStructure, CellArchitectures, get_search_spaces | ||||
|  | ||||
| import time, torch | ||||
| from xautodl.procedures import prepare_seed, get_optim_scheduler | ||||
| from xautodl.utils import get_model_infos, obtain_accuracy | ||||
| from xautodl.config_utils import dict2config | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import get_cell_based_tiny_net | ||||
|  | ||||
| cur_path = os.path.abspath(os.path.curdir) | ||||
| data_path = os.path.join(cur_path, 'NAS-Bench-201-v1_1-096897.pth') | ||||
| print(f'loading data from {data_path}') | ||||
| print(f'loading') | ||||
| api = API(data_path) | ||||
| print(f'loaded') | ||||
|  | ||||
| def find_best_index(dataset): | ||||
|     len = 15625 | ||||
|     accs = [] | ||||
|     for i in range(1, len): | ||||
|         results = api.query_by_index(i, dataset) | ||||
|         dict_items = list(results.items()) | ||||
|         train_info = dict_items[0][1].get_train() | ||||
|         acc = train_info['accuracy'] | ||||
|         accs.append((i, acc)) | ||||
|     return max(accs, key=lambda x: x[1]) | ||||
|  | ||||
| best_cifar_10_index, best_cifar_10_acc = find_best_index('cifar10') | ||||
| best_cifar_100_index, best_cifar_100_acc = find_best_index('cifar100') | ||||
| best_ImageNet16_index, best_ImageNet16_acc= find_best_index('ImageNet16-120') | ||||
| print(f'find best cifar10 index: {best_cifar_10_index}, acc: {best_cifar_10_acc}') | ||||
| print(f'find best cifar100 index: {best_cifar_100_index}, acc: {best_cifar_100_acc}') | ||||
| print(f'find best ImageNet16 index: {best_ImageNet16_index}, acc: {best_ImageNet16_acc}') | ||||
|  | ||||
| from xautodl.models import get_cell_based_tiny_net | ||||
| def get_network_str_by_id(id, dataset): | ||||
|     config = api.get_net_config(id, dataset) | ||||
|     return config['arch_str'] | ||||
|  | ||||
| best_cifar_10_str = get_network_str_by_id(best_cifar_10_index, 'cifar10') | ||||
| best_cifar_100_str = get_network_str_by_id(best_cifar_100_index, 'cifar100') | ||||
| best_ImageNet16_str = get_network_str_by_id(best_ImageNet16_index, 'ImageNet16-120') | ||||
|  | ||||
| def evaluate_all_datasets( | ||||
|     arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger | ||||
| ): | ||||
|     machine_info, arch_config = get_machine_info(), deepcopy(arch_config) | ||||
|     all_infos = {"info": machine_info} | ||||
|     all_dataset_keys = [] | ||||
|     # look all the datasets | ||||
|     for dataset, xpath, split in zip(datasets, xpaths, splits): | ||||
|         # train valid data | ||||
|         train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) | ||||
|         # load the configuration | ||||
|         if dataset == "cifar10" or dataset == "cifar100": | ||||
|             if use_less: | ||||
|                 config_path = "configs/nas-benchmark/LESS.config" | ||||
|             else: | ||||
|                 config_path = "configs/nas-benchmark/CIFAR.config" | ||||
|             split_info = load_config( | ||||
|                 "configs/nas-benchmark/cifar-split.txt", None, None | ||||
|             ) | ||||
|         elif dataset.startswith("ImageNet16"): | ||||
|             if use_less: | ||||
|                 config_path = "configs/nas-benchmark/LESS.config" | ||||
|             else: | ||||
|                 config_path = "configs/nas-benchmark/ImageNet-16.config" | ||||
|             split_info = load_config( | ||||
|                 "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("invalid dataset : {:}".format(dataset)) | ||||
|         config = load_config( | ||||
|             config_path, {"class_num": class_num, "xshape": xshape}, logger | ||||
|         ) | ||||
|         # check whether use splited validation set | ||||
|         if bool(split): | ||||
|             assert dataset == "cifar10" | ||||
|             ValLoaders = { | ||||
|                 "ori-test": torch.utils.data.DataLoader( | ||||
|                     valid_data, | ||||
|                     batch_size=config.batch_size, | ||||
|                     shuffle=False, | ||||
|                     num_workers=workers, | ||||
|                     pin_memory=True, | ||||
|                 ) | ||||
|             } | ||||
|             assert len(train_data) == len(split_info.train) + len( | ||||
|                 split_info.valid | ||||
|             ), "invalid length : {:} vs {:} + {:}".format( | ||||
|                 len(train_data), len(split_info.train), len(split_info.valid) | ||||
|             ) | ||||
|             train_data_v2 = deepcopy(train_data) | ||||
|             train_data_v2.transform = valid_data.transform | ||||
|             valid_data = train_data_v2 | ||||
|             # data loader | ||||
|             train_loader = torch.utils.data.DataLoader( | ||||
|                 train_data, | ||||
|                 batch_size=config.batch_size, | ||||
|                 sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), | ||||
|                 num_workers=workers, | ||||
|                 pin_memory=True, | ||||
|             ) | ||||
|             valid_loader = torch.utils.data.DataLoader( | ||||
|                 valid_data, | ||||
|                 batch_size=config.batch_size, | ||||
|                 sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), | ||||
|                 num_workers=workers, | ||||
|                 pin_memory=True, | ||||
|             ) | ||||
|             ValLoaders["x-valid"] = valid_loader | ||||
|         else: | ||||
|             # data loader | ||||
|             train_loader = torch.utils.data.DataLoader( | ||||
|                 train_data, | ||||
|                 batch_size=config.batch_size, | ||||
|                 shuffle=True, | ||||
|                 num_workers=workers, | ||||
|                 pin_memory=True, | ||||
|             ) | ||||
|             valid_loader = torch.utils.data.DataLoader( | ||||
|                 valid_data, | ||||
|                 batch_size=config.batch_size, | ||||
|                 shuffle=False, | ||||
|                 num_workers=workers, | ||||
|                 pin_memory=True, | ||||
|             ) | ||||
|             if dataset == "cifar10": | ||||
|                 ValLoaders = {"ori-test": valid_loader} | ||||
|             elif dataset == "cifar100": | ||||
|                 cifar100_splits = load_config( | ||||
|                     "configs/nas-benchmark/cifar100-test-split.txt", None, None | ||||
|                 ) | ||||
|                 ValLoaders = { | ||||
|                     "ori-test": valid_loader, | ||||
|                     "x-valid": torch.utils.data.DataLoader( | ||||
|                         valid_data, | ||||
|                         batch_size=config.batch_size, | ||||
|                         sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                             cifar100_splits.xvalid | ||||
|                         ), | ||||
|                         num_workers=workers, | ||||
|                         pin_memory=True, | ||||
|                     ), | ||||
|                     "x-test": torch.utils.data.DataLoader( | ||||
|                         valid_data, | ||||
|                         batch_size=config.batch_size, | ||||
|                         sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                             cifar100_splits.xtest | ||||
|                         ), | ||||
|                         num_workers=workers, | ||||
|                         pin_memory=True, | ||||
|                     ), | ||||
|                 } | ||||
|             elif dataset == "ImageNet16-120": | ||||
|                 imagenet16_splits = load_config( | ||||
|                     "configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None | ||||
|                 ) | ||||
|                 ValLoaders = { | ||||
|                     "ori-test": valid_loader, | ||||
|                     "x-valid": torch.utils.data.DataLoader( | ||||
|                         valid_data, | ||||
|                         batch_size=config.batch_size, | ||||
|                         sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                             imagenet16_splits.xvalid | ||||
|                         ), | ||||
|                         num_workers=workers, | ||||
|                         pin_memory=True, | ||||
|                     ), | ||||
|                     "x-test": torch.utils.data.DataLoader( | ||||
|                         valid_data, | ||||
|                         batch_size=config.batch_size, | ||||
|                         sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                             imagenet16_splits.xtest | ||||
|                         ), | ||||
|                         num_workers=workers, | ||||
|                         pin_memory=True, | ||||
|                     ), | ||||
|                 } | ||||
|             else: | ||||
|                 raise ValueError("invalid dataset : {:}".format(dataset)) | ||||
|  | ||||
|         dataset_key = "{:}".format(dataset) | ||||
|         if bool(split): | ||||
|             dataset_key = dataset_key + "-valid" | ||||
|         logger.log( | ||||
|             "Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( | ||||
|                 dataset_key, | ||||
|                 len(train_data), | ||||
|                 len(valid_data), | ||||
|                 len(train_loader), | ||||
|                 len(valid_loader), | ||||
|                 config.batch_size, | ||||
|             ) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config) | ||||
|         ) | ||||
|         for key, value in ValLoaders.items(): | ||||
|             logger.log( | ||||
|                 "Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value)) | ||||
|             ) | ||||
|         results = evaluate_for_seed( | ||||
|             arch_config, config, arch, train_loader, ValLoaders, seed, logger | ||||
|         ) | ||||
|         all_infos[dataset_key] = results | ||||
|         all_dataset_keys.append(dataset_key) | ||||
|     all_infos["all_dataset_keys"] = all_dataset_keys | ||||
|     return all_infos | ||||
|  | ||||
| def evaluate_for_seed( | ||||
|     arch_config, config, arch, train_loader, valid_loaders, seed, logger | ||||
| ): | ||||
|  | ||||
|     prepare_seed(seed)  # random seed | ||||
|     net = get_cell_based_tiny_net( | ||||
|         dict2config( | ||||
|             { | ||||
|                 "name": "infer.tiny", | ||||
|                 "C": arch_config["channel"], | ||||
|                 "N": arch_config["num_cells"], | ||||
|                 "genotype": arch, | ||||
|                 "num_classes": config.class_num, | ||||
|             }, | ||||
|             None, | ||||
|         ) | ||||
|     ) | ||||
|     # net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) | ||||
|     flop, param = get_model_infos(net, config.xshape) | ||||
|     logger.log("Network : {:}".format(net.get_message()), False) | ||||
|     logger.log( | ||||
|         "{:} Seed-------------------------- {:} --------------------------".format( | ||||
|             time_string(), seed | ||||
|         ) | ||||
|     ) | ||||
|     logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param)) | ||||
|     # train and valid | ||||
|     optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config) | ||||
|     network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda() | ||||
|     # start training | ||||
|     start_time, epoch_time, total_epoch = ( | ||||
|         time.time(), | ||||
|         AverageMeter(), | ||||
|         config.epochs + config.warmup, | ||||
|     ) | ||||
|     ( | ||||
|         train_losses, | ||||
|         train_acc1es, | ||||
|         train_acc5es, | ||||
|         valid_losses, | ||||
|         valid_acc1es, | ||||
|         valid_acc5es, | ||||
|     ) = ({}, {}, {}, {}, {}, {}) | ||||
|     train_times, valid_times = {}, {} | ||||
|     for epoch in range(total_epoch): | ||||
|         scheduler.update(epoch, 0.0) | ||||
|  | ||||
|         train_loss, train_acc1, train_acc5, train_tm = procedure( | ||||
|             train_loader, network, criterion, scheduler, optimizer, "train" | ||||
|         ) | ||||
|         train_losses[epoch] = train_loss | ||||
|         train_acc1es[epoch] = train_acc1 | ||||
|         train_acc5es[epoch] = train_acc5 | ||||
|         train_times[epoch] = train_tm | ||||
|         with torch.no_grad(): | ||||
|             for key, xloder in valid_loaders.items(): | ||||
|                 valid_loss, valid_acc1, valid_acc5, valid_tm = procedure( | ||||
|                     xloder, network, criterion, None, None, "valid" | ||||
|                 ) | ||||
|                 valid_losses["{:}@{:}".format(key, epoch)] = valid_loss | ||||
|                 valid_acc1es["{:}@{:}".format(key, epoch)] = valid_acc1 | ||||
|                 valid_acc5es["{:}@{:}".format(key, epoch)] = valid_acc5 | ||||
|                 valid_times["{:}@{:}".format(key, epoch)] = valid_tm | ||||
|  | ||||
|         # measure elapsed time | ||||
|         epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]".format( | ||||
|                 time_string(), | ||||
|                 need_time, | ||||
|                 epoch, | ||||
|                 total_epoch, | ||||
|                 train_loss, | ||||
|                 train_acc1, | ||||
|                 train_acc5, | ||||
|                 valid_loss, | ||||
|                 valid_acc1, | ||||
|                 valid_acc5, | ||||
|             ) | ||||
|         ) | ||||
|     info_seed = { | ||||
|         "flop": flop, | ||||
|         "param": param, | ||||
|         "channel": arch_config["channel"], | ||||
|         "num_cells": arch_config["num_cells"], | ||||
|         "config": config._asdict(), | ||||
|         "total_epoch": total_epoch, | ||||
|         "train_losses": train_losses, | ||||
|         "train_acc1es": train_acc1es, | ||||
|         "train_acc5es": train_acc5es, | ||||
|         "train_times": train_times, | ||||
|         "valid_losses": valid_losses, | ||||
|         "valid_acc1es": valid_acc1es, | ||||
|         "valid_acc5es": valid_acc5es, | ||||
|         "valid_times": valid_times, | ||||
|         "net_state_dict": net.state_dict(), | ||||
|         "net_string": "{:}".format(net), | ||||
|         "finish-train": True, | ||||
|     } | ||||
|     return info_seed | ||||
|  | ||||
| def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | ||||
|     data_time, batch_time, batch = AverageMeter(), AverageMeter(), None | ||||
|     losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     latencies = [] | ||||
|     network.eval() | ||||
|     with torch.no_grad(): | ||||
|         end = time.time() | ||||
|         for i, (inputs, targets) in enumerate(xloader): | ||||
|             targets = targets.cuda(non_blocking=True) | ||||
|             inputs = inputs.cuda(non_blocking=True) | ||||
|             data_time.update(time.time() - end) | ||||
|             # forward | ||||
|             features, logits = network(inputs) | ||||
|             loss = criterion(logits, targets) | ||||
|             batch_time.update(time.time() - end) | ||||
|             if batch is None or batch == inputs.size(0): | ||||
|                 batch = inputs.size(0) | ||||
|                 latencies.append(batch_time.val - data_time.val) | ||||
|             # record loss and accuracy | ||||
|             prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|             losses.update(loss.item(), inputs.size(0)) | ||||
|             top1.update(prec1.item(), inputs.size(0)) | ||||
|             top5.update(prec5.item(), inputs.size(0)) | ||||
|             end = time.time() | ||||
|     if len(latencies) > 2: | ||||
|         latencies = latencies[1:] | ||||
|     return losses.avg, top1.avg, top5.avg, latencies | ||||
|  | ||||
|  | ||||
| def procedure(xloader, network, criterion, scheduler, optimizer, mode): | ||||
|     losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     if mode == "train": | ||||
|         network.train() | ||||
|     elif mode == "valid": | ||||
|         network.eval() | ||||
|     else: | ||||
|         raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|  | ||||
|     data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|         if mode == "train": | ||||
|             scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|  | ||||
|         targets = targets.cuda(non_blocking=True) | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|         # forward | ||||
|         features, logits = network(inputs) | ||||
|         loss = criterion(logits, targets) | ||||
|         # backward | ||||
|         if mode == "train": | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|         # record loss and accuracy | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|         losses.update(loss.item(), inputs.size(0)) | ||||
|         top1.update(prec1.item(), inputs.size(0)) | ||||
|         top5.update(prec5.item(), inputs.size(0)) | ||||
|         # count time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|     return losses.avg, top1.avg, top5.avg, batch_time.sum | ||||
|  | ||||
| def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | ||||
|     data_time, batch_time, batch = AverageMeter(), AverageMeter(), None | ||||
|     losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     latencies = [] | ||||
|     network.eval() | ||||
|     with torch.no_grad(): | ||||
|         end = time.time() | ||||
|         for i, (inputs, targets) in enumerate(xloader): | ||||
|             targets = targets.cuda(non_blocking=True) | ||||
|             inputs = inputs.cuda(non_blocking=True) | ||||
|             data_time.update(time.time() - end) | ||||
|             # forward | ||||
|             features, logits = network(inputs) | ||||
|             loss = criterion(logits, targets) | ||||
|             batch_time.update(time.time() - end) | ||||
|             if batch is None or batch == inputs.size(0): | ||||
|                 batch = inputs.size(0) | ||||
|                 latencies.append(batch_time.val - data_time.val) | ||||
|             # record loss and accuracy | ||||
|             prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|             losses.update(loss.item(), inputs.size(0)) | ||||
|             top1.update(prec1.item(), inputs.size(0)) | ||||
|             top5.update(prec5.item(), inputs.size(0)) | ||||
|             end = time.time() | ||||
|     if len(latencies) > 2: | ||||
|         latencies = latencies[1:] | ||||
|     return losses.avg, top1.avg, top5.avg, latencies | ||||
|  | ||||
|  | ||||
| def procedure(xloader, network, criterion, scheduler, optimizer, mode): | ||||
|     losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     if mode == "train": | ||||
|         network.train() | ||||
|     elif mode == "valid": | ||||
|         network.eval() | ||||
|     else: | ||||
|         raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|  | ||||
|     data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|         if mode == "train": | ||||
|             scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|  | ||||
|         targets = targets.cuda(non_blocking=True) | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|         # forward | ||||
|         features, logits = network(inputs) | ||||
|         loss = criterion(logits, targets) | ||||
|         # backward | ||||
|         if mode == "train": | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|         # record loss and accuracy | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|         losses.update(loss.item(), inputs.size(0)) | ||||
|         top1.update(prec1.item(), inputs.size(0)) | ||||
|         top5.update(prec5.item(), inputs.size(0)) | ||||
|         # count time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|     return losses.avg, top1.avg, top5.avg, batch_time.sum | ||||
|  | ||||
| def train_single_model( | ||||
|     save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config | ||||
| ): | ||||
|     assert torch.cuda.is_available(), "CUDA is not available." | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     # torch.backends.cudnn.benchmark = True | ||||
|     torch.set_num_threads(workers) | ||||
|  | ||||
|     save_dir = ( | ||||
|         Path(save_dir) | ||||
|         / "specifics" | ||||
|         / "{:}-{:}-{:}-{:}".format( | ||||
|             "LESS" if use_less else "FULL", | ||||
|             model_str, | ||||
|             arch_config["channel"], | ||||
|             arch_config["num_cells"], | ||||
|         ) | ||||
|     ) | ||||
|     logger = Logger(str(save_dir), 0, False) | ||||
|     print(CellArchitectures) | ||||
|     if model_str in CellArchitectures: | ||||
|         arch = CellArchitectures[model_str] | ||||
|         logger.log( | ||||
|             "The model string is found in pre-defined architecture dict : {:}".format( | ||||
|                 model_str | ||||
|             ) | ||||
|         ) | ||||
|     else: | ||||
|         try: | ||||
|             arch = CellStructure.str2structure(model_str) | ||||
|         except: | ||||
|             raise ValueError( | ||||
|                 "Invalid model string : {:}. It can not be found or parsed.".format( | ||||
|                     model_str | ||||
|                 ) | ||||
|             ) | ||||
|     assert arch.check_valid_op( | ||||
|         get_search_spaces("cell", "nas-bench-201") | ||||
|     ), "{:} has the invalid op.".format(arch) | ||||
|     logger.log("Start train-evaluate {:}".format(arch.tostr())) | ||||
|     logger.log("arch_config : {:}".format(arch_config)) | ||||
|  | ||||
|     start_time, seed_time = time.time(), AverageMeter() | ||||
|     for _is, seed in enumerate(seeds): | ||||
|         logger.log( | ||||
|             "\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------".format( | ||||
|                 _is, len(seeds), seed | ||||
|             ) | ||||
|         ) | ||||
|         to_save_name = save_dir / "seed-{:04d}.pth".format(seed) | ||||
|         if to_save_name.exists(): | ||||
|             logger.log( | ||||
|                 "Find the existing file {:}, directly load!".format(to_save_name) | ||||
|             ) | ||||
|             checkpoint = torch.load(to_save_name) | ||||
|         else: | ||||
|             logger.log( | ||||
|                 "Does not find the existing file {:}, train and evaluate!".format( | ||||
|                     to_save_name | ||||
|                 ) | ||||
|             ) | ||||
|             checkpoint = evaluate_all_datasets( | ||||
|                 arch, | ||||
|                 datasets, | ||||
|                 xpaths, | ||||
|                 splits, | ||||
|                 use_less, | ||||
|                 seed, | ||||
|                 arch_config, | ||||
|                 workers, | ||||
|                 logger, | ||||
|             ) | ||||
|             torch.save(checkpoint, to_save_name) | ||||
|         # log information | ||||
|         logger.log("{:}".format(checkpoint["info"])) | ||||
|         all_dataset_keys = checkpoint["all_dataset_keys"] | ||||
|         for dataset_key in all_dataset_keys: | ||||
|             logger.log( | ||||
|                 "\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15) | ||||
|             ) | ||||
|             dataset_info = checkpoint[dataset_key] | ||||
|             # logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] )) | ||||
|             logger.log( | ||||
|                 "Flops = {:} MB, Params = {:} MB".format( | ||||
|                     dataset_info["flop"], dataset_info["param"] | ||||
|                 ) | ||||
|             ) | ||||
|             logger.log("config : {:}".format(dataset_info["config"])) | ||||
|             logger.log( | ||||
|                 "Training State (finish) = {:}".format(dataset_info["finish-train"]) | ||||
|             ) | ||||
|             last_epoch = dataset_info["total_epoch"] - 1 | ||||
|             train_acc1es, train_acc5es = ( | ||||
|                 dataset_info["train_acc1es"], | ||||
|                 dataset_info["train_acc5es"], | ||||
|             ) | ||||
|             valid_acc1es, valid_acc5es = ( | ||||
|                 dataset_info["valid_acc1es"], | ||||
|                 dataset_info["valid_acc5es"], | ||||
|             ) | ||||
|             print(dataset_info["train_acc1es"]) | ||||
|             print(dataset_info["train_acc5es"]) | ||||
|             print(dataset_info["valid_acc1es"]) | ||||
|             print(dataset_info["valid_acc5es"]) | ||||
|             logger.log( | ||||
|                 "Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%".format( | ||||
|                     train_acc1es[last_epoch], | ||||
|                     train_acc5es[last_epoch], | ||||
|                     100 - train_acc1es[last_epoch], | ||||
|                     valid_acc1es['ori-test@'+str(last_epoch)], | ||||
|                     valid_acc5es['ori-test@'+str(last_epoch)], | ||||
|                     100 - valid_acc1es['ori-test@'+str(last_epoch)], | ||||
|                 ) | ||||
|             ) | ||||
|         # measure elapsed time | ||||
|         seed_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}".format( | ||||
|                 _is, len(seeds), seed, need_time | ||||
|             ) | ||||
|         ) | ||||
|     logger.close() | ||||
|  | ||||
| # |nor_conv_3x3~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2| | ||||
| train_strs = [best_cifar_10_str, best_cifar_100_str, best_ImageNet16_str] | ||||
| train_single_model( | ||||
|     save_dir="./outputs", | ||||
|     workers=8, | ||||
|     datasets=["ImageNet16-120"],  | ||||
|     xpaths="./datasets/imagenet16-120", | ||||
|     splits=[0, 0, 0], | ||||
|     use_less=False, | ||||
|     seeds=[777], | ||||
|     model_str=best_ImageNet16_str, | ||||
|     arch_config={"channel": 16, "num_cells": 8},) | ||||
| train_single_model( | ||||
|     save_dir="./outputs", | ||||
|     workers=8, | ||||
|     datasets=["cifar10"],  | ||||
|     xpaths="./datasets/cifar10", | ||||
|     splits=[0, 0, 0], | ||||
|     use_less=False, | ||||
|     seeds=[777], | ||||
|     model_str=best_cifar_10_str, | ||||
|     arch_config={"channel": 16, "num_cells": 8},) | ||||
| train_single_model( | ||||
|     save_dir="./outputs", | ||||
|     workers=8, | ||||
|     datasets=["cifar100"],  | ||||
|     xpaths="./datasets/cifar100", | ||||
|     splits=[0, 0, 0], | ||||
|     use_less=False, | ||||
|     seeds=[777], | ||||
|     model_str=best_cifar_100_str, | ||||
|     arch_config={"channel": 16, "num_cells": 8},) | ||||
| @@ -1,40 +1,39 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| # Modified by Hayeon Lee, Eunyoung Hyung 2021. 03. | ||||
| ################################################## | ||||
| import os, sys, torch | ||||
| import os | ||||
| import sys | ||||
| import torch | ||||
| import os.path as osp | ||||
| import numpy as np | ||||
| import torchvision.datasets as dset | ||||
| import torchvision.transforms as transforms | ||||
| from copy import deepcopy | ||||
| from PIL import Image | ||||
|  | ||||
| from xautodl.config_utils import load_config | ||||
|  | ||||
| from .DownsampledImageNet import ImageNet16 | ||||
| from .SearchDatasetWrap import SearchDataset | ||||
|  | ||||
| # from PIL import Image | ||||
| import random | ||||
| import pdb | ||||
| from .aircraft import FGVCAircraft | ||||
| from .pets import PetDataset | ||||
| from config_utils import load_config | ||||
|  | ||||
| Dataset2Class = { | ||||
|     "cifar10": 10, | ||||
|     "cifar100": 100, | ||||
|     "imagenet-1k-s": 1000, | ||||
|     "imagenet-1k": 1000, | ||||
|     "ImageNet16": 1000, | ||||
|     "ImageNet16-150": 150, | ||||
|     "ImageNet16-120": 120, | ||||
|     "ImageNet16-200": 200, | ||||
| } | ||||
| Dataset2Class = {'cifar10': 10, | ||||
|                  'cifar100': 100, | ||||
|                  'mnist': 10, | ||||
|                  'svhn': 10, | ||||
|                  'aircraft': 30, | ||||
|                  'oxford': 37} | ||||
|  | ||||
|  | ||||
| class CUTOUT(object): | ||||
|  | ||||
|     def __init__(self, length): | ||||
|         self.length = length | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(length={length})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|         return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|     def __call__(self, img): | ||||
|         h, w = img.size(1), img.size(2) | ||||
| @@ -47,7 +46,7 @@ class CUTOUT(object): | ||||
|         x1 = np.clip(x - self.length // 2, 0, w) | ||||
|         x2 = np.clip(x + self.length // 2, 0, w) | ||||
|  | ||||
|         mask[y1:y2, x1:x2] = 0.0 | ||||
|         mask[y1: y2, x1: x2] = 0. | ||||
|         mask = torch.from_numpy(mask) | ||||
|         mask = mask.expand_as(img) | ||||
|         img *= mask | ||||
| @@ -55,21 +54,19 @@ class CUTOUT(object): | ||||
|  | ||||
|  | ||||
| imagenet_pca = { | ||||
|     "eigval": np.asarray([0.2175, 0.0188, 0.0045]), | ||||
|     "eigvec": np.asarray( | ||||
|         [ | ||||
|             [-0.5675, 0.7192, 0.4009], | ||||
|             [-0.5808, -0.0045, -0.8140], | ||||
|             [-0.5836, -0.6948, 0.4203], | ||||
|         ] | ||||
|     ), | ||||
|     'eigval': np.asarray([0.2175, 0.0188, 0.0045]), | ||||
|     'eigvec': np.asarray([ | ||||
|         [-0.5675, 0.7192, 0.4009], | ||||
|         [-0.5808, -0.0045, -0.8140], | ||||
|         [-0.5836, -0.6948, 0.4203], | ||||
|     ]) | ||||
| } | ||||
|  | ||||
|  | ||||
| class Lighting(object): | ||||
|     def __init__( | ||||
|         self, alphastd, eigval=imagenet_pca["eigval"], eigvec=imagenet_pca["eigvec"] | ||||
|     ): | ||||
|     def __init__(self, alphastd, | ||||
|                  eigval=imagenet_pca['eigval'], | ||||
|                  eigvec=imagenet_pca['eigvec']): | ||||
|         self.alphastd = alphastd | ||||
|         assert eigval.shape == (3,) | ||||
|         assert eigvec.shape == (3, 3) | ||||
| @@ -77,10 +74,10 @@ class Lighting(object): | ||||
|         self.eigvec = eigvec | ||||
|  | ||||
|     def __call__(self, img): | ||||
|         if self.alphastd == 0.0: | ||||
|         if self.alphastd == 0.: | ||||
|             return img | ||||
|         rnd = np.random.randn(3) * self.alphastd | ||||
|         rnd = rnd.astype("float32") | ||||
|         rnd = rnd.astype('float32') | ||||
|         v = rnd | ||||
|         old_dtype = np.asarray(img).dtype | ||||
|         v = v * self.eigval | ||||
| @@ -89,275 +86,222 @@ class Lighting(object): | ||||
|         img = np.add(img, inc) | ||||
|         if old_dtype == np.uint8: | ||||
|             img = np.clip(img, 0, 255) | ||||
|         img = Image.fromarray(img.astype(old_dtype), "RGB") | ||||
|         img = Image.fromarray(img.astype(old_dtype), 'RGB') | ||||
|         return img | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return self.__class__.__name__ + "()" | ||||
|         return self.__class__.__name__ + '()' | ||||
|  | ||||
|  | ||||
| def get_datasets(name, root, cutout): | ||||
|  | ||||
|     if name == "cifar10": | ||||
| def get_datasets(name, root, cutout, use_num_cls=None): | ||||
|     if name == 'cifar10': | ||||
|         mean = [x / 255 for x in [125.3, 123.0, 113.9]] | ||||
|         std = [x / 255 for x in [63.0, 62.1, 66.7]] | ||||
|     elif name == "cifar100": | ||||
|     elif name == 'cifar100': | ||||
|         mean = [x / 255 for x in [129.3, 124.1, 112.4]] | ||||
|         std = [x / 255 for x in [68.2, 65.4, 70.4]] | ||||
|     elif name.startswith("imagenet-1k"): | ||||
|         mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] | ||||
|     elif name.startswith("ImageNet16"): | ||||
|         mean = [x / 255 for x in [122.68, 116.66, 104.01]] | ||||
|         std = [x / 255 for x in [63.22, 61.26, 65.09]] | ||||
|     elif name.startswith('mnist'): | ||||
|         mean, std = [0.1307, 0.1307, 0.1307], [0.3081, 0.3081, 0.3081] | ||||
|     elif name.startswith('svhn'): | ||||
|         mean, std = [0.4376821, 0.4437697, 0.47280442], [ 0.19803012, 0.20101562, 0.19703614] | ||||
|     elif name.startswith('aircraft'): | ||||
|         mean = [0.48933587508932375, 0.5183537408957618, 0.5387914411673883] | ||||
| 		std = [0.22388883112804625, 0.21641635409388751, 0.24615605842636115] | ||||
|     elif name.startswith('oxford'): | ||||
|         mean = [0.4828895122298728, 0.4448394893850807, 0.39566558230789783] | ||||
| 		std = [0.25925664613996574, 0.2532760018681693, 0.25981017205097917] | ||||
|     else: | ||||
|         raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|  | ||||
|     # Data Argumentation | ||||
|     if name == "cifar10" or name == "cifar100": | ||||
|         lists = [ | ||||
|             transforms.RandomHorizontalFlip(), | ||||
|             transforms.RandomCrop(32, padding=4), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Normalize(mean, std), | ||||
|         ] | ||||
|     if name == 'cifar10' or name == 'cifar100': | ||||
|         lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), | ||||
|                  transforms.Normalize(mean, std)] | ||||
|         if cutout > 0: | ||||
|             lists += [CUTOUT(cutout)] | ||||
|         train_transform = transforms.Compose(lists) | ||||
|         test_transform = transforms.Compose( | ||||
|             [transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|         ) | ||||
|             [transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|         xshape = (1, 3, 32, 32) | ||||
|     elif name.startswith("ImageNet16"): | ||||
|         lists = [ | ||||
|             transforms.RandomHorizontalFlip(), | ||||
|             transforms.RandomCrop(16, padding=2), | ||||
|     elif name.startswith('cub200'): | ||||
|         train_transform = transforms.Compose([ | ||||
|             transforms.Resize((32, 32)), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Normalize(mean, std), | ||||
|         ] | ||||
|         if cutout > 0: | ||||
|             lists += [CUTOUT(cutout)] | ||||
|         train_transform = transforms.Compose(lists) | ||||
|         test_transform = transforms.Compose( | ||||
|             [transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|         ) | ||||
|         xshape = (1, 3, 16, 16) | ||||
|     elif name == "tiered": | ||||
|         lists = [ | ||||
|             transforms.RandomHorizontalFlip(), | ||||
|             transforms.RandomCrop(80, padding=4), | ||||
|             transforms.Normalize(mean=mean, std=std) | ||||
|         ]) | ||||
|         test_transform = transforms.Compose([ | ||||
|             transforms.Resize((32, 32)), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Normalize(mean, std), | ||||
|         ] | ||||
|         if cutout > 0: | ||||
|             lists += [CUTOUT(cutout)] | ||||
|         train_transform = transforms.Compose(lists) | ||||
|         test_transform = transforms.Compose( | ||||
|             [ | ||||
|                 transforms.CenterCrop(80), | ||||
|                 transforms.ToTensor(), | ||||
|                 transforms.Normalize(mean, std), | ||||
|             ] | ||||
|         ) | ||||
|             transforms.Normalize(mean=mean, std=std) | ||||
|         ]) | ||||
|         xshape = (1, 3, 32, 32) | ||||
|     elif name.startswith('mnist'): | ||||
|         train_transform = transforms.Compose([ | ||||
|             transforms.Resize((32, 32)), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Lambda(lambda x: x.repeat(3, 1, 1)), | ||||
|             transforms.Normalize(mean, std), | ||||
|         ]) | ||||
|         test_transform = transforms.Compose([ | ||||
|             transforms.Resize((32, 32)), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Lambda(lambda x: x.repeat(3, 1, 1)), | ||||
|             transforms.Normalize(mean, std) | ||||
|         ]) | ||||
|         xshape = (1, 3, 32, 32) | ||||
|     elif name.startswith('svhn'): | ||||
|         train_transform = transforms.Compose([ | ||||
|             transforms.Resize((32, 32)), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Normalize(mean=mean, std=std) | ||||
|         ]) | ||||
|         test_transform = transforms.Compose([ | ||||
|             transforms.Resize((32, 32)), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Normalize(mean=mean, std=std) | ||||
|         ]) | ||||
|         xshape = (1, 3, 32, 32) | ||||
|     elif name.startswith('aircraft'): | ||||
|         train_transform = transforms.Compose([ | ||||
|             transforms.Resize((32, 32)), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Normalize(mean=mean, std=std) | ||||
|         ]) | ||||
|         test_transform = transforms.Compose([ | ||||
|             transforms.Resize((32, 32)), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Normalize(mean=mean, std=std), | ||||
|         ]) | ||||
|         xshape = (1, 3, 32, 32) | ||||
|     elif name.startswith('oxford'): | ||||
|         train_transform = transforms.Compose([ | ||||
|             transforms.Resize((32, 32)), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Normalize(mean=mean, std=std) | ||||
|         ]) | ||||
|         test_transform = transforms.Compose([ | ||||
|             transforms.Resize((32, 32)), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Normalize(mean=mean, std=std), | ||||
|         ]) | ||||
|         xshape = (1, 3, 32, 32) | ||||
|     elif name.startswith("imagenet-1k"): | ||||
|         normalize = transforms.Normalize( | ||||
|             mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | ||||
|         ) | ||||
|         if name == "imagenet-1k": | ||||
|             xlists = [transforms.RandomResizedCrop(224)] | ||||
|             xlists.append( | ||||
|                 transforms.ColorJitter( | ||||
|                     brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2 | ||||
|                 ) | ||||
|             ) | ||||
|             xlists.append(Lighting(0.1)) | ||||
|         elif name == "imagenet-1k-s": | ||||
|             xlists = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))] | ||||
|         else: | ||||
|             raise ValueError("invalid name : {:}".format(name)) | ||||
|         xlists.append(transforms.RandomHorizontalFlip(p=0.5)) | ||||
|         xlists.append(transforms.ToTensor()) | ||||
|         xlists.append(normalize) | ||||
|         train_transform = transforms.Compose(xlists) | ||||
|         test_transform = transforms.Compose( | ||||
|             [ | ||||
|                 transforms.Resize(256), | ||||
|                 transforms.CenterCrop(224), | ||||
|                 transforms.ToTensor(), | ||||
|                 normalize, | ||||
|             ] | ||||
|         ) | ||||
|         xshape = (1, 3, 224, 224) | ||||
|     else: | ||||
|         raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|  | ||||
|     if name == "cifar10": | ||||
|     if name == 'cifar10': | ||||
|         train_data = dset.CIFAR10( | ||||
|             root, train=True, transform=train_transform, download=True | ||||
|         ) | ||||
|             root, train=True, transform=train_transform, download=True) | ||||
|         test_data = dset.CIFAR10( | ||||
|             root, train=False, transform=test_transform, download=True | ||||
|         ) | ||||
|             root, train=False, transform=test_transform, download=True) | ||||
|         assert len(train_data) == 50000 and len(test_data) == 10000 | ||||
|     elif name == "cifar100": | ||||
|     elif name == 'cifar100': | ||||
|         train_data = dset.CIFAR100( | ||||
|             root, train=True, transform=train_transform, download=True | ||||
|         ) | ||||
|             root, train=True, transform=train_transform, download=True) | ||||
|         test_data = dset.CIFAR100( | ||||
|             root, train=False, transform=test_transform, download=True | ||||
|         ) | ||||
|             root, train=False, transform=test_transform, download=True) | ||||
|         assert len(train_data) == 50000 and len(test_data) == 10000 | ||||
|     elif name.startswith("imagenet-1k"): | ||||
|         train_data = dset.ImageFolder(osp.join(root, "train"), train_transform) | ||||
|         test_data = dset.ImageFolder(osp.join(root, "val"), test_transform) | ||||
|         assert ( | ||||
|             len(train_data) == 1281167 and len(test_data) == 50000 | ||||
|         ), "invalid number of images : {:} & {:} vs {:} & {:}".format( | ||||
|             len(train_data), len(test_data), 1281167, 50000 | ||||
|         ) | ||||
|     elif name == "ImageNet16": | ||||
|         train_data = ImageNet16(root, True, train_transform) | ||||
|         test_data = ImageNet16(root, False, test_transform) | ||||
|         assert len(train_data) == 1281167 and len(test_data) == 50000 | ||||
|     elif name == "ImageNet16-120": | ||||
|         train_data = ImageNet16(root, True, train_transform, 120) | ||||
|         test_data = ImageNet16(root, False, test_transform, 120) | ||||
|         assert len(train_data) == 151700 and len(test_data) == 6000 | ||||
|     elif name == "ImageNet16-150": | ||||
|         train_data = ImageNet16(root, True, train_transform, 150) | ||||
|         test_data = ImageNet16(root, False, test_transform, 150) | ||||
|         assert len(train_data) == 190272 and len(test_data) == 7500 | ||||
|     elif name == "ImageNet16-200": | ||||
|         train_data = ImageNet16(root, True, train_transform, 200) | ||||
|         test_data = ImageNet16(root, False, test_transform, 200) | ||||
|         assert len(train_data) == 254775 and len(test_data) == 10000 | ||||
|     elif name == 'mnist': | ||||
|         train_data = dset.MNIST( | ||||
|             root, train=True, transform=train_transform, download=True) | ||||
|         test_data = dset.MNIST( | ||||
|             root, train=False, transform=test_transform, download=True) | ||||
|         assert len(train_data) == 60000 and len(test_data) == 10000 | ||||
|     elif name == 'svhn': | ||||
|         train_data = dset.SVHN(root, split='train', | ||||
|                                transform=train_transform, download=True) | ||||
|         test_data = dset.SVHN(root, split='test', | ||||
|                               transform=test_transform, download=True) | ||||
|         assert len(train_data) == 73257 and len(test_data) == 26032 | ||||
|     elif name == 'aircraft': | ||||
|         train_data = FGVCAircraft(root, class_type='manufacturer', split='trainval', | ||||
|                                   transform=train_transform, download=False) | ||||
|         test_data = FGVCAircraft(root, class_type='manufacturer', split='test', | ||||
|                                  transform=test_transform, download=False) | ||||
|         assert len(train_data) == 6667 and len(test_data) == 3333 | ||||
|     elif name == 'oxford': | ||||
|         train_data = PetDataset(root, train=True, num_cl=37, | ||||
|                                 val_split=0.15, transforms=train_transform) | ||||
|         test_data = PetDataset(root, train=False, num_cl=37, | ||||
|                                val_split=0.15, transforms=test_transform) | ||||
|     else: | ||||
|         raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|  | ||||
|     class_num = Dataset2Class[name] | ||||
|     class_num = Dataset2Class[name] if use_num_cls is None else len( | ||||
|         use_num_cls) | ||||
|     return train_data, test_data, xshape, class_num | ||||
|  | ||||
|  | ||||
| def get_nas_search_loaders( | ||||
|     train_data, valid_data, dataset, config_root, batch_size, workers | ||||
| ): | ||||
| def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers, num_cls=None): | ||||
|     if isinstance(batch_size, (list, tuple)): | ||||
|         batch, test_batch = batch_size | ||||
|     else: | ||||
|         batch, test_batch = batch_size, batch_size | ||||
|     if dataset == "cifar10": | ||||
|     if dataset == 'cifar10': | ||||
|         # split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|         cifar_split = load_config("{:}/cifar-split.txt".format(config_root), None, None) | ||||
|         train_split, valid_split = ( | ||||
|             cifar_split.train, | ||||
|             cifar_split.valid, | ||||
|         )  # search over the proposed training and validation set | ||||
|         cifar_split = load_config( | ||||
|             '{:}/cifar-split.txt'.format(config_root), None, None) | ||||
|         # search over the proposed training and validation set | ||||
|         train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|         # logger.log('Load split file from {:}'.format(split_Fpath))      # they are two disjoint groups in the original CIFAR-10 training set | ||||
|         # To split data | ||||
|         xvalid_data = deepcopy(train_data) | ||||
|         if hasattr(xvalid_data, "transforms"):  # to avoid a print issue | ||||
|         if hasattr(xvalid_data, 'transforms'):  # to avoid a print issue | ||||
|             xvalid_data.transforms = valid_data.transform | ||||
|         xvalid_data.transform = deepcopy(valid_data.transform) | ||||
|         search_data = SearchDataset(dataset, train_data, train_split, valid_split) | ||||
|         search_data = SearchDataset( | ||||
|             dataset, train_data, train_split, valid_split) | ||||
|         # data loader | ||||
|         search_loader = torch.utils.data.DataLoader( | ||||
|             search_data, | ||||
|             batch_size=batch, | ||||
|             shuffle=True, | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         train_loader = torch.utils.data.DataLoader( | ||||
|             train_data, | ||||
|             batch_size=batch, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         valid_loader = torch.utils.data.DataLoader( | ||||
|             xvalid_data, | ||||
|             batch_size=test_batch, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|     elif dataset == "cifar100": | ||||
|         search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers, | ||||
|                                                     pin_memory=True) | ||||
|         train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, | ||||
|                                                    sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                                                        train_split), | ||||
|                                                    num_workers=workers, pin_memory=True) | ||||
|         valid_loader = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch, | ||||
|                                                    sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                                                        valid_split), | ||||
|                                                    num_workers=workers, pin_memory=True) | ||||
|     elif dataset == 'cifar100': | ||||
|         cifar100_test_split = load_config( | ||||
|             "{:}/cifar100-test-split.txt".format(config_root), None, None | ||||
|         ) | ||||
|             '{:}/cifar100-test-split.txt'.format(config_root), None, None) | ||||
|         search_train_data = train_data | ||||
|         search_valid_data = deepcopy(valid_data) | ||||
|         search_valid_data.transform = train_data.transform | ||||
|         search_data = SearchDataset( | ||||
|             dataset, | ||||
|             [search_train_data, search_valid_data], | ||||
|             list(range(len(search_train_data))), | ||||
|             cifar100_test_split.xvalid, | ||||
|         ) | ||||
|         search_loader = torch.utils.data.DataLoader( | ||||
|             search_data, | ||||
|             batch_size=batch, | ||||
|             shuffle=True, | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         train_loader = torch.utils.data.DataLoader( | ||||
|             train_data, | ||||
|             batch_size=batch, | ||||
|             shuffle=True, | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         valid_loader = torch.utils.data.DataLoader( | ||||
|             valid_data, | ||||
|             batch_size=test_batch, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                 cifar100_test_split.xvalid | ||||
|             ), | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|     elif dataset == "ImageNet16-120": | ||||
|         imagenet_test_split = load_config( | ||||
|             "{:}/imagenet-16-120-test-split.txt".format(config_root), None, None | ||||
|         ) | ||||
|         search_data = SearchDataset(dataset, [search_train_data, search_valid_data], | ||||
|                                     list(range(len(search_train_data))), | ||||
|                                     cifar100_test_split.xvalid) | ||||
|         search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers, | ||||
|                                                     pin_memory=True) | ||||
|         train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True, num_workers=workers, | ||||
|                                                    pin_memory=True) | ||||
|         valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch, | ||||
|                                                    sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                                                        cifar100_test_split.xvalid), num_workers=workers, pin_memory=True) | ||||
|     elif dataset in ['mnist', 'svhn', 'aircraft', 'oxford']: | ||||
|         if not os.path.exists('{:}/{}-test-split.txt'.format(config_root, dataset)): | ||||
|             import json | ||||
|             label_list = list(range(len(valid_data))) | ||||
|             random.shuffle(label_list) | ||||
|             strlist = [str(label_list[i]) for i in range(len(label_list))] | ||||
|             split = {'xvalid': ["int", strlist[:len(valid_data) // 2]], | ||||
|                      'xtest': ["int", strlist[len(valid_data) // 2:]]} | ||||
|             with open('{:}/{}-test-split.txt'.format(config_root, dataset), 'w') as f: | ||||
|                 f.write(json.dumps(split)) | ||||
|         test_split = load_config( | ||||
|             '{:}/{}-test-split.txt'.format(config_root, dataset), None, None) | ||||
|  | ||||
|         search_train_data = train_data | ||||
|         search_valid_data = deepcopy(valid_data) | ||||
|         search_valid_data.transform = train_data.transform | ||||
|         search_data = SearchDataset( | ||||
|             dataset, | ||||
|             [search_train_data, search_valid_data], | ||||
|             list(range(len(search_train_data))), | ||||
|             imagenet_test_split.xvalid, | ||||
|         ) | ||||
|         search_loader = torch.utils.data.DataLoader( | ||||
|             search_data, | ||||
|             batch_size=batch, | ||||
|             shuffle=True, | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         train_loader = torch.utils.data.DataLoader( | ||||
|             train_data, | ||||
|             batch_size=batch, | ||||
|             shuffle=True, | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         valid_loader = torch.utils.data.DataLoader( | ||||
|             valid_data, | ||||
|             batch_size=test_batch, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                 imagenet_test_split.xvalid | ||||
|             ), | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         search_data = SearchDataset(dataset, [search_train_data, search_valid_data], | ||||
|                                     list(range(len(search_train_data))), test_split.xvalid) | ||||
|         search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, | ||||
|                                                     num_workers=workers, pin_memory=True) | ||||
|         train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True, | ||||
|                                                    num_workers=workers, pin_memory=True) | ||||
|         valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch, | ||||
|                                                    sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                                                        test_split.xvalid), num_workers=workers, pin_memory=True) | ||||
|     else: | ||||
|         raise ValueError("invalid dataset : {:}".format(dataset)) | ||||
|         raise ValueError('invalid dataset : {:}'.format(dataset)) | ||||
|     return search_loader, train_loader, valid_loader | ||||
|  | ||||
|  | ||||
| # if __name__ == '__main__': | ||||
| #  train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1) | ||||
| #  import pdb; pdb.set_trace() | ||||
|   | ||||
| @@ -213,6 +213,13 @@ AllConv3x3_CODE = Structure( | ||||
|         (("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)), | ||||
|     ]  # node-3 | ||||
| ) | ||||
| Number_5374 = Structure( | ||||
|     [ | ||||
|         (("nor_conv_3x3", 0),),  # node-1 | ||||
|         (("nor_conv_1x1", 0), ("nor_conv_3x3", 1)),  # node-2 | ||||
|         (("skip_connect", 0), ("none", 1), ("nor_conv_3x3", 2)),  # node-3 | ||||
|     ] | ||||
| ) | ||||
|  | ||||
| AllFull_CODE = Structure( | ||||
|     [ | ||||
| @@ -271,4 +278,5 @@ architectures = { | ||||
|     "all_c1x1": AllConv1x1_CODE, | ||||
|     "all_idnt": AllIdentity_CODE, | ||||
|     "all_full": AllFull_CODE, | ||||
|     "5374": Number_5374, | ||||
| } | ||||
|   | ||||
| @@ -12,6 +12,7 @@ def obtain_accuracy(output, target, topk=(1,)): | ||||
|  | ||||
|     res = [] | ||||
|     for k in topk: | ||||
|         correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||||
|         # correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||||
|         correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) | ||||
|         res.append(correct_k.mul_(100.0 / batch_size)) | ||||
|     return res | ||||
|   | ||||
		Reference in New Issue
	
	Block a user