174 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			174 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | import math | ||
|  | from copy import deepcopy | ||
|  | import torch | ||
|  | import torch.nn as nn | ||
|  | import torch.nn.functional as F | ||
|  | from .construct_utils import drop_path | ||
|  | from ..operations import OPS, Identity, FactorizedReduce, ReLUConvBN | ||
|  | 
 | ||
|  | 
 | ||
|  | class MixedOp(nn.Module): | ||
|  | 
 | ||
|  |   def __init__(self, C, stride, PRIMITIVES): | ||
|  |     super(MixedOp, self).__init__() | ||
|  |     self._ops = nn.ModuleList() | ||
|  |     self.name2idx = {} | ||
|  |     for idx, primitive in enumerate(PRIMITIVES): | ||
|  |       op = OPS[primitive](C, C, stride, False) | ||
|  |       self._ops.append(op) | ||
|  |       assert primitive not in self.name2idx, '{:} has already in'.format(primitive) | ||
|  |       self.name2idx[primitive] = idx | ||
|  | 
 | ||
|  |   def forward(self, x, weights, op_name): | ||
|  |     if op_name is None: | ||
|  |       if weights is None: | ||
|  |         return [op(x) for op in self._ops] | ||
|  |       else: | ||
|  |         return sum(w * op(x) for w, op in zip(weights, self._ops)) | ||
|  |     else: | ||
|  |       op_index = self.name2idx[op_name] | ||
|  |       return self._ops[op_index](x) | ||
|  | 
 | ||
|  | 
 | ||
|  | 
 | ||
|  | class SearchCell(nn.Module): | ||
|  | 
 | ||
|  |   def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, PRIMITIVES, use_residual): | ||
|  |     super(SearchCell, self).__init__() | ||
|  |     self.reduction  = reduction | ||
|  |     self.PRIMITIVES = deepcopy(PRIMITIVES) | ||
|  |    | ||
|  |     if reduction_prev: | ||
|  |       self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine=False) | ||
|  |     else: | ||
|  |       self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False) | ||
|  |     self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False) | ||
|  |     self._steps        = steps | ||
|  |     self._multiplier   = multiplier | ||
|  |     self._use_residual = use_residual | ||
|  | 
 | ||
|  |     self._ops = nn.ModuleList() | ||
|  |     for i in range(self._steps): | ||
|  |       for j in range(2+i): | ||
|  |         stride = 2 if reduction and j < 2 else 1 | ||
|  |         op = MixedOp(C, stride, self.PRIMITIVES) | ||
|  |         self._ops.append(op) | ||
|  | 
 | ||
|  |   def extra_repr(self): | ||
|  |     return ('{name}(residual={_use_residual}, steps={_steps}, multiplier={_multiplier})'.format(name=self.__class__.__name__, **self.__dict__)) | ||
|  | 
 | ||
|  |   def forward(self, S0, S1, weights, connect, adjacency, drop_prob, modes): | ||
|  |     if modes[0] is None: | ||
|  |       if modes[1] == 'normal': | ||
|  |         output = self.__forwardBoth(S0, S1, weights, connect, adjacency, drop_prob) | ||
|  |       elif modes[1] == 'only_W': | ||
|  |         output = self.__forwardOnlyW(S0, S1, drop_prob) | ||
|  |     else: | ||
|  |       test_genotype = modes[0] | ||
|  |       if self.reduction: operations, concats = test_genotype.reduce, test_genotype.reduce_concat | ||
|  |       else             : operations, concats = test_genotype.normal, test_genotype.normal_concat | ||
|  |       s0, s1 = self.preprocess0(S0), self.preprocess1(S1) | ||
|  |       states, offset = [s0, s1], 0 | ||
|  |       assert self._steps == len(operations), '{:} vs. {:}'.format(self._steps, len(operations)) | ||
|  |       for i, (opA, opB) in enumerate(operations): | ||
|  |         A = self._ops[offset + opA[1]](states[opA[1]], None, opA[0]) | ||
|  |         B = self._ops[offset + opB[1]](states[opB[1]], None, opB[0]) | ||
|  |         state = A + B | ||
|  |         offset += len(states) | ||
|  |         states.append(state) | ||
|  |       output = torch.cat([states[i] for i in concats], dim=1) | ||
|  |     if self._use_residual and S1.size() == output.size(): | ||
|  |       return S1 + output | ||
|  |     else: return output | ||
|  |    | ||
|  |   def __forwardBoth(self, S0, S1, weights, connect, adjacency, drop_prob): | ||
|  |     s0, s1 = self.preprocess0(S0), self.preprocess1(S1) | ||
|  |     states, offset = [s0, s1], 0 | ||
|  |     for i in range(self._steps): | ||
|  |       clist = [] | ||
|  |       for j, h in enumerate(states): | ||
|  |         x = self._ops[offset+j](h, weights[offset+j], None) | ||
|  |         if self.training and drop_prob > 0.: | ||
|  |           x = drop_path(x, math.pow(drop_prob, 1./len(states))) | ||
|  |         clist.append( x ) | ||
|  |       connection = torch.mm(connect['{:}'.format(i)], adjacency[i]).squeeze(0) | ||
|  |       state = sum(w * node for w, node in zip(connection, clist)) | ||
|  |       offset += len(states) | ||
|  |       states.append(state) | ||
|  |     return torch.cat(states[-self._multiplier:], dim=1) | ||
|  | 
 | ||
