update README
This commit is contained in:
parent
5bd503aed4
commit
487fec21bf
@ -22,6 +22,8 @@ from utils import get_model_infos
|
||||
flop, param = get_model_infos(net, (1,3,32,32))
|
||||
```
|
||||
|
||||
2. Different NAS-searched architectures are defined [here](https://github.com/D-X-Y/NAS-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py).
|
||||
|
||||
|
||||
## [Network Pruning via Transformable Architecture Search](https://arxiv.org/abs/1905.09717)
|
||||
In this paper, we proposed a differentiable searching strategy for transformable architectures, i.e., searching for the depth and width of a deep neural network.
|
||||
|
1
exps/vis/random-nn.py
Normal file
1
exps/vis/random-nn.py
Normal file
@ -0,0 +1 @@
|
||||
from graphviz import Digraph
|
67
exps/vis/show-results.py
Normal file
67
exps/vis/show-results.py
Normal file
@ -0,0 +1,67 @@
|
||||
# python ./vis-exps/show-results.py
|
||||
import os, sys
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from aa_nas_api import AANASBenchAPI
|
||||
|
||||
api = AANASBenchAPI('./output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth')
|
||||
|
||||
def plot_results_nas(dataset, xset, file_name, y_lims):
|
||||
import matplotlib
|
||||
matplotlib.use('agg')
|
||||
import matplotlib.pyplot as plt
|
||||
root = Path('./output/cell-search-tiny-vis').resolve()
|
||||
print ('root path : {:}'.format( root ))
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
checkpoints = ['./output/cell-search-tiny/R-EA-cifar10/results.pth',
|
||||
'./output/cell-search-tiny/REINFORCE-cifar10/results.pth',
|
||||
'./output/cell-search-tiny/RAND-cifar10/results.pth',
|
||||
'./output/cell-search-tiny/BOHB-cifar10/results.pth'
|
||||
]
|
||||
legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None
|
||||
All_Accs = OrderedDict()
|
||||
for legend, checkpoint in zip(legends, checkpoints):
|
||||
all_indexes = torch.load(checkpoint, map_location='cpu')
|
||||
accuracies = []
|
||||
for x in all_indexes:
|
||||
info = api.arch2infos[ x ]
|
||||
_, accy = info.get_metrics(dataset, xset, None, False)
|
||||
accuracies.append( accy )
|
||||
if indexes is None: indexes = list(range(len(all_indexes)))
|
||||
All_Accs[legend] = sorted(accuracies)
|
||||
|
||||
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
|
||||
dpi, width, height = 300, 3400, 2600
|
||||
LabelSize, LegendFontsize = 26, 26
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
fig = plt.figure(figsize=figsize)
|
||||
x_axis = np.arange(0, 600)
|
||||
plt.xlim(0, max(indexes))
|
||||
plt.ylim(y_lims[0], y_lims[1])
|
||||
interval_x, interval_y = 100, y_lims[2]
|
||||
plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize)
|
||||
plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
|
||||
plt.grid()
|
||||
plt.xlabel('The index of runs', fontsize=LabelSize)
|
||||
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
|
||||
|
||||
for idx, legend in enumerate(legends):
|
||||
plt.plot(indexes, All_Accs[legend], color=color_set[idx], linestyle='-', label='{:}'.format(legend), lw=2)
|
||||
print ('{:} : mean = {:}, std = {:}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend])))
|
||||
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||
save_path = root / '{:}-{:}-{:}'.format(dataset, xset, file_name)
|
||||
print('save figure into {:}\n'.format(save_path))
|
||||
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
plot_results_nas('cifar10', 'ori-test', 'nas-com.pdf', (85,95, 1))
|
||||
plot_results_nas('cifar100', 'x-valid', 'nas-com.pdf', (55,75, 3))
|
||||
plot_results_nas('cifar100', 'x-test' , 'nas-com.pdf', (55,75, 3))
|
||||
plot_results_nas('ImageNet16-120', 'x-valid', 'nas-com.pdf', (35,50, 3))
|
||||
plot_results_nas('ImageNet16-120', 'x-test' , 'nas-com.pdf', (35,50, 3))
|
@ -74,20 +74,15 @@ class Structure:
|
||||
nodes[i+1] = sum(sums) > 0
|
||||
return nodes[len(self.nodes)]
|
||||
|
||||
def to_unique_str(self, consider_zero=False):
|
||||
def to_unique_str(self):
|
||||
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
|
||||
# two operations are special, i.e., none and skip_connect
|
||||
nodes = {0: '0'}
|
||||
for i_node, node_info in enumerate(self.nodes):
|
||||
cur_node = []
|
||||
for op, xin in node_info:
|
||||
if consider_zero:
|
||||
if op == 'none' or nodes[xin] == '#': x = '#' # zero
|
||||
elif op == 'skip_connect': x = nodes[xin]
|
||||
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||
else:
|
||||
if op == 'skip_connect': x = nodes[xin]
|
||||
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||
if op == 'skip_connect': x = nodes[xin]
|
||||
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||
cur_node.append(x)
|
||||
nodes[i_node+1] = '+'.join( sorted(cur_node) )
|
||||
return nodes[ len(self.nodes) ]
|
||||
|
Loading…
Reference in New Issue
Block a user