Update visualization codes for NATS-Bench

This commit is contained in:
D-X-Y 2020-12-02 21:43:35 +08:00
parent bda30c7098
commit c8ca1790e9

View File

@ -25,6 +25,7 @@ from config_utils import dict2config, load_config
from nats_bench import create
from log_utils import time_string
plt.rcParams.update({
"text.usetex": True,
"font.family": "sans-serif",
@ -46,7 +47,7 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
if search_space == 'tss':
hp = '$\mathcal{H}^{1}$'
if dataset == 'cifar10':
suffixes = ['-T800000', '-T800000-FULL']
suffixes = ['-T1200000', '-T1200000-FULL']
elif search_space == 'sss':
hp = '$\mathcal{H}^{2}$'
if dataset == 'cifar10':
@ -100,7 +101,7 @@ y_max_s = {('cifar10', 'tss'): 94.5,
('ImageNet16-120', 'tss'): 46,
('ImageNet16-120', 'sss'): 46}
x_axis_s = {('cifar10', 'tss'): 800000,
x_axis_s = {('cifar10', 'tss'): 1200000,
('cifar10', 'sss'): 200000,
('cifar100', 'tss'): 400,
('cifar100', 'sss'): 400,
@ -114,6 +115,16 @@ name2label = {'cifar10': 'CIFAR-10',
spaces2latex = {'tss': r'$\mathcal{S}_{t}$',
'sss': r'$\mathcal{S}_{s}$',}
# FuncFormatter can be used as a decorator
@ticker.FuncFormatter
def major_formatter(x, pos):
if x == 0:
return '0'
else:
return "{:.2f}e5".format(x/1e5)
def visualize_curve(api_dict, vis_save_dir):
vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True)
@ -136,6 +147,7 @@ def visualize_curve(api_dict, vis_save_dir):
tick.set_fontsize(LabelSize - 6)
for tick in ax.get_yticklabels():
tick.set_fontsize(LabelSize - 6)
ax.xaxis.set_major_formatter(major_formatter)
for idx, (alg, xdata) in enumerate(alg2data.items()):
accuracies = []
for ticket in time_tickets: