update GDAS
This commit is contained in:
		| @@ -47,35 +47,17 @@ class SearchCell(nn.Module): | ||||
|     return nodes[-1] | ||||
|  | ||||
|   # GDAS | ||||
|   def forward_gdas(self, inputs, alphas, _tau): | ||||
|     avoid_zero = 0 | ||||
|     while True: | ||||
|       gumbels = -torch.empty_like(alphas).exponential_().log() | ||||
|       logits  = (alphas.log_softmax(dim=1) + gumbels) / _tau | ||||
|       probs   = nn.functional.softmax(logits, dim=1) | ||||
|       index   = probs.max(-1, keepdim=True)[1] | ||||
|       one_h   = torch.zeros_like(logits).scatter_(-1, index, 1.0) | ||||
|       hardwts = one_h - probs.detach() + probs | ||||
|       if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): | ||||
|         continue # avoid the numerical error | ||||
|       nodes   = [inputs] | ||||
|       for i in range(1, self.max_nodes): | ||||
|         inter_nodes = [] | ||||
|         for j in range(i): | ||||
|           node_str = '{:}<-{:}'.format(i, j) | ||||
|           weights  = hardwts[ self.edge2index[node_str] ] | ||||
|           argmaxs  = index[ self.edge2index[node_str] ].item() | ||||
|           weigsum  = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) ) | ||||
|           inter_nodes.append( weigsum ) | ||||
|         nodes.append( sum(inter_nodes) ) | ||||
|       avoid_zero += 1 | ||||
|       if nodes[-1].sum().item() == 0: | ||||
|         if avoid_zero < 10: continue | ||||
|         else: | ||||
|           warnings.warn('get zero outputs with avoid_zero={:}'.format(avoid_zero)) | ||||
|           break | ||||
|       else: | ||||
|         break | ||||
|   def forward_gdas(self, inputs, hardwts, index): | ||||
|     nodes   = [inputs] | ||||
|     for i in range(1, self.max_nodes): | ||||
|       inter_nodes = [] | ||||
|       for j in range(i): | ||||
|         node_str = '{:}<-{:}'.format(i, j) | ||||
|         weights  = hardwts[ self.edge2index[node_str] ] | ||||
|         argmaxs  = index[ self.edge2index[node_str] ].item() | ||||
|         weigsum  = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) ) | ||||
|         inter_nodes.append( weigsum ) | ||||
|       nodes.append( sum(inter_nodes) ) | ||||
|     return nodes[-1] | ||||
|  | ||||
|   # joint | ||||
|   | ||||
| @@ -81,13 +81,21 @@ class TinyNetworkGDAS(nn.Module): | ||||
|     return Structure( genotypes ) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     while True: | ||||
|       gumbels = -torch.empty_like(self.arch_parameters).exponential_().log() | ||||
|       logits  = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau | ||||
|       probs   = nn.functional.softmax(logits, dim=1) | ||||
|       index   = probs.max(-1, keepdim=True)[1] | ||||
|       one_h   = torch.zeros_like(logits).scatter_(-1, index, 1.0) | ||||
|       hardwts = one_h - probs.detach() + probs | ||||
|       if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): continue | ||||
|  | ||||
|     feature = self.stem(inputs) | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       if isinstance(cell, SearchCell): | ||||
|         feature = cell.forward_gdas(feature, self.arch_parameters, self.tau) | ||||
|         feature = cell.forward_gdas(feature, hardwts, index) | ||||
|       else: | ||||
|         feature = cell(feature) | ||||
|  | ||||
|     out = self.lastact(feature) | ||||
|     out = self.global_pooling( out ) | ||||
|     out = out.view(out.size(0), -1) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user