# This file is for experimental usage import os, sys, torch, random import numpy as np from copy import deepcopy from tqdm import tqdm import torch.nn as nn from utils import obtain_accuracy from models import CellStructure from log_utils import time_string def evaluate_one_shot(model, xloader, api, cal_mode, seed=111): weights = deepcopy(model.state_dict()) model.train(cal_mode) with torch.no_grad(): logits = nn.functional.log_softmax(model.arch_parameters, dim=-1) archs = CellStructure.gen_all(model.op_names, model.max_nodes, False) probs, accuracies, gt_accs_10_valid, gt_accs_10_test = [], [], [], [] loader_iter = iter(xloader) random.seed(seed) random.shuffle(archs) for idx, arch in enumerate(archs): arch_index = api.query_index_by_arch( arch ) metrics = api.get_more_info(arch_index, 'cifar10-valid', None, False, False) gt_accs_10_valid.append( metrics['valid-accuracy'] ) metrics = api.get_more_info(arch_index, 'cifar10', None, False, False) gt_accs_10_test.append( metrics['test-accuracy'] ) select_logits = [] for i, node_info in enumerate(arch.nodes): for op, xin in node_info: node_str = '{:}<-{:}'.format(i+1, xin) op_index = model.op_names.index(op) select_logits.append( logits[model.edge2index[node_str], op_index] ) cur_prob = sum(select_logits).item() probs.append( cur_prob ) cor_prob_valid = np.corrcoef(probs, gt_accs_10_valid)[0,1] cor_prob_test = np.corrcoef(probs, gt_accs_10_test )[0,1] print ('{:} correlation for probabilities : {:.6f} on CIFAR-10 validation and {:.6f} on CIFAR-10 test'.format(time_string(), cor_prob_valid, cor_prob_test)) for idx, arch in enumerate(archs): model.set_cal_mode('dynamic', arch) try: inputs, targets = next(loader_iter) except: loader_iter = iter(xloader) inputs, targets = next(loader_iter) _, logits = model(inputs.cuda()) _, preds = torch.max(logits, dim=-1) correct = (preds == targets.cuda() ).float() accuracies.append( correct.mean().item() ) if idx != 0 and (idx % 500 == 0 or idx + 1 == len(archs)): cor_accs_valid = np.corrcoef(accuracies, gt_accs_10_valid[:idx+1])[0,1] cor_accs_test = np.corrcoef(accuracies, gt_accs_10_test [:idx+1])[0,1] print ('{:} {:05d}/{:05d} mode={:5s}, correlation : accs={:.5f} for CIFAR-10 valid, {:.5f} for CIFAR-10 test.'.format(time_string(), idx, len(archs), 'Train' if cal_mode else 'Eval', cor_accs_valid, cor_accs_test)) model.load_state_dict(weights) return archs, probs, accuracies