update
This commit is contained in:
80
nasbench201/search_model_darts_proj.py
Normal file
80
nasbench201/search_model_darts_proj.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import torch
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .search_model import TinyNetwork as TinyNetwork
|
||||
from .genotypes import Structure
|
||||
from torch.autograd import Variable
|
||||
|
||||
class TinyNetworkDartsProj(TinyNetwork):
|
||||
def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args,
|
||||
affine=False, track_running_stats=True, stem_channels=3):
|
||||
super(TinyNetworkDartsProj, self).__init__(C, N, max_nodes, num_classes, criterion, search_space, args,
|
||||
affine=affine, track_running_stats=track_running_stats, stem_channels=stem_channels)
|
||||
self.theta_map = lambda x: torch.softmax(x, dim=-1)
|
||||
|
||||
#### for edgewise projection
|
||||
self.candidate_flags = torch.tensor(len(self._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda()
|
||||
self.proj_weights = torch.zeros_like(self._arch_parameters)
|
||||
|
||||
def project_op(self, eid, opid):
|
||||
self.proj_weights[eid][opid] = 1 ## hard by default
|
||||
self.candidate_flags[eid] = False
|
||||
|
||||
def get_projected_weights(self):
|
||||
weights = self.theta_map(self._arch_parameters)
|
||||
|
||||
## proj op
|
||||
for eid in range(len(self._arch_parameters)):
|
||||
if not self.candidate_flags[eid]:
|
||||
weights[eid].data.copy_(self.proj_weights[eid])
|
||||
|
||||
return weights
|
||||
|
||||
def forward(self, inputs, weights=None):
|
||||
with torch.autograd.set_detect_anomaly(True):
|
||||
if weights is None:
|
||||
weights = self.get_projected_weights()
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell(feature, weights)
|
||||
else:
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling( out )
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return logits
|
||||
|
||||
#### utils
|
||||
def get_theta(self):
|
||||
return self.get_projected_weights()
|
||||
|
||||
def arch_parameters(self):
|
||||
return [self._arch_parameters]
|
||||
|
||||
def set_arch_parameters(self, new_alphas):
|
||||
for eid, alpha in enumerate(self.arch_parameters()):
|
||||
alpha.data.copy_(new_alphas[eid])
|
||||
|
||||
def reset_arch_parameters(self):
|
||||
self._arch_parameters = Variable(1e-3*torch.randn(self.num_edge, len(self.op_names)).cuda(), requires_grad=True)
|
||||
self.candidate_flags = torch.tensor(len(self._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda()
|
||||
self.proj_weights = torch.zeros_like(self._arch_parameters)
|
||||
|
||||
def genotype(self):
|
||||
proj_weights = self.get_projected_weights()
|
||||
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
with torch.no_grad():
|
||||
weights = proj_weights[ self.edge2index[node_str] ]
|
||||
op_name = self.op_names[ weights.argmax().item() ]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append( tuple(xlist) )
|
||||
return Structure( genotypes )
|
||||
Reference in New Issue
Block a user