import math from copy import deepcopy import torch import torch.nn as nn import torch.nn.functional as F from .construct_utils import drop_path from ..operations import OPS, Identity, FactorizedReduce, ReLUConvBN class MixedOp(nn.Module): def __init__(self, C, stride, PRIMITIVES): super(MixedOp, self).__init__() self._ops = nn.ModuleList() self.name2idx = {} for idx, primitive in enumerate(PRIMITIVES): op = OPS[primitive](C, C, stride, False) self._ops.append(op) assert primitive not in self.name2idx, '{:} has already in'.format(primitive) self.name2idx[primitive] = idx def forward(self, x, weights, op_name): if op_name is None: if weights is None: return [op(x) for op in self._ops] else: return sum(w * op(x) for w, op in zip(weights, self._ops)) else: op_index = self.name2idx[op_name] return self._ops[op_index](x) class SearchCell(nn.Module): def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, PRIMITIVES, use_residual): super(SearchCell, self).__init__() self.reduction = reduction self.PRIMITIVES = deepcopy(PRIMITIVES) if reduction_prev: self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine=False) else: self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False) self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False) self._steps = steps self._multiplier = multiplier self._use_residual = use_residual self._ops = nn.ModuleList() for i in range(self._steps): for j in range(2+i): stride = 2 if reduction and j < 2 else 1 op = MixedOp(C, stride, self.PRIMITIVES) self._ops.append(op) def extra_repr(self): return ('{name}(residual={_use_residual}, steps={_steps}, multiplier={_multiplier})'.format(name=self.__class__.__name__, **self.__dict__)) def forward(self, S0, S1, weights, connect, adjacency, drop_prob, modes): if modes[0] is None: if modes[1] == 'normal': output = self.__forwardBoth(S0, S1, weights, connect, adjacency, drop_prob) elif modes[1] == 'only_W': output = self.__forwardOnlyW(S0, S1, drop_prob) else: test_genotype = modes[0] if self.reduction: operations, concats = test_genotype.reduce, test_genotype.reduce_concat else : operations, concats = test_genotype.normal, test_genotype.normal_concat s0, s1 = self.preprocess0(S0), self.preprocess1(S1) states, offset = [s0, s1], 0 assert self._steps == len(operations), '{:} vs. {:}'.format(self._steps, len(operations)) for i, (opA, opB) in enumerate(operations): A = self._ops[offset + opA[1]](states[opA[1]], None, opA[0]) B = self._ops[offset + opB[1]](states[opB[1]], None, opB[0]) state = A + B offset += len(states) states.append(state) output = torch.cat([states[i] for i in concats], dim=1) if self._use_residual and S1.size() == output.size(): return S1 + output else: return output def __forwardBoth(self, S0, S1, weights, connect, adjacency, drop_prob): s0, s1 = self.preprocess0(S0), self.preprocess1(S1) states, offset = [s0, s1], 0 for i in range(self._steps): clist = [] for j, h in enumerate(states): x = self._ops[offset+j](h, weights[offset+j], None) if self.training and drop_prob > 0.: x = drop_path(x, math.pow(drop_prob, 1./len(states))) clist.append( x ) connection = torch.mm(connect['{:}'.format(i)], adjacency[i]).squeeze(0) state = sum(w * node for w, node in zip(connection, clist)) offset += len(states) states.append(state) return torch.cat(states[-self._multiplier:], dim=1) def __forwardOnlyW(self, S0, S1, drop_prob): s0, s1 = self.preprocess0(S0), self.preprocess1(S1) states, offset = [s0, s1], 0 for i in range(self._steps): clist = [] for j, h in enumerate(states): xs = self._ops[offset+j](h, None, None) clist += xs if self.training and drop_prob > 0.: xlist = [drop_path(x, math.pow(drop_prob, 1./len(states))) for x in clist] else: xlist = clist state = sum(xlist) * 2 / len(xlist) offset += len(states) states.append(state) return torch.cat(states[-self._multiplier:], dim=1) class InferCell(nn.Module): def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): super(InferCell, self).__init__() print(C_prev_prev, C_prev, C) if reduction_prev is None: self.preprocess0 = Identity() elif reduction_prev: self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2) else: self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0) if reduction: step_ops, concat = genotype.reduce, genotype.reduce_concat else : step_ops, concat = genotype.normal, genotype.normal_concat self._steps = len(step_ops) self._concat = concat self._multiplier = len(concat) self._ops = nn.ModuleList() self._indices = [] for operations in step_ops: for name, index in operations: stride = 2 if reduction and index < 2 else 1 if reduction_prev is None and index == 0: op = OPS[name](C_prev_prev, C, stride, True) else: op = OPS[name](C , C, stride, True) self._ops.append( op ) self._indices.append( index ) def extra_repr(self): return ('{name}(steps={_steps}, concat={_concat})'.format(name=self.__class__.__name__, **self.__dict__)) def forward(self, S0, S1, drop_prob): s0 = self.preprocess0(S0) s1 = self.preprocess1(S1) states = [s0, s1] for i in range(self._steps): h1 = states[self._indices[2*i]] h2 = states[self._indices[2*i+1]] op1 = self._ops[2*i] op2 = self._ops[2*i+1] h1 = op1(h1) h2 = op2(h2) if self.training and drop_prob > 0.: if not isinstance(op1, Identity): h1 = drop_path(h1, drop_prob) if not isinstance(op2, Identity): h2 = drop_path(h2, drop_prob) state = h1 + h2 states += [state] output = torch.cat([states[i] for i in self._concat], dim=1) return output