2020-01-09 12:26:23 +01:00
# This file is for experimental usage
2020-03-11 08:44:39 +01:00
import torch , random
2020-01-09 12:26:23 +01:00
import numpy as np
from copy import deepcopy
import torch . nn as nn
2020-03-11 08:44:39 +01:00
# from utils import obtain_accuracy
2020-01-09 12:26:23 +01:00
from models import CellStructure
from log_utils import time_string
def evaluate_one_shot ( model , xloader , api , cal_mode , seed = 111 ) :
weights = deepcopy ( model . state_dict ( ) )
model . train ( cal_mode )
with torch . no_grad ( ) :
logits = nn . functional . log_softmax ( model . arch_parameters , dim = - 1 )
archs = CellStructure . gen_all ( model . op_names , model . max_nodes , False )
2020-01-10 07:26:37 +01:00
probs , accuracies , gt_accs_10_valid , gt_accs_10_test = [ ] , [ ] , [ ] , [ ]
2020-01-09 12:26:23 +01:00
loader_iter = iter ( xloader )
random . seed ( seed )
random . shuffle ( archs )
for idx , arch in enumerate ( archs ) :
arch_index = api . query_index_by_arch ( arch )
metrics = api . get_more_info ( arch_index , ' cifar10-valid ' , None , False , False )
2020-01-10 07:26:37 +01:00
gt_accs_10_valid . append ( metrics [ ' valid-accuracy ' ] )
metrics = api . get_more_info ( arch_index , ' cifar10 ' , None , False , False )
gt_accs_10_test . append ( metrics [ ' test-accuracy ' ] )
2020-01-09 12:26:23 +01:00
select_logits = [ ]
for i , node_info in enumerate ( arch . nodes ) :
for op , xin in node_info :
node_str = ' {:} <- {:} ' . format ( i + 1 , xin )
op_index = model . op_names . index ( op )
select_logits . append ( logits [ model . edge2index [ node_str ] , op_index ] )
cur_prob = sum ( select_logits ) . item ( )
probs . append ( cur_prob )
2020-01-10 07:26:37 +01:00
cor_prob_valid = np . corrcoef ( probs , gt_accs_10_valid ) [ 0 , 1 ]
cor_prob_test = np . corrcoef ( probs , gt_accs_10_test ) [ 0 , 1 ]
print ( ' {:} correlation for probabilities : {:.6f} on CIFAR-10 validation and {:.6f} on CIFAR-10 test ' . format ( time_string ( ) , cor_prob_valid , cor_prob_test ) )
2020-01-09 12:26:23 +01:00
for idx , arch in enumerate ( archs ) :
model . set_cal_mode ( ' dynamic ' , arch )
try :
inputs , targets = next ( loader_iter )
except :
loader_iter = iter ( xloader )
inputs , targets = next ( loader_iter )
_ , logits = model ( inputs . cuda ( ) )
_ , preds = torch . max ( logits , dim = - 1 )
correct = ( preds == targets . cuda ( ) ) . float ( )
accuracies . append ( correct . mean ( ) . item ( ) )
2020-01-10 07:26:37 +01:00
if idx != 0 and ( idx % 500 == 0 or idx + 1 == len ( archs ) ) :
cor_accs_valid = np . corrcoef ( accuracies , gt_accs_10_valid [ : idx + 1 ] ) [ 0 , 1 ]
cor_accs_test = np . corrcoef ( accuracies , gt_accs_10_test [ : idx + 1 ] ) [ 0 , 1 ]
print ( ' {:} {:05d} / {:05d} mode= {:5s} , correlation : accs= {:.5f} for CIFAR-10 valid, {:.5f} for CIFAR-10 test. ' . format ( time_string ( ) , idx , len ( archs ) , ' Train ' if cal_mode else ' Eval ' , cor_accs_valid , cor_accs_test ) )
2020-01-09 12:26:23 +01:00
model . load_state_dict ( weights )
return archs , probs , accuracies