upload
This commit is contained in:
336
sota/cnn/init_projection.py
Normal file
336
sota/cnn/init_projection.py
Normal file
@@ -0,0 +1,336 @@
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
import torch.utils
|
||||
from copy import deepcopy
|
||||
from foresight.pruners import *
|
||||
|
||||
torch.set_printoptions(precision=4, sci_mode=False)
|
||||
|
||||
def sample_op(model, input, target, args, cell_type, selected_eid=None):
|
||||
''' operation '''
|
||||
#### macros
|
||||
num_edges, num_ops = model.num_edges, model.num_ops
|
||||
candidate_flags = model.candidate_flags[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
#### select an edge
|
||||
if selected_eid is None:
|
||||
remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
selected_eid = np.random.choice(remain_eids, size=1)[0]
|
||||
logging.info('selected edge: %d %s', selected_eid, cell_type)
|
||||
|
||||
select_opid = np.random.choice(np.array(range(num_ops)), size=1)[0]
|
||||
return selected_eid, select_opid
|
||||
|
||||
def project_op(model, input, target, args, cell_type, proj_queue=None, selected_eid=None):
|
||||
''' operation '''
|
||||
#### macros
|
||||
num_edges, num_ops = model.num_edges, model.num_ops
|
||||
candidate_flags = model.candidate_flags[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
#### select an edge
|
||||
if selected_eid is None:
|
||||
remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
# print(num_edges, num_ops, remain_eids)
|
||||
if args.edge_decision == "random":
|
||||
selected_eid = np.random.choice(remain_eids, size=1)[0]
|
||||
logging.info('selected edge: %d %s', selected_eid, cell_type)
|
||||
elif args.edge_decision == 'reverse':
|
||||
selected_eid = remain_eids[-1]
|
||||
logging.info('selected edge: %d %s', selected_eid, cell_type)
|
||||
else:
|
||||
selected_eid = remain_eids[0]
|
||||
logging.info('selected node: %d %s', selected_eid, cell_type)
|
||||
|
||||
#### select the best operation
|
||||
if proj_crit == 'jacob':
|
||||
crit_idx = 3
|
||||
compare = lambda x, y: x < y
|
||||
else:
|
||||
crit_idx = 0
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
if args.dataset == 'cifar100':
|
||||
n_classes = 100
|
||||
elif args.dataset == 'imagenet16-120':
|
||||
n_classes = 120
|
||||
else:
|
||||
n_classes = 10
|
||||
|
||||
best_opid = 0
|
||||
crit_extrema = None
|
||||
crit_list = []
|
||||
op_ids = []
|
||||
for opid in range(num_ops):
|
||||
## projection
|
||||
weights = model.get_projected_weights(cell_type)
|
||||
proj_mask = torch.ones_like(weights[selected_eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[selected_eid] = weights[selected_eid] * proj_mask
|
||||
|
||||
# ## proj evaluation
|
||||
# with torch.no_grad():
|
||||
# valid_stats = Jocab_Score(model, cell_type, input, target, weights=weights)
|
||||
# crit = valid_stats
|
||||
# crit_list.append(crit)
|
||||
# if crit_extrema is None or compare(crit, crit_extrema):
|
||||
# crit_extrema = crit
|
||||
# best_opid = opid
|
||||
|
||||
## proj evaluation
|
||||
if proj_crit == 'jacob':
|
||||
crit = Jocab_Score(model,cell_type, input, target, weights=weights)
|
||||
else:
|
||||
cache_weight = model.proj_weights[cell_type][selected_eid]
|
||||
cache_flag = model.candidate_flags[cell_type][selected_eid]
|
||||
|
||||
for idx in range(num_ops):
|
||||
if idx == opid:
|
||||
model.proj_weights[cell_type][selected_eid][opid] = 0
|
||||
else:
|
||||
model.proj_weights[cell_type][selected_eid][idx] = 1.0 / num_ops
|
||||
|
||||
model.candidate_flags[cell_type][selected_eid] = False
|
||||
# print(model.get_projected_weights())
|
||||
if proj_crit == 'comb':
|
||||
synflow = predictive.find_measures(model,
|
||||
proj_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=['synflow'])
|
||||
var = predictive.find_measures(model,
|
||||
proj_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=['var'])
|
||||
# print(synflow, var)
|
||||
comb = np.log(synflow['synflow'] + 1) / (var['var'] + 0.1)
|
||||
measures = {'comb': comb}
|
||||
else:
|
||||
measures = predictive.find_measures(model,
|
||||
proj_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=[proj_crit])
|
||||
|
||||
# print(measures)
|
||||
for idx in range(num_ops):
|
||||
model.proj_weights[cell_type][selected_eid][idx] = 0
|
||||
model.candidate_flags[cell_type][selected_eid] = cache_flag
|
||||
crit = measures[proj_crit]
|
||||
|
||||
crit_list.append(crit)
|
||||
op_ids.append(opid)
|
||||
|
||||
best_opid = op_ids[np.nanargmin(crit_list)]
|
||||
|
||||
|
||||
|
||||
#### project
|
||||
logging.info('best opid: %d', best_opid)
|
||||
logging.info(crit_list)
|
||||
return selected_eid, best_opid
|
||||
|
||||
def project_global_op(model, input, target, args, infer, cell_type, selected_eid=None):
|
||||
''' operation '''
|
||||
#### macros
|
||||
num_edges, num_ops = model.num_edges, model.num_ops
|
||||
candidate_flags = model.candidate_flags[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
|
||||
#### select the best operation
|
||||
if proj_crit == 'jacob':
|
||||
crit_idx = 3
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
best_opid = 0
|
||||
crit_extrema = None
|
||||
best_eid = None
|
||||
for eid in remain_eids:
|
||||
for opid in range(num_ops):
|
||||
## projection
|
||||
weights = model.get_projected_weights(cell_type)
|
||||
proj_mask = torch.ones_like(weights[eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[eid] = weights[eid] * proj_mask
|
||||
|
||||
## proj evaluation
|
||||
|
||||
#weights_dict = {cell_type:weights}
|
||||
with torch.no_grad():
|
||||
valid_stats = Jocab_Score(model, cell_type, input, target, weights=weights)
|
||||
crit = valid_stats
|
||||
if crit_extrema is None or compare(crit, crit_extrema):
|
||||
crit_extrema = crit
|
||||
best_opid = opid
|
||||
best_eid = eid
|
||||
|
||||
#### project
|
||||
logging.info('best opid: %d', best_opid)
|
||||
#logging.info(crit_list)
|
||||
return best_eid, best_opid
|
||||
|
||||
def sample_edge(model, input, target, args, cell_type, selected_eid=None):
|
||||
''' topology '''
|
||||
#### macros
|
||||
candidate_flags = model.candidate_flags_edge[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
#### select an node
|
||||
remain_nids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
selected_nid = np.random.choice(remain_nids, size=1)[0]
|
||||
logging.info('selected node: %d %s', selected_nid, cell_type)
|
||||
|
||||
eids = deepcopy(model.nid2eids[selected_nid])
|
||||
|
||||
while len(eids) > 2:
|
||||
elected_eid = np.random.choice(eids, size=1)[0]
|
||||
eids.remove(elected_eid)
|
||||
|
||||
return selected_nid, eids
|
||||
|
||||
def project_edge(model, input, target, args, cell_type):
|
||||
''' topology '''
|
||||
#### macros
|
||||
candidate_flags = model.candidate_flags_edge[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
#### select an node
|
||||
remain_nids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
if args.edge_decision == "random":
|
||||
selected_nid = np.random.choice(remain_nids, size=1)[0]
|
||||
logging.info('selected node: %d %s', selected_nid, cell_type)
|
||||
elif args.edge_decision == 'reverse':
|
||||
selected_nid = remain_nids[-1]
|
||||
logging.info('selected node: %d %s', selected_nid, cell_type)
|
||||
else:
|
||||
selected_nid = np.random.choice(remain_nids, size=1)[0]
|
||||
logging.info('selected node: %d %s', selected_nid, cell_type)
|
||||
|
||||
#### select top2 edges
|
||||
if proj_crit == 'jacob':
|
||||
crit_idx = 3
|
||||
compare = lambda x, y: x < y
|
||||
else:
|
||||
crit_idx = 3
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
eids = deepcopy(model.nid2eids[selected_nid])
|
||||
crit_list = []
|
||||
while len(eids) > 2:
|
||||
eid_todel = None
|
||||
crit_extrema = None
|
||||
for eid in eids:
|
||||
weights = model.get_projected_weights(cell_type)
|
||||
weights[eid].data.fill_(0)
|
||||
|
||||
## proj evaluation
|
||||
with torch.no_grad():
|
||||
valid_stats = Jocab_Score(model, cell_type, input, target, weights=weights)
|
||||
crit = valid_stats
|
||||
|
||||
crit_list.append(crit)
|
||||
if crit_extrema is None or not compare(crit, crit_extrema): # find out bad edges
|
||||
crit_extrema = crit
|
||||
eid_todel = eid
|
||||
|
||||
eids.remove(eid_todel)
|
||||
|
||||
#### project
|
||||
logging.info('top2 edges: (%d, %d)', eids[0], eids[1])
|
||||
#logging.info(crit_list)
|
||||
return selected_nid, eids
|
||||
|
||||
|
||||
def pt_project(train_queue, model, args):
|
||||
model.eval()
|
||||
|
||||
#### macros
|
||||
num_projs = model.num_edges + len(model.nid2eids.keys())
|
||||
args.proj_crit = {'normal':args.proj_crit_normal, 'reduce':args.proj_crit_reduce}
|
||||
proj_queue = train_queue
|
||||
|
||||
epoch = 0
|
||||
for step, (input, target) in enumerate(proj_queue):
|
||||
if epoch < model.num_edges:
|
||||
logging.info('project op')
|
||||
|
||||
if args.edge_decision == 'global_op_greedy':
|
||||
selected_eid_normal, best_opid_normal = project_global_op(model, input, target, args, cell_type='normal')
|
||||
elif args.edge_decision == 'sample':
|
||||
selected_eid_normal, best_opid_normal = sample_op(model, input, target, args, cell_type='normal')
|
||||
else:
|
||||
selected_eid_normal, best_opid_normal = project_op(model, input, target, args, proj_queue=proj_queue, cell_type='normal')
|
||||
model.project_op(selected_eid_normal, best_opid_normal, cell_type='normal')
|
||||
if args.edge_decision == 'global_op_greedy':
|
||||
selected_eid_reduce, best_opid_reduce = project_global_op(model, input, target, args, cell_type='reduce')
|
||||
elif args.edge_decision == 'sample':
|
||||
selected_eid_reduce, best_opid_reduce = sample_op(model, input, target, args, cell_type='reduce')
|
||||
else:
|
||||
selected_eid_reduce, best_opid_reduce = project_op(model, input, target, args, proj_queue=proj_queue, cell_type='reduce')
|
||||
model.project_op(selected_eid_reduce, best_opid_reduce, cell_type='reduce')
|
||||
|
||||
else:
|
||||
logging.info('project edge')
|
||||
if args.edge_decision == 'sample':
|
||||
selected_nid_normal, eids_normal = sample_edge(model, input, target, args, cell_type='normal')
|
||||
model.project_edge(selected_nid_normal, eids_normal, cell_type='normal')
|
||||
selected_nid_reduce, eids_reduce = sample_edge(model, input, target, args, cell_type='reduce')
|
||||
model.project_edge(selected_nid_reduce, eids_reduce, cell_type='reduce')
|
||||
else:
|
||||
selected_nid_normal, eids_normal = project_edge(model, input, target, args, cell_type='normal')
|
||||
model.project_edge(selected_nid_normal, eids_normal, cell_type='normal')
|
||||
selected_nid_reduce, eids_reduce = project_edge(model, input, target, args, cell_type='reduce')
|
||||
model.project_edge(selected_nid_reduce, eids_reduce, cell_type='reduce')
|
||||
epoch+=1
|
||||
|
||||
if epoch == num_projs:
|
||||
break
|
||||
|
||||
return
|
||||
|
||||
def Jocab_Score(ori_model, cell_type, input, target, weights=None):
|
||||
model = deepcopy(ori_model)
|
||||
model.eval()
|
||||
if cell_type == 'reduce':
|
||||
model.proj_weights['reduce'] = weights
|
||||
model.proj_weights['normal'] = model.get_projected_weights('normal')
|
||||
else:
|
||||
model.proj_weights['normal'] = weights
|
||||
model.proj_weights['reduce'] = model.get_projected_weights('reduce')
|
||||
|
||||
batch_size = input.shape[0]
|
||||
model.K = torch.zeros(batch_size, batch_size).cuda()
|
||||
def counting_forward_hook(module, inp, out):
|
||||
try:
|
||||
if isinstance(inp, tuple):
|
||||
inp = inp[0]
|
||||
inp = inp.view(inp.size(0), -1)
|
||||
x = (inp > 0).float()
|
||||
K = x @ x.t()
|
||||
K2 = (1.-x) @ (1.-x.t())
|
||||
model.K = model.K + K + K2
|
||||
except:
|
||||
pass
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if 'ReLU' in str(type(module)):
|
||||
module.register_forward_hook(counting_forward_hook)
|
||||
|
||||
input = input.cuda()
|
||||
|
||||
model(input, using_proj=True)
|
||||
score = hooklogdet(model.K.cpu().numpy())
|
||||
|
||||
del model
|
||||
return score
|
||||
|
||||
def hooklogdet(K, labels=None):
|
||||
s, ld = np.linalg.slogdet(K)
|
||||
return ld
|
Reference in New Issue
Block a user