upload
This commit is contained in:
248
sota/cnn/projection.py
Normal file
248
sota/cnn/projection.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
import numpy as np
|
||||
import torch
|
||||
import nasbench201.utils as ig_utils
|
||||
import logging
|
||||
import torch.utils
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
torch.set_printoptions(precision=4, sci_mode=False)
|
||||
|
||||
|
||||
def project_op(model, proj_queue, args, infer, cell_type, selected_eid=None):
|
||||
''' operation '''
|
||||
#### macros
|
||||
num_edges, num_ops = model.num_edges, model.num_ops
|
||||
candidate_flags = model.candidate_flags[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
#### select an edge
|
||||
if selected_eid is None:
|
||||
remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
if args.edge_decision == "random":
|
||||
selected_eid = np.random.choice(remain_eids, size=1)[0]
|
||||
logging.info('selected edge: %d %s', selected_eid, cell_type)
|
||||
|
||||
#### select the best operation
|
||||
if proj_crit == 'loss':
|
||||
crit_idx = 1
|
||||
compare = lambda x, y: x > y
|
||||
elif proj_crit == 'acc':
|
||||
crit_idx = 0
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
best_opid = 0
|
||||
crit_extrema = None
|
||||
for opid in range(num_ops):
|
||||
## projection
|
||||
weights = model.get_projected_weights(cell_type)
|
||||
proj_mask = torch.ones_like(weights[selected_eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[selected_eid] = weights[selected_eid] * proj_mask
|
||||
|
||||
## proj evaluation
|
||||
weights_dict = {cell_type:weights}
|
||||
valid_stats = infer(proj_queue, model, log=False, _eval=False, weights_dict=weights_dict)
|
||||
crit = valid_stats[crit_idx]
|
||||
|
||||
if crit_extrema is None or compare(crit, crit_extrema):
|
||||
crit_extrema = crit
|
||||
best_opid = opid
|
||||
logging.info('valid_acc %f', valid_stats[0])
|
||||
logging.info('valid_loss %f', valid_stats[1])
|
||||
|
||||
#### project
|
||||
logging.info('best opid: %d', best_opid)
|
||||
return selected_eid, best_opid
|
||||
|
||||
|
||||
def project_edge(model, proj_queue, args, infer, cell_type):
|
||||
''' topology '''
|
||||
#### macros
|
||||
candidate_flags = model.candidate_flags_edge[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
#### select an edge
|
||||
remain_nids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
if args.edge_decision == "random":
|
||||
selected_nid = np.random.choice(remain_nids, size=1)[0]
|
||||
logging.info('selected node: %d %s', selected_nid, cell_type)
|
||||
|
||||
#### select top2 edges
|
||||
if proj_crit == 'loss':
|
||||
crit_idx = 1
|
||||
compare = lambda x, y: x > y
|
||||
elif proj_crit == 'acc':
|
||||
crit_idx = 0
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
eids = deepcopy(model.nid2eids[selected_nid])
|
||||
while len(eids) > 2:
|
||||
eid_todel = None
|
||||
crit_extrema = None
|
||||
for eid in eids:
|
||||
weights = model.get_projected_weights(cell_type)
|
||||
weights[eid].data.fill_(0)
|
||||
weights_dict = {cell_type:weights}
|
||||
|
||||
## proj evaluation
|
||||
valid_stats = infer(proj_queue, model, log=False, _eval=False, weights_dict=weights_dict)
|
||||
crit = valid_stats[crit_idx]
|
||||
|
||||
if crit_extrema is None or not compare(crit, crit_extrema): # find out bad edges
|
||||
crit_extrema = crit
|
||||
eid_todel = eid
|
||||
logging.info('valid_acc %f', valid_stats[0])
|
||||
logging.info('valid_loss %f', valid_stats[1])
|
||||
eids.remove(eid_todel)
|
||||
|
||||
#### project
|
||||
logging.info('top2 edges: (%d, %d)', eids[0], eids[1])
|
||||
return selected_nid, eids
|
||||
|
||||
|
||||
def pt_project(train_queue, valid_queue, model, architect, optimizer,
|
||||
epoch, args, infer, perturb_alpha, epsilon_alpha):
|
||||
model.train()
|
||||
model.printing(logging)
|
||||
|
||||
train_acc, train_obj = infer(train_queue, model, log=False)
|
||||
logging.info('train_acc %f', train_acc)
|
||||
logging.info('train_loss %f', train_obj)
|
||||
|
||||
valid_acc, valid_obj = infer(valid_queue, model, log=False)
|
||||
logging.info('valid_acc %f', valid_acc)
|
||||
logging.info('valid_loss %f', valid_obj)
|
||||
|
||||
objs = ig_utils.AvgrageMeter()
|
||||
top1 = ig_utils.AvgrageMeter()
|
||||
top5 = ig_utils.AvgrageMeter()
|
||||
|
||||
|
||||
#### macros
|
||||
num_projs = model.num_edges + len(model.nid2eids.keys()) - 1 ## -1 because we project at both epoch 0 and -1
|
||||
tune_epochs = args.proj_intv * num_projs + 1
|
||||
proj_intv = args.proj_intv
|
||||
args.proj_crit = {'normal':args.proj_crit_normal, 'reduce':args.proj_crit_reduce}
|
||||
proj_queue = valid_queue
|
||||
|
||||
|
||||
#### reset optimizer
|
||||
model.reset_optimizer(args.learning_rate / 10, args.momentum, args.weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
model.optimizer, float(tune_epochs), eta_min=args.learning_rate_min)
|
||||
|
||||
|
||||
#### load proj checkpoints
|
||||
start_epoch = 0
|
||||
if args.dev_resume_epoch >= 0:
|
||||
filename = os.path.join(args.dev_resume_checkpoint_dir, 'checkpoint_{}.pth.tar'.format(args.dev_resume_epoch))
|
||||
if os.path.isfile(filename):
|
||||
logging.info("=> loading projection checkpoint '{}'".format(filename))
|
||||
checkpoint = torch.load(filename, map_location='cpu')
|
||||
start_epoch = checkpoint['epoch']
|
||||
model.set_state_dict(architect, scheduler, checkpoint)
|
||||
model.set_arch_parameters(checkpoint['alpha'])
|
||||
scheduler.load_state_dict(checkpoint['scheduler'])
|
||||
model.optimizer.load_state_dict(checkpoint['optimizer']) # optimizer
|
||||
else:
|
||||
logging.info("=> no checkpoint found at '{}'".format(filename))
|
||||
exit(0)
|
||||
|
||||
|
||||
#### projecting and tuning
|
||||
for epoch in range(start_epoch, tune_epochs):
|
||||
logging.info('epoch %d', epoch)
|
||||
|
||||
## project
|
||||
if epoch % proj_intv == 0 or epoch == tune_epochs - 1:
|
||||
## saving every projection
|
||||
save_state_dict = model.get_state_dict(epoch, architect, scheduler)
|
||||
ig_utils.save_checkpoint(save_state_dict, False, args.dev_save_checkpoint_dir, per_epoch=True)
|
||||
|
||||
if epoch < proj_intv * model.num_edges:
|
||||
logging.info('project op')
|
||||
|
||||
selected_eid_normal, best_opid_normal = project_op(model, proj_queue, args, infer, cell_type='normal')
|
||||
model.project_op(selected_eid_normal, best_opid_normal, cell_type='normal')
|
||||
selected_eid_reduce, best_opid_reduce = project_op(model, proj_queue, args, infer, cell_type='reduce')
|
||||
model.project_op(selected_eid_reduce, best_opid_reduce, cell_type='reduce')
|
||||
|
||||
model.printing(logging)
|
||||
else:
|
||||
logging.info('project edge')
|
||||
|
||||
selected_nid_normal, eids_normal = project_edge(model, proj_queue, args, infer, cell_type='normal')
|
||||
model.project_edge(selected_nid_normal, eids_normal, cell_type='normal')
|
||||
selected_nid_reduce, eids_reduce = project_edge(model, proj_queue, args, infer, cell_type='reduce')
|
||||
model.project_edge(selected_nid_reduce, eids_reduce, cell_type='reduce')
|
||||
|
||||
model.printing(logging)
|
||||
|
||||
## tune
|
||||
for step, (input, target) in enumerate(train_queue):
|
||||
model.train()
|
||||
n = input.size(0)
|
||||
|
||||
## fetch data
|
||||
input = input.cuda()
|
||||
target = target.cuda(non_blocking=True)
|
||||
input_search, target_search = next(iter(valid_queue))
|
||||
input_search = input_search.cuda()
|
||||
target_search = target_search.cuda(non_blocking=True)
|
||||
|
||||
## train alpha
|
||||
optimizer.zero_grad(); architect.optimizer.zero_grad()
|
||||
architect.step(input, target, input_search, target_search,
|
||||
return_logits=True)
|
||||
|
||||
## sdarts
|
||||
if perturb_alpha:
|
||||
# transform arch_parameters to prob (for perturbation)
|
||||
model.softmax_arch_parameters()
|
||||
optimizer.zero_grad(); architect.optimizer.zero_grad()
|
||||
perturb_alpha(model, input, target, epsilon_alpha)
|
||||
|
||||
## train weight
|
||||
optimizer.zero_grad(); architect.optimizer.zero_grad()
|
||||
logits, loss = model.step(input, target, args)
|
||||
|
||||
## sdarts
|
||||
if perturb_alpha:
|
||||
## restore alpha to unperturbed arch_parameters
|
||||
model.restore_arch_parameters()
|
||||
|
||||
## logging
|
||||
prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
|
||||
objs.update(loss.data, n)
|
||||
top1.update(prec1.data, n)
|
||||
top5.update(prec5.data, n)
|
||||
if step % args.report_freq == 0:
|
||||
logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
|
||||
|
||||
if args.fast:
|
||||
break
|
||||
|
||||
## one epoch end
|
||||
model.printing(logging)
|
||||
|
||||
train_acc, train_obj = infer(train_queue, model, log=False)
|
||||
logging.info('train_acc %f', train_acc)
|
||||
logging.info('train_loss %f', train_obj)
|
||||
|
||||
valid_acc, valid_obj = infer(valid_queue, model, log=False)
|
||||
logging.info('valid_acc %f', valid_acc)
|
||||
logging.info('valid_loss %f', valid_obj)
|
||||
|
||||
|
||||
logging.info('projection finished')
|
||||
model.printing(logging)
|
||||
num_params = ig_utils.count_parameters_in_Compact(model)
|
||||
genotype = model.genotype()
|
||||
logging.info('param size = %f', num_params)
|
||||
logging.info('genotype = %s', genotype)
|
||||
|
||||
return
|
Reference in New Issue
Block a user