diff --git a/exps/NATS-Bench/draw-fig8.py b/exps/NATS-Bench/draw-fig8.py index d2a4ad5..d287c3c 100644 --- a/exps/NATS-Bench/draw-fig8.py +++ b/exps/NATS-Bench/draw-fig8.py @@ -25,6 +25,7 @@ from config_utils import dict2config, load_config from nats_bench import create from log_utils import time_string + plt.rcParams.update({ "text.usetex": True, "font.family": "sans-serif", @@ -46,7 +47,7 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): if search_space == 'tss': hp = '$\mathcal{H}^{1}$' if dataset == 'cifar10': - suffixes = ['-T800000', '-T800000-FULL'] + suffixes = ['-T1200000', '-T1200000-FULL'] elif search_space == 'sss': hp = '$\mathcal{H}^{2}$' if dataset == 'cifar10': @@ -100,7 +101,7 @@ y_max_s = {('cifar10', 'tss'): 94.5, ('ImageNet16-120', 'tss'): 46, ('ImageNet16-120', 'sss'): 46} -x_axis_s = {('cifar10', 'tss'): 800000, +x_axis_s = {('cifar10', 'tss'): 1200000, ('cifar10', 'sss'): 200000, ('cifar100', 'tss'): 400, ('cifar100', 'sss'): 400, @@ -114,6 +115,16 @@ name2label = {'cifar10': 'CIFAR-10', spaces2latex = {'tss': r'$\mathcal{S}_{t}$', 'sss': r'$\mathcal{S}_{s}$',} + +# FuncFormatter can be used as a decorator +@ticker.FuncFormatter +def major_formatter(x, pos): + if x == 0: + return '0' + else: + return "{:.2f}e5".format(x/1e5) + + def visualize_curve(api_dict, vis_save_dir): vis_save_dir = vis_save_dir.resolve() vis_save_dir.mkdir(parents=True, exist_ok=True) @@ -136,6 +147,7 @@ def visualize_curve(api_dict, vis_save_dir): tick.set_fontsize(LabelSize - 6) for tick in ax.get_yticklabels(): tick.set_fontsize(LabelSize - 6) + ax.xaxis.set_major_formatter(major_formatter) for idx, (alg, xdata) in enumerate(alg2data.items()): accuracies = [] for ticket in time_tickets: