102 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			102 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
 | |
| from torch import nn
 | |
| from .initialization import initialize_resnet
 | |
| 
 | |
| 
 | |
| class ConvBNReLU(nn.Module):
 | |
|   def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
 | |
|     super(ConvBNReLU, self).__init__()
 | |
|     padding = (kernel_size - 1) // 2
 | |
|     self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False)
 | |
|     self.bn   = nn.BatchNorm2d(out_planes)
 | |
|     self.relu = nn.ReLU6(inplace=True)
 | |
|   
 | |
|   def forward(self, x):
 | |
|     out = self.conv( x )
 | |
|     out = self.bn  ( out )
 | |
|     out = self.relu( out )
 | |
|     return out
 | |
| 
 | |
| 
 | |
| class InvertedResidual(nn.Module):
 | |
|   def __init__(self, inp, oup, stride, expand_ratio):
 | |
|     super(InvertedResidual, self).__init__()
 | |
|     self.stride = stride
 | |
|     assert stride in [1, 2]
 | |
| 
 | |
|     hidden_dim = int(round(inp * expand_ratio))
 | |
|     self.use_res_connect = self.stride == 1 and inp == oup
 | |
| 
 | |
|     layers = []
 | |
|     if expand_ratio != 1:
 | |
|       # pw
 | |
|       layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
 | |
|     layers.extend([
 | |
|       # dw
 | |
|       ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
 | |
|       # pw-linear
 | |
|       nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
 | |
|       nn.BatchNorm2d(oup),
 | |
|     ])
 | |
|     self.conv = nn.Sequential(*layers)
 | |
| 
 | |
|   def forward(self, x):
 | |
|     if self.use_res_connect:
 | |
|       return x + self.conv(x)
 | |
|     else:
 | |
|       return self.conv(x)
 | |
| 
 | |
| 
 | |
| class MobileNetV2(nn.Module):
 | |
|   def __init__(self, num_classes, width_mult, input_channel, last_channel, block_name, dropout):
 | |
|     super(MobileNetV2, self).__init__()
 | |
|     if block_name == 'InvertedResidual':
 | |
|       block = InvertedResidual
 | |
|     else:
 | |
|       raise ValueError('invalid block name : {:}'.format(block_name))
 | |
|     inverted_residual_setting = [
 | |
|       # t, c,  n, s
 | |
|       [1, 16 , 1, 1],
 | |
|       [6, 24 , 2, 2],
 | |
|       [6, 32 , 3, 2],
 | |
|       [6, 64 , 4, 2],
 | |
|       [6, 96 , 3, 1],
 | |
|       [6, 160, 3, 2],
 | |
|       [6, 320, 1, 1],
 | |
|     ]
 | |
| 
 | |
|     # building first layer
 | |
|     input_channel = int(input_channel * width_mult)
 | |
|     self.last_channel = int(last_channel * max(1.0, width_mult))
 | |
|     features = [ConvBNReLU(3, input_channel, stride=2)]
 | |
|     # building inverted residual blocks
 | |
|     for t, c, n, s in inverted_residual_setting:
 | |
|       output_channel = int(c * width_mult)
 | |
|       for i in range(n):
 | |
|         stride = s if i == 0 else 1
 | |
|         features.append(block(input_channel, output_channel, stride, expand_ratio=t))
 | |
|         input_channel = output_channel
 | |
|     # building last several layers
 | |
|     features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
 | |
|     # make it nn.Sequential
 | |
|     self.features = nn.Sequential(*features)
 | |
| 
 | |
|     # building classifier
 | |
|     self.classifier = nn.Sequential(
 | |
|       nn.Dropout(dropout),
 | |
|       nn.Linear(self.last_channel, num_classes),
 | |
|     )
 | |
|     self.message = 'MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}'.format(width_mult, input_channel, last_channel, block_name, dropout)
 | |
| 
 | |
|     # weight initialization
 | |
|     self.apply( initialize_resnet )
 | |
| 
 | |
|   def get_message(self):
 | |
|     return self.message
 | |
| 
 | |
|   def forward(self, inputs):
 | |
|     features = self.features(inputs)
 | |
|     vectors  = features.mean([2, 3])
 | |
|     predicts = self.classifier(vectors)
 | |
|     return features, predicts
 |