diff --git a/exps/experimental/test-flops.py b/exps/experimental/test-flops.py new file mode 100644 index 0000000..73df113 --- /dev/null +++ b/exps/experimental/test-flops.py @@ -0,0 +1,24 @@ +import sys, time, random, argparse +from copy import deepcopy +import torchvision.models as models +from pathlib import Path +lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() +if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +from utils import get_model_infos +#from models.ImageNet_MobileNetV2 import MobileNetV2 +from torchvision.models.mobilenet import MobileNetV2 + +def main(width_mult): + # model = MobileNetV2(1001, width_mult, 32, 1280, 'InvertedResidual', 0.2) + model = MobileNetV2(width_mult=width_mult) + print(model) + flops, params = get_model_infos(model, (2, 3, 224, 224)) + print('FLOPs : {:}'.format(flops)) + print('Params : {:}'.format(params)) + print('-'*50) + + +if __name__ == '__main__': + main(1.0) + main(1.4) diff --git a/lib/models/ImageNet_MobileNetV2.py b/lib/models/ImageNet_MobileNetV2.py new file mode 100644 index 0000000..ec7e341 --- /dev/null +++ b/lib/models/ImageNet_MobileNetV2.py @@ -0,0 +1,101 @@ +# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018 +from torch import nn +from .initialization import initialize_resnet + + +class ConvBNReLU(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + super(ConvBNReLU, self).__init__() + padding = (kernel_size - 1) // 2 + self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False) + self.bn = nn.BatchNorm2d(out_planes) + self.relu = nn.ReLU6(inplace=True) + + def forward(self, x): + out = self.conv( x ) + out = self.bn ( out ) + out = self.relu( out ) + return out + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, num_classes, width_mult, input_channel, last_channel, block_name, dropout): + super(MobileNetV2, self).__init__() + if block_name == 'InvertedResidual': + block = InvertedResidual + else: + raise ValueError('invalid block name : {:}'.format(block_name)) + inverted_residual_setting = [ + # t, c, n, s + [1, 16 , 1, 1], + [6, 24 , 2, 2], + [6, 32 , 3, 2], + [6, 64 , 4, 2], + [6, 96 , 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # building first layer + input_channel = int(input_channel * width_mult) + self.last_channel = int(last_channel * max(1.0, width_mult)) + features = [ConvBNReLU(3, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = int(c * width_mult) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + # building last several layers + features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) + # make it nn.Sequential + self.features = nn.Sequential(*features) + + # building classifier + self.classifier = nn.Sequential( + nn.Dropout(dropout), + nn.Linear(self.last_channel, num_classes), + ) + self.message = 'MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}'.format(width_mult, input_channel, last_channel, block_name, dropout) + + # weight initialization + self.apply( initialize_resnet ) + + def get_message(self): + return self.message + + def forward(self, inputs): + features = self.features(inputs) + vectors = features.mean([2, 3]) + predicts = self.classifier(vectors) + return features, predicts diff --git a/lib/models/__init__.py b/lib/models/__init__.py index aa16975..4f7b735 100644 --- a/lib/models/__init__.py +++ b/lib/models/__init__.py @@ -110,8 +110,11 @@ def get_imagenet_models(config): super_type = getattr(config, 'super_type', 'basic') if super_type == 'basic': from .ImagenetResNet import ResNet + from .ImageNet_MobileNetV2 import MobileNetV2 if config.arch == 'resnet': return ResNet(config.block_name, config.layers, config.deep_stem, config.class_num, config.zero_init_residual, config.groups, config.width_per_group) + elif config.arch == 'mobilenet_v2': + return MobileNetV2(config.class_num, config.width_multi, config.input_channel, config.last_channel, 'InvertedResidual', config.dropout) else: raise ValueError('invalid arch : {:}'.format( config.arch )) elif super_type.startswith('infer'): # NAS searched architecture