2020-12-01 15:25:23 +01:00
###############################################################
# NATS-Bench (https://arxiv.org/pdf/2009.00437.pdf) #
# The code to draw Figure 6 in our paper. #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NATS-Bench/draw-fig8.py #
###############################################################
import os , gc , sys , time , torch , argparse
import numpy as np
from typing import List , Text , Dict , Any
from shutil import copyfile
from collections import defaultdict , OrderedDict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib . use ( ' agg ' )
import matplotlib . pyplot as plt
import matplotlib . ticker as ticker
lib_dir = ( Path ( __file__ ) . parent / ' .. ' / ' .. ' / ' lib ' ) . resolve ( )
if str ( lib_dir ) not in sys . path : sys . path . insert ( 0 , str ( lib_dir ) )
from config_utils import dict2config , load_config
from nats_bench import create
from log_utils import time_string
2020-12-02 14:43:35 +01:00
2020-12-01 15:25:23 +01:00
plt . rcParams . update ( {
" text.usetex " : True ,
" font.family " : " sans-serif " ,
" font.sans-serif " : [ " Helvetica " ] } )
## for Palatino and other serif fonts use:
plt . rcParams . update ( {
" text.usetex " : True ,
" font.family " : " serif " ,
" font.serif " : [ " Palatino " ] ,
} )
def fetch_data ( root_dir = ' ./output/search ' , search_space = ' tss ' , dataset = None ) :
ss_dir = ' {:} - {:} ' . format ( root_dir , search_space )
alg2all = OrderedDict ( )
# alg2name['REINFORCE'] = 'REINFORCE-0.01'
# alg2name['RANDOM'] = 'RANDOM'
# alg2name['BOHB'] = 'BOHB'
if search_space == ' tss ' :
hp = ' $ \ mathcal {H} ^ {1} $ '
2020-12-02 01:05:03 +01:00
if dataset == ' cifar10 ' :
2020-12-02 14:43:35 +01:00
suffixes = [ ' -T1200000 ' , ' -T1200000-FULL ' ]
2020-12-01 15:25:23 +01:00
elif search_space == ' sss ' :
hp = ' $ \ mathcal {H} ^ {2} $ '
2020-12-02 01:05:03 +01:00
if dataset == ' cifar10 ' :
suffixes = [ ' -T200000 ' , ' -T200000-FULL ' ]
2020-12-01 15:25:23 +01:00
else :
raise ValueError ( ' Unkonwn search space: {:} ' . format ( search_space ) )
alg2all [ r ' REA ($ \ mathcal {H} ^ {0} $) ' ] = dict (
path = os . path . join ( ss_dir , dataset + suffixes [ 0 ] , ' R-EA-SS3 ' , ' results.pth ' ) ,
color = ' b ' , linestyle = ' - ' )
alg2all [ r ' REA ( {:} ) ' . format ( hp ) ] = dict (
path = os . path . join ( ss_dir , dataset + suffixes [ 1 ] , ' R-EA-SS3 ' , ' results.pth ' ) ,
color = ' b ' , linestyle = ' -- ' )
for alg , xdata in alg2all . items ( ) :
data = torch . load ( xdata [ ' path ' ] )
for index , info in data . items ( ) :
info [ ' time_w_arch ' ] = [ ( x , y ) for x , y in zip ( info [ ' all_total_times ' ] , info [ ' all_archs ' ] ) ]
for j , arch in enumerate ( info [ ' all_archs ' ] ) :
assert arch != - 1 , ' invalid arch from {:} {:} {:} ( {:} , {:} ) ' . format ( alg , search_space , dataset , index , j )
xdata [ ' data ' ] = data
return alg2all
def query_performance ( api , data , dataset , ticket ) :
results , is_size_space = [ ] , api . search_space_name == ' size '
for i , info in data . items ( ) :
time_w_arch = sorted ( info [ ' time_w_arch ' ] , key = lambda x : abs ( x [ 0 ] - ticket ) )
time_a , arch_a = time_w_arch [ 0 ]
time_b , arch_b = time_w_arch [ 1 ]
info_a = api . get_more_info ( arch_a , dataset = dataset , hp = 90 if is_size_space else 200 , is_random = False )
info_b = api . get_more_info ( arch_b , dataset = dataset , hp = 90 if is_size_space else 200 , is_random = False )
accuracy_a , accuracy_b = info_a [ ' test-accuracy ' ] , info_b [ ' test-accuracy ' ]
interplate = ( time_b - ticket ) / ( time_b - time_a ) * accuracy_a + ( ticket - time_a ) / ( time_b - time_a ) * accuracy_b
results . append ( interplate )
# return sum(results) / len(results)
return np . mean ( results ) , np . std ( results )
2020-12-02 01:05:03 +01:00
y_min_s = { ( ' cifar10 ' , ' tss ' ) : 91 ,
( ' cifar10 ' , ' sss ' ) : 91 ,
2020-12-01 15:25:23 +01:00
( ' cifar100 ' , ' tss ' ) : 65 ,
( ' cifar100 ' , ' sss ' ) : 65 ,
( ' ImageNet16-120 ' , ' tss ' ) : 36 ,
( ' ImageNet16-120 ' , ' sss ' ) : 40 }
y_max_s = { ( ' cifar10 ' , ' tss ' ) : 94.5 ,
2020-12-02 01:05:03 +01:00
( ' cifar10 ' , ' sss ' ) : 93.5 ,
2020-12-01 15:25:23 +01:00
( ' cifar100 ' , ' tss ' ) : 72.5 ,
( ' cifar100 ' , ' sss ' ) : 70.5 ,
( ' ImageNet16-120 ' , ' tss ' ) : 46 ,
( ' ImageNet16-120 ' , ' sss ' ) : 46 }
2020-12-02 14:43:35 +01:00
x_axis_s = { ( ' cifar10 ' , ' tss ' ) : 1200000 ,
2020-12-01 15:25:23 +01:00
( ' cifar10 ' , ' sss ' ) : 200000 ,
( ' cifar100 ' , ' tss ' ) : 400 ,
( ' cifar100 ' , ' sss ' ) : 400 ,
( ' ImageNet16-120 ' , ' tss ' ) : 1200 ,
( ' ImageNet16-120 ' , ' sss ' ) : 600 }
name2label = { ' cifar10 ' : ' CIFAR-10 ' ,
' cifar100 ' : ' CIFAR-100 ' ,
' ImageNet16-120 ' : ' ImageNet-16-120 ' }
spaces2latex = { ' tss ' : r ' $ \ mathcal {S} _ {t} $ ' ,
' sss ' : r ' $ \ mathcal {S} _ {s} $ ' , }
2020-12-02 14:43:35 +01:00
# FuncFormatter can be used as a decorator
@ticker.FuncFormatter
def major_formatter ( x , pos ) :
if x == 0 :
return ' 0 '
else :
return " {:.2f} e5 " . format ( x / 1e5 )
2020-12-01 15:25:23 +01:00
def visualize_curve ( api_dict , vis_save_dir ) :
vis_save_dir = vis_save_dir . resolve ( )
vis_save_dir . mkdir ( parents = True , exist_ok = True )
2020-12-02 01:05:03 +01:00
dpi , width , height = 250 , 5000 , 2000
2020-12-01 15:25:23 +01:00
figsize = width / float ( dpi ) , height / float ( dpi )
2020-12-02 01:05:03 +01:00
LabelSize , LegendFontsize = 28 , 28
2020-12-01 15:25:23 +01:00
def sub_plot_fn ( ax , search_space , dataset ) :
max_time = x_axis_s [ ( dataset , search_space ) ]
alg2data = fetch_data ( search_space = search_space , dataset = dataset )
alg2accuracies = OrderedDict ( )
total_tickets = 200
time_tickets = [ float ( i ) / total_tickets * int ( max_time ) for i in range ( total_tickets ) ]
ax . set_xlim ( 0 , x_axis_s [ ( dataset , search_space ) ] )
ax . set_ylim ( y_min_s [ ( dataset , search_space ) ] ,
y_max_s [ ( dataset , search_space ) ] )
2020-12-02 01:05:03 +01:00
for tick in ax . get_xticklabels ( ) :
tick . set_rotation ( 25 )
tick . set_fontsize ( LabelSize - 6 )
for tick in ax . get_yticklabels ( ) :
tick . set_fontsize ( LabelSize - 6 )
2020-12-02 14:43:35 +01:00
ax . xaxis . set_major_formatter ( major_formatter )
2020-12-01 15:25:23 +01:00
for idx , ( alg , xdata ) in enumerate ( alg2data . items ( ) ) :
accuracies = [ ]
for ticket in time_tickets :
# import pdb; pdb.set_trace()
accuracy , accuracy_std = query_performance (
api_dict [ search_space ] , xdata [ ' data ' ] , dataset , ticket )
accuracies . append ( accuracy )
# print('{:} plot alg : {:10s}, final accuracy = {:.2f}$\pm${:.2f}'.format(time_string(), alg, accuracy, accuracy_std))
print ( ' {:} plot alg : {:10s} on {:} ' . format ( time_string ( ) , alg , search_space ) )
alg2accuracies [ alg ] = accuracies
ax . plot ( time_tickets , accuracies , c = xdata [ ' color ' ] , linestyle = xdata [ ' linestyle ' ] , label = ' {:} ' . format ( alg ) )
ax . set_xlabel ( ' Estimated wall-clock time ' , fontsize = LabelSize )
ax . set_ylabel ( ' Test accuracy ' , fontsize = LabelSize )
2020-12-02 01:05:03 +01:00
ax . set_title ( r ' Results on {:} over {:} ' . format ( name2label [ dataset ] , spaces2latex [ search_space ] ) ,
fontsize = LabelSize )
2020-12-01 15:25:23 +01:00
ax . legend ( loc = 4 , fontsize = LegendFontsize )
fig , axs = plt . subplots ( 1 , 2 , figsize = figsize )
sub_plot_fn ( axs [ 0 ] , ' tss ' , ' cifar10 ' )
sub_plot_fn ( axs [ 1 ] , ' sss ' , ' cifar10 ' )
save_path = ( vis_save_dir / ' full-curve.png ' ) . resolve ( )
fig . savefig ( save_path , dpi = dpi , bbox_inches = ' tight ' , format = ' png ' )
print ( ' {:} save into {:} ' . format ( time_string ( ) , save_path ) )
plt . close ( ' all ' )
if __name__ == ' __main__ ' :
parser = argparse . ArgumentParser ( description = ' NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size ' , formatter_class = argparse . ArgumentDefaultsHelpFormatter )
2020-12-02 01:05:03 +01:00
parser . add_argument ( ' --save_dir ' , type = str , default = ' output/vis-nas-bench/nas-algos-vs-h ' , help = ' Folder to save checkpoints and log. ' )
2020-12-01 15:25:23 +01:00
args = parser . parse_args ( )
save_dir = Path ( args . save_dir )
api_tss = create ( None , ' tss ' , fast_mode = True , verbose = False )
api_sss = create ( None , ' sss ' , fast_mode = True , verbose = False )
visualize_curve ( dict ( tss = api_tss , sss = api_sss ) , save_dir )