Update TAS abd FBV2 for NAS-Bench
This commit is contained in:
		| @@ -5,13 +5,14 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import OPS | ||||
|  | ||||
| from models.cell_operations import OPS | ||||
|  | ||||
|  | ||||
| # Cell for NAS-Bench-201 | ||||
| class InferCell(nn.Module): | ||||
|  | ||||
|   def __init__(self, genotype, C_in, C_out, stride): | ||||
|   def __init__(self, genotype, C_in, C_out, stride, affine=True, track_running_stats=True): | ||||
|     super(InferCell, self).__init__() | ||||
|  | ||||
|     self.layers  = nn.ModuleList() | ||||
| @@ -24,9 +25,9 @@ class InferCell(nn.Module): | ||||
|       cur_innod = [] | ||||
|       for (op_name, op_in) in node_info: | ||||
|         if op_in == 0: | ||||
|           layer = OPS[op_name](C_in , C_out, stride, True, True) | ||||
|           layer = OPS[op_name](C_in , C_out, stride, affine, track_running_stats) | ||||
|         else: | ||||
|           layer = OPS[op_name](C_out, C_out,      1, True, True) | ||||
|           layer = OPS[op_name](C_out, C_out,      1, affine, track_running_stats) | ||||
|         cur_index.append( len(self.layers) ) | ||||
|         cur_innod.append( op_in ) | ||||
|         self.layers.append( layer ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user