MeCo/nasbench201/init_projection.py
HamsterMimi 2410fe9f5e update
2023-05-04 13:42:06 +08:00

619 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
import numpy as np
import torch
import torch.nn.functional as f
sys.path.insert(0, '../')
import nasbench201.utils as ig_utils
import logging
import torch.utils
import copy
import scipy.stats as ss
from collections import OrderedDict
from foresight.pruners import *
from op_score import Jocab_Score, get_ntk_n
import gc
from nasbench201.linear_region import Linear_Region_Collector
torch.set_printoptions(precision=4, sci_mode=False)
np.set_printoptions(precision=4, suppress=True)
# global-edge-iter: similar toglobal-op-iterbut iteratively selects edge e from E based on the average score of all operations on each edge
def global_op_greedy_pt_project(proj_queue, model, args):
def project(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
##get remain eid numbers
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
compare = lambda x, y : x < y
crit_extrema = None
best_eid = None
input, target = next(iter(proj_queue))
for eid in remain_eids:
for opid in range(num_op):
# projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
valid_stats = Jocab_Score(model, input, target, weights=weights)
crit = valid_stats
if crit_extrema is None or compare(crit, crit_extrema):
crit_extrema = crit
best_opid = opid
best_eid = eid
logging.info('best opid %d', best_opid)
return best_eid, best_opid
tune_epochs = model.arch_parameters()[0].shape[0]
for epoch in range(tune_epochs):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, best_opid = project(model, args)
model.project_op(selected_eid, best_opid)
return
# global-edge-iter: similar toglobal-op-oncebut uses the average score of operations on edges to obtain the edge discretization order
def global_edge_greedy_pt_project(proj_queue, model, args):
def select_eid(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
##get remain eid numbers
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
compare = lambda x, y : x < y
crit_extrema = None
best_eid = None
input, target = next(iter(proj_queue))
for eid in remain_eids:
eid_score = []
for opid in range(num_op):
# projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
valid_stats = Jocab_Score(model, input, target, weights=weights)
crit = valid_stats
eid_score.append(crit)
eid_score = np.mean(eid_score)
if crit_extrema is None or compare(eid_score, crit_extrema):
crit_extrema = eid_score
best_eid = eid
return best_eid
def project(model, args, selected_eid):
## macros
num_edge, num_op = model.num_edge, model.num_op
## select the best operation
if args.proj_crit == 'jacob':
crit_idx = 3
compare = lambda x, y: x < y
else:
crit_idx = 4
compare = lambda x, y: x < y
best_opid = 0
crit_list = []
op_ids = []
input, target = next(iter(proj_queue))
for opid in range(num_op):
## projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[selected_eid])
proj_mask[opid] = 0
weights[selected_eid] = weights[selected_eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
valid_stats = Jocab_Score(model, input, target, weights=weights)
crit = valid_stats
crit_list.append(crit)
op_ids.append(opid)
best_opid = op_ids[np.nanargmin(crit_list)]
logging.info('best opid %d', best_opid)
logging.info(crit_list)
return selected_eid, best_opid
num_edges = model.arch_parameters()[0].shape[0]
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid = select_eid(model, args)
selected_eid, best_opid = project(model, args, selected_eid)
model.project_op(selected_eid, best_opid)
return
# global-op-once: only evaluates S(A(e,o)) for all operations once to obtain a ranking order of the operations, and discretizes the edgesEaccording to this order
def global_op_once_pt_project(proj_queue, model, args):
def order(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
##get remain eid numbers
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
compare = lambda x, y : x < y
edge_score = OrderedDict()
input, target = next(iter(proj_queue))
for eid in remain_eids:
crit_list = []
for opid in range(num_op):
# projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
valid_stats = Jocab_Score(model, input, target, weights=weights)
crit = valid_stats
crit_list.append(crit)
edge_score[eid] = np.nanargmin(crit_list)
return edge_score
def project(model, args, selected_eid):
## macros
num_edge, num_op = model.num_edge, model.num_op
## select the best operation
if args.proj_crit == 'jacob':
crit_idx = 3
compare = lambda x, y: x < y
else:
crit_idx = 4
compare = lambda x, y: x < y
best_opid = 0
crit_list = []
op_ids = []
input, target = next(iter(proj_queue))
for opid in range(num_op):
## projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[selected_eid])
proj_mask[opid] = 0
weights[selected_eid] = weights[selected_eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
crit = Jocab_Score(model, input, target, weights=weights)
crit_list.append(crit)
op_ids.append(opid)
best_opid = op_ids[np.nanargmin(crit_list)]
logging.info('best opid %d', best_opid)
logging.info(crit_list)
return selected_eid, best_opid
num_edges = model.arch_parameters()[0].shape[0]
eid_order = order(model, args)
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, _ = eid_order.popitem()
selected_eid, best_opid = project(model, args, selected_eid)
model.project_op(selected_eid, best_opid)
return
# global-edge-once: similar toglobal-op-oncebut uses the average score of operations on dges to obtain the edge discretization order
def global_edge_once_pt_project(proj_queue, model, args):
def order(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
##get remain eid numbers
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
compare = lambda x, y : x < y
edge_score = OrderedDict()
crit_extrema = None
best_eid = None
input, target = next(iter(proj_queue))
for eid in remain_eids:
crit_list = []
for opid in range(num_op):
# projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
crit = Jocab_Score(model, input, target, weights=weights)
crit_list.append(crit)
edge_score[eid] = np.mean(crit_list)
return edge_score
def project(model, args, selected_eid):
## macros
num_edge, num_op = model.num_edge, model.num_op
## select the best operation
if args.proj_crit == 'jacob':
crit_idx = 3
compare = lambda x, y: x < y
else:
crit_idx = 4
compare = lambda x, y: x < y
best_opid = 0
crit_extrema = None
crit_list = []
op_ids = []
input, target = next(iter(proj_queue))
for opid in range(num_op):
## projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[selected_eid])
proj_mask[opid] = 0
weights[selected_eid] = weights[selected_eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
crit = Jocab_Score(model, input, target, weights=weights)
crit_list.append(crit)
op_ids.append(opid)
best_opid = op_ids[np.nanargmin(crit_list)]
logging.info('best opid %d', best_opid)
logging.info(crit_list)
return selected_eid, best_opid
num_edges = model.arch_parameters()[0].shape[0]
eid_order = order(model, args)
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, _ = eid_order.popitem()
selected_eid, best_opid = project(model, args, selected_eid)
model.project_op(selected_eid, best_opid)
return
# fixed [reverse, order]: discretizes the edges in a fixed order, where in our experiments we discretize from the222input towards the output of the cell struct
# random: discretizes the edges in a random order (DARTS-PT)
# NOTE: Only this methods allows use other zero-cost proxy metrics
def pt_project(proj_queue, model, args):
def project(model, args):
## macros,一共6条边每条边有5个操作
num_edge, num_op = model.num_edge, model.num_op
## select an edge
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
# print('candidate_flags:', model.candidate_flags)
# print(model.candidate_flags)
# 选边的方法
if args.edge_decision == "random":
# 选出来了一个数组,取其中的一个元素
selected_eid = np.random.choice(remain_eids, size=1)[0]
elif args.edge_decision == "reverse":
selected_eid = remain_eids[-1]
else:
selected_eid = remain_eids[0]
## select the best operation
if args.proj_crit == 'jacob':
crit_idx = 3
compare = lambda x, y: x < y
else:
crit_idx = 4
compare = lambda x, y: x < y
if args.dataset == 'cifar100':
n_classes = 100
elif args.dataset == 'imagenet16-120':
n_classes = 120
else:
n_classes = 10
best_opid = 0
crit_extrema = None
crit_list = []
op_ids = []
input, target = next(iter(proj_queue))
for opid in range(num_op):
## projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[selected_eid])
# print(selected_eid, weights[selected_eid])
proj_mask[opid] = 0
weights[selected_eid] = weights[selected_eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
crit = Jocab_Score(model, input, target, weights=weights)
else:
cache_weight = model.proj_weights[selected_eid]
cache_flag = model.candidate_flags[selected_eid]
for idx in range(num_op):
if idx == opid:
model.proj_weights[selected_eid][opid] = 0
else:
model.proj_weights[selected_eid][idx] = 1.0/num_op
model.candidate_flags[selected_eid] = False
# print(model.get_projected_weights())
if args.proj_crit == 'comb':
synflow = predictive.find_measures(model,
proj_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=['synflow'])
var = predictive.find_measures(model,
proj_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=['var'])
# print(synflow, var)
comb = np.log(synflow['synflow'] + 1) / (var['var'] + 0.1)
measures = {'comb': comb}
else:
measures = predictive.find_measures(model,
proj_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=[args.proj_crit])
# print(measures)
for idx in range(num_op):
model.proj_weights[selected_eid][idx] = 0
model.candidate_flags[selected_eid] = cache_flag
crit = measures[args.proj_crit]
crit_list.append(crit)
op_ids.append(opid)
best_opid = op_ids[np.nanargmin(crit_list)]
# best_opid = op_ids[np.nanargmax(crit_list)]
logging.info('best opid %d', best_opid)
logging.info('current edge id %d', selected_eid)
logging.info(crit_list)
return selected_eid, best_opid
num_edges = model.arch_parameters()[0].shape[0]
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, best_opid = project(model, args)
model.project_op(selected_eid, best_opid)
return
def tenas_project(proj_queue, model, model_thin, args):
def project(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
##get remain eid numbers
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
compare = lambda x, y : x < y
ntks = []
lrs = []
edge_op_id = []
best_eid = None
if args.proj_crit == 'tenas':
lrc_model = Linear_Region_Collector(input_size=(1000, 1, 3, 3), sample_batch=3, dataset=args.dataset, data_path=args.data, seed=args.seed)
for eid in remain_eids:
for opid in range(num_op):
# projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
if args.proj_crit == 'tenas':
lrc_model.reinit(ori_models=[model_thin], seed=args.seed, weights=weights)
lr = lrc_model.forward_batch_sample()
lrc_model.clear()
ntk = get_ntk_n(proj_queue, [model], recalbn=0, train_mode=True, num_batch=1, weights=weights)
ntks.append(ntk)
lrs.append(lr)
edge_op_id.append('{}:{}'.format(eid, opid))
print('ntls', ntks)
print('lrs', lrs)
ntks_ranks = ss.rankdata(ntks)
lrs_ranks = ss.rankdata(lrs)
ntks_ranks = len(ntks_ranks) - ntks_ranks.astype(int)
op_ranks = []
for i in range(len(edge_op_id)):
op_ranks.append(ntks_ranks[i]+lrs_ranks[i])
best_op_index = edge_op_id[np.nanargmin(op_ranks[0:num_op])]
best_eid, best_opid = [int(x) for x in best_op_index.split(':')]
logging.info(op_ranks)
logging.info('best eid %d', best_eid)
logging.info('best opid %d', best_opid)
return best_eid, best_opid
num_edges = model.arch_parameters()[0].shape[0]
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, best_opid = project(model, args)
model.project_op(selected_eid, best_opid)
return
#new methods
#Randomly propose candidate of networks and transfer it to supernet, then perform global op selection in this subspace
def shrink_pt_project(proj_queue, model, args):
def project(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
## select an edge
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
selected_eid = np.random.choice(remain_eids, size=1)[0]
## select the best operation
if args.proj_crit == 'jacob':
crit_idx = 3
compare = lambda x, y: x < y
else:
crit_idx = 4
compare = lambda x, y: x < y
if args.dataset == 'cifar100':
n_classes = 100
elif args.dataset == 'imagenet16-120':
n_classes = 120
else:
n_classes = 10
best_opid = 0
crit_extrema = None
crit_list = []
op_ids = []
input, target = next(iter(proj_queue))
for opid in range(num_op):
## projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[selected_eid])
proj_mask[opid] = 0
weights[selected_eid] = weights[selected_eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
crit = Jocab_Score(model, input, target, weights=weights)
else:
cache_weight = model.proj_weights[selected_eid]
cache_flag = model.candidate_flags[selected_eid]
for idx in range(num_op):
if idx == opid:
model.proj_weights[selected_eid][opid] = 0
else:
model.proj_weights[selected_eid][idx] = 1.0/num_op
model.candidate_flags[selected_eid] = False
measures = predictive.find_measures(model,
train_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=[args.proj_crit])
for idx in range(num_op):
model.proj_weights[selected_eid][idx] = 0
model.candidate_flags[selected_eid] = cache_flag
crit = measures[args.proj_crit]
crit_list.append(crit)
op_ids.append(opid)
best_opid = op_ids[np.nanargmin(crit_list)]
logging.info('best opid %d', best_opid)
logging.info('current edge id %d', selected_eid)
logging.info(crit_list)
return selected_eid, best_opid
def global_project(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
##get remain eid numbers
remain_eids = torch.nonzero(model.subspace_candidate_flags).cpu().numpy().T[0]
compare = lambda x, y : x < y
crit_extrema = None
best_eid = None
best_opid = None
input, target = next(iter(proj_queue))
for eid in remain_eids:
remain_oids = torch.nonzero(model.proj_weights[eid]).cpu().numpy().T[0]
for opid in remain_oids:
# projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
valid_stats = Jocab_Score(model, input, target, weights=weights)
crit = valid_stats
if crit_extrema is None or compare(crit, crit_extrema):
crit_extrema = crit
best_opid = opid
best_eid = eid
logging.info('best eid %d', best_eid)
logging.info('best opid %d', best_opid)
model.subspace_candidate_flags[best_eid] = False
proj_mask = torch.zeros_like(model.proj_weights[best_eid])
model.proj_weights[best_eid] = model.proj_weights[best_eid] * proj_mask
model.proj_weights[best_eid][best_opid] = 1
return best_eid, best_opid
num_edges = model.arch_parameters()[0].shape[0]
#subspace
logging.info('Start subspace proposal')
subspace = copy.deepcopy(model.proj_weights)
for i in range(20):
model.reset_arch_parameters()
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, best_opid = project(model, args)
model.project_op(selected_eid, best_opid)
subspace += model.proj_weights
model.reset_arch_parameters()
subspace = torch.gt(subspace, 0).int().float()
subspace = f.normalize(subspace, p=1, dim=1)
model.proj_weights += subspace
for i in range(num_edges):
model.candidate_flags[i] = False
logging.info('Start final search in subspace')
logging.info(subspace)
model.subspace_candidate_flags = torch.tensor(len(model._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda()
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, best_opid = global_project(model, args)
model.printing(logging)
#model.project_op(selected_eid, best_opid)
return