2020-11-29 17:48:10 +01:00
###############################################################
2021-01-25 14:48:14 +01:00
# NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021 #
2020-11-29 17:48:10 +01:00
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NATS-Bench/draw-correlations.py #
###############################################################
import os , gc , sys , time , scipy , torch , argparse
import numpy as np
from typing import List , Text , Dict , Any
from shutil import copyfile
from collections import defaultdict , OrderedDict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib . use ( ' agg ' )
import matplotlib . pyplot as plt
import matplotlib . ticker as ticker
lib_dir = ( Path ( __file__ ) . parent / ' .. ' / ' .. ' / ' lib ' ) . resolve ( )
if str ( lib_dir ) not in sys . path : sys . path . insert ( 0 , str ( lib_dir ) )
from config_utils import dict2config , load_config
from nats_bench import create
from log_utils import time_string
def get_valid_test_acc ( api , arch , dataset ) :
is_size_space = api . search_space_name == ' size '
if dataset == ' cifar10 ' :
xinfo = api . get_more_info ( arch , dataset = dataset , hp = 90 if is_size_space else 200 , is_random = False )
test_acc = xinfo [ ' test-accuracy ' ]
xinfo = api . get_more_info ( arch , dataset = ' cifar10-valid ' , hp = 90 if is_size_space else 200 , is_random = False )
valid_acc = xinfo [ ' valid-accuracy ' ]
else :
xinfo = api . get_more_info ( arch , dataset = dataset , hp = 90 if is_size_space else 200 , is_random = False )
valid_acc = xinfo [ ' valid-accuracy ' ]
test_acc = xinfo [ ' test-accuracy ' ]
return valid_acc , test_acc , ' validation = {:.2f} , test = {:.2f} \n ' . format ( valid_acc , test_acc )
def compute_kendalltau ( vectori , vectorj ) :
# indexes = list(range(len(vectori)))
# rank_1 = sorted(indexes, key=lambda i: vectori[i])
# rank_2 = sorted(indexes, key=lambda i: vectorj[i])
# import pdb; pdb.set_trace()
coef , p = scipy . stats . kendalltau ( vectori , vectorj )
return coef
def compute_spearmanr ( vectori , vectorj ) :
coef , p = scipy . stats . spearmanr ( vectori , vectorj )
return coef
if __name__ == ' __main__ ' :
2021-01-25 14:48:14 +01:00
parser = argparse . ArgumentParser ( description = ' NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size ' , formatter_class = argparse . ArgumentDefaultsHelpFormatter )
2020-11-29 17:48:10 +01:00
parser . add_argument ( ' --save_dir ' , type = str , default = ' output/vis-nas-bench/nas-algos ' , help = ' Folder to save checkpoints and log. ' )
parser . add_argument ( ' --search_space ' , type = str , choices = [ ' tss ' , ' sss ' ] , help = ' Choose the search space. ' )
args = parser . parse_args ( )
save_dir = Path ( args . save_dir )
api = create ( None , ' tss ' , fast_mode = True , verbose = False )
indexes = list ( range ( 1 , 10000 , 300 ) )
scores_1 = [ ]
scores_2 = [ ]
for index in indexes :
valid_acc , test_acc , _ = get_valid_test_acc ( api , index , ' cifar10 ' )
scores_1 . append ( valid_acc )
scores_2 . append ( test_acc )
correlation = compute_kendalltau ( scores_1 , scores_2 )
print ( ' The kendall tau correlation of {:} samples : {:} ' . format ( len ( indexes ) , correlation ) )
correlation = compute_spearmanr ( scores_1 , scores_2 )
print ( ' The spearmanr correlation of {:} samples : {:} ' . format ( len ( indexes ) , correlation ) )
# scores_1 = ['{:.2f}'.format(x) for x in scores_1]
# scores_2 = ['{:.2f}'.format(x) for x in scores_2]
# print(', '.join(scores_1))
# print(', '.join(scores_2))
dpi , width , height = 250 , 1000 , 1000
figsize = width / float ( dpi ) , height / float ( dpi )
LabelSize , LegendFontsize = 14 , 14
fig , ax = plt . subplots ( 1 , 1 , figsize = figsize )
ax . scatter ( scores_1 , scores_2 , marker = ' ^ ' , s = 0.5 , c = ' tab:green ' , alpha = 0.8 )
save_path = ' /Users/xuanyidong/Desktop/test-temp-rank.png '
fig . savefig ( save_path , dpi = dpi , bbox_inches = ' tight ' , format = ' png ' )
plt . close ( ' all ' )