174 lines
6.2 KiB
Python
174 lines
6.2 KiB
Python
|
import math
|
||
|
from copy import deepcopy
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from .construct_utils import drop_path
|
||
|
from ..operations import OPS, Identity, FactorizedReduce, ReLUConvBN
|
||
|
|
||
|
|
||
|
class MixedOp(nn.Module):
|
||
|
|
||
|
def __init__(self, C, stride, PRIMITIVES):
|
||
|
super(MixedOp, self).__init__()
|
||
|
self._ops = nn.ModuleList()
|
||
|
self.name2idx = {}
|
||
|
for idx, primitive in enumerate(PRIMITIVES):
|
||
|
op = OPS[primitive](C, C, stride, False)
|
||
|
self._ops.append(op)
|
||
|
assert primitive not in self.name2idx, '{:} has already in'.format(primitive)
|
||
|
self.name2idx[primitive] = idx
|
||
|
|
||
|
def forward(self, x, weights, op_name):
|
||
|
if op_name is None:
|
||
|
if weights is None:
|
||
|
return [op(x) for op in self._ops]
|
||
|
else:
|
||
|
return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||
|
else:
|
||
|
op_index = self.name2idx[op_name]
|
||
|
return self._ops[op_index](x)
|
||
|
|
||
|
|
||
|
|
||
|
class SearchCell(nn.Module):
|
||
|
|
||
|
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, PRIMITIVES, use_residual):
|
||
|
super(SearchCell, self).__init__()
|
||
|
self.reduction = reduction
|
||
|
self.PRIMITIVES = deepcopy(PRIMITIVES)
|
||
|
|
||
|
if reduction_prev:
|
||
|
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine=False)
|
||
|
else:
|
||
|
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
|
||
|
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
|
||
|
self._steps = steps
|
||
|
self._multiplier = multiplier
|
||
|
self._use_residual = use_residual
|
||
|
|
||
|
self._ops = nn.ModuleList()
|
||
|
for i in range(self._steps):
|
||
|
for j in range(2+i):
|
||
|
stride = 2 if reduction and j < 2 else 1
|
||
|
op = MixedOp(C, stride, self.PRIMITIVES)
|
||
|
self._ops.append(op)
|
||
|
|
||
|
def extra_repr(self):
|
||
|
return ('{name}(residual={_use_residual}, steps={_steps}, multiplier={_multiplier})'.format(name=self.__class__.__name__, **self.__dict__))
|
||
|
|
||
|
def forward(self, S0, S1, weights, connect, adjacency, drop_prob, modes):
|
||
|
if modes[0] is None:
|
||
|
if modes[1] == 'normal':
|
||
|
output = self.__forwardBoth(S0, S1, weights, connect, adjacency, drop_prob)
|
||
|
elif modes[1] == 'only_W':
|
||
|
output = self.__forwardOnlyW(S0, S1, drop_prob)
|
||
|
else:
|
||
|
test_genotype = modes[0]
|
||
|
if self.reduction: operations, concats = test_genotype.reduce, test_genotype.reduce_concat
|
||
|
else : operations, concats = test_genotype.normal, test_genotype.normal_concat
|
||
|
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
|
||
|
states, offset = [s0, s1], 0
|
||
|
assert self._steps == len(operations), '{:} vs. {:}'.format(self._steps, len(operations))
|
||
|
for i, (opA, opB) in enumerate(operations):
|
||
|
A = self._ops[offset + opA[1]](states[opA[1]], None, opA[0])
|
||
|
B = self._ops[offset + opB[1]](states[opB[1]], None, opB[0])
|
||
|
state = A + B
|
||
|
offset += len(states)
|
||
|
states.append(state)
|
||
|
output = torch.cat([states[i] for i in concats], dim=1)
|
||
|
if self._use_residual and S1.size() == output.size():
|
||
|
return S1 + output
|
||
|
else: return output
|
||
|
|
||
|
def __forwardBoth(self, S0, S1, weights, connect, adjacency, drop_prob):
|
||
|
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
|
||
|
states, offset = [s0, s1], 0
|
||
|
for i in range(self._steps):
|
||
|
clist = []
|
||
|
for j, h in enumerate(states):
|
||
|
x = self._ops[offset+j](h, weights[offset+j], None)
|
||
|
if self.training and drop_prob > 0.:
|
||
|
x = drop_path(x, math.pow(drop_prob, 1./len(states)))
|
||
|
clist.append( x )
|
||
|
connection = torch.mm(connect['{:}'.format(i)], adjacency[i]).squeeze(0)
|
||
|
state = sum(w * node for w, node in zip(connection, clist))
|
||
|
offset += len(states)
|
||
|
states.append(state)
|
||
|
return torch.cat(states[-self._multiplier:], dim=1)
|
||
|
|
||
|
def __forwardOnlyW(self, S0, S1, drop_prob):
|
||
|
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
|
||
|
states, offset = [s0, s1], 0
|
||
|
for i in range(self._steps):
|
||
|
clist = []
|
||
|
for j, h in enumerate(states):
|
||
|
xs = self._ops[offset+j](h, None, None)
|
||
|
clist += xs
|
||
|
if self.training and drop_prob > 0.:
|
||
|
xlist = [drop_path(x, math.pow(drop_prob, 1./len(states))) for x in clist]
|
||
|
else: xlist = clist
|
||
|
state = sum(xlist) * 2 / len(xlist)
|
||
|
offset += len(states)
|
||
|
states.append(state)
|
||
|
return torch.cat(states[-self._multiplier:], dim=1)
|
||
|
|
||
|
|
||
|
|
||
|
class InferCell(nn.Module):
|
||
|
|
||
|
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||
|
super(InferCell, self).__init__()
|
||
|
print(C_prev_prev, C_prev, C)
|
||
|
|
||
|
if reduction_prev is None:
|
||
|
self.preprocess0 = Identity()
|
||
|
elif reduction_prev:
|
||
|
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2)
|
||
|
else:
|
||
|
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
|
||
|
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
|
||
|
|
||
|
if reduction: step_ops, concat = genotype.reduce, genotype.reduce_concat
|
||
|
else : step_ops, concat = genotype.normal, genotype.normal_concat
|
||
|
self._steps = len(step_ops)
|
||
|
self._concat = concat
|
||
|
self._multiplier = len(concat)
|
||
|
self._ops = nn.ModuleList()
|
||
|
self._indices = []
|
||
|
for operations in step_ops:
|
||
|
for name, index in operations:
|
||
|
stride = 2 if reduction and index < 2 else 1
|
||
|
if reduction_prev is None and index == 0:
|
||
|
op = OPS[name](C_prev_prev, C, stride, True)
|
||
|
else:
|
||
|
op = OPS[name](C , C, stride, True)
|
||
|
self._ops.append( op )
|
||
|
self._indices.append( index )
|
||
|
|
||
|
def extra_repr(self):
|
||
|
return ('{name}(steps={_steps}, concat={_concat})'.format(name=self.__class__.__name__, **self.__dict__))
|
||
|
|
||
|
def forward(self, S0, S1, drop_prob):
|
||
|
s0 = self.preprocess0(S0)
|
||
|
s1 = self.preprocess1(S1)
|
||
|
|
||
|
states = [s0, s1]
|
||
|
for i in range(self._steps):
|
||
|
h1 = states[self._indices[2*i]]
|
||
|
h2 = states[self._indices[2*i+1]]
|
||
|
op1 = self._ops[2*i]
|
||
|
op2 = self._ops[2*i+1]
|
||
|
h1 = op1(h1)
|
||
|
h2 = op2(h2)
|
||
|
if self.training and drop_prob > 0.:
|
||
|
if not isinstance(op1, Identity):
|
||
|
h1 = drop_path(h1, drop_prob)
|
||
|
if not isinstance(op2, Identity):
|
||
|
h2 = drop_path(h2, drop_prob)
|
||
|
|
||
|
state = h1 + h2
|
||
|
states += [state]
|
||
|
output = torch.cat([states[i] for i in self._concat], dim=1)
|
||
|
return output
|