update README
This commit is contained in:
		| @@ -22,6 +22,8 @@ from utils import get_model_infos | |||||||
| flop, param  = get_model_infos(net, (1,3,32,32)) | flop, param  = get_model_infos(net, (1,3,32,32)) | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
|  | 2. Different NAS-searched architectures are defined [here](https://github.com/D-X-Y/NAS-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py). | ||||||
|  |  | ||||||
|  |  | ||||||
| ## [Network Pruning via Transformable Architecture Search](https://arxiv.org/abs/1905.09717) | ## [Network Pruning via Transformable Architecture Search](https://arxiv.org/abs/1905.09717) | ||||||
| In this paper, we proposed a differentiable searching strategy for transformable architectures, i.e., searching for the depth and width of a deep neural network. | In this paper, we proposed a differentiable searching strategy for transformable architectures, i.e., searching for the depth and width of a deep neural network. | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								exps/vis/random-nn.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								exps/vis/random-nn.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | |||||||
|  | from graphviz import Digraph | ||||||
							
								
								
									
										67
									
								
								exps/vis/show-results.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								exps/vis/show-results.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,67 @@ | |||||||
|  | # python ./vis-exps/show-results.py | ||||||
|  | import os, sys | ||||||
|  | from pathlib import Path | ||||||
|  | import torch | ||||||
|  | import numpy as np | ||||||
|  | from collections import OrderedDict | ||||||
|  | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
|  | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
|  | from aa_nas_api   import AANASBenchAPI | ||||||
|  |  | ||||||
|  | api = AANASBenchAPI('./output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth') | ||||||
|  |  | ||||||
|  | def plot_results_nas(dataset, xset, file_name, y_lims): | ||||||
|  |   import matplotlib | ||||||
|  |   matplotlib.use('agg') | ||||||
|  |   import matplotlib.pyplot as plt | ||||||
|  |   root = Path('./output/cell-search-tiny-vis').resolve() | ||||||
|  |   print ('root path : {:}'.format( root )) | ||||||
|  |   root.mkdir(parents=True, exist_ok=True) | ||||||
|  |   checkpoints = ['./output/cell-search-tiny/R-EA-cifar10/results.pth', | ||||||
|  |                  './output/cell-search-tiny/REINFORCE-cifar10/results.pth', | ||||||
|  |                  './output/cell-search-tiny/RAND-cifar10/results.pth', | ||||||
|  |                  './output/cell-search-tiny/BOHB-cifar10/results.pth' | ||||||
|  |                 ] | ||||||
|  |   legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None | ||||||
|  |   All_Accs = OrderedDict() | ||||||
|  |   for legend, checkpoint in zip(legends, checkpoints): | ||||||
|  |     all_indexes = torch.load(checkpoint, map_location='cpu') | ||||||
|  |     accuracies  = [] | ||||||
|  |     for x in all_indexes: | ||||||
|  |       info = api.arch2infos[ x ] | ||||||
|  |       _, accy = info.get_metrics(dataset, xset, None, False) | ||||||
|  |       accuracies.append( accy ) | ||||||
|  |     if indexes is None: indexes = list(range(len(all_indexes))) | ||||||
|  |     All_Accs[legend] = sorted(accuracies) | ||||||
|  |    | ||||||
|  |   color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] | ||||||
|  |   dpi, width, height = 300, 3400, 2600 | ||||||
|  |   LabelSize, LegendFontsize = 26, 26 | ||||||
|  |   figsize = width / float(dpi), height / float(dpi) | ||||||
|  |   fig = plt.figure(figsize=figsize) | ||||||
|  |   x_axis = np.arange(0, 600) | ||||||
|  |   plt.xlim(0, max(indexes)) | ||||||
|  |   plt.ylim(y_lims[0], y_lims[1]) | ||||||
|  |   interval_x, interval_y = 100, y_lims[2] | ||||||
|  |   plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize) | ||||||
|  |   plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize) | ||||||
|  |   plt.grid() | ||||||
|  |   plt.xlabel('The index of runs', fontsize=LabelSize) | ||||||
|  |   plt.ylabel('The accuracy (%)', fontsize=LabelSize) | ||||||
|  |  | ||||||
|  |   for idx, legend in enumerate(legends): | ||||||
|  |     plt.plot(indexes, All_Accs[legend], color=color_set[idx], linestyle='-', label='{:}'.format(legend), lw=2) | ||||||
|  |     print ('{:} : mean = {:}, std = {:}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend]))) | ||||||
|  |   plt.legend(loc=4, fontsize=LegendFontsize) | ||||||
|  |   save_path = root / '{:}-{:}-{:}'.format(dataset, xset, file_name) | ||||||
|  |   print('save figure into {:}\n'.format(save_path)) | ||||||
|  |   fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |   plot_results_nas('cifar10', 'ori-test', 'nas-com.pdf', (85,95, 1)) | ||||||
|  |   plot_results_nas('cifar100', 'x-valid', 'nas-com.pdf', (55,75, 3)) | ||||||
|  |   plot_results_nas('cifar100', 'x-test' , 'nas-com.pdf', (55,75, 3)) | ||||||
|  |   plot_results_nas('ImageNet16-120', 'x-valid', 'nas-com.pdf', (35,50, 3)) | ||||||
|  |   plot_results_nas('ImageNet16-120', 'x-test' , 'nas-com.pdf', (35,50, 3)) | ||||||
| @@ -74,18 +74,13 @@ class Structure: | |||||||
|       nodes[i+1] = sum(sums) > 0 |       nodes[i+1] = sum(sums) > 0 | ||||||
|     return nodes[len(self.nodes)] |     return nodes[len(self.nodes)] | ||||||
|  |  | ||||||
|   def to_unique_str(self, consider_zero=False): |   def to_unique_str(self): | ||||||
|     # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation |     # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation | ||||||
|     # two operations are special, i.e., none and skip_connect |     # two operations are special, i.e., none and skip_connect | ||||||
|     nodes = {0: '0'} |     nodes = {0: '0'} | ||||||
|     for i_node, node_info in enumerate(self.nodes): |     for i_node, node_info in enumerate(self.nodes): | ||||||
|       cur_node = [] |       cur_node = [] | ||||||
|       for op, xin in node_info: |       for op, xin in node_info: | ||||||
|         if consider_zero: |  | ||||||
|           if op == 'none' or nodes[xin] == '#': x = '#' # zero |  | ||||||
|           elif op == 'skip_connect': x = nodes[xin] |  | ||||||
|           else: x = '('+nodes[xin]+')' + '@{:}'.format(op) |  | ||||||
|         else: |  | ||||||
|         if op == 'skip_connect': x = nodes[xin] |         if op == 'skip_connect': x = nodes[xin] | ||||||
|         else: x = '('+nodes[xin]+')' + '@{:}'.format(op) |         else: x = '('+nodes[xin]+')' + '@{:}'.format(op) | ||||||
|         cur_node.append(x) |         cur_node.append(x) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user