update vis
This commit is contained in:
parent
23d5ee08ae
commit
dd6cf5a9c5
@ -18,6 +18,7 @@ The benchmark file of NAS-Bench-102 can be downloaded from [Google Drive](https:
|
|||||||
You can move it to anywhere you want and send its path to our API for initialization.
|
You can move it to anywhere you want and send its path to our API for initialization.
|
||||||
- v1.0: `NAS-Bench-102-v1_0-e61699.pth`, where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial.
|
- v1.0: `NAS-Bench-102-v1_0-e61699.pth`, where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial.
|
||||||
- v1.0: The full data of each architecture can be download from [Google Drive](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights.
|
- v1.0: The full data of each architecture can be download from [Google Drive](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights.
|
||||||
|
- v1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi).
|
||||||
|
|
||||||
The training and evaluation data used in NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ).
|
The training and evaluation data used in NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ).
|
||||||
It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default). If you want to generate NAS-Bench-102 or similar NAS datasets or training models by yourself, you need these data.
|
It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default). If you want to generate NAS-Bench-102 or similar NAS datasets or training models by yourself, you need these data.
|
||||||
|
@ -464,18 +464,17 @@ def just_show(api):
|
|||||||
print ('[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}'.format(dataset, metric_on_set, arch_index, highest_acc))
|
print ('[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}'.format(dataset, metric_on_set, arch_index, highest_acc))
|
||||||
|
|
||||||
|
|
||||||
def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims):
|
def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_maxs):
|
||||||
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
|
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
|
||||||
dpi, width, height = 300, 3400, 2600
|
dpi, width, height = 300, 3400, 2600
|
||||||
LabelSize, LegendFontsize = 28, 28
|
LabelSize, LegendFontsize = 28, 28
|
||||||
figsize = width / float(dpi), height / float(dpi)
|
figsize = width / float(dpi), height / float(dpi)
|
||||||
fig = plt.figure(figsize=figsize)
|
fig = plt.figure(figsize=figsize)
|
||||||
x_maxs = 250
|
#x_maxs = 250
|
||||||
x_axis = np.arange(0, x_maxs)
|
plt.xlim(0, x_maxs+1)
|
||||||
plt.xlim(0, x_maxs)
|
|
||||||
plt.ylim(y_lims[0], y_lims[1])
|
plt.ylim(y_lims[0], y_lims[1])
|
||||||
interval_x, interval_y = x_maxs // 5, y_lims[2]
|
interval_x, interval_y = x_maxs // 5, y_lims[2]
|
||||||
plt.xticks(np.arange(0, x_maxs, interval_x), fontsize=LegendFontsize)
|
plt.xticks(np.arange(0, x_maxs+1, interval_x), fontsize=LegendFontsize)
|
||||||
plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
|
plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
|
||||||
plt.grid()
|
plt.grid()
|
||||||
plt.xlabel('The searching epoch', fontsize=LabelSize)
|
plt.xlabel('The searching epoch', fontsize=LabelSize)
|
||||||
@ -505,17 +504,24 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims):
|
|||||||
xresults.append( metrics['accuracy'] )
|
xresults.append( metrics['accuracy'] )
|
||||||
return xresults
|
return xresults
|
||||||
|
|
||||||
for idx, method in enumerate(['RSPS', 'GDAS', 'SETN', 'ENAS']):
|
if x_maxs == 50:
|
||||||
|
xox, xxxstrs = 'v2', ['DARTS-V1', 'DARTS-V2']
|
||||||
|
elif x_maxs == 250:
|
||||||
|
xox, xxxstrs = 'v1', ['RSPS', 'GDAS', 'SETN', 'ENAS']
|
||||||
|
else: raise ValueError('invalid x_maxs={:}'.format(x_maxs))
|
||||||
|
|
||||||
|
for idx, method in enumerate(xxxstrs):
|
||||||
xkey = method
|
xkey = method
|
||||||
all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ]
|
all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ]
|
||||||
all_datas = [torch.load(xpath) for xpath in all_paths]
|
all_datas = [torch.load(xpath, map_location='cpu') for xpath in all_paths]
|
||||||
accyss = [get_accs(xdatas) for xdatas in all_datas]
|
accyss = [get_accs(xdatas) for xdatas in all_datas]
|
||||||
accyss = np.array( accyss )
|
accyss = np.array( accyss )
|
||||||
epochs = list(range(accyss.shape[1]))
|
epochs = list(range(accyss.shape[1]))
|
||||||
plt.plot(epochs, [accyss[:,i].mean() for i in epochs], color=color_set[idx], linestyle='-', label='{:}'.format(method), lw=2)
|
plt.plot(epochs, [accyss[:,i].mean() for i in epochs], color=color_set[idx], linestyle='-', label='{:}'.format(method), lw=2)
|
||||||
plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color_set[idx])
|
plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color_set[idx])
|
||||||
plt.legend(loc=4, fontsize=LegendFontsize)
|
#plt.legend(loc=4, fontsize=LegendFontsize)
|
||||||
save_path = vis_save_dir / '{:}-{:}-{:}'.format(dataset, subset, file_name)
|
plt.legend(loc=0, fontsize=LegendFontsize)
|
||||||
|
save_path = vis_save_dir / '{:}-{:}-{:}-{:}'.format(xox, dataset, subset, file_name)
|
||||||
print('save figure into {:}\n'.format(save_path))
|
print('save figure into {:}\n'.format(save_path))
|
||||||
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
|
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
|
||||||
|
|
||||||
@ -540,7 +546,13 @@ if __name__ == '__main__':
|
|||||||
#visualize_relative_ranking(vis_save_dir)
|
#visualize_relative_ranking(vis_save_dir)
|
||||||
|
|
||||||
api = API(args.api_path)
|
api = API(args.api_path)
|
||||||
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (5,95,10))
|
for x_maxs in [50, 250]:
|
||||||
|
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||||
|
show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||||
|
show_nas_sharing_w(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||||
|
show_nas_sharing_w(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||||
|
show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||||
|
show_nas_sharing_w(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||||
"""
|
"""
|
||||||
just_show(api)
|
just_show(api)
|
||||||
plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
# python ./exps/vis/test.py
|
# python ./exps/vis/test.py
|
||||||
import os, sys
|
import os, sys, random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||||
|
from graphviz import Digraph
|
||||||
|
|
||||||
|
|
||||||
def test_nas_api():
|
def test_nas_api():
|
||||||
@ -23,5 +24,35 @@ def test_nas_api():
|
|||||||
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True))
|
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True))
|
||||||
print(archRes.query('cifar10-valid', 777))
|
print(archRes.query('cifar10-valid', 777))
|
||||||
|
|
||||||
|
|
||||||
|
OPS = ['skip-connect', 'conv-1x1', 'conv-3x3', 'pool-3x3']
|
||||||
|
COLORS = ['chartreuse' , 'cyan' , 'navyblue', 'chocolate1']
|
||||||
|
|
||||||
|
def plot(filename):
|
||||||
|
g = Digraph(
|
||||||
|
format='png',
|
||||||
|
edge_attr=dict(fontsize='20', fontname="times"),
|
||||||
|
node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"),
|
||||||
|
engine='dot')
|
||||||
|
g.body.extend(['rankdir=LR'])
|
||||||
|
|
||||||
|
steps = 5
|
||||||
|
for i in range(0, steps):
|
||||||
|
if i == 0:
|
||||||
|
g.node(str(i), fillcolor='darkseagreen2')
|
||||||
|
elif i+1 == steps:
|
||||||
|
g.node(str(i), fillcolor='palegoldenrod')
|
||||||
|
else: g.node(str(i), fillcolor='lightblue')
|
||||||
|
|
||||||
|
for i in range(1, steps):
|
||||||
|
for xin in range(i):
|
||||||
|
op_i = random.randint(0, len(OPS)-1)
|
||||||
|
#g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i])
|
||||||
|
g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i])
|
||||||
|
#import pdb; pdb.set_trace()
|
||||||
|
g.render(filename, cleanup=True, view=False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_nas_api()
|
test_nas_api()
|
||||||
|
for i in range(200): plot('{:04d}'.format(i))
|
||||||
|
Loading…
Reference in New Issue
Block a user