Compare commits
	
		
			10 Commits
		
	
	
		
			d2cef525f3
			...
			bb33ca9a68
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| bb33ca9a68 | |||
|  | f46486e21b | ||
|  | 5908a1edef | ||
|  | ed34024a88 | ||
|  | 5bf036a763 | ||
|  | b557a22928 | ||
|  | f549ed2e61 | ||
|  | 5a5cb82537 | ||
|  | 676e8e411d | ||
|  | 8d0799dfb1 | 
							
								
								
									
										2
									
								
								.github/workflows/test-basic.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/test-basic.yaml
									
									
									
									
										vendored
									
									
								
							| @@ -41,7 +41,7 @@ jobs: | ||||
|  | ||||
|       - name: Install XAutoDL from source | ||||
|         run: | | ||||
|           python setup.py install | ||||
|           pip install . | ||||
|  | ||||
|       - name: Test Search Space | ||||
|         run: | | ||||
|   | ||||
							
								
								
									
										2
									
								
								.github/workflows/test-misc.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/test-misc.yaml
									
									
									
									
										vendored
									
									
								
							| @@ -26,7 +26,7 @@ jobs: | ||||
|  | ||||
|       - name: Install XAutoDL from source | ||||
|         run: | | ||||
|           python setup.py install | ||||
|           pip install . | ||||
|  | ||||
|       - name: Test Xmisc | ||||
|         run: | | ||||
|   | ||||
| @@ -26,7 +26,7 @@ jobs: | ||||
|  | ||||
|       - name: Install XAutoDL from source | ||||
|         run: | | ||||
|           python setup.py install | ||||
|           pip install . | ||||
|  | ||||
|       - name: Test Super Model | ||||
|         run: | | ||||
|   | ||||
| @@ -61,13 +61,13 @@ At this moment, this project provides the following algorithms and scripts to ru | ||||
|     <tr> <!-- (6-th row) --> | ||||
|     <td align="center" valign="middle"> NATS-Bench </td> | ||||
|     <td align="center" valign="middle"> <a href="https://xuanyidong.com/assets/projects/NATS-Bench"> NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size</a> </td> | ||||
|     <td align="center" valign="middle"> <a href="https://github.com/D-X-Y/NATS-Bench">NATS-Bench.md</a> </td> | ||||
|     <td align="center" valign="middle"> <a href="https://github.com/D-X-Y/NATS-Bench/blob/main/README.md">NATS-Bench.md</a> </td> | ||||
|     </tr> | ||||
|     <tr> <!-- (7-th row) --> | ||||
|     <td align="center" valign="middle"> ... </td> | ||||
|     <td align="center" valign="middle"> ENAS / REA / REINFORCE / BOHB </td> | ||||
|     <td align="center" valign="middle"> Please check the original papers </td> | ||||
|     <td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/main/docs/NAS-Bench-201.md">NAS-Bench-201.md</a>  <a href="https://github.com/D-X-Y/NATS-Bench">NATS-Bench.md</a> </td> | ||||
|     <td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/main/docs/NAS-Bench-201.md">NAS-Bench-201.md</a>  <a href="https://github.com/D-X-Y/NATS-Bench/blob/main/README.md">NATS-Bench.md</a> </td> | ||||
|     </tr> | ||||
|     <tr> <!-- (start second block) --> | ||||
|     <td rowspan="1" align="center" valign="middle" halign="middle"> HPO </td> | ||||
| @@ -89,7 +89,7 @@ At this moment, this project provides the following algorithms and scripts to ru | ||||
| ## Requirements and Preparation | ||||
|  | ||||
|  | ||||
| **First of all**, please use `python setup.py install` to install `xautodl` library. | ||||
| **First of all**, please use `pip install .` to install `xautodl` library. | ||||
|  | ||||
| Please install `Python>=3.6` and `PyTorch>=1.5.0`. (You could use lower versions of Python and PyTorch, but may have bugs). | ||||
| Some visualization codes may require `opencv`. | ||||
|   | ||||
| @@ -29,7 +29,7 @@ You can simply type `pip install nas-bench-201` to install our api. Please see s | ||||
| You can move it to anywhere you want and send its path to our API for initialization. | ||||
| - [2020.02.25] APIv1.0/FILEv1.0: [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial. | ||||
| - [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [ | ||||
| NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights. | ||||
| NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the trained weights. | ||||
| - [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi). | ||||
| - [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions | ||||
| - [2020.03.16] APIv1.3/FILEv1.1: [`NAS-Bench-201-v1_1-096897.pth`](https://drive.google.com/open?id=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_) (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable. | ||||
|   | ||||
| @@ -27,7 +27,7 @@ You can simply type `pip install nas-bench-201` to install our api. Please see s | ||||
| You can move it to anywhere you want and send its path to our API for initialization. | ||||
| - [2020.02.25] APIv1.0/FILEv1.0: [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial. | ||||
| - [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [ | ||||
| NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights. | ||||
| NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the trained weights. | ||||
| - [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi). | ||||
| - [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions | ||||
| - [2020.03.16] APIv1.3/FILEv1.1: [`NAS-Bench-201-v1_1-096897.pth`](https://drive.google.com/open?id=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_) (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable. | ||||
|   | ||||
| @@ -3,7 +3,7 @@ | ||||
| </p> | ||||
|  | ||||
| --------- | ||||
| [](LICENSE.md) | ||||
| [](../LICENSE.md) | ||||
|  | ||||
| 自动深度学习库 (AutoDL-Projects) 是一个开源的,轻量级的,功能强大的项目。 | ||||
| 该项目实现了多种网络结构搜索(NAS)和超参数优化(HPO)算法。 | ||||
| @@ -142,8 +142,8 @@ | ||||
|  | ||||
| # 其他 | ||||
|  | ||||
| 如果你想要给这份代码库做贡献,请看[CONTRIBUTING.md](.github/CONTRIBUTING.md)。 | ||||
| 此外,使用规范请参考[CODE-OF-CONDUCT.md](.github/CODE-OF-CONDUCT.md)。 | ||||
| 如果你想要给这份代码库做贡献,请看[CONTRIBUTING.md](../.github/CONTRIBUTING.md)。 | ||||
| 此外,使用规范请参考[CODE-OF-CONDUCT.md](../.github/CODE-OF-CONDUCT.md)。 | ||||
|  | ||||
| # 许可证 | ||||
| The entire codebase is under [MIT license](LICENSE.md) | ||||
| The entire codebase is under [MIT license](../LICENSE.md) | ||||
|   | ||||
| @@ -24,6 +24,9 @@ | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | ||||
| #### | ||||
| # The following scripts are added in 20 Mar 2022 | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo gdas_v1 --rand_seed 777 | ||||
| ###################################################################################### | ||||
| import os, sys, time, random, argparse | ||||
| import numpy as np | ||||
| @@ -166,6 +169,8 @@ def search_func( | ||||
|             network.set_cal_mode("dynamic", sampled_arch) | ||||
|         elif algo == "gdas": | ||||
|             network.set_cal_mode("gdas", None) | ||||
|         elif algo == "gdas_v1": | ||||
|             network.set_cal_mode("gdas_v1", None) | ||||
|         elif algo.startswith("darts"): | ||||
|             network.set_cal_mode("joint", None) | ||||
|         elif algo == "random": | ||||
| @@ -196,6 +201,8 @@ def search_func( | ||||
|             network.set_cal_mode("joint") | ||||
|         elif algo == "gdas": | ||||
|             network.set_cal_mode("gdas", None) | ||||
|         elif algo == "gdas_v1": | ||||
|             network.set_cal_mode("gdas_v1", None) | ||||
|         elif algo.startswith("darts"): | ||||
|             network.set_cal_mode("joint", None) | ||||
|         elif algo == "random": | ||||
| @@ -373,7 +380,7 @@ def get_best_arch(xloader, network, n_samples, algo): | ||||
|             archs, valid_accs = network.return_topK(n_samples, True), [] | ||||
|         elif algo == "setn": | ||||
|             archs, valid_accs = network.return_topK(n_samples, False), [] | ||||
|         elif algo.startswith("darts") or algo == "gdas": | ||||
|         elif algo.startswith("darts") or algo == "gdas" or algo == "gdas_v1": | ||||
|             arch = network.genotype | ||||
|             archs, valid_accs = [arch], [] | ||||
|         elif algo == "enas": | ||||
| @@ -568,7 +575,7 @@ def main(xargs): | ||||
|         ) | ||||
|  | ||||
|         network.set_drop_path(float(epoch + 1) / total_epoch, xargs.drop_path_rate) | ||||
|         if xargs.algo == "gdas": | ||||
|         if xargs.algo == "gdas" or xargs.algo == "gdas_v1": | ||||
|             network.set_tau( | ||||
|                 xargs.tau_max | ||||
|                 - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1) | ||||
| @@ -632,6 +639,8 @@ def main(xargs): | ||||
|             network.set_cal_mode("dynamic", genotype) | ||||
|         elif xargs.algo == "gdas": | ||||
|             network.set_cal_mode("gdas", None) | ||||
|         elif xargs.algo == "gdas_v1": | ||||
|             network.set_cal_mode("gdas_v1", None) | ||||
|         elif xargs.algo.startswith("darts"): | ||||
|             network.set_cal_mode("joint", None) | ||||
|         elif xargs.algo == "random": | ||||
| @@ -699,6 +708,8 @@ def main(xargs): | ||||
|         network.set_cal_mode("dynamic", genotype) | ||||
|     elif xargs.algo == "gdas": | ||||
|         network.set_cal_mode("gdas", None) | ||||
|     elif xargs.algo == "gdas_v1": | ||||
|         network.set_cal_mode("gdas_v1", None) | ||||
|     elif xargs.algo.startswith("darts"): | ||||
|         network.set_cal_mode("joint", None) | ||||
|     elif xargs.algo == "random": | ||||
| @@ -747,7 +758,7 @@ if __name__ == "__main__": | ||||
|     parser.add_argument( | ||||
|         "--algo", | ||||
|         type=str, | ||||
|         choices=["darts-v1", "darts-v2", "gdas", "setn", "random", "enas"], | ||||
|         choices=["darts-v1", "darts-v2", "gdas", "gdas_v1", "setn", "random", "enas"], | ||||
|         help="The search space name.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|   | ||||
							
								
								
									
										57
									
								
								exps/experimental/test-dks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								exps/experimental/test-dks.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,57 @@ | ||||
| from dks.base.activation_getter import ( | ||||
|     get_activation_function as _get_numpy_activation_function, | ||||
| ) | ||||
| from dks.base.activation_transform import _get_activations_params | ||||
|  | ||||
|  | ||||
| def subnet_max_func(x, r_fn): | ||||
|     depth = 7 | ||||
|     res_x = r_fn(x) | ||||
|     x = r_fn(x) | ||||
|     for _ in range(depth): | ||||
|         x = r_fn(r_fn(x)) + x | ||||
|     return max(x, res_x) | ||||
|  | ||||
|  | ||||
| def subnet_max_func_v2(x, r_fn): | ||||
|     depth = 2 | ||||
|     res_x = r_fn(x) | ||||
|  | ||||
|     x = r_fn(x) | ||||
|     for _ in range(depth): | ||||
|         x = 0.8 * r_fn(r_fn(x)) + 0.2 * x | ||||
|  | ||||
|     return max(x, res_x) | ||||
|  | ||||
|  | ||||
| def get_transformed_activations( | ||||
|     activation_names, | ||||
|     method="TAT", | ||||
|     dks_params=None, | ||||
|     tat_params=None, | ||||
|     max_slope_func=None, | ||||
|     max_curv_func=None, | ||||
|     subnet_max_func=None, | ||||
|     activation_getter=_get_numpy_activation_function, | ||||
| ): | ||||
|     params = _get_activations_params( | ||||
|         activation_names, | ||||
|         method=method, | ||||
|         dks_params=dks_params, | ||||
|         tat_params=tat_params, | ||||
|         max_slope_func=max_slope_func, | ||||
|         max_curv_func=max_curv_func, | ||||
|         subnet_max_func=subnet_max_func, | ||||
|     ) | ||||
|     return params | ||||
|  | ||||
|  | ||||
| params = get_transformed_activations( | ||||
|     ["swish"], method="TAT", subnet_max_func=subnet_max_func | ||||
| ) | ||||
| print(params) | ||||
|  | ||||
| params = get_transformed_activations( | ||||
|     ["leaky_relu"], method="TAT", subnet_max_func=subnet_max_func_v2 | ||||
| ) | ||||
| print(params) | ||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							| @@ -37,7 +37,7 @@ def read(fname="README.md"): | ||||
|  | ||||
|  | ||||
| # What packages are required for this module to be executed? | ||||
| REQUIRED = ["numpy>=1.16.5,<=1.19.5", "pyyaml>=5.0.0", "fvcore"] | ||||
| REQUIRED = ["numpy>=1.16.5", "pyyaml>=5.0.0", "fvcore"] | ||||
|  | ||||
| packages = find_packages( | ||||
|     exclude=("tests", "scripts", "scripts-search", "lib*", "exps*") | ||||
|   | ||||
							
								
								
									
										502
									
								
								test.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										502
									
								
								test.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,502 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from nats_bench import create\n", | ||||
|     "\n", | ||||
|     "# Create the API for size search space\n", | ||||
|     "api = create(None, 'sss', fast_mode=True, verbose=True)\n", | ||||
|     "\n", | ||||
|     "# Create the API for tologoy search space\n", | ||||
|     "api = create(None, 'tss', fast_mode=True, verbose=True)\n", | ||||
|     "\n", | ||||
|     "# Query the loss / accuracy / time for 1234-th candidate architecture on CIFAR-10\n", | ||||
|     "# info is a dict, where you can easily figure out the meaning by key\n", | ||||
|     "info = api.get_more_info(1234, 'cifar10')\n", | ||||
|     "\n", | ||||
|     "# Query the flops, params, latency. info is a dict.\n", | ||||
|     "info = api.get_cost_info(12, 'cifar10')\n", | ||||
|     "\n", | ||||
|     "# Simulate the training of the 1224-th candidate:\n", | ||||
|     "validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(1224, dataset='cifar10', hp='12')\n", | ||||
|     "\n", | ||||
|     "# Clear the parameters of the 12-th candidate.\n", | ||||
|     "api.clear_params(12)\n", | ||||
|     "\n", | ||||
|     "# Reload all information of the 12-th candidate.\n", | ||||
|     "api.reload(index=12)\n", | ||||
|     "\n", | ||||
|     "# Create the instance of th 12-th candidate for CIFAR-10.\n", | ||||
|     "from models import get_cell_based_tiny_net\n", | ||||
|     "config = api.get_net_config(12, 'cifar10')\n", | ||||
|     "network = get_cell_based_tiny_net(config)\n", | ||||
|     "\n", | ||||
|     "# Load the pre-trained weights: params is a dict, where the key is the seed and value is the weights.\n", | ||||
|     "params = api.get_net_param(12, 'cifar10', None)\n", | ||||
|     "network.load_state_dict(next(iter(params.values())))\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from nas_201_api import NASBench201API as API\n", | ||||
|     "import os\n", | ||||
|     "# api = API('./NAS-Bench-201-v1_1_096897.pth')\n", | ||||
|     "# get the current path\n", | ||||
|     "print(os.path.abspath(os.path.curdir))\n", | ||||
|     "cur_path = os.path.abspath(os.path.curdir)\n", | ||||
|     "data_path = os.path.join(cur_path, 'NAS-Bench-201-v1_1-096897.pth')\n", | ||||
|     "api = API(data_path)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "# get the best performance on CIFAR-10\n", | ||||
|     "len = 15625\n", | ||||
|     "accs = []\n", | ||||
|     "for i in range(1, len):\n", | ||||
|     "    results = api.query_by_index(i, 'cifar10')\n", | ||||
|     "    dict_items = list(results.items())\n", | ||||
|     "    train_info = dict_items[0][1].get_train()\n", | ||||
|     "    acc = train_info['accuracy']\n", | ||||
|     "    accs.append((i, acc))\n", | ||||
|     "print(max(accs, key=lambda x: x[1]))\n", | ||||
|     "best_index, best_acc = max(accs, key=lambda x: x[1])\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def find_best_index(dataset):\n", | ||||
|     "    len = 15625\n", | ||||
|     "    accs = []\n", | ||||
|     "    for i in range(1, len):\n", | ||||
|     "        results = api.query_by_index(i, dataset)\n", | ||||
|     "        dict_items = list(results.items())\n", | ||||
|     "        train_info = dict_items[0][1].get_train()\n", | ||||
|     "        acc = train_info['accuracy']\n", | ||||
|     "        accs.append((i, acc))\n", | ||||
|     "    return max(accs, key=lambda x: x[1])\n", | ||||
|     "best_cifar_10_index, best_cifar_10_acc = find_best_index('cifar10')\n", | ||||
|     "best_cifar_100_index, best_cifar_100_acc = find_best_index('cifar100')\n", | ||||
|     "best_ImageNet16_index, best_ImageNet16_acc= find_best_index('ImageNet16-120')\n", | ||||
|     "print(best_cifar_10_index, best_cifar_10_acc)\n", | ||||
|     "print(best_cifar_100_index, best_cifar_100_acc)\n", | ||||
|     "print(best_ImageNet16_index, best_ImageNet16_acc)\n", | ||||
|     "\n", | ||||
|     "\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "api.show(5374)\n", | ||||
|     "config = api.get_net_config(best_index, 'cifar10')\n", | ||||
|     "from models import get_cell_based_tiny_net\n", | ||||
|     "network = get_cell_based_tiny_net(config)\n", | ||||
|     "print(network)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "api.get_net_param(5374, 'cifar10', None)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import os, sys, time, torch, random, argparse\n", | ||||
|     "from PIL import ImageFile\n", | ||||
|     "\n", | ||||
|     "ImageFile.LOAD_TRUNCATED_IMAGES = True\n", | ||||
|     "from copy import deepcopy\n", | ||||
|     "from pathlib import Path\n", | ||||
|     "\n", | ||||
|     "from config_utils import load_config\n", | ||||
|     "from procedures.starts import get_machine_info\n", | ||||
|     "from datasets.get_dataset_with_transform import get_datasets\n", | ||||
|     "from log_utils import Logger, AverageMeter, time_string, convert_secs2time\n", | ||||
|     "from models import CellStructure, CellArchitectures, get_search_spaces" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def evaluate_all_datasets(\n", | ||||
|     "    arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger\n", | ||||
|     "):\n", | ||||
|     "    machine_info, arch_config = get_machine_info(), deepcopy(arch_config)\n", | ||||
|     "    all_infos = {\"info\": machine_info}\n", | ||||
|     "    all_dataset_keys = []\n", | ||||
|     "    # look all the datasets\n", | ||||
|     "    for dataset, xpath, split in zip(datasets, xpaths, splits):\n", | ||||
|     "        # train valid data\n", | ||||
|     "        train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)\n", | ||||
|     "        # load the configuration\n", | ||||
|     "        if dataset == \"cifar10\" or dataset == \"cifar100\":\n", | ||||
|     "            if use_less:\n", | ||||
|     "                config_path = \"configs/nas-benchmark/LESS.config\"\n", | ||||
|     "            else:\n", | ||||
|     "                config_path = \"configs/nas-benchmark/CIFAR.config\"\n", | ||||
|     "            split_info = load_config(\n", | ||||
|     "                \"configs/nas-benchmark/cifar-split.txt\", None, None\n", | ||||
|     "            )\n", | ||||
|     "        elif dataset.startswith(\"ImageNet16\"):\n", | ||||
|     "            if use_less:\n", | ||||
|     "                config_path = \"configs/nas-benchmark/LESS.config\"\n", | ||||
|     "            else:\n", | ||||
|     "                config_path = \"configs/nas-benchmark/ImageNet-16.config\"\n", | ||||
|     "            split_info = load_config(\n", | ||||
|     "                \"configs/nas-benchmark/{:}-split.txt\".format(dataset), None, None\n", | ||||
|     "            )\n", | ||||
|     "        else:\n", | ||||
|     "            raise ValueError(\"invalid dataset : {:}\".format(dataset))\n", | ||||
|     "        config = load_config(\n", | ||||
|     "            config_path, {\"class_num\": class_num, \"xshape\": xshape}, logger\n", | ||||
|     "        )\n", | ||||
|     "        # check whether use splited validation set\n", | ||||
|     "        if bool(split):\n", | ||||
|     "            assert dataset == \"cifar10\"\n", | ||||
|     "            ValLoaders = {\n", | ||||
|     "                \"ori-test\": torch.utils.data.DataLoader(\n", | ||||
|     "                    valid_data,\n", | ||||
|     "                    batch_size=config.batch_size,\n", | ||||
|     "                    shuffle=False,\n", | ||||
|     "                    num_workers=workers,\n", | ||||
|     "                    pin_memory=True,\n", | ||||
|     "                )\n", | ||||
|     "            }\n", | ||||
|     "            assert len(train_data) == len(split_info.train) + len(\n", | ||||
|     "                split_info.valid\n", | ||||
|     "            ), \"invalid length : {:} vs {:} + {:}\".format(\n", | ||||
|     "                len(train_data), len(split_info.train), len(split_info.valid)\n", | ||||
|     "            )\n", | ||||
|     "            train_data_v2 = deepcopy(train_data)\n", | ||||
|     "            train_data_v2.transform = valid_data.transform\n", | ||||
|     "            valid_data = train_data_v2\n", | ||||
|     "            # data loader\n", | ||||
|     "            train_loader = torch.utils.data.DataLoader(\n", | ||||
|     "                train_data,\n", | ||||
|     "                batch_size=config.batch_size,\n", | ||||
|     "                sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),\n", | ||||
|     "                num_workers=workers,\n", | ||||
|     "                pin_memory=True,\n", | ||||
|     "            )\n", | ||||
|     "            valid_loader = torch.utils.data.DataLoader(\n", | ||||
|     "                valid_data,\n", | ||||
|     "                batch_size=config.batch_size,\n", | ||||
|     "                sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),\n", | ||||
|     "                num_workers=workers,\n", | ||||
|     "                pin_memory=True,\n", | ||||
|     "            )\n", | ||||
|     "            ValLoaders[\"x-valid\"] = valid_loader\n", | ||||
|     "        else:\n", | ||||
|     "            # data loader\n", | ||||
|     "            train_loader = torch.utils.data.DataLoader(\n", | ||||
|     "                train_data,\n", | ||||
|     "                batch_size=config.batch_size,\n", | ||||
|     "                shuffle=True,\n", | ||||
|     "                num_workers=workers,\n", | ||||
|     "                pin_memory=True,\n", | ||||
|     "            )\n", | ||||
|     "            valid_loader = torch.utils.data.DataLoader(\n", | ||||
|     "                valid_data,\n", | ||||
|     "                batch_size=config.batch_size,\n", | ||||
|     "                shuffle=False,\n", | ||||
|     "                num_workers=workers,\n", | ||||
|     "                pin_memory=True,\n", | ||||
|     "            )\n", | ||||
|     "            if dataset == \"cifar10\":\n", | ||||
|     "                ValLoaders = {\"ori-test\": valid_loader}\n", | ||||
|     "            elif dataset == \"cifar100\":\n", | ||||
|     "                cifar100_splits = load_config(\n", | ||||
|     "                    \"configs/nas-benchmark/cifar100-test-split.txt\", None, None\n", | ||||
|     "                )\n", | ||||
|     "                ValLoaders = {\n", | ||||
|     "                    \"ori-test\": valid_loader,\n", | ||||
|     "                    \"x-valid\": torch.utils.data.DataLoader(\n", | ||||
|     "                        valid_data,\n", | ||||
|     "                        batch_size=config.batch_size,\n", | ||||
|     "                        sampler=torch.utils.data.sampler.SubsetRandomSampler(\n", | ||||
|     "                            cifar100_splits.xvalid\n", | ||||
|     "                        ),\n", | ||||
|     "                        num_workers=workers,\n", | ||||
|     "                        pin_memory=True,\n", | ||||
|     "                    ),\n", | ||||
|     "                    \"x-test\": torch.utils.data.DataLoader(\n", | ||||
|     "                        valid_data,\n", | ||||
|     "                        batch_size=config.batch_size,\n", | ||||
|     "                        sampler=torch.utils.data.sampler.SubsetRandomSampler(\n", | ||||
|     "                            cifar100_splits.xtest\n", | ||||
|     "                        ),\n", | ||||
|     "                        num_workers=workers,\n", | ||||
|     "                        pin_memory=True,\n", | ||||
|     "                    ),\n", | ||||
|     "                }\n", | ||||
|     "            elif dataset == \"ImageNet16-120\":\n", | ||||
|     "                imagenet16_splits = load_config(\n", | ||||
|     "                    \"configs/nas-benchmark/imagenet-16-120-test-split.txt\", None, None\n", | ||||
|     "                )\n", | ||||
|     "                ValLoaders = {\n", | ||||
|     "                    \"ori-test\": valid_loader,\n", | ||||
|     "                    \"x-valid\": torch.utils.data.DataLoader(\n", | ||||
|     "                        valid_data,\n", | ||||
|     "                        batch_size=config.batch_size,\n", | ||||
|     "                        sampler=torch.utils.data.sampler.SubsetRandomSampler(\n", | ||||
|     "                            imagenet16_splits.xvalid\n", | ||||
|     "                        ),\n", | ||||
|     "                        num_workers=workers,\n", | ||||
|     "                        pin_memory=True,\n", | ||||
|     "                    ),\n", | ||||
|     "                    \"x-test\": torch.utils.data.DataLoader(\n", | ||||
|     "                        valid_data,\n", | ||||
|     "                        batch_size=config.batch_size,\n", | ||||
|     "                        sampler=torch.utils.data.sampler.SubsetRandomSampler(\n", | ||||
|     "                            imagenet16_splits.xtest\n", | ||||
|     "                        ),\n", | ||||
|     "                        num_workers=workers,\n", | ||||
|     "                        pin_memory=True,\n", | ||||
|     "                    ),\n", | ||||
|     "                }\n", | ||||
|     "            else:\n", | ||||
|     "                raise ValueError(\"invalid dataset : {:}\".format(dataset))\n", | ||||
|     "\n", | ||||
|     "        dataset_key = \"{:}\".format(dataset)\n", | ||||
|     "        if bool(split):\n", | ||||
|     "            dataset_key = dataset_key + \"-valid\"\n", | ||||
|     "        logger.log(\n", | ||||
|     "            \"Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}\".format(\n", | ||||
|     "                dataset_key,\n", | ||||
|     "                len(train_data),\n", | ||||
|     "                len(valid_data),\n", | ||||
|     "                len(train_loader),\n", | ||||
|     "                len(valid_loader),\n", | ||||
|     "                config.batch_size,\n", | ||||
|     "            )\n", | ||||
|     "        )\n", | ||||
|     "        logger.log(\n", | ||||
|     "            \"Evaluate ||||||| {:10s} ||||||| Config={:}\".format(dataset_key, config)\n", | ||||
|     "        )\n", | ||||
|     "        for key, value in ValLoaders.items():\n", | ||||
|     "            logger.log(\n", | ||||
|     "                \"Evaluate ---->>>> {:10s} with {:} batchs\".format(key, len(value))\n", | ||||
|     "            )\n", | ||||
|     "        results = evaluate_for_seed(\n", | ||||
|     "            arch_config, config, arch, train_loader, ValLoaders, seed, logger\n", | ||||
|     "        )\n", | ||||
|     "        all_infos[dataset_key] = results\n", | ||||
|     "        all_dataset_keys.append(dataset_key)\n", | ||||
|     "    all_infos[\"all_dataset_keys\"] = all_dataset_keys\n", | ||||
|     "    return all_infos\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def train_single_model(\n", | ||||
|     "    save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config\n", | ||||
|     "):\n", | ||||
|     "    assert torch.cuda.is_available(), \"CUDA is not available.\"\n", | ||||
|     "    torch.backends.cudnn.enabled = True\n", | ||||
|     "    torch.backends.cudnn.deterministic = True\n", | ||||
|     "    # torch.backends.cudnn.benchmark = True\n", | ||||
|     "    torch.set_num_threads(workers)\n", | ||||
|     "\n", | ||||
|     "    save_dir = (\n", | ||||
|     "        Path(save_dir)\n", | ||||
|     "        / \"specifics\"\n", | ||||
|     "        / \"{:}-{:}-{:}-{:}\".format(\n", | ||||
|     "            \"LESS\" if use_less else \"FULL\",\n", | ||||
|     "            model_str,\n", | ||||
|     "            arch_config[\"channel\"],\n", | ||||
|     "            arch_config[\"num_cells\"],\n", | ||||
|     "        )\n", | ||||
|     "    )\n", | ||||
|     "    logger = Logger(str(save_dir), 0, False)\n", | ||||
|     "    if model_str in CellArchitectures:\n", | ||||
|     "        arch = CellArchitectures[model_str]\n", | ||||
|     "        logger.log(\n", | ||||
|     "            \"The model string is found in pre-defined architecture dict : {:}\".format(\n", | ||||
|     "                model_str\n", | ||||
|     "            )\n", | ||||
|     "        )\n", | ||||
|     "    else:\n", | ||||
|     "        try:\n", | ||||
|     "            arch = CellStructure.str2structure(model_str)\n", | ||||
|     "        except:\n", | ||||
|     "            raise ValueError(\n", | ||||
|     "                \"Invalid model string : {:}. It can not be found or parsed.\".format(\n", | ||||
|     "                    model_str\n", | ||||
|     "                )\n", | ||||
|     "            )\n", | ||||
|     "    assert arch.check_valid_op(\n", | ||||
|     "        get_search_spaces(\"cell\", \"full\")\n", | ||||
|     "    ), \"{:} has the invalid op.\".format(arch)\n", | ||||
|     "    logger.log(\"Start train-evaluate {:}\".format(arch.tostr()))\n", | ||||
|     "    logger.log(\"arch_config : {:}\".format(arch_config))\n", | ||||
|     "\n", | ||||
|     "    start_time, seed_time = time.time(), AverageMeter()\n", | ||||
|     "    for _is, seed in enumerate(seeds):\n", | ||||
|     "        logger.log(\n", | ||||
|     "            \"\\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------\".format(\n", | ||||
|     "                _is, len(seeds), seed\n", | ||||
|     "            )\n", | ||||
|     "        )\n", | ||||
|     "        to_save_name = save_dir / \"seed-{:04d}.pth\".format(seed)\n", | ||||
|     "        if to_save_name.exists():\n", | ||||
|     "            logger.log(\n", | ||||
|     "                \"Find the existing file {:}, directly load!\".format(to_save_name)\n", | ||||
|     "            )\n", | ||||
|     "            checkpoint = torch.load(to_save_name)\n", | ||||
|     "        else:\n", | ||||
|     "            logger.log(\n", | ||||
|     "                \"Does not find the existing file {:}, train and evaluate!\".format(\n", | ||||
|     "                    to_save_name\n", | ||||
|     "                )\n", | ||||
|     "            )\n", | ||||
|     "            checkpoint = evaluate_all_datasets(\n", | ||||
|     "                arch,\n", | ||||
|     "                datasets,\n", | ||||
|     "                xpaths,\n", | ||||
|     "                splits,\n", | ||||
|     "                use_less,\n", | ||||
|     "                seed,\n", | ||||
|     "                arch_config,\n", | ||||
|     "                workers,\n", | ||||
|     "                logger,\n", | ||||
|     "            )\n", | ||||
|     "            torch.save(checkpoint, to_save_name)\n", | ||||
|     "        # log information\n", | ||||
|     "        logger.log(\"{:}\".format(checkpoint[\"info\"]))\n", | ||||
|     "        all_dataset_keys = checkpoint[\"all_dataset_keys\"]\n", | ||||
|     "        for dataset_key in all_dataset_keys:\n", | ||||
|     "            logger.log(\n", | ||||
|     "                \"\\n{:} dataset : {:} {:}\".format(\"-\" * 15, dataset_key, \"-\" * 15)\n", | ||||
|     "            )\n", | ||||
|     "            dataset_info = checkpoint[dataset_key]\n", | ||||
|     "            # logger.log('Network ==>\\n{:}'.format( dataset_info['net_string'] ))\n", | ||||
|     "            logger.log(\n", | ||||
|     "                \"Flops = {:} MB, Params = {:} MB\".format(\n", | ||||
|     "                    dataset_info[\"flop\"], dataset_info[\"param\"]\n", | ||||
|     "                )\n", | ||||
|     "            )\n", | ||||
|     "            logger.log(\"config : {:}\".format(dataset_info[\"config\"]))\n", | ||||
|     "            logger.log(\n", | ||||
|     "                \"Training State (finish) = {:}\".format(dataset_info[\"finish-train\"])\n", | ||||
|     "            )\n", | ||||
|     "            last_epoch = dataset_info[\"total_epoch\"] - 1\n", | ||||
|     "            train_acc1es, train_acc5es = (\n", | ||||
|     "                dataset_info[\"train_acc1es\"],\n", | ||||
|     "                dataset_info[\"train_acc5es\"],\n", | ||||
|     "            )\n", | ||||
|     "            valid_acc1es, valid_acc5es = (\n", | ||||
|     "                dataset_info[\"valid_acc1es\"],\n", | ||||
|     "                dataset_info[\"valid_acc5es\"],\n", | ||||
|     "            )\n", | ||||
|     "            logger.log(\n", | ||||
|     "                \"Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%\".format(\n", | ||||
|     "                    train_acc1es[last_epoch],\n", | ||||
|     "                    train_acc5es[last_epoch],\n", | ||||
|     "                    100 - train_acc1es[last_epoch],\n", | ||||
|     "                    valid_acc1es[last_epoch],\n", | ||||
|     "                    valid_acc5es[last_epoch],\n", | ||||
|     "                    100 - valid_acc1es[last_epoch],\n", | ||||
|     "                )\n", | ||||
|     "            )\n", | ||||
|     "        # measure elapsed time\n", | ||||
|     "        seed_time.update(time.time() - start_time)\n", | ||||
|     "        start_time = time.time()\n", | ||||
|     "        need_time = \"Time Left: {:}\".format(\n", | ||||
|     "            convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True)\n", | ||||
|     "        )\n", | ||||
|     "        logger.log(\n", | ||||
|     "            \"\\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}\".format(\n", | ||||
|     "                _is, len(seeds), seed, need_time\n", | ||||
|     "            )\n", | ||||
|     "        )\n", | ||||
|     "    logger.close()\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "train_single_model(\n", | ||||
|     "    save_dir=\"./outputs\",\n", | ||||
|     "    workers=8,\n", | ||||
|     "    datasets=\"cifar10\", \n", | ||||
|     "    xpaths=\"/root/cifardata/cifar-10-batches-py\",\n", | ||||
|     "    splits=[0, 0, 0],\n", | ||||
|     "    use_less=False,\n", | ||||
|     "    seeds=[777],\n", | ||||
|     "    model_str=\"|nor_conv_3x3~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|\",\n", | ||||
|     "    arch_config={\"channel\": 16, \"num_cells\": 8},)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "metadata": {}, | ||||
|    "source": [] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "natsbench", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.9.19" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 2 | ||||
| } | ||||
| @@ -347,6 +347,10 @@ class GenericNAS201Model(nn.Module): | ||||
|                     feature = cell.forward_gdas(feature, alphas, index) | ||||
|                     if self.verbose: | ||||
|                         verbose_str += "-forward_gdas" | ||||
|                 elif self.mode == "gdas_v1": | ||||
|                     feature = cell.forward_gdas_v1(feature, alphas, index) | ||||
|                     if self.verbose: | ||||
|                         verbose_str += "-forward_gdas_v1" | ||||
|                 else: | ||||
|                     raise ValueError("invalid mode={:}".format(self.mode)) | ||||
|             else: | ||||
|   | ||||
| @@ -213,6 +213,13 @@ AllConv3x3_CODE = Structure( | ||||
|         (("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)), | ||||
|     ]  # node-3 | ||||
| ) | ||||
| Number_5374 = Structure( | ||||
|     [ | ||||
|         (("nor_conv_3x3", 0),),  # node-1 | ||||
|         (("nor_conv_1x1", 0), ("nor_conv_3x3", 1)),  # node-2 | ||||
|         (("skip_connect", 0), ("none", 1), ("nor_conv_3x3", 2)),  # node-3 | ||||
|     ] | ||||
| ) | ||||
|  | ||||
| AllFull_CODE = Structure( | ||||
|     [ | ||||
| @@ -271,4 +278,5 @@ architectures = { | ||||
|     "all_c1x1": AllConv1x1_CODE, | ||||
|     "all_idnt": AllIdentity_CODE, | ||||
|     "all_full": AllFull_CODE, | ||||
|     "5374": Number_5374, | ||||
| } | ||||
|   | ||||
| @@ -85,6 +85,20 @@ class NAS201SearchCell(nn.Module): | ||||
|             nodes.append(sum(inter_nodes)) | ||||
|         return nodes[-1] | ||||
|  | ||||
|     # GDAS Variant: https://github.com/D-X-Y/AutoDL-Projects/issues/119 | ||||
|     def forward_gdas_v1(self, inputs, hardwts, index): | ||||
|         nodes = [inputs] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             inter_nodes = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 weights = hardwts[self.edge2index[node_str]] | ||||
|                 argmaxs = index[self.edge2index[node_str]].item() | ||||
|                 weigsum = weights[argmaxs] * self.edges[node_str](nodes[j]) | ||||
|                 inter_nodes.append(weigsum) | ||||
|             nodes.append(sum(inter_nodes)) | ||||
|         return nodes[-1] | ||||
|  | ||||
|     # joint | ||||
|     def forward_joint(self, inputs, weightss): | ||||
|         nodes = [inputs] | ||||
| @@ -152,6 +166,9 @@ class NAS201SearchCell(nn.Module): | ||||
|         return nodes[-1] | ||||
|  | ||||
|  | ||||
| # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 | ||||
|  | ||||
|  | ||||
| class MixedOp(nn.Module): | ||||
|     def __init__(self, space, C, stride, affine, track_running_stats): | ||||
|         super(MixedOp, self).__init__() | ||||
| @@ -167,7 +184,6 @@ class MixedOp(nn.Module): | ||||
|         return sum(w * op(x) for w, op in zip(weights, self._ops)) | ||||
|  | ||||
|  | ||||
| # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 | ||||
| class NASNetSearchCell(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|   | ||||
| @@ -12,6 +12,7 @@ def obtain_accuracy(output, target, topk=(1,)): | ||||
|  | ||||
|     res = [] | ||||
|     for k in topk: | ||||
|         correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||||
|         # correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||||
|         correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) | ||||
|         res.append(correct_k.mul_(100.0 / batch_size)) | ||||
|     return res | ||||
|   | ||||
		Reference in New Issue
	
	Block a user