Update visualization codes for NATS-Bench
This commit is contained in:
parent
bda30c7098
commit
c8ca1790e9
@ -25,6 +25,7 @@ from config_utils import dict2config, load_config
|
|||||||
from nats_bench import create
|
from nats_bench import create
|
||||||
from log_utils import time_string
|
from log_utils import time_string
|
||||||
|
|
||||||
|
|
||||||
plt.rcParams.update({
|
plt.rcParams.update({
|
||||||
"text.usetex": True,
|
"text.usetex": True,
|
||||||
"font.family": "sans-serif",
|
"font.family": "sans-serif",
|
||||||
@ -46,7 +47,7 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
|
|||||||
if search_space == 'tss':
|
if search_space == 'tss':
|
||||||
hp = '$\mathcal{H}^{1}$'
|
hp = '$\mathcal{H}^{1}$'
|
||||||
if dataset == 'cifar10':
|
if dataset == 'cifar10':
|
||||||
suffixes = ['-T800000', '-T800000-FULL']
|
suffixes = ['-T1200000', '-T1200000-FULL']
|
||||||
elif search_space == 'sss':
|
elif search_space == 'sss':
|
||||||
hp = '$\mathcal{H}^{2}$'
|
hp = '$\mathcal{H}^{2}$'
|
||||||
if dataset == 'cifar10':
|
if dataset == 'cifar10':
|
||||||
@ -100,7 +101,7 @@ y_max_s = {('cifar10', 'tss'): 94.5,
|
|||||||
('ImageNet16-120', 'tss'): 46,
|
('ImageNet16-120', 'tss'): 46,
|
||||||
('ImageNet16-120', 'sss'): 46}
|
('ImageNet16-120', 'sss'): 46}
|
||||||
|
|
||||||
x_axis_s = {('cifar10', 'tss'): 800000,
|
x_axis_s = {('cifar10', 'tss'): 1200000,
|
||||||
('cifar10', 'sss'): 200000,
|
('cifar10', 'sss'): 200000,
|
||||||
('cifar100', 'tss'): 400,
|
('cifar100', 'tss'): 400,
|
||||||
('cifar100', 'sss'): 400,
|
('cifar100', 'sss'): 400,
|
||||||
@ -114,6 +115,16 @@ name2label = {'cifar10': 'CIFAR-10',
|
|||||||
spaces2latex = {'tss': r'$\mathcal{S}_{t}$',
|
spaces2latex = {'tss': r'$\mathcal{S}_{t}$',
|
||||||
'sss': r'$\mathcal{S}_{s}$',}
|
'sss': r'$\mathcal{S}_{s}$',}
|
||||||
|
|
||||||
|
|
||||||
|
# FuncFormatter can be used as a decorator
|
||||||
|
@ticker.FuncFormatter
|
||||||
|
def major_formatter(x, pos):
|
||||||
|
if x == 0:
|
||||||
|
return '0'
|
||||||
|
else:
|
||||||
|
return "{:.2f}e5".format(x/1e5)
|
||||||
|
|
||||||
|
|
||||||
def visualize_curve(api_dict, vis_save_dir):
|
def visualize_curve(api_dict, vis_save_dir):
|
||||||
vis_save_dir = vis_save_dir.resolve()
|
vis_save_dir = vis_save_dir.resolve()
|
||||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@ -136,6 +147,7 @@ def visualize_curve(api_dict, vis_save_dir):
|
|||||||
tick.set_fontsize(LabelSize - 6)
|
tick.set_fontsize(LabelSize - 6)
|
||||||
for tick in ax.get_yticklabels():
|
for tick in ax.get_yticklabels():
|
||||||
tick.set_fontsize(LabelSize - 6)
|
tick.set_fontsize(LabelSize - 6)
|
||||||
|
ax.xaxis.set_major_formatter(major_formatter)
|
||||||
for idx, (alg, xdata) in enumerate(alg2data.items()):
|
for idx, (alg, xdata) in enumerate(alg2data.items()):
|
||||||
accuracies = []
|
accuracies = []
|
||||||
for ticket in time_tickets:
|
for ticket in time_tickets:
|
||||||
|
Loading…
Reference in New Issue
Block a user