Update visualization codes for NATS-Bench
This commit is contained in:
parent
bda30c7098
commit
c8ca1790e9
@ -25,6 +25,7 @@ from config_utils import dict2config, load_config
|
||||
from nats_bench import create
|
||||
from log_utils import time_string
|
||||
|
||||
|
||||
plt.rcParams.update({
|
||||
"text.usetex": True,
|
||||
"font.family": "sans-serif",
|
||||
@ -46,7 +47,7 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
|
||||
if search_space == 'tss':
|
||||
hp = '$\mathcal{H}^{1}$'
|
||||
if dataset == 'cifar10':
|
||||
suffixes = ['-T800000', '-T800000-FULL']
|
||||
suffixes = ['-T1200000', '-T1200000-FULL']
|
||||
elif search_space == 'sss':
|
||||
hp = '$\mathcal{H}^{2}$'
|
||||
if dataset == 'cifar10':
|
||||
@ -100,7 +101,7 @@ y_max_s = {('cifar10', 'tss'): 94.5,
|
||||
('ImageNet16-120', 'tss'): 46,
|
||||
('ImageNet16-120', 'sss'): 46}
|
||||
|
||||
x_axis_s = {('cifar10', 'tss'): 800000,
|
||||
x_axis_s = {('cifar10', 'tss'): 1200000,
|
||||
('cifar10', 'sss'): 200000,
|
||||
('cifar100', 'tss'): 400,
|
||||
('cifar100', 'sss'): 400,
|
||||
@ -114,6 +115,16 @@ name2label = {'cifar10': 'CIFAR-10',
|
||||
spaces2latex = {'tss': r'$\mathcal{S}_{t}$',
|
||||
'sss': r'$\mathcal{S}_{s}$',}
|
||||
|
||||
|
||||
# FuncFormatter can be used as a decorator
|
||||
@ticker.FuncFormatter
|
||||
def major_formatter(x, pos):
|
||||
if x == 0:
|
||||
return '0'
|
||||
else:
|
||||
return "{:.2f}e5".format(x/1e5)
|
||||
|
||||
|
||||
def visualize_curve(api_dict, vis_save_dir):
|
||||
vis_save_dir = vis_save_dir.resolve()
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
@ -136,6 +147,7 @@ def visualize_curve(api_dict, vis_save_dir):
|
||||
tick.set_fontsize(LabelSize - 6)
|
||||
for tick in ax.get_yticklabels():
|
||||
tick.set_fontsize(LabelSize - 6)
|
||||
ax.xaxis.set_major_formatter(major_formatter)
|
||||
for idx, (alg, xdata) in enumerate(alg2data.items()):
|
||||
accuracies = []
|
||||
for ticket in time_tickets:
|
||||
|
Loading…
Reference in New Issue
Block a user