import math
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from .construct_utils import drop_path
from ..operations import OPS, Identity, FactorizedReduce, ReLUConvBN
class MixedOp(nn.Module):
def __init__(self, C, stride, PRIMITIVES):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
self.name2idx = {}
for idx, primitive in enumerate(PRIMITIVES):
op = OPS[primitive](C, C, stride, False)
assert primitive not in self.name2idx, '{:} has already in'.format(primitive)
self.name2idx[primitive] = idx
def forward(self, x, weights, op_name):
if op_name is None:
if weights is None:
return [op(x) for op in self._ops]
return sum(w * op(x) for w, op in zip(weights, self._ops))
op_index = self.name2idx[op_name]
return self._ops[op_index](x)
class SearchCell(nn.Module):
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, PRIMITIVES, use_residual):
super(SearchCell, self).__init__()
self.reduction = reduction
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine=False)
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
self._steps = steps
self._multiplier = multiplier
self._use_residual = use_residual
self._ops = nn.ModuleList()
for i in range(self._steps):
for j in range(2+i):
stride = 2 if reduction and j < 2 else 1
op = MixedOp(C, stride, self.PRIMITIVES)
def extra_repr(self):
return ('{name}(residual={_use_residual}, steps={_steps}, multiplier={_multiplier})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, S0, S1, weights, connect, adjacency, drop_prob, modes):
if modes[0] is None:
if modes[1] == 'normal':
output = self.__forwardBoth(S0, S1, weights, connect, adjacency, drop_prob)
elif modes[1] == 'only_W':
output = self.__forwardOnlyW(S0, S1, drop_prob)
test_genotype = modes[0]
if self.reduction: operations, concats = test_genotype.reduce, test_genotype.reduce_concat
else : operations, concats = test_genotype.normal, test_genotype.normal_concat
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
states, offset = [s0, s1], 0
assert self._steps == len(operations), '{:} vs. {:}'.format(self._steps, len(operations))
for i, (opA, opB) in enumerate(operations):
A = self._ops[offset + opA[1]](states[opA[1]], None, opA[0])
B = self._ops[offset + opB[1]](states[opB[1]], None, opB[0])
state = A + B
offset += len(states)
output = torch.cat([states[i] for i in concats], dim=1)
if self._use_residual and S1.size() == output.size():
return S1 + output
else: return output
def __forwardBoth(self, S0, S1, weights, connect, adjacency, drop_prob):
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
states, offset = [s0, s1], 0
for i in range(self._steps):
clist = []
for j, h in enumerate(states):
x = self._ops[offset+j](h, weights[offset+j], None)
if self.training and drop_prob > 0.:
x = drop_path(x, math.pow(drop_prob, 1./len(states)))
clist.append( x )
connection = torch.mm(connect['{:}'.format(i)], adjacency[i]).squeeze(0)
state = sum(w * node for w, node in zip(connection, clist))
offset += len(states)
return torch.cat(states[-self._multiplier:], dim=1)
def __forwardOnlyW(self, S0, S1, drop_prob):
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
states, offset = [s0, s1], 0
for i in range(self._steps):
clist = []
for j, h in enumerate(states):
xs = self._ops[offset+j](h, None, None)
clist += xs
if self.training and drop_prob > 0.:
xlist = [drop_path(x, math.pow(drop_prob, 1./len(states))) for x in clist]
else: xlist = clist
state = sum(xlist) * 2 / len(xlist)
offset += len(states)
return torch.cat(states[-self._multiplier:], dim=1)
class InferCell(nn.Module):
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
super(InferCell, self).__init__()
print(C_prev_prev, C_prev, C)
if reduction_prev is None:
self.preprocess0 = Identity()
elif reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2)
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
if reduction: step_ops, concat = genotype.reduce, genotype.reduce_concat
else : step_ops, concat = genotype.normal, genotype.normal_concat
self._steps = len(step_ops)
self._concat = concat
self._multiplier = len(concat)
self._ops = nn.ModuleList()
self._indices = []
for operations in step_ops:
for name, index in operations:
stride = 2 if reduction and index < 2 else 1
if reduction_prev is None and index == 0:
op = OPS[name](C_prev_prev, C, stride, True)
op = OPS[name](C , C, stride, True)
self._ops.append( op )
self._indices.append( index )
def extra_repr(self):
return ('{name}(steps={_steps}, concat={_concat})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, S0, S1, drop_prob):
s0 = self.preprocess0(S0)
s1 = self.preprocess1(S1)
states = [s0, s1]
for i in range(self._steps):
h1 = states[self._indices[2*i]]
h2 = states[self._indices[2*i+1]]
op1 = self._ops[2*i]
op2 = self._ops[2*i+1]
h1 = op1(h1)
h2 = op2(h2)
if self.training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
state = h1 + h2
states += [state]
output = torch.cat([states[i] for i in self._concat], dim=1)
return output