|  |   def __forwardOnlyW(self, S0, S1, drop_prob): | ||
|  |     s0, s1 = self.preprocess0(S0), self.preprocess1(S1) | ||
|  |     states, offset = [s0, s1], 0 | ||
|  |     for i in range(self._steps): | ||
|  |       clist = [] | ||
|  |       for j, h in enumerate(states): | ||
|  |         xs = self._ops[offset+j](h, None, None) | ||
|  |         clist += xs | ||
|  |       if self.training and drop_prob > 0.: | ||
|  |         xlist = [drop_path(x, math.pow(drop_prob, 1./len(states))) for x in clist] | ||
|  |       else: xlist = clist | ||
|  |       state = sum(xlist) * 2 / len(xlist) | ||
|  |       offset += len(states) | ||
|  |       states.append(state) | ||
|  |     return torch.cat(states[-self._multiplier:], dim=1) | ||
|  | 
 | ||
|  | 
 | ||
|  | 
 | ||
|  | class InferCell(nn.Module): | ||
|  | 
 | ||
|  |   def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): | ||
|  |     super(InferCell, self).__init__() | ||
|  |     print(C_prev_prev, C_prev, C) | ||
|  | 
 | ||
|  |     if reduction_prev is None: | ||
|  |       self.preprocess0 = Identity() | ||
|  |     elif reduction_prev: | ||
|  |       self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2) | ||
|  |     else: | ||
|  |       self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) | ||
|  |     self.preprocess1   = ReLUConvBN(C_prev, C, 1, 1, 0) | ||
|  |      | ||
|  |     if reduction: step_ops, concat = genotype.reduce, genotype.reduce_concat | ||
|  |     else        : step_ops, concat = genotype.normal, genotype.normal_concat | ||
|  |     self._steps        = len(step_ops) | ||
|  |     self._concat       = concat | ||
|  |     self._multiplier   = len(concat) | ||
|  |     self._ops          = nn.ModuleList() | ||
|  |     self._indices      = [] | ||
|  |     for operations in step_ops: | ||
|  |       for name, index in operations: | ||
|  |         stride = 2 if reduction and index < 2 else 1 | ||
|  |         if reduction_prev is None and index == 0: | ||
|  |           op = OPS[name](C_prev_prev, C, stride, True) | ||
|  |         else: | ||
|  |           op = OPS[name](C          , C, stride, True) | ||
|  |         self._ops.append( op ) | ||
|  |         self._indices.append( index ) | ||
|  | 
 | ||
|  |   def extra_repr(self): | ||
|  |     return ('{name}(steps={_steps}, concat={_concat})'.format(name=self.__class__.__name__, **self.__dict__)) | ||
|  | 
 | ||
|  |   def forward(self, S0, S1, drop_prob): | ||
|  |     s0 = self.preprocess0(S0) | ||
|  |     s1 = self.preprocess1(S1) | ||
|  | 
 | ||
|  |     states = [s0, s1] | ||
|  |     for i in range(self._steps): | ||
|  |       h1 = states[self._indices[2*i]] | ||
|  |       h2 = states[self._indices[2*i+1]] | ||
|  |       op1 = self._ops[2*i] | ||
|  |       op2 = self._ops[2*i+1] | ||
|  |       h1 = op1(h1) | ||
|  |       h2 = op2(h2) | ||
|  |       if self.training and drop_prob > 0.: | ||
|  |         if not isinstance(op1, Identity): | ||
|  |           h1 = drop_path(h1, drop_prob) | ||
|  |         if not isinstance(op2, Identity): | ||
|  |           h2 = drop_path(h2, drop_prob) | ||
|  | 
 | ||
|  |       state = h1 + h2 | ||
|  |       states += [state] | ||
|  |     output = torch.cat([states[i] for i in self._concat], dim=1) | ||
|  |     return output |