Use black for lib/models
This commit is contained in:
		| @@ -7,161 +7,280 @@ from ..initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|    | ||||
|   def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu): | ||||
|     super(ConvBNReLU, self).__init__() | ||||
|     if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|     else       : self.avg = None | ||||
|     self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) | ||||
|     if has_bn  : self.bn  = nn.BatchNorm2d(nOut) | ||||
|     else       : self.bn  = None | ||||
|     if has_relu: self.relu = nn.ReLU(inplace=True) | ||||
|     else       : self.relu = None | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         if has_bn: | ||||
|             self.bn = nn.BatchNorm2d(nOut) | ||||
|         else: | ||||
|             self.bn = None | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     if self.avg : out = self.avg( inputs ) | ||||
|     else        : out = inputs | ||||
|     conv = self.conv( out ) | ||||
|     if self.bn  : out = self.bn( conv ) | ||||
|     else        : out = conv | ||||
|     if self.relu: out = self.relu( out ) | ||||
|     else        : out = out | ||||
|     def forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.bn: | ||||
|             out = self.bn(conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|  | ||||
|     return out | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|   num_conv  = 2 | ||||
|   expansion = 1 | ||||
|   def __init__(self, iCs, stride): | ||||
|     super(ResNetBasicblock, self).__init__() | ||||
|     assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) | ||||
|     assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs ) | ||||
|     assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs) | ||||
|      | ||||
|     self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) | ||||
|     self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3,      1, 1, False, has_avg=False, has_bn=True, has_relu=False) | ||||
|     residual_in = iCs[0] | ||||
|     if stride == 2: | ||||
|       self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False) | ||||
|       residual_in = iCs[2] | ||||
|     elif iCs[0] != iCs[2]: | ||||
|       self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False) | ||||
|     else: | ||||
|       self.downsample = None | ||||
|     #self.out_dim  = max(residual_in, iCs[2]) | ||||
|     self.out_dim  = iCs[2] | ||||
|     num_conv = 2 | ||||
|     expansion = 1 | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     basicblock = self.conv_a(inputs) | ||||
|     basicblock = self.conv_b(basicblock) | ||||
|     def __init__(self, iCs, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         assert isinstance(iCs, tuple) or isinstance( | ||||
|             iCs, list | ||||
|         ), "invalid type of iCs : {:}".format(iCs) | ||||
|         assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs) | ||||
|  | ||||
|     if self.downsample is not None: | ||||
|       residual = self.downsample(inputs) | ||||
|     else: | ||||
|       residual = inputs | ||||
|     out = residual + basicblock | ||||
|     return F.relu(out, inplace=True) | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             iCs[0], | ||||
|             iCs[1], | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         residual_in = iCs[0] | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[2], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[2] | ||||
|         elif iCs[0] != iCs[2]: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[2], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         # self.out_dim  = max(residual_in, iCs[2]) | ||||
|         self.out_dim = iCs[2] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + basicblock | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|   expansion = 4 | ||||
|   num_conv  = 3 | ||||
|   def __init__(self, iCs, stride): | ||||
|     super(ResNetBottleneck, self).__init__() | ||||
|     assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) | ||||
|     assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs ) | ||||
|     assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs) | ||||
|     self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1,      1, 0, False, has_avg=False, has_bn=True, has_relu=True) | ||||
|     self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) | ||||
|     self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1,      1, 0, False, has_avg=False, has_bn=True, has_relu=False) | ||||
|     residual_in = iCs[0] | ||||
|     if stride == 2: | ||||
|       self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False) | ||||
|       residual_in     = iCs[3] | ||||
|     elif iCs[0] != iCs[3]: | ||||
|       self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False) | ||||
|       residual_in     = iCs[3] | ||||
|     else: | ||||
|       self.downsample = None | ||||
|     #self.out_dim = max(residual_in, iCs[3]) | ||||
|     self.out_dim = iCs[3] | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     def __init__(self, iCs, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         assert isinstance(iCs, tuple) or isinstance( | ||||
|             iCs, list | ||||
|         ), "invalid type of iCs : {:}".format(iCs) | ||||
|         assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             iCs[1], | ||||
|             iCs[2], | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         residual_in = iCs[0] | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[3], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[3] | ||||
|         elif iCs[0] != iCs[3]: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[3], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[3] | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         # self.out_dim = max(residual_in, iCs[3]) | ||||
|         self.out_dim = iCs[3] | ||||
|  | ||||
|     bottleneck = self.conv_1x1(inputs) | ||||
|     bottleneck = self.conv_3x3(bottleneck) | ||||
|     bottleneck = self.conv_1x4(bottleneck) | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|     if self.downsample is not None: | ||||
|       residual = self.downsample(inputs) | ||||
|     else: | ||||
|       residual = inputs | ||||
|     out = residual + bottleneck | ||||
|     return F.relu(out, inplace=True) | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + bottleneck | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class InferCifarResNet(nn.Module): | ||||
|     def __init__( | ||||
|         self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual | ||||
|     ): | ||||
|         super(InferCifarResNet, self).__init__() | ||||
|  | ||||
|   def __init__(self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual): | ||||
|     super(InferCifarResNet, self).__init__() | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|         assert len(xblocks) == 3, "invalid xblocks : {:}".format(xblocks) | ||||
|  | ||||
|     #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|     if block_name == 'ResNetBasicblock': | ||||
|       block = ResNetBasicblock | ||||
|       assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' | ||||
|       layer_blocks = (depth - 2) // 6 | ||||
|     elif block_name == 'ResNetBottleneck': | ||||
|       block = ResNetBottleneck | ||||
|       assert (depth - 2) % 9 == 0, 'depth should be one of 164' | ||||
|       layer_blocks = (depth - 2) // 9 | ||||
|     else: | ||||
|       raise ValueError('invalid block : {:}'.format(block_name)) | ||||
|     assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks) | ||||
|         self.message = ( | ||||
|             "InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.xchannels = xchannels | ||||
|         self.layers = nn.ModuleList( | ||||
|             [ | ||||
|                 ConvBNReLU( | ||||
|                     xchannels[0], | ||||
|                     xchannels[1], | ||||
|                     3, | ||||
|                     1, | ||||
|                     1, | ||||
|                     False, | ||||
|                     has_avg=False, | ||||
|                     has_bn=True, | ||||
|                     has_relu=True, | ||||
|                 ) | ||||
|             ] | ||||
|         ) | ||||
|         last_channel_idx = 1 | ||||
|         for stage in range(3): | ||||
|             for iL in range(layer_blocks): | ||||
|                 num_conv = block.num_conv | ||||
|                 iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1] | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iCs, stride) | ||||
|                 last_channel_idx += num_conv | ||||
|                 self.xchannels[last_channel_idx] = module.out_dim | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iCs, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|                 if iL + 1 == xblocks[stage]:  # reach the maximum depth | ||||
|                     out_channel = module.out_dim | ||||
|                     for iiL in range(iL + 1, layer_blocks): | ||||
|                         last_channel_idx += num_conv | ||||
|                     self.xchannels[last_channel_idx] = module.out_dim | ||||
|                     break | ||||
|  | ||||
|     self.message     = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks) | ||||
|     self.num_classes = num_classes | ||||
|     self.xchannels   = xchannels | ||||
|     self.layers      = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] ) | ||||
|     last_channel_idx = 1 | ||||
|     for stage in range(3): | ||||
|       for iL in range(layer_blocks): | ||||
|         num_conv = block.num_conv  | ||||
|         iCs      = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1] | ||||
|         stride   = 2 if stage > 0 and iL == 0 else 1 | ||||
|         module   = block(iCs, stride) | ||||
|         last_channel_idx += num_conv | ||||
|         self.xchannels[last_channel_idx] = module.out_dim | ||||
|         self.layers.append  ( module ) | ||||
|         self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride) | ||||
|         if iL + 1 == xblocks[stage]: # reach the maximum depth | ||||
|           out_channel = module.out_dim | ||||
|           for iiL in range(iL+1, layer_blocks): | ||||
|             last_channel_idx += num_conv | ||||
|           self.xchannels[last_channel_idx] = module.out_dim | ||||
|           break | ||||
|    | ||||
|     self.avgpool    = nn.AvgPool2d(8) | ||||
|     self.classifier = nn.Linear(self.xchannels[-1], num_classes) | ||||
|      | ||||
|     self.apply(initialize_resnet) | ||||
|     if zero_init_residual: | ||||
|       for m in self.modules(): | ||||
|         if isinstance(m, ResNetBasicblock): | ||||
|           nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|         elif isinstance(m, ResNetBottleneck): | ||||
|           nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(self.xchannels[-1], num_classes) | ||||
|  | ||||
|   def get_message(self): | ||||
|     return self.message | ||||
|         self.apply(initialize_resnet) | ||||
|         if zero_init_residual: | ||||
|             for m in self.modules(): | ||||
|                 if isinstance(m, ResNetBasicblock): | ||||
|                     nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|                 elif isinstance(m, ResNetBottleneck): | ||||
|                     nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     x = inputs | ||||
|     for i, layer in enumerate(self.layers): | ||||
|       x = layer( x ) | ||||
|     features = self.avgpool(x) | ||||
|     features = features.view(features.size(0), -1) | ||||
|     logits   = self.classifier(features) | ||||
|     return features, logits | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
|   | ||||
| @@ -7,144 +7,257 @@ from ..initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|    | ||||
|   def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu): | ||||
|     super(ConvBNReLU, self).__init__() | ||||
|     if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|     else       : self.avg = None | ||||
|     self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) | ||||
|     if has_bn  : self.bn  = nn.BatchNorm2d(nOut) | ||||
|     else       : self.bn  = None | ||||
|     if has_relu: self.relu = nn.ReLU(inplace=True) | ||||
|     else       : self.relu = None | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         if has_bn: | ||||
|             self.bn = nn.BatchNorm2d(nOut) | ||||
|         else: | ||||
|             self.bn = None | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     if self.avg : out = self.avg( inputs ) | ||||
|     else        : out = inputs | ||||
|     conv = self.conv( out ) | ||||
|     if self.bn  : out = self.bn( conv ) | ||||
|     else        : out = conv | ||||
|     if self.relu: out = self.relu( out ) | ||||
|     else        : out = out | ||||
|     def forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.bn: | ||||
|             out = self.bn(conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|  | ||||
|     return out | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|   num_conv  = 2 | ||||
|   expansion = 1 | ||||
|   def __init__(self, inplanes, planes, stride): | ||||
|     super(ResNetBasicblock, self).__init__() | ||||
|     assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) | ||||
|      | ||||
|     self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) | ||||
|     self.conv_b = ConvBNReLU(  planes, planes, 3,      1, 1, False, has_avg=False, has_bn=True, has_relu=False) | ||||
|     if stride == 2: | ||||
|       self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False) | ||||
|     elif inplanes != planes: | ||||
|       self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False) | ||||
|     else: | ||||
|       self.downsample = None | ||||
|     self.out_dim  = planes | ||||
|     num_conv = 2 | ||||
|     expansion = 1 | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     basicblock = self.conv_a(inputs) | ||||
|     basicblock = self.conv_b(basicblock) | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|  | ||||
|     if self.downsample is not None: | ||||
|       residual = self.downsample(inputs) | ||||
|     else: | ||||
|       residual = inputs | ||||
|     out = residual + basicblock | ||||
|     return F.relu(out, inplace=True) | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             inplanes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + basicblock | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|   expansion = 4 | ||||
|   num_conv  = 3 | ||||
|   def __init__(self, inplanes, planes, stride): | ||||
|     super(ResNetBottleneck, self).__init__() | ||||
|     assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) | ||||
|     self.conv_1x1 = ConvBNReLU(inplanes, planes, 1,      1, 0, False, has_avg=False, has_bn=True, has_relu=True) | ||||
|     self.conv_3x3 = ConvBNReLU(  planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) | ||||
|     self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1,      1, 0, False, has_avg=False, has_bn=True, has_relu=False) | ||||
|     if stride == 2: | ||||
|       self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False) | ||||
|     elif inplanes != planes*self.expansion: | ||||
|       self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False) | ||||
|     else: | ||||
|       self.downsample = None | ||||
|     self.out_dim = planes*self.expansion | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes * self.expansion, | ||||
|             1, | ||||
|             1, | ||||
|             0, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=False, | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes * self.expansion: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes * self.expansion | ||||
|  | ||||
|     bottleneck = self.conv_1x1(inputs) | ||||
|     bottleneck = self.conv_3x3(bottleneck) | ||||
|     bottleneck = self.conv_1x4(bottleneck) | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|     if self.downsample is not None: | ||||
|       residual = self.downsample(inputs) | ||||
|     else: | ||||
|       residual = inputs | ||||
|     out = residual + bottleneck | ||||
|     return F.relu(out, inplace=True) | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + bottleneck | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class InferDepthCifarResNet(nn.Module): | ||||
|     def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual): | ||||
|         super(InferDepthCifarResNet, self).__init__() | ||||
|  | ||||
|   def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual): | ||||
|     super(InferDepthCifarResNet, self).__init__() | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|         assert len(xblocks) == 3, "invalid xblocks : {:}".format(xblocks) | ||||
|  | ||||
|     #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|     if block_name == 'ResNetBasicblock': | ||||
|       block = ResNetBasicblock | ||||
|       assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' | ||||
|       layer_blocks = (depth - 2) // 6 | ||||
|     elif block_name == 'ResNetBottleneck': | ||||
|       block = ResNetBottleneck | ||||
|       assert (depth - 2) % 9 == 0, 'depth should be one of 164' | ||||
|       layer_blocks = (depth - 2) // 9 | ||||
|     else: | ||||
|       raise ValueError('invalid block : {:}'.format(block_name)) | ||||
|     assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks) | ||||
|         self.message = ( | ||||
|             "InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.layers = nn.ModuleList( | ||||
|             [ | ||||
|                 ConvBNReLU( | ||||
|                     3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True | ||||
|                 ) | ||||
|             ] | ||||
|         ) | ||||
|         self.channels = [16] | ||||
|         for stage in range(3): | ||||
|             for iL in range(layer_blocks): | ||||
|                 iC = self.channels[-1] | ||||
|                 planes = 16 * (2 ** stage) | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iC, planes, stride) | ||||
|                 self.channels.append(module.out_dim) | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     planes, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|                 if iL + 1 == xblocks[stage]:  # reach the maximum depth | ||||
|                     break | ||||
|  | ||||
|     self.message     = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks) | ||||
|     self.num_classes = num_classes | ||||
|     self.layers      = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] ) | ||||
|     self.channels    = [16] | ||||
|     for stage in range(3): | ||||
|       for iL in range(layer_blocks): | ||||
|         iC       = self.channels[-1] | ||||
|         planes = 16 * (2**stage) | ||||
|         stride   = 2 if stage > 0 and iL == 0 else 1 | ||||
|         module   = block(iC, planes, stride) | ||||
|         self.channels.append( module.out_dim ) | ||||
|         self.layers.append  ( module ) | ||||
|         self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, planes, module.out_dim, stride) | ||||
|         if iL + 1 == xblocks[stage]: # reach the maximum depth | ||||
|           break | ||||
|    | ||||
|     self.avgpool    = nn.AvgPool2d(8) | ||||
|     self.classifier = nn.Linear(self.channels[-1], num_classes) | ||||
|      | ||||
|     self.apply(initialize_resnet) | ||||
|     if zero_init_residual: | ||||
|       for m in self.modules(): | ||||
|         if isinstance(m, ResNetBasicblock): | ||||
|           nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|         elif isinstance(m, ResNetBottleneck): | ||||
|           nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(self.channels[-1], num_classes) | ||||
|  | ||||
|   def get_message(self): | ||||
|     return self.message | ||||
|         self.apply(initialize_resnet) | ||||
|         if zero_init_residual: | ||||
|             for m in self.modules(): | ||||
|                 if isinstance(m, ResNetBasicblock): | ||||
|                     nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|                 elif isinstance(m, ResNetBottleneck): | ||||
|                     nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     x = inputs | ||||
|     for i, layer in enumerate(self.layers): | ||||
|       x = layer( x ) | ||||
|     features = self.avgpool(x) | ||||
|     features = features.view(features.size(0), -1) | ||||
|     logits   = self.classifier(features) | ||||
|     return features, logits | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
|   | ||||
| @@ -7,154 +7,271 @@ from ..initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|    | ||||
|   def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu): | ||||
|     super(ConvBNReLU, self).__init__() | ||||
|     if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|     else       : self.avg = None | ||||
|     self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) | ||||
|     if has_bn  : self.bn  = nn.BatchNorm2d(nOut) | ||||
|     else       : self.bn  = None | ||||
|     if has_relu: self.relu = nn.ReLU(inplace=True) | ||||
|     else       : self.relu = None | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         if has_bn: | ||||
|             self.bn = nn.BatchNorm2d(nOut) | ||||
|         else: | ||||
|             self.bn = None | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     if self.avg : out = self.avg( inputs ) | ||||
|     else        : out = inputs | ||||
|     conv = self.conv( out ) | ||||
|     if self.bn  : out = self.bn( conv ) | ||||
|     else        : out = conv | ||||
|     if self.relu: out = self.relu( out ) | ||||
|     else        : out = out | ||||
|     def forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.bn: | ||||
|             out = self.bn(conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|  | ||||
|     return out | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|   num_conv  = 2 | ||||
|   expansion = 1 | ||||
|   def __init__(self, iCs, stride): | ||||
|     super(ResNetBasicblock, self).__init__() | ||||
|     assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) | ||||
|     assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs ) | ||||
|     assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs) | ||||
|      | ||||
|     self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) | ||||
|     self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3,      1, 1, False, has_avg=False, has_bn=True, has_relu=False) | ||||
|     residual_in = iCs[0] | ||||
|     if stride == 2: | ||||
|       self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False) | ||||
|       residual_in = iCs[2] | ||||
|     elif iCs[0] != iCs[2]: | ||||
|       self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False) | ||||
|     else: | ||||
|       self.downsample = None | ||||
|     #self.out_dim  = max(residual_in, iCs[2]) | ||||
|     self.out_dim  = iCs[2] | ||||
|     num_conv = 2 | ||||
|     expansion = 1 | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     basicblock = self.conv_a(inputs) | ||||
|     basicblock = self.conv_b(basicblock) | ||||
|     def __init__(self, iCs, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         assert isinstance(iCs, tuple) or isinstance( | ||||
|             iCs, list | ||||
|         ), "invalid type of iCs : {:}".format(iCs) | ||||
|         assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs) | ||||
|  | ||||
|     if self.downsample is not None: | ||||
|       residual = self.downsample(inputs) | ||||
|     else: | ||||
|       residual = inputs | ||||
|     out = residual + basicblock | ||||
|     return F.relu(out, inplace=True) | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             iCs[0], | ||||
|             iCs[1], | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         residual_in = iCs[0] | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[2], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[2] | ||||
|         elif iCs[0] != iCs[2]: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[2], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         # self.out_dim  = max(residual_in, iCs[2]) | ||||
|         self.out_dim = iCs[2] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + basicblock | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|   expansion = 4 | ||||
|   num_conv  = 3 | ||||
|   def __init__(self, iCs, stride): | ||||
|     super(ResNetBottleneck, self).__init__() | ||||
|     assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) | ||||
|     assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs ) | ||||
|     assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs) | ||||
|     self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1,      1, 0, False, has_avg=False, has_bn=True, has_relu=True) | ||||
|     self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) | ||||
|     self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1,      1, 0, False, has_avg=False, has_bn=True, has_relu=False) | ||||
|     residual_in = iCs[0] | ||||
|     if stride == 2: | ||||
|       self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False) | ||||
|       residual_in     = iCs[3] | ||||
|     elif iCs[0] != iCs[3]: | ||||
|       self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False) | ||||
|       residual_in     = iCs[3] | ||||
|     else: | ||||
|       self.downsample = None | ||||
|     #self.out_dim = max(residual_in, iCs[3]) | ||||
|     self.out_dim = iCs[3] | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     def __init__(self, iCs, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         assert isinstance(iCs, tuple) or isinstance( | ||||
|             iCs, list | ||||
|         ), "invalid type of iCs : {:}".format(iCs) | ||||
|         assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             iCs[1], | ||||
|             iCs[2], | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         residual_in = iCs[0] | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[3], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[3] | ||||
|         elif iCs[0] != iCs[3]: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[3], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[3] | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         # self.out_dim = max(residual_in, iCs[3]) | ||||
|         self.out_dim = iCs[3] | ||||
|  | ||||
|     bottleneck = self.conv_1x1(inputs) | ||||
|     bottleneck = self.conv_3x3(bottleneck) | ||||
|     bottleneck = self.conv_1x4(bottleneck) | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|     if self.downsample is not None: | ||||
|       residual = self.downsample(inputs) | ||||
|     else: | ||||
|       residual = inputs | ||||
|     out = residual + bottleneck | ||||
|     return F.relu(out, inplace=True) | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + bottleneck | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class InferWidthCifarResNet(nn.Module): | ||||
|     def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual): | ||||
|         super(InferWidthCifarResNet, self).__init__() | ||||
|  | ||||
|   def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual): | ||||
|     super(InferWidthCifarResNet, self).__init__() | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|  | ||||
|     #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|     if block_name == 'ResNetBasicblock': | ||||
|       block = ResNetBasicblock | ||||
|       assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' | ||||
|       layer_blocks = (depth - 2) // 6 | ||||
|     elif block_name == 'ResNetBottleneck': | ||||
|       block = ResNetBottleneck | ||||
|       assert (depth - 2) % 9 == 0, 'depth should be one of 164' | ||||
|       layer_blocks = (depth - 2) // 9 | ||||
|     else: | ||||
|       raise ValueError('invalid block : {:}'.format(block_name)) | ||||
|         self.message = ( | ||||
|             "InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.xchannels = xchannels | ||||
|         self.layers = nn.ModuleList( | ||||
|             [ | ||||
|                 ConvBNReLU( | ||||
|                     xchannels[0], | ||||
|                     xchannels[1], | ||||
|                     3, | ||||
|                     1, | ||||
|                     1, | ||||
|                     False, | ||||
|                     has_avg=False, | ||||
|                     has_bn=True, | ||||
|                     has_relu=True, | ||||
|                 ) | ||||
|             ] | ||||
|         ) | ||||
|         last_channel_idx = 1 | ||||
|         for stage in range(3): | ||||
|             for iL in range(layer_blocks): | ||||
|                 num_conv = block.num_conv | ||||
|                 iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1] | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iCs, stride) | ||||
|                 last_channel_idx += num_conv | ||||
|                 self.xchannels[last_channel_idx] = module.out_dim | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iCs, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|  | ||||
|     self.message     = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks) | ||||
|     self.num_classes = num_classes | ||||
|     self.xchannels   = xchannels | ||||
|     self.layers      = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] ) | ||||
|     last_channel_idx = 1 | ||||
|     for stage in range(3): | ||||
|       for iL in range(layer_blocks): | ||||
|         num_conv = block.num_conv  | ||||
|         iCs      = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1] | ||||
|         stride   = 2 if stage > 0 and iL == 0 else 1 | ||||
|         module   = block(iCs, stride) | ||||
|         last_channel_idx += num_conv | ||||
|         self.xchannels[last_channel_idx] = module.out_dim | ||||
|         self.layers.append  ( module ) | ||||
|         self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride) | ||||
|    | ||||
|     self.avgpool    = nn.AvgPool2d(8) | ||||
|     self.classifier = nn.Linear(self.xchannels[-1], num_classes) | ||||
|      | ||||
|     self.apply(initialize_resnet) | ||||
|     if zero_init_residual: | ||||
|       for m in self.modules(): | ||||
|         if isinstance(m, ResNetBasicblock): | ||||
|           nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|         elif isinstance(m, ResNetBottleneck): | ||||
|           nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(self.xchannels[-1], num_classes) | ||||
|  | ||||
|   def get_message(self): | ||||
|     return self.message | ||||
|         self.apply(initialize_resnet) | ||||
|         if zero_init_residual: | ||||
|             for m in self.modules(): | ||||
|                 if isinstance(m, ResNetBasicblock): | ||||
|                     nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|                 elif isinstance(m, ResNetBottleneck): | ||||
|                     nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     x = inputs | ||||
|     for i, layer in enumerate(self.layers): | ||||
|       x = layer( x ) | ||||
|     features = self.avgpool(x) | ||||
|     features = features.view(features.size(0), -1) | ||||
|     logits   = self.classifier(features) | ||||
|     return features, logits | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
|   | ||||
| @@ -7,164 +7,318 @@ from ..initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|    | ||||
|   num_conv  = 1 | ||||
|   def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu): | ||||
|     super(ConvBNReLU, self).__init__() | ||||
|     if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|     else       : self.avg = None | ||||
|     self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) | ||||
|     if has_bn  : self.bn  = nn.BatchNorm2d(nOut) | ||||
|     else       : self.bn  = None | ||||
|     if has_relu: self.relu = nn.ReLU(inplace=True) | ||||
|     else       : self.relu = None | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     if self.avg : out = self.avg( inputs ) | ||||
|     else        : out = inputs | ||||
|     conv = self.conv( out ) | ||||
|     if self.bn  : out = self.bn( conv ) | ||||
|     else        : out = conv | ||||
|     if self.relu: out = self.relu( out ) | ||||
|     else        : out = out | ||||
|     num_conv = 1 | ||||
|  | ||||
|     return out | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         if has_bn: | ||||
|             self.bn = nn.BatchNorm2d(nOut) | ||||
|         else: | ||||
|             self.bn = None | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.bn: | ||||
|             out = self.bn(conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|  | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|   num_conv  = 2 | ||||
|   expansion = 1 | ||||
|   def __init__(self, iCs, stride): | ||||
|     super(ResNetBasicblock, self).__init__() | ||||
|     assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) | ||||
|     assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs ) | ||||
|     assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs) | ||||
|      | ||||
|     self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) | ||||
|     self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3,      1, 1, False, has_avg=False, has_bn=True, has_relu=False) | ||||
|     residual_in = iCs[0] | ||||
|     if stride == 2: | ||||
|       self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=True, has_relu=False) | ||||
|       residual_in = iCs[2] | ||||
|     elif iCs[0] != iCs[2]: | ||||
|       self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False) | ||||
|     else: | ||||
|       self.downsample = None | ||||
|     #self.out_dim  = max(residual_in, iCs[2]) | ||||
|     self.out_dim  = iCs[2] | ||||
|     num_conv = 2 | ||||
|     expansion = 1 | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     basicblock = self.conv_a(inputs) | ||||
|     basicblock = self.conv_b(basicblock) | ||||
|     def __init__(self, iCs, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         assert isinstance(iCs, tuple) or isinstance( | ||||
|             iCs, list | ||||
|         ), "invalid type of iCs : {:}".format(iCs) | ||||
|         assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs) | ||||
|  | ||||
|     if self.downsample is not None: | ||||
|       residual = self.downsample(inputs) | ||||
|     else: | ||||
|       residual = inputs | ||||
|     out = residual + basicblock | ||||
|     return F.relu(out, inplace=True) | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             iCs[0], | ||||
|             iCs[1], | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         residual_in = iCs[0] | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[2], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[2] | ||||
|         elif iCs[0] != iCs[2]: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[2], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         # self.out_dim  = max(residual_in, iCs[2]) | ||||
|         self.out_dim = iCs[2] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + basicblock | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|   expansion = 4 | ||||
|   num_conv  = 3 | ||||
|   def __init__(self, iCs, stride): | ||||
|     super(ResNetBottleneck, self).__init__() | ||||
|     assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) | ||||
|     assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs ) | ||||
|     assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs) | ||||
|     self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1,      1, 0, False, has_avg=False, has_bn=True, has_relu=True) | ||||
|     self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) | ||||
|     self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1,      1, 0, False, has_avg=False, has_bn=True, has_relu=False) | ||||
|     residual_in = iCs[0] | ||||
|     if stride == 2: | ||||
|       self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=True, has_relu=False) | ||||
|       residual_in     = iCs[3] | ||||
|     elif iCs[0] != iCs[3]: | ||||
|       self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False) | ||||
|       residual_in     = iCs[3] | ||||
|     else: | ||||
|       self.downsample = None | ||||
|     #self.out_dim = max(residual_in, iCs[3]) | ||||
|     self.out_dim = iCs[3] | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     def __init__(self, iCs, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         assert isinstance(iCs, tuple) or isinstance( | ||||
|             iCs, list | ||||
|         ), "invalid type of iCs : {:}".format(iCs) | ||||
|         assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             iCs[1], | ||||
|             iCs[2], | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         residual_in = iCs[0] | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[3], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[3] | ||||
|         elif iCs[0] != iCs[3]: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[3], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[3] | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         # self.out_dim = max(residual_in, iCs[3]) | ||||
|         self.out_dim = iCs[3] | ||||
|  | ||||
|     bottleneck = self.conv_1x1(inputs) | ||||
|     bottleneck = self.conv_3x3(bottleneck) | ||||
|     bottleneck = self.conv_1x4(bottleneck) | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|     if self.downsample is not None: | ||||
|       residual = self.downsample(inputs) | ||||
|     else: | ||||
|       residual = inputs | ||||
|     out = residual + bottleneck | ||||
|     return F.relu(out, inplace=True) | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + bottleneck | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class InferImagenetResNet(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         block_name, | ||||
|         layers, | ||||
|         xblocks, | ||||
|         xchannels, | ||||
|         deep_stem, | ||||
|         num_classes, | ||||
|         zero_init_residual, | ||||
|     ): | ||||
|         super(InferImagenetResNet, self).__init__() | ||||
|  | ||||
|   def __init__(self, block_name, layers, xblocks, xchannels, deep_stem, num_classes, zero_init_residual): | ||||
|     super(InferImagenetResNet, self).__init__() | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "BasicBlock": | ||||
|             block = ResNetBasicblock | ||||
|         elif block_name == "Bottleneck": | ||||
|             block = ResNetBottleneck | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|         assert len(xblocks) == len( | ||||
|             layers | ||||
|         ), "invalid layers : {:} vs xblocks : {:}".format(layers, xblocks) | ||||
|  | ||||
|     #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|     if block_name == 'BasicBlock': | ||||
|       block = ResNetBasicblock | ||||
|     elif block_name == 'Bottleneck': | ||||
|       block = ResNetBottleneck | ||||
|     else: | ||||
|       raise ValueError('invalid block : {:}'.format(block_name)) | ||||
|     assert len(xblocks) == len(layers), 'invalid layers : {:} vs xblocks : {:}'.format(layers, xblocks) | ||||
|         self.message = "InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}".format( | ||||
|             sum(layers) * block.num_conv, sum(xblocks) * block.num_conv, xblocks | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.xchannels = xchannels | ||||
|         if not deep_stem: | ||||
|             self.layers = nn.ModuleList( | ||||
|                 [ | ||||
|                     ConvBNReLU( | ||||
|                         xchannels[0], | ||||
|                         xchannels[1], | ||||
|                         7, | ||||
|                         2, | ||||
|                         3, | ||||
|                         False, | ||||
|                         has_avg=False, | ||||
|                         has_bn=True, | ||||
|                         has_relu=True, | ||||
|                     ) | ||||
|                 ] | ||||
|             ) | ||||
|             last_channel_idx = 1 | ||||
|         else: | ||||
|             self.layers = nn.ModuleList( | ||||
|                 [ | ||||
|                     ConvBNReLU( | ||||
|                         xchannels[0], | ||||
|                         xchannels[1], | ||||
|                         3, | ||||
|                         2, | ||||
|                         1, | ||||
|                         False, | ||||
|                         has_avg=False, | ||||
|                         has_bn=True, | ||||
|                         has_relu=True, | ||||
|                     ), | ||||
|                     ConvBNReLU( | ||||
|                         xchannels[1], | ||||
|                         xchannels[2], | ||||
|                         3, | ||||
|                         1, | ||||
|                         1, | ||||
|                         False, | ||||
|                         has_avg=False, | ||||
|                         has_bn=True, | ||||
|                         has_relu=True, | ||||
|                     ), | ||||
|                 ] | ||||
|             ) | ||||
|             last_channel_idx = 2 | ||||
|         self.layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) | ||||
|         for stage, layer_blocks in enumerate(layers): | ||||
|             for iL in range(layer_blocks): | ||||
|                 num_conv = block.num_conv | ||||
|                 iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1] | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iCs, stride) | ||||
|                 last_channel_idx += num_conv | ||||
|                 self.xchannels[last_channel_idx] = module.out_dim | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iCs, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|                 if iL + 1 == xblocks[stage]:  # reach the maximum depth | ||||
|                     out_channel = module.out_dim | ||||
|                     for iiL in range(iL + 1, layer_blocks): | ||||
|                         last_channel_idx += num_conv | ||||
|                     self.xchannels[last_channel_idx] = module.out_dim | ||||
|                     break | ||||
|         assert last_channel_idx + 1 == len(self.xchannels), "{:} vs {:}".format( | ||||
|             last_channel_idx, len(self.xchannels) | ||||
|         ) | ||||
|         self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||||
|         self.classifier = nn.Linear(self.xchannels[-1], num_classes) | ||||
|  | ||||
|     self.message     = 'InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}'.format(sum(layers)*block.num_conv, sum(xblocks)*block.num_conv, xblocks) | ||||
|     self.num_classes = num_classes | ||||
|     self.xchannels   = xchannels | ||||
|     if not deep_stem: | ||||
|       self.layers      = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 7, 2, 3, False, has_avg=False, has_bn=True, has_relu=True) ] ) | ||||
|       last_channel_idx = 1 | ||||
|     else: | ||||
|       self.layers      = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True) | ||||
|                                          ,ConvBNReLU(xchannels[1], xchannels[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] ) | ||||
|       last_channel_idx = 2 | ||||
|     self.layers.append( nn.MaxPool2d(kernel_size=3, stride=2, padding=1) ) | ||||
|     for stage, layer_blocks in enumerate(layers): | ||||
|       for iL in range(layer_blocks): | ||||
|         num_conv = block.num_conv  | ||||
|         iCs      = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1] | ||||
|         stride   = 2 if stage > 0 and iL == 0 else 1 | ||||
|         module   = block(iCs, stride) | ||||
|         last_channel_idx += num_conv | ||||
|         self.xchannels[last_channel_idx] = module.out_dim | ||||
|         self.layers.append  ( module ) | ||||
|         self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride) | ||||
|         if iL + 1 == xblocks[stage]: # reach the maximum depth | ||||
|           out_channel = module.out_dim | ||||
|           for iiL in range(iL+1, layer_blocks): | ||||
|             last_channel_idx += num_conv | ||||
|           self.xchannels[last_channel_idx] = module.out_dim | ||||
|           break | ||||
|     assert last_channel_idx + 1 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels)) | ||||
|     self.avgpool    = nn.AdaptiveAvgPool2d((1,1)) | ||||
|     self.classifier = nn.Linear(self.xchannels[-1], num_classes) | ||||
|      | ||||
|     self.apply(initialize_resnet) | ||||
|     if zero_init_residual: | ||||
|       for m in self.modules(): | ||||
|         if isinstance(m, ResNetBasicblock): | ||||
|           nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|         elif isinstance(m, ResNetBottleneck): | ||||
|           nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|         self.apply(initialize_resnet) | ||||
|         if zero_init_residual: | ||||
|             for m in self.modules(): | ||||
|                 if isinstance(m, ResNetBasicblock): | ||||
|                     nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|                 elif isinstance(m, ResNetBottleneck): | ||||
|                     nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|  | ||||
|   def get_message(self): | ||||
|     return self.message | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     x = inputs | ||||
|     for i, layer in enumerate(self.layers): | ||||
|       x = layer( x ) | ||||
|     features = self.avgpool(x) | ||||
|     features = features.view(features.size(0), -1) | ||||
|     logits   = self.classifier(features) | ||||
|     return features, logits | ||||
|     def forward(self, inputs): | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
|   | ||||
| @@ -4,119 +4,171 @@ | ||||
| # MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018 | ||||
| from torch import nn | ||||
| from ..initialization import initialize_resnet | ||||
| from ..SharedUtils    import parse_channel_info | ||||
| from ..SharedUtils import parse_channel_info | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|   def __init__(self, in_planes, out_planes, kernel_size, stride, groups, has_bn=True, has_relu=True): | ||||
|     super(ConvBNReLU, self).__init__() | ||||
|     padding = (kernel_size - 1) // 2 | ||||
|     self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False) | ||||
|     if has_bn: self.bn = nn.BatchNorm2d(out_planes) | ||||
|     else     : self.bn = None | ||||
|     if has_relu: self.relu = nn.ReLU6(inplace=True) | ||||
|     else       : self.relu = None | ||||
|    | ||||
|   def forward(self, x): | ||||
|     out = self.conv( x ) | ||||
|     if self.bn:   out = self.bn  ( out ) | ||||
|     if self.relu: out = self.relu( out ) | ||||
|     return out | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_planes, | ||||
|         out_planes, | ||||
|         kernel_size, | ||||
|         stride, | ||||
|         groups, | ||||
|         has_bn=True, | ||||
|         has_relu=True, | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         padding = (kernel_size - 1) // 2 | ||||
|         self.conv = nn.Conv2d( | ||||
|             in_planes, | ||||
|             out_planes, | ||||
|             kernel_size, | ||||
|             stride, | ||||
|             padding, | ||||
|             groups=groups, | ||||
|             bias=False, | ||||
|         ) | ||||
|         if has_bn: | ||||
|             self.bn = nn.BatchNorm2d(out_planes) | ||||
|         else: | ||||
|             self.bn = None | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU6(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|  | ||||
|     def forward(self, x): | ||||
|         out = self.conv(x) | ||||
|         if self.bn: | ||||
|             out = self.bn(out) | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class InvertedResidual(nn.Module): | ||||
|   def __init__(self, channels, stride, expand_ratio, additive): | ||||
|     super(InvertedResidual, self).__init__() | ||||
|     self.stride = stride | ||||
|     assert stride in [1, 2], 'invalid stride : {:}'.format(stride) | ||||
|     assert len(channels) in [2, 3], 'invalid channels : {:}'.format(channels) | ||||
|     def __init__(self, channels, stride, expand_ratio, additive): | ||||
|         super(InvertedResidual, self).__init__() | ||||
|         self.stride = stride | ||||
|         assert stride in [1, 2], "invalid stride : {:}".format(stride) | ||||
|         assert len(channels) in [2, 3], "invalid channels : {:}".format(channels) | ||||
|  | ||||
|     if len(channels) == 2: | ||||
|       layers = [] | ||||
|     else: | ||||
|       layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)] | ||||
|     layers.extend([ | ||||
|       # dw | ||||
|       ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]), | ||||
|       # pw-linear | ||||
|       ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False), | ||||
|     ]) | ||||
|     self.conv = nn.Sequential(*layers) | ||||
|     self.additive = additive | ||||
|     if self.additive and channels[0] != channels[-1]: | ||||
|       self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False) | ||||
|     else: | ||||
|       self.shortcut = None | ||||
|     self.out_dim  = channels[-1] | ||||
|         if len(channels) == 2: | ||||
|             layers = [] | ||||
|         else: | ||||
|             layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)] | ||||
|         layers.extend( | ||||
|             [ | ||||
|                 # dw | ||||
|                 ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]), | ||||
|                 # pw-linear | ||||
|                 ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False), | ||||
|             ] | ||||
|         ) | ||||
|         self.conv = nn.Sequential(*layers) | ||||
|         self.additive = additive | ||||
|         if self.additive and channels[0] != channels[-1]: | ||||
|             self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False) | ||||
|         else: | ||||
|             self.shortcut = None | ||||
|         self.out_dim = channels[-1] | ||||
|  | ||||
|   def forward(self, x): | ||||
|     out = self.conv(x) | ||||
|     # if self.additive: return additive_func(out, x) | ||||
|     if self.shortcut: return out + self.shortcut(x) | ||||
|     else            : return out | ||||
|     def forward(self, x): | ||||
|         out = self.conv(x) | ||||
|         # if self.additive: return additive_func(out, x) | ||||
|         if self.shortcut: | ||||
|             return out + self.shortcut(x) | ||||
|         else: | ||||
|             return out | ||||
|  | ||||
|  | ||||
| class InferMobileNetV2(nn.Module): | ||||
|   def __init__(self, num_classes, xchannels, xblocks, dropout): | ||||
|     super(InferMobileNetV2, self).__init__() | ||||
|     block = InvertedResidual | ||||
|     inverted_residual_setting = [ | ||||
|       # t, c,  n, s | ||||
|       [1, 16 , 1, 1], | ||||
|       [6, 24 , 2, 2], | ||||
|       [6, 32 , 3, 2], | ||||
|       [6, 64 , 4, 2], | ||||
|       [6, 96 , 3, 1], | ||||
|       [6, 160, 3, 2], | ||||
|       [6, 320, 1, 1], | ||||
|     ] | ||||
|     assert len(inverted_residual_setting) == len(xblocks), 'invalid number of layers : {:} vs {:}'.format(len(inverted_residual_setting), len(xblocks)) | ||||
|     for block_num, ir_setting in zip(xblocks, inverted_residual_setting): | ||||
|       assert block_num <= ir_setting[2], '{:} vs {:}'.format(block_num, ir_setting) | ||||
|     xchannels = parse_channel_info(xchannels) | ||||
|     #for i, chs in enumerate(xchannels): | ||||
|     #  if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs) | ||||
|     self.xchannels = xchannels | ||||
|     self.message     = 'InferMobileNetV2 : xblocks={:}'.format(xblocks) | ||||
|     # building first layer | ||||
|     features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)] | ||||
|     last_channel_idx = 1 | ||||
|     def __init__(self, num_classes, xchannels, xblocks, dropout): | ||||
|         super(InferMobileNetV2, self).__init__() | ||||
|         block = InvertedResidual | ||||
|         inverted_residual_setting = [ | ||||
|             # t, c,  n, s | ||||
|             [1, 16, 1, 1], | ||||
|             [6, 24, 2, 2], | ||||
|             [6, 32, 3, 2], | ||||
|             [6, 64, 4, 2], | ||||
|             [6, 96, 3, 1], | ||||
|             [6, 160, 3, 2], | ||||
|             [6, 320, 1, 1], | ||||
|         ] | ||||
|         assert len(inverted_residual_setting) == len( | ||||
|             xblocks | ||||
|         ), "invalid number of layers : {:} vs {:}".format( | ||||
|             len(inverted_residual_setting), len(xblocks) | ||||
|         ) | ||||
|         for block_num, ir_setting in zip(xblocks, inverted_residual_setting): | ||||
|             assert block_num <= ir_setting[2], "{:} vs {:}".format( | ||||
|                 block_num, ir_setting | ||||
|             ) | ||||
|         xchannels = parse_channel_info(xchannels) | ||||
|         # for i, chs in enumerate(xchannels): | ||||
|         #  if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs) | ||||
|         self.xchannels = xchannels | ||||
|         self.message = "InferMobileNetV2 : xblocks={:}".format(xblocks) | ||||
|         # building first layer | ||||
|         features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)] | ||||
|         last_channel_idx = 1 | ||||
|  | ||||
|     # building inverted residual blocks | ||||
|     for stage, (t, c, n, s) in enumerate(inverted_residual_setting): | ||||
|       for i in range(n): | ||||
|         stride = s if i == 0 else 1 | ||||
|         additv = True if i > 0 else False | ||||
|         module = block(self.xchannels[last_channel_idx], stride, t, additv) | ||||
|         features.append(module) | ||||
|         self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format(stage, i, n, len(features), self.xchannels[last_channel_idx], stride, t, c) | ||||
|         last_channel_idx += 1 | ||||
|         if i + 1 == xblocks[stage]: | ||||
|           out_channel = module.out_dim | ||||
|           for iiL in range(i+1, n): | ||||
|             last_channel_idx += 1 | ||||
|           self.xchannels[last_channel_idx][0] = module.out_dim | ||||
|           break | ||||
|     # building last several layers | ||||
|     features.append(ConvBNReLU(self.xchannels[last_channel_idx][0], self.xchannels[last_channel_idx][1], 1, 1, 1)) | ||||
|     assert last_channel_idx + 2 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels)) | ||||
|     # make it nn.Sequential | ||||
|     self.features = nn.Sequential(*features) | ||||
|         # building inverted residual blocks | ||||
|         for stage, (t, c, n, s) in enumerate(inverted_residual_setting): | ||||
|             for i in range(n): | ||||
|                 stride = s if i == 0 else 1 | ||||
|                 additv = True if i > 0 else False | ||||
|                 module = block(self.xchannels[last_channel_idx], stride, t, additv) | ||||
|                 features.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format( | ||||
|                     stage, | ||||
|                     i, | ||||
|                     n, | ||||
|                     len(features), | ||||
|                     self.xchannels[last_channel_idx], | ||||
|                     stride, | ||||
|                     t, | ||||
|                     c, | ||||
|                 ) | ||||
|                 last_channel_idx += 1 | ||||
|                 if i + 1 == xblocks[stage]: | ||||
|                     out_channel = module.out_dim | ||||
|                     for iiL in range(i + 1, n): | ||||
|                         last_channel_idx += 1 | ||||
|                     self.xchannels[last_channel_idx][0] = module.out_dim | ||||
|                     break | ||||
|         # building last several layers | ||||
|         features.append( | ||||
|             ConvBNReLU( | ||||
|                 self.xchannels[last_channel_idx][0], | ||||
|                 self.xchannels[last_channel_idx][1], | ||||
|                 1, | ||||
|                 1, | ||||
|                 1, | ||||
|             ) | ||||
|         ) | ||||
|         assert last_channel_idx + 2 == len(self.xchannels), "{:} vs {:}".format( | ||||
|             last_channel_idx, len(self.xchannels) | ||||
|         ) | ||||
|         # make it nn.Sequential | ||||
|         self.features = nn.Sequential(*features) | ||||
|  | ||||
|     # building classifier | ||||
|     self.classifier = nn.Sequential( | ||||
|       nn.Dropout(dropout), | ||||
|       nn.Linear(self.xchannels[last_channel_idx][1], num_classes), | ||||
|     ) | ||||
|         # building classifier | ||||
|         self.classifier = nn.Sequential( | ||||
|             nn.Dropout(dropout), | ||||
|             nn.Linear(self.xchannels[last_channel_idx][1], num_classes), | ||||
|         ) | ||||
|  | ||||
|     # weight initialization | ||||
|     self.apply( initialize_resnet ) | ||||
|         # weight initialization | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|   def get_message(self): | ||||
|     return self.message | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     features = self.features(inputs) | ||||
|     vectors  = features.mean([2, 3]) | ||||
|     predicts = self.classifier(vectors) | ||||
|     return features, predicts | ||||
|     def forward(self, inputs): | ||||
|         features = self.features(inputs) | ||||
|         vectors = features.mean([2, 3]) | ||||
|         predicts = self.classifier(vectors) | ||||
|         return features, predicts | ||||
|   | ||||
| @@ -8,51 +8,57 @@ from models.cell_infers.cells import InferCell | ||||
|  | ||||
|  | ||||
| class DynamicShapeTinyNet(nn.Module): | ||||
|     def __init__(self, channels: List[int], genotype: Any, num_classes: int): | ||||
|         super(DynamicShapeTinyNet, self).__init__() | ||||
|         self._channels = channels | ||||
|         if len(channels) % 3 != 2: | ||||
|             raise ValueError("invalid number of layers : {:}".format(len(channels))) | ||||
|         self._num_stage = N = len(channels) // 3 | ||||
|  | ||||
|   def __init__(self, channels: List[int], genotype: Any, num_classes: int): | ||||
|     super(DynamicShapeTinyNet, self).__init__() | ||||
|     self._channels = channels | ||||
|     if len(channels) % 3 != 2: | ||||
|       raise ValueError('invalid number of layers : {:}'.format(len(channels))) | ||||
|     self._num_stage = N = len(channels) // 3 | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(channels[0]), | ||||
|         ) | ||||
|  | ||||
|     self.stem = nn.Sequential( | ||||
|                     nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False), | ||||
|                     nn.BatchNorm2d(channels[0])) | ||||
|         # layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|     # layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N     | ||||
|     layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|         c_prev = channels[0] | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(c_prev, c_curr, 2, True) | ||||
|             else: | ||||
|                 cell = InferCell(genotype, c_prev, c_curr, 1) | ||||
|             self.cells.append(cell) | ||||
|             c_prev = cell.out_dim | ||||
|         self._num_layer = len(self.cells) | ||||
|  | ||||
|     c_prev = channels[0] | ||||
|     self.cells = nn.ModuleList() | ||||
|     for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)): | ||||
|       if reduction : cell = ResNetBasicblock(c_prev, c_curr, 2, True) | ||||
|       else         : cell = InferCell(genotype, c_prev, c_curr, 1) | ||||
|       self.cells.append( cell ) | ||||
|       c_prev = cell.out_dim | ||||
|     self._num_layer = len(self.cells) | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(c_prev, num_classes) | ||||
|  | ||||
|     self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True)) | ||||
|     self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|     self.classifier = nn.Linear(c_prev, num_classes) | ||||
|     def get_message(self) -> Text: | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|   def get_message(self) -> Text: | ||||
|     string = self.extra_repr() | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) | ||||
|     return string | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_channels}, N={_num_stage}, L={_num_layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(C={_channels}, N={_num_stage}, L={_num_layer})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|     def forward(self, inputs): | ||||
|         feature = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             feature = cell(feature) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     feature = self.stem(inputs) | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       feature = cell(feature) | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|     out = self.lastact(feature) | ||||
|     out = self.global_pooling( out ) | ||||
|     out = out.view(out.size(0), -1) | ||||
|     logits = self.classifier(out) | ||||
|  | ||||
|     return out, logits | ||||
|         return out, logits | ||||
|   | ||||
| @@ -6,4 +6,4 @@ from .InferImagenetResNet import InferImagenetResNet | ||||
| from .InferCifarResNet_depth import InferDepthCifarResNet | ||||
| from .InferCifarResNet import InferCifarResNet | ||||
| from .InferMobileNetV2 import InferMobileNetV2 | ||||
| from .InferTinyCellNet import DynamicShapeTinyNet | ||||
| from .InferTinyCellNet import DynamicShapeTinyNet | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| def parse_channel_info(xstring): | ||||
|   blocks = xstring.split(' ') | ||||
|   blocks = [x.split('-') for x in blocks] | ||||
|   blocks = [[int(_) for _ in x] for x in blocks] | ||||
|   return blocks | ||||
|     blocks = xstring.split(" ") | ||||
|     blocks = [x.split("-") for x in blocks] | ||||
|     blocks = [[int(_) for _ in x] for x in blocks] | ||||
|     return blocks | ||||
|   | ||||
		Reference in New Issue
	
	Block a user