xautodl/lib/nas_rnn/model_search.py
2019-02-01 01:27:38 +11:00

105 lines
3.5 KiB
Python

import copy, torch
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple
from .genotypes import PRIMITIVES, STEPS, CONCAT, Genotype
from .basemodel import DARTSCell, RNNModel
class DARTSCellSearch(DARTSCell):
def __init__(self, ninp, nhid, dropouth, dropoutx):
super(DARTSCellSearch, self).__init__(ninp, nhid, dropouth, dropoutx, genotype=None)
self.bn = nn.BatchNorm1d(nhid, affine=False)
self.check_zero = False
def set_check(self, check_zero):
self.check_zero = check_zero
def cell(self, x, h_prev, x_mask, h_mask, arch_probs):
s0 = self._compute_init_state(x, h_prev, x_mask, h_mask)
s0 = self.bn(s0)
if self.check_zero:
arch_probs_cpu = arch_probs.cpu().tolist()
#arch_probs = F.softmax(self.weights, dim=-1)
offset = 0
states = s0.unsqueeze(0)
for i in range(STEPS):
if self.training:
masked_states = states * h_mask.unsqueeze(0)
else:
masked_states = states
ch = masked_states.view(-1, self.nhid).mm(self._Ws[i]).view(i+1, -1, 2*self.nhid)
c, h = torch.split(ch, self.nhid, dim=-1)
c = c.sigmoid()
s = torch.zeros_like(s0)
for k, name in enumerate(PRIMITIVES):
if name == 'none':
continue
fn = self._get_activation(name)
unweighted = states + c * (fn(h) - states)
if self.check_zero:
INDEX, INDDX = [], []
for jj in range(offset, offset+i+1):
if arch_probs_cpu[jj][k] > 0:
INDEX.append(jj)
INDDX.append(jj-offset)
if len(INDEX) == 0: continue
s += torch.sum(arch_probs[INDEX, k].unsqueeze(-1).unsqueeze(-1) * unweighted[INDDX, :, :], dim=0)
else:
s += torch.sum(arch_probs[offset:offset+i+1, k].unsqueeze(-1).unsqueeze(-1) * unweighted, dim=0)
s = self.bn(s)
states = torch.cat([states, s.unsqueeze(0)], 0)
offset += i+1
output = torch.mean(states[-CONCAT:], dim=0)
return output
class RNNModelSearch(RNNModel):
def __init__(self, *args):
super(RNNModelSearch, self).__init__(*args)
self._args = copy.deepcopy( args )
k = sum(i for i in range(1, STEPS+1))
self.arch_weights = nn.Parameter(torch.Tensor(k, len(PRIMITIVES)))
nn.init.normal_(self.arch_weights, 0, 0.001)
def base_parameters(self):
lists = list(self.lockdrop.parameters())
lists += list(self.encoder.parameters())
lists += list(self.rnns.parameters())
lists += list(self.decoder.parameters())
return lists
def arch_parameters(self):
return [self.arch_weights]
def genotype(self):
def _parse(probs):
gene = []
start = 0
for i in range(STEPS):
end = start + i + 1
W = probs[start:end].copy()
#j = sorted(range(i + 1), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[0]
j = sorted(range(i + 1), key=lambda x: -max(W[x][k] for k in range(len(W[x])) ))[0]
k_best = None
for k in range(len(W[j])):
#if k != PRIMITIVES.index('none'):
# if k_best is None or W[j][k] > W[j][k_best]:
# k_best = k
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
gene.append((PRIMITIVES[k_best], j))
start = end
return gene
with torch.no_grad():
gene = _parse(F.softmax(self.arch_weights, dim=-1).cpu().numpy())
genotype = Genotype(recurrent=gene, concat=list(range(STEPS+1)[-CONCAT:]))
return genotype