Compare commits
	
		
			11 Commits
		
	
	
		
			5bf036a763
			...
			4612cd198b
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 4612cd198b | ||
| 889bd1974c | |||
| af0e7786b6 | |||
| c6d53f08ae | |||
| ef2608bb42 | |||
| 50ff507a15 | |||
| 03d7d04d41 | |||
| bb33ca9a68 | |||
|  | f46486e21b | ||
|  | 5908a1edef | ||
|  | ed34024a88 | 
| @@ -61,13 +61,13 @@ At this moment, this project provides the following algorithms and scripts to ru | |||||||
|     <tr> <!-- (6-th row) --> |     <tr> <!-- (6-th row) --> | ||||||
|     <td align="center" valign="middle"> NATS-Bench </td> |     <td align="center" valign="middle"> NATS-Bench </td> | ||||||
|     <td align="center" valign="middle"> <a href="https://xuanyidong.com/assets/projects/NATS-Bench"> NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size</a> </td> |     <td align="center" valign="middle"> <a href="https://xuanyidong.com/assets/projects/NATS-Bench"> NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size</a> </td> | ||||||
|     <td align="center" valign="middle"> <a href="https://github.com/D-X-Y/NATS-Bench">NATS-Bench.md</a> </td> |     <td align="center" valign="middle"> <a href="https://github.com/D-X-Y/NATS-Bench/blob/main/README.md">NATS-Bench.md</a> </td> | ||||||
|     </tr> |     </tr> | ||||||
|     <tr> <!-- (7-th row) --> |     <tr> <!-- (7-th row) --> | ||||||
|     <td align="center" valign="middle"> ... </td> |     <td align="center" valign="middle"> ... </td> | ||||||
|     <td align="center" valign="middle"> ENAS / REA / REINFORCE / BOHB </td> |     <td align="center" valign="middle"> ENAS / REA / REINFORCE / BOHB </td> | ||||||
|     <td align="center" valign="middle"> Please check the original papers </td> |     <td align="center" valign="middle"> Please check the original papers </td> | ||||||
|     <td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/main/docs/NAS-Bench-201.md">NAS-Bench-201.md</a>  <a href="https://github.com/D-X-Y/NATS-Bench">NATS-Bench.md</a> </td> |     <td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/main/docs/NAS-Bench-201.md">NAS-Bench-201.md</a>  <a href="https://github.com/D-X-Y/NATS-Bench/blob/main/README.md">NATS-Bench.md</a> </td> | ||||||
|     </tr> |     </tr> | ||||||
|     <tr> <!-- (start second block) --> |     <tr> <!-- (start second block) --> | ||||||
|     <td rowspan="1" align="center" valign="middle" halign="middle"> HPO </td> |     <td rowspan="1" align="center" valign="middle" halign="middle"> HPO </td> | ||||||
|   | |||||||
| @@ -29,7 +29,7 @@ You can simply type `pip install nas-bench-201` to install our api. Please see s | |||||||
| You can move it to anywhere you want and send its path to our API for initialization. | You can move it to anywhere you want and send its path to our API for initialization. | ||||||
| - [2020.02.25] APIv1.0/FILEv1.0: [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial. | - [2020.02.25] APIv1.0/FILEv1.0: [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial. | ||||||
| - [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [ | - [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [ | ||||||
| NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights. | NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the trained weights. | ||||||
| - [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi). | - [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi). | ||||||
| - [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions | - [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions | ||||||
| - [2020.03.16] APIv1.3/FILEv1.1: [`NAS-Bench-201-v1_1-096897.pth`](https://drive.google.com/open?id=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_) (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable. | - [2020.03.16] APIv1.3/FILEv1.1: [`NAS-Bench-201-v1_1-096897.pth`](https://drive.google.com/open?id=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_) (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable. | ||||||
|   | |||||||
| @@ -27,7 +27,7 @@ You can simply type `pip install nas-bench-201` to install our api. Please see s | |||||||
| You can move it to anywhere you want and send its path to our API for initialization. | You can move it to anywhere you want and send its path to our API for initialization. | ||||||
| - [2020.02.25] APIv1.0/FILEv1.0: [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial. | - [2020.02.25] APIv1.0/FILEv1.0: [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial. | ||||||
| - [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [ | - [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [ | ||||||
| NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights. | NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the trained weights. | ||||||
| - [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi). | - [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi). | ||||||
| - [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions | - [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions | ||||||
| - [2020.03.16] APIv1.3/FILEv1.1: [`NAS-Bench-201-v1_1-096897.pth`](https://drive.google.com/open?id=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_) (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable. | - [2020.03.16] APIv1.3/FILEv1.1: [`NAS-Bench-201-v1_1-096897.pth`](https://drive.google.com/open?id=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_) (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable. | ||||||
|   | |||||||
| @@ -3,7 +3,7 @@ | |||||||
| </p> | </p> | ||||||
|  |  | ||||||
| --------- | --------- | ||||||
| [](LICENSE.md) | [](../LICENSE.md) | ||||||
|  |  | ||||||
| 自动深度学习库 (AutoDL-Projects) 是一个开源的,轻量级的,功能强大的项目。 | 自动深度学习库 (AutoDL-Projects) 是一个开源的,轻量级的,功能强大的项目。 | ||||||
| 该项目实现了多种网络结构搜索(NAS)和超参数优化(HPO)算法。 | 该项目实现了多种网络结构搜索(NAS)和超参数优化(HPO)算法。 | ||||||
| @@ -142,8 +142,8 @@ | |||||||
|  |  | ||||||
| # 其他 | # 其他 | ||||||
|  |  | ||||||
| 如果你想要给这份代码库做贡献,请看[CONTRIBUTING.md](.github/CONTRIBUTING.md)。 | 如果你想要给这份代码库做贡献,请看[CONTRIBUTING.md](../.github/CONTRIBUTING.md)。 | ||||||
| 此外,使用规范请参考[CODE-OF-CONDUCT.md](.github/CODE-OF-CONDUCT.md)。 | 此外,使用规范请参考[CODE-OF-CONDUCT.md](../.github/CODE-OF-CONDUCT.md)。 | ||||||
|  |  | ||||||
| # 许可证 | # 许可证 | ||||||
| The entire codebase is under [MIT license](LICENSE.md) | The entire codebase is under [MIT license](../LICENSE.md) | ||||||
|   | |||||||
| @@ -2,11 +2,11 @@ | |||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||||
| ##################################################### | ##################################################### | ||||||
| import time, torch | import time, torch | ||||||
| from procedures import prepare_seed, get_optim_scheduler | from xautodl.procedures import prepare_seed, get_optim_scheduler | ||||||
| from utils import get_model_infos, obtain_accuracy | from xautodl.utils import get_model_infos, obtain_accuracy | ||||||
| from config_utils import dict2config | from xautodl.config_utils import dict2config | ||||||
| from log_utils import AverageMeter, time_string, convert_secs2time | from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||||
| from models import get_cell_based_tiny_net | from xautodl.models import get_cell_based_tiny_net | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ["evaluate_for_seed", "pure_evaluate"] | __all__ = ["evaluate_for_seed", "pure_evaluate"] | ||||||
|   | |||||||
| @@ -16,10 +16,96 @@ from xautodl.procedures import get_machine_info | |||||||
| from xautodl.datasets import get_datasets | from xautodl.datasets import get_datasets | ||||||
| from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time | from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time | ||||||
| from xautodl.models import CellStructure, CellArchitectures, get_search_spaces | from xautodl.models import CellStructure, CellArchitectures, get_search_spaces | ||||||
| from xautodl.functions import evaluate_for_seed | from functions import evaluate_for_seed | ||||||
|  |  | ||||||
|  | from torchvision import datasets, transforms | ||||||
|  |  | ||||||
|  | # NASBENCH201_CONFIG_PATH = os.path.join( os.getcwd(), 'main_exp', 'transfer_nag') | ||||||
|  |  | ||||||
|  | NASBENCH201_CONFIG_PATH = '/lustre/hpe/ws11/ws11.1/ws/xmuhanma-nbdit/autodl-projects/configs/nas-benchmark' | ||||||
|  |  | ||||||
|  |  | ||||||
| def evaluate_all_datasets( | def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, | ||||||
|  |                           arch_config, workers, logger): | ||||||
|  |     machine_info, arch_config = get_machine_info(), deepcopy(arch_config) | ||||||
|  |     all_infos = {'info': machine_info} | ||||||
|  |     all_dataset_keys = [] | ||||||
|  |     # look all the datasets | ||||||
|  |     for dataset, xpath, split in zip(datasets, xpaths, splits): | ||||||
|  |         # train valid data | ||||||
|  |         task = None | ||||||
|  |         train_data, valid_data, xshape, class_num = get_datasets( | ||||||
|  |             dataset, xpath, -1, task) | ||||||
|  |  | ||||||
|  |         # load the configuration | ||||||
|  |         if dataset in ['mnist', 'svhn', 'aircraft', 'oxford']: | ||||||
|  |             if use_less: | ||||||
|  |                 # config_path = os.path.join( | ||||||
|  |                 #     NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/LESS.config') | ||||||
|  |                 config_path = os.path.join( | ||||||
|  |                     NASBENCH201_CONFIG_PATH, 'LESS.config') | ||||||
|  |             else: | ||||||
|  |                 # config_path = os.path.join( | ||||||
|  |                 #     NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{}.config'.format(dataset)) | ||||||
|  |                 config_path = os.path.join( | ||||||
|  |                     NASBENCH201_CONFIG_PATH, '{}.config'.format(dataset)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |             p = os.path.join( | ||||||
|  |                 NASBENCH201_CONFIG_PATH, '{:}-split.txt'.format(dataset)) | ||||||
|  |             if not os.path.exists(p): | ||||||
|  |                 import json | ||||||
|  |                 label_list = list(range(len(train_data))) | ||||||
|  |                 random.shuffle(label_list) | ||||||
|  |                 strlist = [str(label_list[i]) for i in range(len(label_list))] | ||||||
|  |                 splited = {'train': ["int", strlist[:len(train_data) // 2]], | ||||||
|  |                            'valid': ["int", strlist[len(train_data) // 2:]]} | ||||||
|  |                 with open(p, 'w') as f: | ||||||
|  |                     f.write(json.dumps(splited)) | ||||||
|  |             split_info = load_config(os.path.join( | ||||||
|  |                 NASBENCH201_CONFIG_PATH, '{:}-split.txt'.format(dataset)), None, None) | ||||||
|  |         else: | ||||||
|  |             raise ValueError('invalid dataset : {:}'.format(dataset)) | ||||||
|  |  | ||||||
|  |         config = load_config( | ||||||
|  |             config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||||
|  |         # data loader | ||||||
|  |         train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, | ||||||
|  |                                                    shuffle=True, num_workers=workers, pin_memory=True) | ||||||
|  |         valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, | ||||||
|  |                                                    shuffle=False, num_workers=workers, pin_memory=True) | ||||||
|  |         splits = load_config(os.path.join( | ||||||
|  |             NASBENCH201_CONFIG_PATH, '{}-test-split.txt'.format(dataset)), None, None) | ||||||
|  |         ValLoaders = {'ori-test': valid_loader, | ||||||
|  |                       'x-valid': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, | ||||||
|  |                                                              sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||||
|  |                                                                  splits.xvalid), | ||||||
|  |                                                              num_workers=workers, pin_memory=True), | ||||||
|  |                       'x-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, | ||||||
|  |                                                             sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||||
|  |                                                                 splits.xtest), | ||||||
|  |                                                             num_workers=workers, pin_memory=True) | ||||||
|  |                       } | ||||||
|  |         dataset_key = '{:}'.format(dataset) | ||||||
|  |         if bool(split): | ||||||
|  |             dataset_key = dataset_key + '-valid' | ||||||
|  |         logger.log( | ||||||
|  |             'Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'. | ||||||
|  |             format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size)) | ||||||
|  |         logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format( | ||||||
|  |             dataset_key, config)) | ||||||
|  |         for key, value in ValLoaders.items(): | ||||||
|  |             logger.log( | ||||||
|  |                 'Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value))) | ||||||
|  |  | ||||||
|  |         results = evaluate_for_seed( | ||||||
|  |             arch_config, config, arch, train_loader, ValLoaders, seed, logger) | ||||||
|  |         all_infos[dataset_key] = results | ||||||
|  |         all_dataset_keys.append(dataset_key) | ||||||
|  |     all_infos['all_dataset_keys'] = all_dataset_keys | ||||||
|  |     return all_infos | ||||||
|  |  | ||||||
|  | def evaluate_all_datasets1( | ||||||
|     arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger |     arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger | ||||||
| ): | ): | ||||||
|     machine_info, arch_config = get_machine_info(), deepcopy(arch_config) |     machine_info, arch_config = get_machine_info(), deepcopy(arch_config) | ||||||
| @@ -46,47 +132,117 @@ def evaluate_all_datasets( | |||||||
|             split_info = load_config( |             split_info = load_config( | ||||||
|                 "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None |                 "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None | ||||||
|             ) |             ) | ||||||
|  |         elif dataset.startswith("aircraft"): | ||||||
|  |             if use_less: | ||||||
|  |                 config_path = "configs/nas-benchmark/LESS.config" | ||||||
|  |             else: | ||||||
|  |                 config_path = "configs/nas-benchmark/aircraft.config" | ||||||
|  |             split_info = load_config( | ||||||
|  |                 "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None | ||||||
|  |             ) | ||||||
|  |         elif dataset.startswith("oxford"): | ||||||
|  |             if use_less: | ||||||
|  |                 config_path = "configs/nas-benchmark/LESS.config" | ||||||
|  |             else: | ||||||
|  |                 config_path = "configs/nas-benchmark/oxford.config" | ||||||
|  |             split_info = load_config( | ||||||
|  |                 "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None | ||||||
|  |             ) | ||||||
|         else: |         else: | ||||||
|             raise ValueError("invalid dataset : {:}".format(dataset)) |             raise ValueError("invalid dataset : {:}".format(dataset)) | ||||||
|         config = load_config( |         config = load_config( | ||||||
|             config_path, {"class_num": class_num, "xshape": xshape}, logger |             config_path, {"class_num": class_num, "xshape": xshape}, logger | ||||||
|         ) |         ) | ||||||
|         # check whether use splited validation set |         # check whether use splited validation set | ||||||
|  |         # if dataset == 'aircraft': | ||||||
|  |         #     split = True | ||||||
|         if bool(split): |         if bool(split): | ||||||
|             assert dataset == "cifar10" |             if dataset == "cifar10" or dataset == "cifar100": | ||||||
|             ValLoaders = { |                 assert dataset == "cifar10" | ||||||
|                 "ori-test": torch.utils.data.DataLoader( |                 ValLoaders = { | ||||||
|                     valid_data, |                     "ori-test": torch.utils.data.DataLoader( | ||||||
|  |                         valid_data, | ||||||
|  |                         batch_size=config.batch_size, | ||||||
|  |                         shuffle=False, | ||||||
|  |                         num_workers=workers, | ||||||
|  |                         pin_memory=True, | ||||||
|  |                     ) | ||||||
|  |                 } | ||||||
|  |                 assert len(train_data) == len(split_info.train) + len( | ||||||
|  |                     split_info.valid | ||||||
|  |                 ), "invalid length : {:} vs {:} + {:}".format( | ||||||
|  |                     len(train_data), len(split_info.train), len(split_info.valid) | ||||||
|  |                 ) | ||||||
|  |                 train_data_v2 = deepcopy(train_data) | ||||||
|  |                 train_data_v2.transform = valid_data.transform | ||||||
|  |                 valid_data = train_data_v2 | ||||||
|  |                 # data loader | ||||||
|  |                 train_loader = torch.utils.data.DataLoader( | ||||||
|  |                     train_data, | ||||||
|                     batch_size=config.batch_size, |                     batch_size=config.batch_size, | ||||||
|                     shuffle=False, |                     sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), | ||||||
|                     num_workers=workers, |                     num_workers=workers, | ||||||
|                     pin_memory=True, |                     pin_memory=True, | ||||||
|                 ) |                 ) | ||||||
|             } |                 valid_loader = torch.utils.data.DataLoader( | ||||||
|             assert len(train_data) == len(split_info.train) + len( |                     valid_data, | ||||||
|                 split_info.valid |                     batch_size=config.batch_size, | ||||||
|             ), "invalid length : {:} vs {:} + {:}".format( |                     sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), | ||||||
|                 len(train_data), len(split_info.train), len(split_info.valid) |                     num_workers=workers, | ||||||
|             ) |                     pin_memory=True, | ||||||
|             train_data_v2 = deepcopy(train_data) |                 ) | ||||||
|             train_data_v2.transform = valid_data.transform |                 ValLoaders["x-valid"] = valid_loader | ||||||
|             valid_data = train_data_v2 |             elif dataset == "aircraft": | ||||||
|             # data loader |                 ValLoaders = { | ||||||
|             train_loader = torch.utils.data.DataLoader( |                     "ori-test": torch.utils.data.DataLoader( | ||||||
|                 train_data, |                         valid_data, | ||||||
|                 batch_size=config.batch_size, |                         batch_size=config.batch_size, | ||||||
|                 sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), |                         shuffle=False, | ||||||
|                 num_workers=workers, |                         num_workers=workers, | ||||||
|                 pin_memory=True, |                         pin_memory=True, | ||||||
|             ) |                     ) | ||||||
|             valid_loader = torch.utils.data.DataLoader( |                 } | ||||||
|                 valid_data, |                 train_data_v2 = deepcopy(train_data) | ||||||
|                 batch_size=config.batch_size, |                 train_data_v2.transform = valid_data.transform | ||||||
|                 sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), |                 valid_data = train_data_v2 | ||||||
|                 num_workers=workers, |                 # 使用 DataLoader | ||||||
|                 pin_memory=True, |                 train_loader = torch.utils.data.DataLoader( | ||||||
|             ) |                     train_data,  | ||||||
|             ValLoaders["x-valid"] = valid_loader |                     batch_size=config.batch_size,  | ||||||
|  |                     sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), | ||||||
|  |                     num_workers=workers, | ||||||
|  |                     pin_memory=True) | ||||||
|  |                 valid_loader = torch.utils.data.DataLoader( | ||||||
|  |                     valid_data, | ||||||
|  |                     batch_size=config.batch_size,  | ||||||
|  |                     sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), | ||||||
|  |                     num_workers=workers, | ||||||
|  |                     pin_memory=True) | ||||||
|  |             elif dataset == "oxford": | ||||||
|  |                 ValLoaders = { | ||||||
|  |                     "ori-test": torch.utils.data.DataLoader( | ||||||
|  |                         valid_data, | ||||||
|  |                         batch_size=config.batch_size, | ||||||
|  |                         shuffle=False, | ||||||
|  |                         num_workers=workers, | ||||||
|  |                         pin_memory=True | ||||||
|  |                     ) | ||||||
|  |                 } | ||||||
|  |                 # train_data_v2 = deepcopy(train_data) | ||||||
|  |                 # train_data_v2.transform = valid_data.transform | ||||||
|  |                 train_loader = torch.utils.data.DataLoader( | ||||||
|  |                     train_data,  | ||||||
|  |                     batch_size=config.batch_size,  | ||||||
|  |                     sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), | ||||||
|  |                     num_workers=workers, | ||||||
|  |                     pin_memory=True) | ||||||
|  |                 valid_loader = torch.utils.data.DataLoader( | ||||||
|  |                     valid_data, | ||||||
|  |                     batch_size=config.batch_size,  | ||||||
|  |                     sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), | ||||||
|  |                     num_workers=workers, | ||||||
|  |                     pin_memory=True) | ||||||
|  |  | ||||||
|         else: |         else: | ||||||
|             # data loader |             # data loader | ||||||
|             train_loader = torch.utils.data.DataLoader( |             train_loader = torch.utils.data.DataLoader( | ||||||
| @@ -103,7 +259,7 @@ def evaluate_all_datasets( | |||||||
|                 num_workers=workers, |                 num_workers=workers, | ||||||
|                 pin_memory=True, |                 pin_memory=True, | ||||||
|             ) |             ) | ||||||
|             if dataset == "cifar10": |             if dataset == "cifar10" or dataset == "aircraft" or dataset == "oxford": | ||||||
|                 ValLoaders = {"ori-test": valid_loader} |                 ValLoaders = {"ori-test": valid_loader} | ||||||
|             elif dataset == "cifar100": |             elif dataset == "cifar100": | ||||||
|                 cifar100_splits = load_config( |                 cifar100_splits = load_config( | ||||||
|   | |||||||
| @@ -28,16 +28,41 @@ else | |||||||
|   mode=cover |   mode=cover | ||||||
| fi | fi | ||||||
|  |  | ||||||
|  | # OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \ | ||||||
|  | # 	--mode ${mode} --save_dir ${save_dir} --max_node 4 \ | ||||||
|  | # 	--use_less ${use_less} \ | ||||||
|  | # 	--datasets cifar10 cifar10 cifar100 ImageNet16-120 \ | ||||||
|  | # 	--splits   1       0       0        0 \ | ||||||
|  | # 	--xpaths $TORCH_HOME/cifar.python \ | ||||||
|  | # 		 $TORCH_HOME/cifar.python \ | ||||||
|  | # 		 $TORCH_HOME/cifar.python \ | ||||||
|  | # 		 $TORCH_HOME/cifar.python/ImageNet16 \ | ||||||
|  | # 	--channel 16 --num_cells 5 \ | ||||||
|  | # 	--workers 4 \ | ||||||
|  | # 	--srange ${xstart} ${xend} --arch_index ${arch_index} \ | ||||||
|  | # 	--seeds ${all_seeds} | ||||||
|  |  | ||||||
| OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \ | OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \ | ||||||
| 	--mode ${mode} --save_dir ${save_dir} --max_node 4 \ | 	--mode ${mode} --save_dir ${save_dir} --max_node 4 \ | ||||||
| 	--use_less ${use_less} \ | 	--use_less ${use_less} \ | ||||||
| 	--datasets cifar10 cifar10 cifar100 ImageNet16-120 \ | 	--datasets aircraft \ | ||||||
| 	--splits   1       0       0        0 \ | 	--xpaths /lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/ \ | ||||||
| 	--xpaths $TORCH_HOME/cifar.python \ | 	--channel 16 \ | ||||||
| 		 $TORCH_HOME/cifar.python \ | 	--splits 1 \ | ||||||
| 		 $TORCH_HOME/cifar.python \ | 	--num_cells 5 \ | ||||||
| 		 $TORCH_HOME/cifar.python/ImageNet16 \ |  | ||||||
| 	--channel 16 --num_cells 5 \ |  | ||||||
| 	--workers 4 \ | 	--workers 4 \ | ||||||
| 	--srange ${xstart} ${xend} --arch_index ${arch_index} \ | 	--srange ${xstart} ${xend} --arch_index ${arch_index} \ | ||||||
| 	--seeds ${all_seeds} | 	--seeds ${all_seeds} | ||||||
|  |  | ||||||
|  | # OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \ | ||||||
|  | # 	--mode ${mode} --save_dir ${save_dir} --max_node 4 \ | ||||||
|  | # 	--use_less ${use_less} \ | ||||||
|  | # 	--datasets oxford\ | ||||||
|  | # 	--xpaths /lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/ \ | ||||||
|  | # 	--channel 16 \ | ||||||
|  | # 	--splits 1 \ | ||||||
|  | # 	--num_cells 5 \ | ||||||
|  | # 	--workers 4 \ | ||||||
|  | # 	--srange ${xstart} ${xend} --arch_index ${arch_index} \ | ||||||
|  | # 	--seeds ${all_seeds} | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										104336
									
								
								test.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104336
									
								
								test.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										616
									
								
								test_network.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										616
									
								
								test_network.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,616 @@ | |||||||
|  | from nas_201_api import NASBench201API as API | ||||||
|  | import os | ||||||
|  |  | ||||||
|  | import os, sys, time, torch, random, argparse | ||||||
|  | from PIL import ImageFile | ||||||
|  |  | ||||||
|  | ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||||
|  | from copy import deepcopy | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | from xautodl.config_utils import load_config | ||||||
|  | from xautodl.procedures import save_checkpoint, copy_checkpoint | ||||||
|  | from xautodl.procedures import get_machine_info | ||||||
|  | from xautodl.datasets import get_datasets | ||||||
|  | from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time | ||||||
|  | from xautodl.models import CellStructure, CellArchitectures, get_search_spaces | ||||||
|  |  | ||||||
|  | import time, torch | ||||||
|  | from xautodl.procedures import prepare_seed, get_optim_scheduler | ||||||
|  | from xautodl.utils import get_model_infos, obtain_accuracy | ||||||
|  | from xautodl.config_utils import dict2config | ||||||
|  | from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||||
|  | from xautodl.models import get_cell_based_tiny_net | ||||||
|  |  | ||||||
|  | cur_path = os.path.abspath(os.path.curdir) | ||||||
|  | data_path = os.path.join(cur_path, 'NAS-Bench-201-v1_1-096897.pth') | ||||||
|  | print(f'loading data from {data_path}') | ||||||
|  | print(f'loading') | ||||||
|  | api = API(data_path) | ||||||
|  | print(f'loaded') | ||||||
|  |  | ||||||
|  | def find_best_index(dataset): | ||||||
|  |     len = 15625 | ||||||
|  |     accs = [] | ||||||
|  |     for i in range(1, len): | ||||||
|  |         results = api.query_by_index(i, dataset) | ||||||
|  |         dict_items = list(results.items()) | ||||||
|  |         train_info = dict_items[0][1].get_train() | ||||||
|  |         acc = train_info['accuracy'] | ||||||
|  |         accs.append((i, acc)) | ||||||
|  |     return max(accs, key=lambda x: x[1]) | ||||||
|  |  | ||||||
|  | best_cifar_10_index, best_cifar_10_acc = find_best_index('cifar10') | ||||||
|  | best_cifar_100_index, best_cifar_100_acc = find_best_index('cifar100') | ||||||
|  | best_ImageNet16_index, best_ImageNet16_acc= find_best_index('ImageNet16-120') | ||||||
|  | print(f'find best cifar10 index: {best_cifar_10_index}, acc: {best_cifar_10_acc}') | ||||||
|  | print(f'find best cifar100 index: {best_cifar_100_index}, acc: {best_cifar_100_acc}') | ||||||
|  | print(f'find best ImageNet16 index: {best_ImageNet16_index}, acc: {best_ImageNet16_acc}') | ||||||
|  |  | ||||||
|  | from xautodl.models import get_cell_based_tiny_net | ||||||
|  | def get_network_str_by_id(id, dataset): | ||||||
|  |     config = api.get_net_config(id, dataset) | ||||||
|  |     return config['arch_str'] | ||||||
|  |  | ||||||
|  | best_cifar_10_str = get_network_str_by_id(best_cifar_10_index, 'cifar10') | ||||||
|  | best_cifar_100_str = get_network_str_by_id(best_cifar_100_index, 'cifar100') | ||||||
|  | best_ImageNet16_str = get_network_str_by_id(best_ImageNet16_index, 'ImageNet16-120') | ||||||
|  |  | ||||||
|  | def evaluate_all_datasets( | ||||||
|  |     arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger | ||||||
|  | ): | ||||||
|  |     machine_info, arch_config = get_machine_info(), deepcopy(arch_config) | ||||||
|  |     all_infos = {"info": machine_info} | ||||||
|  |     all_dataset_keys = [] | ||||||
|  |     # look all the datasets | ||||||
|  |     for dataset, xpath, split in zip(datasets, xpaths, splits): | ||||||
|  |         # train valid data | ||||||
|  |         train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) | ||||||
|  |         # load the configuration | ||||||
|  |         if dataset == "cifar10" or dataset == "cifar100": | ||||||
|  |             if use_less: | ||||||
|  |                 config_path = "configs/nas-benchmark/LESS.config" | ||||||
|  |             else: | ||||||
|  |                 config_path = "configs/nas-benchmark/CIFAR.config" | ||||||
|  |             split_info = load_config( | ||||||
|  |                 "configs/nas-benchmark/cifar-split.txt", None, None | ||||||
|  |             ) | ||||||
|  |         elif dataset.startswith("ImageNet16"): | ||||||
|  |             if use_less: | ||||||
|  |                 config_path = "configs/nas-benchmark/LESS.config" | ||||||
|  |             else: | ||||||
|  |                 config_path = "configs/nas-benchmark/ImageNet-16.config" | ||||||
|  |             split_info = load_config( | ||||||
|  |                 "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             raise ValueError("invalid dataset : {:}".format(dataset)) | ||||||
|  |         config = load_config( | ||||||
|  |             config_path, {"class_num": class_num, "xshape": xshape}, logger | ||||||
|  |         ) | ||||||
|  |         # check whether use splited validation set | ||||||
|  |         if bool(split): | ||||||
|  |             assert dataset == "cifar10" | ||||||
|  |             ValLoaders = { | ||||||
|  |                 "ori-test": torch.utils.data.DataLoader( | ||||||
|  |                     valid_data, | ||||||
|  |                     batch_size=config.batch_size, | ||||||
|  |                     shuffle=False, | ||||||
|  |                     num_workers=workers, | ||||||
|  |                     pin_memory=True, | ||||||
|  |                 ) | ||||||
|  |             } | ||||||
|  |             assert len(train_data) == len(split_info.train) + len( | ||||||
|  |                 split_info.valid | ||||||
|  |             ), "invalid length : {:} vs {:} + {:}".format( | ||||||
|  |                 len(train_data), len(split_info.train), len(split_info.valid) | ||||||
|  |             ) | ||||||
|  |             train_data_v2 = deepcopy(train_data) | ||||||
|  |             train_data_v2.transform = valid_data.transform | ||||||
|  |             valid_data = train_data_v2 | ||||||
|  |             # data loader | ||||||
|  |             train_loader = torch.utils.data.DataLoader( | ||||||
|  |                 train_data, | ||||||
|  |                 batch_size=config.batch_size, | ||||||
|  |                 sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), | ||||||
|  |                 num_workers=workers, | ||||||
|  |                 pin_memory=True, | ||||||
|  |             ) | ||||||
|  |             valid_loader = torch.utils.data.DataLoader( | ||||||
|  |                 valid_data, | ||||||
|  |                 batch_size=config.batch_size, | ||||||
|  |                 sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), | ||||||
|  |                 num_workers=workers, | ||||||
|  |                 pin_memory=True, | ||||||
|  |             ) | ||||||
|  |             ValLoaders["x-valid"] = valid_loader | ||||||
|  |         else: | ||||||
|  |             # data loader | ||||||
|  |             train_loader = torch.utils.data.DataLoader( | ||||||
|  |                 train_data, | ||||||
|  |                 batch_size=config.batch_size, | ||||||
|  |                 shuffle=True, | ||||||
|  |                 num_workers=workers, | ||||||
|  |                 pin_memory=True, | ||||||
|  |             ) | ||||||
|  |             valid_loader = torch.utils.data.DataLoader( | ||||||
|  |                 valid_data, | ||||||
|  |                 batch_size=config.batch_size, | ||||||
|  |                 shuffle=False, | ||||||
|  |                 num_workers=workers, | ||||||
|  |                 pin_memory=True, | ||||||
|  |             ) | ||||||
|  |             if dataset == "cifar10": | ||||||
|  |                 ValLoaders = {"ori-test": valid_loader} | ||||||
|  |             elif dataset == "cifar100": | ||||||
|  |                 cifar100_splits = load_config( | ||||||
|  |                     "configs/nas-benchmark/cifar100-test-split.txt", None, None | ||||||
|  |                 ) | ||||||
|  |                 ValLoaders = { | ||||||
|  |                     "ori-test": valid_loader, | ||||||
|  |                     "x-valid": torch.utils.data.DataLoader( | ||||||
|  |                         valid_data, | ||||||
|  |                         batch_size=config.batch_size, | ||||||
|  |                         sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||||
|  |                             cifar100_splits.xvalid | ||||||
|  |                         ), | ||||||
|  |                         num_workers=workers, | ||||||
|  |                         pin_memory=True, | ||||||
|  |                     ), | ||||||
|  |                     "x-test": torch.utils.data.DataLoader( | ||||||
|  |                         valid_data, | ||||||
|  |                         batch_size=config.batch_size, | ||||||
|  |                         sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||||
|  |                             cifar100_splits.xtest | ||||||
|  |                         ), | ||||||
|  |                         num_workers=workers, | ||||||
|  |                         pin_memory=True, | ||||||
|  |                     ), | ||||||
|  |                 } | ||||||
|  |             elif dataset == "ImageNet16-120": | ||||||
|  |                 imagenet16_splits = load_config( | ||||||
|  |                     "configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None | ||||||
|  |                 ) | ||||||
|  |                 ValLoaders = { | ||||||
|  |                     "ori-test": valid_loader, | ||||||
|  |                     "x-valid": torch.utils.data.DataLoader( | ||||||
|  |                         valid_data, | ||||||
|  |                         batch_size=config.batch_size, | ||||||
|  |                         sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||||
|  |                             imagenet16_splits.xvalid | ||||||
|  |                         ), | ||||||
|  |                         num_workers=workers, | ||||||
|  |                         pin_memory=True, | ||||||
|  |                     ), | ||||||
|  |                     "x-test": torch.utils.data.DataLoader( | ||||||
|  |                         valid_data, | ||||||
|  |                         batch_size=config.batch_size, | ||||||
|  |                         sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||||
|  |                             imagenet16_splits.xtest | ||||||
|  |                         ), | ||||||
|  |                         num_workers=workers, | ||||||
|  |                         pin_memory=True, | ||||||
|  |                     ), | ||||||
|  |                 } | ||||||
|  |             else: | ||||||
|  |                 raise ValueError("invalid dataset : {:}".format(dataset)) | ||||||
|  |  | ||||||
|  |         dataset_key = "{:}".format(dataset) | ||||||
|  |         if bool(split): | ||||||
|  |             dataset_key = dataset_key + "-valid" | ||||||
|  |         logger.log( | ||||||
|  |             "Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( | ||||||
|  |                 dataset_key, | ||||||
|  |                 len(train_data), | ||||||
|  |                 len(valid_data), | ||||||
|  |                 len(train_loader), | ||||||
|  |                 len(valid_loader), | ||||||
|  |                 config.batch_size, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         logger.log( | ||||||
|  |             "Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config) | ||||||
|  |         ) | ||||||
|  |         for key, value in ValLoaders.items(): | ||||||
|  |             logger.log( | ||||||
|  |                 "Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value)) | ||||||
|  |             ) | ||||||
|  |         results = evaluate_for_seed( | ||||||
|  |             arch_config, config, arch, train_loader, ValLoaders, seed, logger | ||||||
|  |         ) | ||||||
|  |         all_infos[dataset_key] = results | ||||||
|  |         all_dataset_keys.append(dataset_key) | ||||||
|  |     all_infos["all_dataset_keys"] = all_dataset_keys | ||||||
|  |     return all_infos | ||||||
|  |  | ||||||
|  | def evaluate_for_seed( | ||||||
|  |     arch_config, config, arch, train_loader, valid_loaders, seed, logger | ||||||
|  | ): | ||||||
|  |  | ||||||
|  |     prepare_seed(seed)  # random seed | ||||||
|  |     net = get_cell_based_tiny_net( | ||||||
|  |         dict2config( | ||||||
|  |             { | ||||||
|  |                 "name": "infer.tiny", | ||||||
|  |                 "C": arch_config["channel"], | ||||||
|  |                 "N": arch_config["num_cells"], | ||||||
|  |                 "genotype": arch, | ||||||
|  |                 "num_classes": config.class_num, | ||||||
|  |             }, | ||||||
|  |             None, | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |     # net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) | ||||||
|  |     flop, param = get_model_infos(net, config.xshape) | ||||||
|  |     logger.log("Network : {:}".format(net.get_message()), False) | ||||||
|  |     logger.log( | ||||||
|  |         "{:} Seed-------------------------- {:} --------------------------".format( | ||||||
|  |             time_string(), seed | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |     logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param)) | ||||||
|  |     # train and valid | ||||||
|  |     optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config) | ||||||
|  |     network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda() | ||||||
|  |     # start training | ||||||
|  |     start_time, epoch_time, total_epoch = ( | ||||||
|  |         time.time(), | ||||||
|  |         AverageMeter(), | ||||||
|  |         config.epochs + config.warmup, | ||||||
|  |     ) | ||||||
|  |     ( | ||||||
|  |         train_losses, | ||||||
|  |         train_acc1es, | ||||||
|  |         train_acc5es, | ||||||
|  |         valid_losses, | ||||||
|  |         valid_acc1es, | ||||||
|  |         valid_acc5es, | ||||||
|  |     ) = ({}, {}, {}, {}, {}, {}) | ||||||
|  |     train_times, valid_times = {}, {} | ||||||
|  |     for epoch in range(total_epoch): | ||||||
|  |         scheduler.update(epoch, 0.0) | ||||||
|  |  | ||||||
|  |         train_loss, train_acc1, train_acc5, train_tm = procedure( | ||||||
|  |             train_loader, network, criterion, scheduler, optimizer, "train" | ||||||
|  |         ) | ||||||
|  |         train_losses[epoch] = train_loss | ||||||
|  |         train_acc1es[epoch] = train_acc1 | ||||||
|  |         train_acc5es[epoch] = train_acc5 | ||||||
|  |         train_times[epoch] = train_tm | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             for key, xloder in valid_loaders.items(): | ||||||
|  |                 valid_loss, valid_acc1, valid_acc5, valid_tm = procedure( | ||||||
|  |                     xloder, network, criterion, None, None, "valid" | ||||||
|  |                 ) | ||||||
|  |                 valid_losses["{:}@{:}".format(key, epoch)] = valid_loss | ||||||
|  |                 valid_acc1es["{:}@{:}".format(key, epoch)] = valid_acc1 | ||||||
|  |                 valid_acc5es["{:}@{:}".format(key, epoch)] = valid_acc5 | ||||||
|  |                 valid_times["{:}@{:}".format(key, epoch)] = valid_tm | ||||||
|  |  | ||||||
|  |         # measure elapsed time | ||||||
|  |         epoch_time.update(time.time() - start_time) | ||||||
|  |         start_time = time.time() | ||||||
|  |         need_time = "Time Left: {:}".format( | ||||||
|  |             convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True) | ||||||
|  |         ) | ||||||
|  |         logger.log( | ||||||
|  |             "{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]".format( | ||||||
|  |                 time_string(), | ||||||
|  |                 need_time, | ||||||
|  |                 epoch, | ||||||
|  |                 total_epoch, | ||||||
|  |                 train_loss, | ||||||
|  |                 train_acc1, | ||||||
|  |                 train_acc5, | ||||||
|  |                 valid_loss, | ||||||
|  |                 valid_acc1, | ||||||
|  |                 valid_acc5, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |     info_seed = { | ||||||
|  |         "flop": flop, | ||||||
|  |         "param": param, | ||||||
|  |         "channel": arch_config["channel"], | ||||||
|  |         "num_cells": arch_config["num_cells"], | ||||||
|  |         "config": config._asdict(), | ||||||
|  |         "total_epoch": total_epoch, | ||||||
|  |         "train_losses": train_losses, | ||||||
|  |         "train_acc1es": train_acc1es, | ||||||
|  |         "train_acc5es": train_acc5es, | ||||||
|  |         "train_times": train_times, | ||||||
|  |         "valid_losses": valid_losses, | ||||||
|  |         "valid_acc1es": valid_acc1es, | ||||||
|  |         "valid_acc5es": valid_acc5es, | ||||||
|  |         "valid_times": valid_times, | ||||||
|  |         "net_state_dict": net.state_dict(), | ||||||
|  |         "net_string": "{:}".format(net), | ||||||
|  |         "finish-train": True, | ||||||
|  |     } | ||||||
|  |     return info_seed | ||||||
|  |  | ||||||
|  | def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | ||||||
|  |     data_time, batch_time, batch = AverageMeter(), AverageMeter(), None | ||||||
|  |     losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||||
|  |     latencies = [] | ||||||
|  |     network.eval() | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         end = time.time() | ||||||
|  |         for i, (inputs, targets) in enumerate(xloader): | ||||||
|  |             targets = targets.cuda(non_blocking=True) | ||||||
|  |             inputs = inputs.cuda(non_blocking=True) | ||||||
|  |             data_time.update(time.time() - end) | ||||||
|  |             # forward | ||||||
|  |             features, logits = network(inputs) | ||||||
|  |             loss = criterion(logits, targets) | ||||||
|  |             batch_time.update(time.time() - end) | ||||||
|  |             if batch is None or batch == inputs.size(0): | ||||||
|  |                 batch = inputs.size(0) | ||||||
|  |                 latencies.append(batch_time.val - data_time.val) | ||||||
|  |             # record loss and accuracy | ||||||
|  |             prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||||
|  |             losses.update(loss.item(), inputs.size(0)) | ||||||
|  |             top1.update(prec1.item(), inputs.size(0)) | ||||||
|  |             top5.update(prec5.item(), inputs.size(0)) | ||||||
|  |             end = time.time() | ||||||
|  |     if len(latencies) > 2: | ||||||
|  |         latencies = latencies[1:] | ||||||
|  |     return losses.avg, top1.avg, top5.avg, latencies | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def procedure(xloader, network, criterion, scheduler, optimizer, mode): | ||||||
|  |     losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||||
|  |     if mode == "train": | ||||||
|  |         network.train() | ||||||
|  |     elif mode == "valid": | ||||||
|  |         network.eval() | ||||||
|  |     else: | ||||||
|  |         raise ValueError("The mode is not right : {:}".format(mode)) | ||||||
|  |  | ||||||
|  |     data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() | ||||||
|  |     for i, (inputs, targets) in enumerate(xloader): | ||||||
|  |         if mode == "train": | ||||||
|  |             scheduler.update(None, 1.0 * i / len(xloader)) | ||||||
|  |  | ||||||
|  |         targets = targets.cuda(non_blocking=True) | ||||||
|  |         if mode == "train": | ||||||
|  |             optimizer.zero_grad() | ||||||
|  |         # forward | ||||||
|  |         features, logits = network(inputs) | ||||||
|  |         loss = criterion(logits, targets) | ||||||
|  |         # backward | ||||||
|  |         if mode == "train": | ||||||
|  |             loss.backward() | ||||||
|  |             optimizer.step() | ||||||
|  |         # record loss and accuracy | ||||||
|  |         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||||
|  |         losses.update(loss.item(), inputs.size(0)) | ||||||
|  |         top1.update(prec1.item(), inputs.size(0)) | ||||||
|  |         top5.update(prec5.item(), inputs.size(0)) | ||||||
|  |         # count time | ||||||
|  |         batch_time.update(time.time() - end) | ||||||
|  |         end = time.time() | ||||||
|  |     return losses.avg, top1.avg, top5.avg, batch_time.sum | ||||||
|  |  | ||||||
|  | def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | ||||||
|  |     data_time, batch_time, batch = AverageMeter(), AverageMeter(), None | ||||||
|  |     losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||||
|  |     latencies = [] | ||||||
|  |     network.eval() | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         end = time.time() | ||||||
|  |         for i, (inputs, targets) in enumerate(xloader): | ||||||
|  |             targets = targets.cuda(non_blocking=True) | ||||||
|  |             inputs = inputs.cuda(non_blocking=True) | ||||||
|  |             data_time.update(time.time() - end) | ||||||
|  |             # forward | ||||||
|  |             features, logits = network(inputs) | ||||||
|  |             loss = criterion(logits, targets) | ||||||
|  |             batch_time.update(time.time() - end) | ||||||
|  |             if batch is None or batch == inputs.size(0): | ||||||
|  |                 batch = inputs.size(0) | ||||||
|  |                 latencies.append(batch_time.val - data_time.val) | ||||||
|  |             # record loss and accuracy | ||||||
|  |             prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||||
|  |             losses.update(loss.item(), inputs.size(0)) | ||||||
|  |             top1.update(prec1.item(), inputs.size(0)) | ||||||
|  |             top5.update(prec5.item(), inputs.size(0)) | ||||||
|  |             end = time.time() | ||||||
|  |     if len(latencies) > 2: | ||||||
|  |         latencies = latencies[1:] | ||||||
|  |     return losses.avg, top1.avg, top5.avg, latencies | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def procedure(xloader, network, criterion, scheduler, optimizer, mode): | ||||||
|  |     losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||||
|  |     if mode == "train": | ||||||
|  |         network.train() | ||||||
|  |     elif mode == "valid": | ||||||
|  |         network.eval() | ||||||
|  |     else: | ||||||
|  |         raise ValueError("The mode is not right : {:}".format(mode)) | ||||||
|  |  | ||||||
|  |     data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() | ||||||
|  |     for i, (inputs, targets) in enumerate(xloader): | ||||||
|  |         if mode == "train": | ||||||
|  |             scheduler.update(None, 1.0 * i / len(xloader)) | ||||||
|  |  | ||||||
|  |         targets = targets.cuda(non_blocking=True) | ||||||
|  |         if mode == "train": | ||||||
|  |             optimizer.zero_grad() | ||||||
|  |         # forward | ||||||
|  |         features, logits = network(inputs) | ||||||
|  |         loss = criterion(logits, targets) | ||||||
|  |         # backward | ||||||
|  |         if mode == "train": | ||||||
|  |             loss.backward() | ||||||
|  |             optimizer.step() | ||||||
|  |         # record loss and accuracy | ||||||
|  |         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||||
|  |         losses.update(loss.item(), inputs.size(0)) | ||||||
|  |         top1.update(prec1.item(), inputs.size(0)) | ||||||
|  |         top5.update(prec5.item(), inputs.size(0)) | ||||||
|  |         # count time | ||||||
|  |         batch_time.update(time.time() - end) | ||||||
|  |         end = time.time() | ||||||
|  |     return losses.avg, top1.avg, top5.avg, batch_time.sum | ||||||
|  |  | ||||||
|  | def train_single_model( | ||||||
|  |     save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config | ||||||
|  | ): | ||||||
|  |     assert torch.cuda.is_available(), "CUDA is not available." | ||||||
|  |     torch.backends.cudnn.enabled = True | ||||||
|  |     torch.backends.cudnn.deterministic = True | ||||||
|  |     # torch.backends.cudnn.benchmark = True | ||||||
|  |     torch.set_num_threads(workers) | ||||||
|  |  | ||||||
|  |     save_dir = ( | ||||||
|  |         Path(save_dir) | ||||||
|  |         / "specifics" | ||||||
|  |         / "{:}-{:}-{:}-{:}".format( | ||||||
|  |             "LESS" if use_less else "FULL", | ||||||
|  |             model_str, | ||||||
|  |             arch_config["channel"], | ||||||
|  |             arch_config["num_cells"], | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |     logger = Logger(str(save_dir), 0, False) | ||||||
|  |     print(CellArchitectures) | ||||||
|  |     if model_str in CellArchitectures: | ||||||
|  |         arch = CellArchitectures[model_str] | ||||||
|  |         logger.log( | ||||||
|  |             "The model string is found in pre-defined architecture dict : {:}".format( | ||||||
|  |                 model_str | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |     else: | ||||||
|  |         try: | ||||||
|  |             arch = CellStructure.str2structure(model_str) | ||||||
|  |         except: | ||||||
|  |             raise ValueError( | ||||||
|  |                 "Invalid model string : {:}. It can not be found or parsed.".format( | ||||||
|  |                     model_str | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |     assert arch.check_valid_op( | ||||||
|  |         get_search_spaces("cell", "nas-bench-201") | ||||||
|  |     ), "{:} has the invalid op.".format(arch) | ||||||
|  |     logger.log("Start train-evaluate {:}".format(arch.tostr())) | ||||||
|  |     logger.log("arch_config : {:}".format(arch_config)) | ||||||
|  |  | ||||||
|  |     start_time, seed_time = time.time(), AverageMeter() | ||||||
|  |     for _is, seed in enumerate(seeds): | ||||||
|  |         logger.log( | ||||||
|  |             "\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------".format( | ||||||
|  |                 _is, len(seeds), seed | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         to_save_name = save_dir / "seed-{:04d}.pth".format(seed) | ||||||
|  |         if to_save_name.exists(): | ||||||
|  |             logger.log( | ||||||
|  |                 "Find the existing file {:}, directly load!".format(to_save_name) | ||||||
|  |             ) | ||||||
|  |             checkpoint = torch.load(to_save_name) | ||||||
|  |         else: | ||||||
|  |             logger.log( | ||||||
|  |                 "Does not find the existing file {:}, train and evaluate!".format( | ||||||
|  |                     to_save_name | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |             checkpoint = evaluate_all_datasets( | ||||||
|  |                 arch, | ||||||
|  |                 datasets, | ||||||
|  |                 xpaths, | ||||||
|  |                 splits, | ||||||
|  |                 use_less, | ||||||
|  |                 seed, | ||||||
|  |                 arch_config, | ||||||
|  |                 workers, | ||||||
|  |                 logger, | ||||||
|  |             ) | ||||||
|  |             torch.save(checkpoint, to_save_name) | ||||||
|  |         # log information | ||||||
|  |         logger.log("{:}".format(checkpoint["info"])) | ||||||
|  |         all_dataset_keys = checkpoint["all_dataset_keys"] | ||||||
|  |         for dataset_key in all_dataset_keys: | ||||||
|  |             logger.log( | ||||||
|  |                 "\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15) | ||||||
|  |             ) | ||||||
|  |             dataset_info = checkpoint[dataset_key] | ||||||
|  |             # logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] )) | ||||||
|  |             logger.log( | ||||||
|  |                 "Flops = {:} MB, Params = {:} MB".format( | ||||||
|  |                     dataset_info["flop"], dataset_info["param"] | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |             logger.log("config : {:}".format(dataset_info["config"])) | ||||||
|  |             logger.log( | ||||||
|  |                 "Training State (finish) = {:}".format(dataset_info["finish-train"]) | ||||||
|  |             ) | ||||||
|  |             last_epoch = dataset_info["total_epoch"] - 1 | ||||||
|  |             train_acc1es, train_acc5es = ( | ||||||
|  |                 dataset_info["train_acc1es"], | ||||||
|  |                 dataset_info["train_acc5es"], | ||||||
|  |             ) | ||||||
|  |             valid_acc1es, valid_acc5es = ( | ||||||
|  |                 dataset_info["valid_acc1es"], | ||||||
|  |                 dataset_info["valid_acc5es"], | ||||||
|  |             ) | ||||||
|  |             print(dataset_info["train_acc1es"]) | ||||||
|  |             print(dataset_info["train_acc5es"]) | ||||||
|  |             print(dataset_info["valid_acc1es"]) | ||||||
|  |             print(dataset_info["valid_acc5es"]) | ||||||
|  |             logger.log( | ||||||
|  |                 "Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%".format( | ||||||
|  |                     train_acc1es[last_epoch], | ||||||
|  |                     train_acc5es[last_epoch], | ||||||
|  |                     100 - train_acc1es[last_epoch], | ||||||
|  |                     valid_acc1es['ori-test@'+str(last_epoch)], | ||||||
|  |                     valid_acc5es['ori-test@'+str(last_epoch)], | ||||||
|  |                     100 - valid_acc1es['ori-test@'+str(last_epoch)], | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         # measure elapsed time | ||||||
|  |         seed_time.update(time.time() - start_time) | ||||||
|  |         start_time = time.time() | ||||||
|  |         need_time = "Time Left: {:}".format( | ||||||
|  |             convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True) | ||||||
|  |         ) | ||||||
|  |         logger.log( | ||||||
|  |             "\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}".format( | ||||||
|  |                 _is, len(seeds), seed, need_time | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |     logger.close() | ||||||
|  |  | ||||||
|  | # |nor_conv_3x3~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2| | ||||||
|  | train_strs = [best_cifar_10_str, best_cifar_100_str, best_ImageNet16_str] | ||||||
|  | train_single_model( | ||||||
|  |     save_dir="./outputs", | ||||||
|  |     workers=8, | ||||||
|  |     datasets=["ImageNet16-120"],  | ||||||
|  |     xpaths="./datasets/imagenet16-120", | ||||||
|  |     splits=[0, 0, 0], | ||||||
|  |     use_less=False, | ||||||
|  |     seeds=[777], | ||||||
|  |     model_str=best_ImageNet16_str, | ||||||
|  |     arch_config={"channel": 16, "num_cells": 8},) | ||||||
|  | train_single_model( | ||||||
|  |     save_dir="./outputs", | ||||||
|  |     workers=8, | ||||||
|  |     datasets=["cifar10"],  | ||||||
|  |     xpaths="./datasets/cifar10", | ||||||
|  |     splits=[0, 0, 0], | ||||||
|  |     use_less=False, | ||||||
|  |     seeds=[777], | ||||||
|  |     model_str=best_cifar_10_str, | ||||||
|  |     arch_config={"channel": 16, "num_cells": 8},) | ||||||
|  | train_single_model( | ||||||
|  |     save_dir="./outputs", | ||||||
|  |     workers=8, | ||||||
|  |     datasets=["cifar100"],  | ||||||
|  |     xpaths="./datasets/cifar100", | ||||||
|  |     splits=[0, 0, 0], | ||||||
|  |     use_less=False, | ||||||
|  |     seeds=[777], | ||||||
|  |     model_str=best_cifar_100_str, | ||||||
|  |     arch_config={"channel": 16, "num_cells": 8},) | ||||||
| @@ -1,40 +1,39 @@ | |||||||
| ################################################## | ################################################## | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
|  | # Modified by Hayeon Lee, Eunyoung Hyung 2021. 03. | ||||||
| ################################################## | ################################################## | ||||||
| import os, sys, torch | import os | ||||||
|  | import sys | ||||||
|  | import torch | ||||||
| import os.path as osp | import os.path as osp | ||||||
| import numpy as np | import numpy as np | ||||||
| import torchvision.datasets as dset | import torchvision.datasets as dset | ||||||
| import torchvision.transforms as transforms | import torchvision.transforms as transforms | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from PIL import Image |  | ||||||
|  |  | ||||||
| from xautodl.config_utils import load_config |  | ||||||
|  |  | ||||||
| from .DownsampledImageNet import ImageNet16 |  | ||||||
| from .SearchDatasetWrap import SearchDataset | from .SearchDatasetWrap import SearchDataset | ||||||
|  |  | ||||||
|  | # from PIL import Image | ||||||
|  | import random | ||||||
|  | import pdb | ||||||
|  | from .aircraft import FGVCAircraft | ||||||
|  | from .pets import PetDataset | ||||||
|  | from config_utils import load_config | ||||||
|  |  | ||||||
| Dataset2Class = { | Dataset2Class = {'cifar10': 10, | ||||||
|     "cifar10": 10, |                  'cifar100': 100, | ||||||
|     "cifar100": 100, |                  'mnist': 10, | ||||||
|     "imagenet-1k-s": 1000, |                  'svhn': 10, | ||||||
|     "imagenet-1k": 1000, |                  'aircraft': 30, | ||||||
|     "ImageNet16": 1000, |                  'oxford': 37} | ||||||
|     "ImageNet16-150": 150, |  | ||||||
|     "ImageNet16-120": 120, |  | ||||||
|     "ImageNet16-200": 200, |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class CUTOUT(object): | class CUTOUT(object): | ||||||
|  |  | ||||||
|     def __init__(self, length): |     def __init__(self, length): | ||||||
|         self.length = length |         self.length = length | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}(length={length})".format( |         return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||||
|             name=self.__class__.__name__, **self.__dict__ |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def __call__(self, img): |     def __call__(self, img): | ||||||
|         h, w = img.size(1), img.size(2) |         h, w = img.size(1), img.size(2) | ||||||
| @@ -47,7 +46,7 @@ class CUTOUT(object): | |||||||
|         x1 = np.clip(x - self.length // 2, 0, w) |         x1 = np.clip(x - self.length // 2, 0, w) | ||||||
|         x2 = np.clip(x + self.length // 2, 0, w) |         x2 = np.clip(x + self.length // 2, 0, w) | ||||||
|  |  | ||||||
|         mask[y1:y2, x1:x2] = 0.0 |         mask[y1: y2, x1: x2] = 0. | ||||||
|         mask = torch.from_numpy(mask) |         mask = torch.from_numpy(mask) | ||||||
|         mask = mask.expand_as(img) |         mask = mask.expand_as(img) | ||||||
|         img *= mask |         img *= mask | ||||||
| @@ -55,21 +54,19 @@ class CUTOUT(object): | |||||||
|  |  | ||||||
|  |  | ||||||
| imagenet_pca = { | imagenet_pca = { | ||||||
|     "eigval": np.asarray([0.2175, 0.0188, 0.0045]), |     'eigval': np.asarray([0.2175, 0.0188, 0.0045]), | ||||||
|     "eigvec": np.asarray( |     'eigvec': np.asarray([ | ||||||
|         [ |         [-0.5675, 0.7192, 0.4009], | ||||||
|             [-0.5675, 0.7192, 0.4009], |         [-0.5808, -0.0045, -0.8140], | ||||||
|             [-0.5808, -0.0045, -0.8140], |         [-0.5836, -0.6948, 0.4203], | ||||||
|             [-0.5836, -0.6948, 0.4203], |     ]) | ||||||
|         ] |  | ||||||
|     ), |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| class Lighting(object): | class Lighting(object): | ||||||
|     def __init__( |     def __init__(self, alphastd, | ||||||
|         self, alphastd, eigval=imagenet_pca["eigval"], eigvec=imagenet_pca["eigvec"] |                  eigval=imagenet_pca['eigval'], | ||||||
|     ): |                  eigvec=imagenet_pca['eigvec']): | ||||||
|         self.alphastd = alphastd |         self.alphastd = alphastd | ||||||
|         assert eigval.shape == (3,) |         assert eigval.shape == (3,) | ||||||
|         assert eigvec.shape == (3, 3) |         assert eigvec.shape == (3, 3) | ||||||
| @@ -77,10 +74,10 @@ class Lighting(object): | |||||||
|         self.eigvec = eigvec |         self.eigvec = eigvec | ||||||
|  |  | ||||||
|     def __call__(self, img): |     def __call__(self, img): | ||||||
|         if self.alphastd == 0.0: |         if self.alphastd == 0.: | ||||||
|             return img |             return img | ||||||
|         rnd = np.random.randn(3) * self.alphastd |         rnd = np.random.randn(3) * self.alphastd | ||||||
|         rnd = rnd.astype("float32") |         rnd = rnd.astype('float32') | ||||||
|         v = rnd |         v = rnd | ||||||
|         old_dtype = np.asarray(img).dtype |         old_dtype = np.asarray(img).dtype | ||||||
|         v = v * self.eigval |         v = v * self.eigval | ||||||
| @@ -89,275 +86,222 @@ class Lighting(object): | |||||||
|         img = np.add(img, inc) |         img = np.add(img, inc) | ||||||
|         if old_dtype == np.uint8: |         if old_dtype == np.uint8: | ||||||
|             img = np.clip(img, 0, 255) |             img = np.clip(img, 0, 255) | ||||||
|         img = Image.fromarray(img.astype(old_dtype), "RGB") |         img = Image.fromarray(img.astype(old_dtype), 'RGB') | ||||||
|         return img |         return img | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return self.__class__.__name__ + "()" |         return self.__class__.__name__ + '()' | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_datasets(name, root, cutout): | def get_datasets(name, root, cutout, use_num_cls=None): | ||||||
|  |     if name == 'cifar10': | ||||||
|     if name == "cifar10": |  | ||||||
|         mean = [x / 255 for x in [125.3, 123.0, 113.9]] |         mean = [x / 255 for x in [125.3, 123.0, 113.9]] | ||||||
|         std = [x / 255 for x in [63.0, 62.1, 66.7]] |         std = [x / 255 for x in [63.0, 62.1, 66.7]] | ||||||
|     elif name == "cifar100": |     elif name == 'cifar100': | ||||||
|         mean = [x / 255 for x in [129.3, 124.1, 112.4]] |         mean = [x / 255 for x in [129.3, 124.1, 112.4]] | ||||||
|         std = [x / 255 for x in [68.2, 65.4, 70.4]] |         std = [x / 255 for x in [68.2, 65.4, 70.4]] | ||||||
|     elif name.startswith("imagenet-1k"): |     elif name.startswith('mnist'): | ||||||
|         mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] |         mean, std = [0.1307, 0.1307, 0.1307], [0.3081, 0.3081, 0.3081] | ||||||
|     elif name.startswith("ImageNet16"): |     elif name.startswith('svhn'): | ||||||
|         mean = [x / 255 for x in [122.68, 116.66, 104.01]] |         mean, std = [0.4376821, 0.4437697, 0.47280442], [ 0.19803012, 0.20101562, 0.19703614] | ||||||
|         std = [x / 255 for x in [63.22, 61.26, 65.09]] |     elif name.startswith('aircraft'): | ||||||
|  |         mean = [0.48933587508932375, 0.5183537408957618, 0.5387914411673883] | ||||||
|  | 		std = [0.22388883112804625, 0.21641635409388751, 0.24615605842636115] | ||||||
|  |     elif name.startswith('oxford'): | ||||||
|  |         mean = [0.4828895122298728, 0.4448394893850807, 0.39566558230789783] | ||||||
|  | 		std = [0.25925664613996574, 0.2532760018681693, 0.25981017205097917] | ||||||
|     else: |     else: | ||||||
|         raise TypeError("Unknow dataset : {:}".format(name)) |         raise TypeError("Unknow dataset : {:}".format(name)) | ||||||
|  |  | ||||||
|     # Data Argumentation |     # Data Argumentation | ||||||
|     if name == "cifar10" or name == "cifar100": |     if name == 'cifar10' or name == 'cifar100': | ||||||
|         lists = [ |         lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), | ||||||
|             transforms.RandomHorizontalFlip(), |                  transforms.Normalize(mean, std)] | ||||||
|             transforms.RandomCrop(32, padding=4), |  | ||||||
|             transforms.ToTensor(), |  | ||||||
|             transforms.Normalize(mean, std), |  | ||||||
|         ] |  | ||||||
|         if cutout > 0: |         if cutout > 0: | ||||||
|             lists += [CUTOUT(cutout)] |             lists += [CUTOUT(cutout)] | ||||||
|         train_transform = transforms.Compose(lists) |         train_transform = transforms.Compose(lists) | ||||||
|         test_transform = transforms.Compose( |         test_transform = transforms.Compose( | ||||||
|             [transforms.ToTensor(), transforms.Normalize(mean, std)] |             [transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||||
|         ) |  | ||||||
|         xshape = (1, 3, 32, 32) |         xshape = (1, 3, 32, 32) | ||||||
|     elif name.startswith("ImageNet16"): |     elif name.startswith('cub200'): | ||||||
|         lists = [ |         train_transform = transforms.Compose([ | ||||||
|             transforms.RandomHorizontalFlip(), |             transforms.Resize((32, 32)), | ||||||
|             transforms.RandomCrop(16, padding=2), |  | ||||||
|             transforms.ToTensor(), |             transforms.ToTensor(), | ||||||
|             transforms.Normalize(mean, std), |             transforms.Normalize(mean=mean, std=std) | ||||||
|         ] |         ]) | ||||||
|         if cutout > 0: |         test_transform = transforms.Compose([ | ||||||
|             lists += [CUTOUT(cutout)] |             transforms.Resize((32, 32)), | ||||||
|         train_transform = transforms.Compose(lists) |  | ||||||
|         test_transform = transforms.Compose( |  | ||||||
|             [transforms.ToTensor(), transforms.Normalize(mean, std)] |  | ||||||
|         ) |  | ||||||
|         xshape = (1, 3, 16, 16) |  | ||||||
|     elif name == "tiered": |  | ||||||
|         lists = [ |  | ||||||
|             transforms.RandomHorizontalFlip(), |  | ||||||
|             transforms.RandomCrop(80, padding=4), |  | ||||||
|             transforms.ToTensor(), |             transforms.ToTensor(), | ||||||
|             transforms.Normalize(mean, std), |             transforms.Normalize(mean=mean, std=std) | ||||||
|         ] |         ]) | ||||||
|         if cutout > 0: |         xshape = (1, 3, 32, 32) | ||||||
|             lists += [CUTOUT(cutout)] |     elif name.startswith('mnist'): | ||||||
|         train_transform = transforms.Compose(lists) |         train_transform = transforms.Compose([ | ||||||
|         test_transform = transforms.Compose( |             transforms.Resize((32, 32)), | ||||||
|             [ |             transforms.ToTensor(), | ||||||
|                 transforms.CenterCrop(80), |             transforms.Lambda(lambda x: x.repeat(3, 1, 1)), | ||||||
|                 transforms.ToTensor(), |             transforms.Normalize(mean, std), | ||||||
|                 transforms.Normalize(mean, std), |         ]) | ||||||
|             ] |         test_transform = transforms.Compose([ | ||||||
|         ) |             transforms.Resize((32, 32)), | ||||||
|  |             transforms.ToTensor(), | ||||||
|  |             transforms.Lambda(lambda x: x.repeat(3, 1, 1)), | ||||||
|  |             transforms.Normalize(mean, std) | ||||||
|  |         ]) | ||||||
|  |         xshape = (1, 3, 32, 32) | ||||||
|  |     elif name.startswith('svhn'): | ||||||
|  |         train_transform = transforms.Compose([ | ||||||
|  |             transforms.Resize((32, 32)), | ||||||
|  |             transforms.ToTensor(), | ||||||
|  |             transforms.Normalize(mean=mean, std=std) | ||||||
|  |         ]) | ||||||
|  |         test_transform = transforms.Compose([ | ||||||
|  |             transforms.Resize((32, 32)), | ||||||
|  |             transforms.ToTensor(), | ||||||
|  |             transforms.Normalize(mean=mean, std=std) | ||||||
|  |         ]) | ||||||
|  |         xshape = (1, 3, 32, 32) | ||||||
|  |     elif name.startswith('aircraft'): | ||||||
|  |         train_transform = transforms.Compose([ | ||||||
|  |             transforms.Resize((32, 32)), | ||||||
|  |             transforms.ToTensor(), | ||||||
|  |             transforms.Normalize(mean=mean, std=std) | ||||||
|  |         ]) | ||||||
|  |         test_transform = transforms.Compose([ | ||||||
|  |             transforms.Resize((32, 32)), | ||||||
|  |             transforms.ToTensor(), | ||||||
|  |             transforms.Normalize(mean=mean, std=std), | ||||||
|  |         ]) | ||||||
|  |         xshape = (1, 3, 32, 32) | ||||||
|  |     elif name.startswith('oxford'): | ||||||
|  |         train_transform = transforms.Compose([ | ||||||
|  |             transforms.Resize((32, 32)), | ||||||
|  |             transforms.ToTensor(), | ||||||
|  |             transforms.Normalize(mean=mean, std=std) | ||||||
|  |         ]) | ||||||
|  |         test_transform = transforms.Compose([ | ||||||
|  |             transforms.Resize((32, 32)), | ||||||
|  |             transforms.ToTensor(), | ||||||
|  |             transforms.Normalize(mean=mean, std=std), | ||||||
|  |         ]) | ||||||
|         xshape = (1, 3, 32, 32) |         xshape = (1, 3, 32, 32) | ||||||
|     elif name.startswith("imagenet-1k"): |  | ||||||
|         normalize = transforms.Normalize( |  | ||||||
|             mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |  | ||||||
|         ) |  | ||||||
|         if name == "imagenet-1k": |  | ||||||
|             xlists = [transforms.RandomResizedCrop(224)] |  | ||||||
|             xlists.append( |  | ||||||
|                 transforms.ColorJitter( |  | ||||||
|                     brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2 |  | ||||||
|                 ) |  | ||||||
|             ) |  | ||||||
|             xlists.append(Lighting(0.1)) |  | ||||||
|         elif name == "imagenet-1k-s": |  | ||||||
|             xlists = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))] |  | ||||||
|         else: |  | ||||||
|             raise ValueError("invalid name : {:}".format(name)) |  | ||||||
|         xlists.append(transforms.RandomHorizontalFlip(p=0.5)) |  | ||||||
|         xlists.append(transforms.ToTensor()) |  | ||||||
|         xlists.append(normalize) |  | ||||||
|         train_transform = transforms.Compose(xlists) |  | ||||||
|         test_transform = transforms.Compose( |  | ||||||
|             [ |  | ||||||
|                 transforms.Resize(256), |  | ||||||
|                 transforms.CenterCrop(224), |  | ||||||
|                 transforms.ToTensor(), |  | ||||||
|                 normalize, |  | ||||||
|             ] |  | ||||||
|         ) |  | ||||||
|         xshape = (1, 3, 224, 224) |  | ||||||
|     else: |     else: | ||||||
|         raise TypeError("Unknow dataset : {:}".format(name)) |         raise TypeError("Unknow dataset : {:}".format(name)) | ||||||
|  |  | ||||||
|     if name == "cifar10": |     if name == 'cifar10': | ||||||
|         train_data = dset.CIFAR10( |         train_data = dset.CIFAR10( | ||||||
|             root, train=True, transform=train_transform, download=True |             root, train=True, transform=train_transform, download=True) | ||||||
|         ) |  | ||||||
|         test_data = dset.CIFAR10( |         test_data = dset.CIFAR10( | ||||||
|             root, train=False, transform=test_transform, download=True |             root, train=False, transform=test_transform, download=True) | ||||||
|         ) |  | ||||||
|         assert len(train_data) == 50000 and len(test_data) == 10000 |         assert len(train_data) == 50000 and len(test_data) == 10000 | ||||||
|     elif name == "cifar100": |     elif name == 'cifar100': | ||||||
|         train_data = dset.CIFAR100( |         train_data = dset.CIFAR100( | ||||||
|             root, train=True, transform=train_transform, download=True |             root, train=True, transform=train_transform, download=True) | ||||||
|         ) |  | ||||||
|         test_data = dset.CIFAR100( |         test_data = dset.CIFAR100( | ||||||
|             root, train=False, transform=test_transform, download=True |             root, train=False, transform=test_transform, download=True) | ||||||
|         ) |  | ||||||
|         assert len(train_data) == 50000 and len(test_data) == 10000 |         assert len(train_data) == 50000 and len(test_data) == 10000 | ||||||
|     elif name.startswith("imagenet-1k"): |     elif name == 'mnist': | ||||||
|         train_data = dset.ImageFolder(osp.join(root, "train"), train_transform) |         train_data = dset.MNIST( | ||||||
|         test_data = dset.ImageFolder(osp.join(root, "val"), test_transform) |             root, train=True, transform=train_transform, download=True) | ||||||
|         assert ( |         test_data = dset.MNIST( | ||||||
|             len(train_data) == 1281167 and len(test_data) == 50000 |             root, train=False, transform=test_transform, download=True) | ||||||
|         ), "invalid number of images : {:} & {:} vs {:} & {:}".format( |         assert len(train_data) == 60000 and len(test_data) == 10000 | ||||||
|             len(train_data), len(test_data), 1281167, 50000 |     elif name == 'svhn': | ||||||
|         ) |         train_data = dset.SVHN(root, split='train', | ||||||
|     elif name == "ImageNet16": |                                transform=train_transform, download=True) | ||||||
|         train_data = ImageNet16(root, True, train_transform) |         test_data = dset.SVHN(root, split='test', | ||||||
|         test_data = ImageNet16(root, False, test_transform) |                               transform=test_transform, download=True) | ||||||
|         assert len(train_data) == 1281167 and len(test_data) == 50000 |         assert len(train_data) == 73257 and len(test_data) == 26032 | ||||||
|     elif name == "ImageNet16-120": |     elif name == 'aircraft': | ||||||
|         train_data = ImageNet16(root, True, train_transform, 120) |         train_data = FGVCAircraft(root, class_type='manufacturer', split='trainval', | ||||||
|         test_data = ImageNet16(root, False, test_transform, 120) |                                   transform=train_transform, download=False) | ||||||
|         assert len(train_data) == 151700 and len(test_data) == 6000 |         test_data = FGVCAircraft(root, class_type='manufacturer', split='test', | ||||||
|     elif name == "ImageNet16-150": |                                  transform=test_transform, download=False) | ||||||
|         train_data = ImageNet16(root, True, train_transform, 150) |         assert len(train_data) == 6667 and len(test_data) == 3333 | ||||||
|         test_data = ImageNet16(root, False, test_transform, 150) |     elif name == 'oxford': | ||||||
|         assert len(train_data) == 190272 and len(test_data) == 7500 |         train_data = PetDataset(root, train=True, num_cl=37, | ||||||
|     elif name == "ImageNet16-200": |                                 val_split=0.15, transforms=train_transform) | ||||||
|         train_data = ImageNet16(root, True, train_transform, 200) |         test_data = PetDataset(root, train=False, num_cl=37, | ||||||
|         test_data = ImageNet16(root, False, test_transform, 200) |                                val_split=0.15, transforms=test_transform) | ||||||
|         assert len(train_data) == 254775 and len(test_data) == 10000 |  | ||||||
|     else: |     else: | ||||||
|         raise TypeError("Unknow dataset : {:}".format(name)) |         raise TypeError("Unknow dataset : {:}".format(name)) | ||||||
|  |  | ||||||
|     class_num = Dataset2Class[name] |     class_num = Dataset2Class[name] if use_num_cls is None else len( | ||||||
|  |         use_num_cls) | ||||||
|     return train_data, test_data, xshape, class_num |     return train_data, test_data, xshape, class_num | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_nas_search_loaders( | def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers, num_cls=None): | ||||||
|     train_data, valid_data, dataset, config_root, batch_size, workers |  | ||||||
| ): |  | ||||||
|     if isinstance(batch_size, (list, tuple)): |     if isinstance(batch_size, (list, tuple)): | ||||||
|         batch, test_batch = batch_size |         batch, test_batch = batch_size | ||||||
|     else: |     else: | ||||||
|         batch, test_batch = batch_size, batch_size |         batch, test_batch = batch_size, batch_size | ||||||
|     if dataset == "cifar10": |     if dataset == 'cifar10': | ||||||
|         # split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |         # split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||||
|         cifar_split = load_config("{:}/cifar-split.txt".format(config_root), None, None) |         cifar_split = load_config( | ||||||
|         train_split, valid_split = ( |             '{:}/cifar-split.txt'.format(config_root), None, None) | ||||||
|             cifar_split.train, |         # search over the proposed training and validation set | ||||||
|             cifar_split.valid, |         train_split, valid_split = cifar_split.train, cifar_split.valid | ||||||
|         )  # search over the proposed training and validation set |  | ||||||
|         # logger.log('Load split file from {:}'.format(split_Fpath))      # they are two disjoint groups in the original CIFAR-10 training set |         # logger.log('Load split file from {:}'.format(split_Fpath))      # they are two disjoint groups in the original CIFAR-10 training set | ||||||
|         # To split data |         # To split data | ||||||
|         xvalid_data = deepcopy(train_data) |         xvalid_data = deepcopy(train_data) | ||||||
|         if hasattr(xvalid_data, "transforms"):  # to avoid a print issue |         if hasattr(xvalid_data, 'transforms'):  # to avoid a print issue | ||||||
|             xvalid_data.transforms = valid_data.transform |             xvalid_data.transforms = valid_data.transform | ||||||
|         xvalid_data.transform = deepcopy(valid_data.transform) |         xvalid_data.transform = deepcopy(valid_data.transform) | ||||||
|         search_data = SearchDataset(dataset, train_data, train_split, valid_split) |         search_data = SearchDataset( | ||||||
|  |             dataset, train_data, train_split, valid_split) | ||||||
|         # data loader |         # data loader | ||||||
|         search_loader = torch.utils.data.DataLoader( |         search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers, | ||||||
|             search_data, |                                                     pin_memory=True) | ||||||
|             batch_size=batch, |         train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, | ||||||
|             shuffle=True, |                                                    sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||||
|             num_workers=workers, |                                                        train_split), | ||||||
|             pin_memory=True, |                                                    num_workers=workers, pin_memory=True) | ||||||
|         ) |         valid_loader = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch, | ||||||
|         train_loader = torch.utils.data.DataLoader( |                                                    sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||||
|             train_data, |                                                        valid_split), | ||||||
|             batch_size=batch, |                                                    num_workers=workers, pin_memory=True) | ||||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), |     elif dataset == 'cifar100': | ||||||
|             num_workers=workers, |  | ||||||
|             pin_memory=True, |  | ||||||
|         ) |  | ||||||
|         valid_loader = torch.utils.data.DataLoader( |  | ||||||
|             xvalid_data, |  | ||||||
|             batch_size=test_batch, |  | ||||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), |  | ||||||
|             num_workers=workers, |  | ||||||
|             pin_memory=True, |  | ||||||
|         ) |  | ||||||
|     elif dataset == "cifar100": |  | ||||||
|         cifar100_test_split = load_config( |         cifar100_test_split = load_config( | ||||||
|             "{:}/cifar100-test-split.txt".format(config_root), None, None |             '{:}/cifar100-test-split.txt'.format(config_root), None, None) | ||||||
|         ) |  | ||||||
|         search_train_data = train_data |         search_train_data = train_data | ||||||
|         search_valid_data = deepcopy(valid_data) |         search_valid_data = deepcopy(valid_data) | ||||||
|         search_valid_data.transform = train_data.transform |         search_valid_data.transform = train_data.transform | ||||||
|         search_data = SearchDataset( |         search_data = SearchDataset(dataset, [search_train_data, search_valid_data], | ||||||
|             dataset, |                                     list(range(len(search_train_data))), | ||||||
|             [search_train_data, search_valid_data], |                                     cifar100_test_split.xvalid) | ||||||
|             list(range(len(search_train_data))), |         search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers, | ||||||
|             cifar100_test_split.xvalid, |                                                     pin_memory=True) | ||||||
|         ) |         train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True, num_workers=workers, | ||||||
|         search_loader = torch.utils.data.DataLoader( |                                                    pin_memory=True) | ||||||
|             search_data, |         valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch, | ||||||
|             batch_size=batch, |                                                    sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||||
|             shuffle=True, |                                                        cifar100_test_split.xvalid), num_workers=workers, pin_memory=True) | ||||||
|             num_workers=workers, |     elif dataset in ['mnist', 'svhn', 'aircraft', 'oxford']: | ||||||
|             pin_memory=True, |         if not os.path.exists('{:}/{}-test-split.txt'.format(config_root, dataset)): | ||||||
|         ) |             import json | ||||||
|         train_loader = torch.utils.data.DataLoader( |             label_list = list(range(len(valid_data))) | ||||||
|             train_data, |             random.shuffle(label_list) | ||||||
|             batch_size=batch, |             strlist = [str(label_list[i]) for i in range(len(label_list))] | ||||||
|             shuffle=True, |             split = {'xvalid': ["int", strlist[:len(valid_data) // 2]], | ||||||
|             num_workers=workers, |                      'xtest': ["int", strlist[len(valid_data) // 2:]]} | ||||||
|             pin_memory=True, |             with open('{:}/{}-test-split.txt'.format(config_root, dataset), 'w') as f: | ||||||
|         ) |                 f.write(json.dumps(split)) | ||||||
|         valid_loader = torch.utils.data.DataLoader( |         test_split = load_config( | ||||||
|             valid_data, |             '{:}/{}-test-split.txt'.format(config_root, dataset), None, None) | ||||||
|             batch_size=test_batch, |  | ||||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler( |  | ||||||
|                 cifar100_test_split.xvalid |  | ||||||
|             ), |  | ||||||
|             num_workers=workers, |  | ||||||
|             pin_memory=True, |  | ||||||
|         ) |  | ||||||
|     elif dataset == "ImageNet16-120": |  | ||||||
|         imagenet_test_split = load_config( |  | ||||||
|             "{:}/imagenet-16-120-test-split.txt".format(config_root), None, None |  | ||||||
|         ) |  | ||||||
|         search_train_data = train_data |         search_train_data = train_data | ||||||
|         search_valid_data = deepcopy(valid_data) |         search_valid_data = deepcopy(valid_data) | ||||||
|         search_valid_data.transform = train_data.transform |         search_valid_data.transform = train_data.transform | ||||||
|         search_data = SearchDataset( |         search_data = SearchDataset(dataset, [search_train_data, search_valid_data], | ||||||
|             dataset, |                                     list(range(len(search_train_data))), test_split.xvalid) | ||||||
|             [search_train_data, search_valid_data], |         search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, | ||||||
|             list(range(len(search_train_data))), |                                                     num_workers=workers, pin_memory=True) | ||||||
|             imagenet_test_split.xvalid, |         train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True, | ||||||
|         ) |                                                    num_workers=workers, pin_memory=True) | ||||||
|         search_loader = torch.utils.data.DataLoader( |         valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch, | ||||||
|             search_data, |                                                    sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||||
|             batch_size=batch, |                                                        test_split.xvalid), num_workers=workers, pin_memory=True) | ||||||
|             shuffle=True, |  | ||||||
|             num_workers=workers, |  | ||||||
|             pin_memory=True, |  | ||||||
|         ) |  | ||||||
|         train_loader = torch.utils.data.DataLoader( |  | ||||||
|             train_data, |  | ||||||
|             batch_size=batch, |  | ||||||
|             shuffle=True, |  | ||||||
|             num_workers=workers, |  | ||||||
|             pin_memory=True, |  | ||||||
|         ) |  | ||||||
|         valid_loader = torch.utils.data.DataLoader( |  | ||||||
|             valid_data, |  | ||||||
|             batch_size=test_batch, |  | ||||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler( |  | ||||||
|                 imagenet_test_split.xvalid |  | ||||||
|             ), |  | ||||||
|             num_workers=workers, |  | ||||||
|             pin_memory=True, |  | ||||||
|         ) |  | ||||||
|     else: |     else: | ||||||
|         raise ValueError("invalid dataset : {:}".format(dataset)) |         raise ValueError('invalid dataset : {:}'.format(dataset)) | ||||||
|     return search_loader, train_loader, valid_loader |     return search_loader, train_loader, valid_loader | ||||||
|  |  | ||||||
|  |  | ||||||
| # if __name__ == '__main__': |  | ||||||
| #  train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1) |  | ||||||
| #  import pdb; pdb.set_trace() |  | ||||||
|   | |||||||
| @@ -213,6 +213,13 @@ AllConv3x3_CODE = Structure( | |||||||
|         (("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)), |         (("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)), | ||||||
|     ]  # node-3 |     ]  # node-3 | ||||||
| ) | ) | ||||||
|  | Number_5374 = Structure( | ||||||
|  |     [ | ||||||
|  |         (("nor_conv_3x3", 0),),  # node-1 | ||||||
|  |         (("nor_conv_1x1", 0), ("nor_conv_3x3", 1)),  # node-2 | ||||||
|  |         (("skip_connect", 0), ("none", 1), ("nor_conv_3x3", 2)),  # node-3 | ||||||
|  |     ] | ||||||
|  | ) | ||||||
|  |  | ||||||
| AllFull_CODE = Structure( | AllFull_CODE = Structure( | ||||||
|     [ |     [ | ||||||
| @@ -271,4 +278,5 @@ architectures = { | |||||||
|     "all_c1x1": AllConv1x1_CODE, |     "all_c1x1": AllConv1x1_CODE, | ||||||
|     "all_idnt": AllIdentity_CODE, |     "all_idnt": AllIdentity_CODE, | ||||||
|     "all_full": AllFull_CODE, |     "all_full": AllFull_CODE, | ||||||
|  |     "5374": Number_5374, | ||||||
| } | } | ||||||
|   | |||||||
| @@ -12,6 +12,7 @@ def obtain_accuracy(output, target, topk=(1,)): | |||||||
|  |  | ||||||
|     res = [] |     res = [] | ||||||
|     for k in topk: |     for k in topk: | ||||||
|         correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) |         # correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||||||
|  |         correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) | ||||||
|         res.append(correct_k.mul_(100.0 / batch_size)) |         res.append(correct_k.mul_(100.0 / batch_size)) | ||||||
|     return res |     return res | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user