From 487fec21bf13f133bf4bc4d46364c38266aaf330 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 21 Nov 2019 00:52:17 +1100 Subject: [PATCH] update README --- README.md | 2 + exps/vis/random-nn.py | 1 + exps/vis/show-results.py | 67 ++++++++++++++++++++++++++++ lib/models/cell_searchs/genotypes.py | 11 ++--- 4 files changed, 73 insertions(+), 8 deletions(-) create mode 100644 exps/vis/random-nn.py create mode 100644 exps/vis/show-results.py diff --git a/README.md b/README.md index 0f036e4..bc9348d 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,8 @@ from utils import get_model_infos flop, param = get_model_infos(net, (1,3,32,32)) ``` +2. Different NAS-searched architectures are defined [here](https://github.com/D-X-Y/NAS-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py). + ## [Network Pruning via Transformable Architecture Search](https://arxiv.org/abs/1905.09717) In this paper, we proposed a differentiable searching strategy for transformable architectures, i.e., searching for the depth and width of a deep neural network. diff --git a/exps/vis/random-nn.py b/exps/vis/random-nn.py new file mode 100644 index 0000000..e441dde --- /dev/null +++ b/exps/vis/random-nn.py @@ -0,0 +1 @@ +from graphviz import Digraph diff --git a/exps/vis/show-results.py b/exps/vis/show-results.py new file mode 100644 index 0000000..2231e92 --- /dev/null +++ b/exps/vis/show-results.py @@ -0,0 +1,67 @@ +# python ./vis-exps/show-results.py +import os, sys +from pathlib import Path +import torch +import numpy as np +from collections import OrderedDict +lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() +if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +from aa_nas_api import AANASBenchAPI + +api = AANASBenchAPI('./output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth') + +def plot_results_nas(dataset, xset, file_name, y_lims): + import matplotlib + matplotlib.use('agg') + import matplotlib.pyplot as plt + root = Path('./output/cell-search-tiny-vis').resolve() + print ('root path : {:}'.format( root )) + root.mkdir(parents=True, exist_ok=True) + checkpoints = ['./output/cell-search-tiny/R-EA-cifar10/results.pth', + './output/cell-search-tiny/REINFORCE-cifar10/results.pth', + './output/cell-search-tiny/RAND-cifar10/results.pth', + './output/cell-search-tiny/BOHB-cifar10/results.pth' + ] + legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None + All_Accs = OrderedDict() + for legend, checkpoint in zip(legends, checkpoints): + all_indexes = torch.load(checkpoint, map_location='cpu') + accuracies = [] + for x in all_indexes: + info = api.arch2infos[ x ] + _, accy = info.get_metrics(dataset, xset, None, False) + accuracies.append( accy ) + if indexes is None: indexes = list(range(len(all_indexes))) + All_Accs[legend] = sorted(accuracies) + + color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] + dpi, width, height = 300, 3400, 2600 + LabelSize, LegendFontsize = 26, 26 + figsize = width / float(dpi), height / float(dpi) + fig = plt.figure(figsize=figsize) + x_axis = np.arange(0, 600) + plt.xlim(0, max(indexes)) + plt.ylim(y_lims[0], y_lims[1]) + interval_x, interval_y = 100, y_lims[2] + plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize) + plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize) + plt.grid() + plt.xlabel('The index of runs', fontsize=LabelSize) + plt.ylabel('The accuracy (%)', fontsize=LabelSize) + + for idx, legend in enumerate(legends): + plt.plot(indexes, All_Accs[legend], color=color_set[idx], linestyle='-', label='{:}'.format(legend), lw=2) + print ('{:} : mean = {:}, std = {:}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend]))) + plt.legend(loc=4, fontsize=LegendFontsize) + save_path = root / '{:}-{:}-{:}'.format(dataset, xset, file_name) + print('save figure into {:}\n'.format(save_path)) + fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') + + +if __name__ == '__main__': + plot_results_nas('cifar10', 'ori-test', 'nas-com.pdf', (85,95, 1)) + plot_results_nas('cifar100', 'x-valid', 'nas-com.pdf', (55,75, 3)) + plot_results_nas('cifar100', 'x-test' , 'nas-com.pdf', (55,75, 3)) + plot_results_nas('ImageNet16-120', 'x-valid', 'nas-com.pdf', (35,50, 3)) + plot_results_nas('ImageNet16-120', 'x-test' , 'nas-com.pdf', (35,50, 3)) diff --git a/lib/models/cell_searchs/genotypes.py b/lib/models/cell_searchs/genotypes.py index e0f2e2e..52f0e25 100644 --- a/lib/models/cell_searchs/genotypes.py +++ b/lib/models/cell_searchs/genotypes.py @@ -74,20 +74,15 @@ class Structure: nodes[i+1] = sum(sums) > 0 return nodes[len(self.nodes)] - def to_unique_str(self, consider_zero=False): + def to_unique_str(self): # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation # two operations are special, i.e., none and skip_connect nodes = {0: '0'} for i_node, node_info in enumerate(self.nodes): cur_node = [] for op, xin in node_info: - if consider_zero: - if op == 'none' or nodes[xin] == '#': x = '#' # zero - elif op == 'skip_connect': x = nodes[xin] - else: x = '('+nodes[xin]+')' + '@{:}'.format(op) - else: - if op == 'skip_connect': x = nodes[xin] - else: x = '('+nodes[xin]+')' + '@{:}'.format(op) + if op == 'skip_connect': x = nodes[xin] + else: x = '('+nodes[xin]+')' + '@{:}'.format(op) cur_node.append(x) nodes[i_node+1] = '+'.join( sorted(cur_node) ) return nodes[ len(self.nodes) ]