update
This commit is contained in:
52
nasbench201/architect_ig.py
Normal file
52
nasbench201/architect_ig.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Architect(object):
|
||||
def __init__(self, model, args):
|
||||
self.network_momentum = args.momentum
|
||||
self.network_weight_decay = args.weight_decay
|
||||
self.model = model
|
||||
self.optimizer = torch.optim.Adam(self.model.arch_parameters(),
|
||||
lr=args.arch_learning_rate, betas=(0.5, 0.999),
|
||||
weight_decay=args.arch_weight_decay)
|
||||
|
||||
self._init_arch_parameters = []
|
||||
for alpha in self.model.arch_parameters():
|
||||
alpha_init = torch.zeros_like(alpha)
|
||||
alpha_init.data.copy_(alpha)
|
||||
self._init_arch_parameters.append(alpha_init)
|
||||
|
||||
#### mode
|
||||
if args.method in ['darts', 'darts-proj', 'sdarts', 'sdarts-proj']:
|
||||
self.method = 'fo' # first order update
|
||||
elif 'so' in args.method:
|
||||
print('ERROR: PLEASE USE architect.py for second order darts')
|
||||
elif args.method in ['blank', 'blank-proj']:
|
||||
self.method = 'blank'
|
||||
else:
|
||||
print('ERROR: WRONG ARCH UPDATE METHOD', args.method); exit(0)
|
||||
|
||||
def reset_arch_parameters(self):
|
||||
for alpha, alpha_init in zip(self.model.arch_parameters(), self._init_arch_parameters):
|
||||
alpha.data.copy_(alpha_init.data)
|
||||
|
||||
def step(self, input_train, target_train, input_valid, target_valid, *args, **kwargs):
|
||||
if self.method == 'fo':
|
||||
shared = self._step_fo(input_train, target_train, input_valid, target_valid)
|
||||
elif self.method == 'so':
|
||||
raise NotImplementedError
|
||||
elif self.method == 'blank': ## do not update alpha
|
||||
shared = None
|
||||
|
||||
return shared
|
||||
|
||||
#### first order
|
||||
def _step_fo(self, input_train, target_train, input_valid, target_valid):
|
||||
loss = self.model._loss(input_valid, target_valid)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
return None
|
||||
|
||||
#### darts 2nd order
|
||||
def _step_darts_so(self, input_train, target_train, input_valid, target_valid, eta, model_optimizer):
|
||||
raise NotImplementedError
|
Reference in New Issue
Block a user