upload
This commit is contained in:
		
							
								
								
									
										108
									
								
								pycls/models/common.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								pycls/models/common.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,108 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from pycls.core.config import cfg | ||||
|  | ||||
|  | ||||
| def Preprocess(x): | ||||
|     if cfg.TASK == 'jig': | ||||
|         assert len(x.shape) == 5, 'Wrong tensor dimension for jigsaw' | ||||
|         assert x.shape[1] == cfg.JIGSAW_GRID ** 2, 'Wrong grid for jigsaw' | ||||
|         x = x.view([x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4]]) | ||||
|     return x | ||||
|  | ||||
|  | ||||
| class Classifier(nn.Module): | ||||
|     def __init__(self, channels, num_classes): | ||||
|         super(Classifier, self).__init__() | ||||
|         if cfg.TASK == 'jig': | ||||
|             self.jig_sq = cfg.JIGSAW_GRID ** 2 | ||||
|             self.pooling = nn.AdaptiveAvgPool2d(1) | ||||
|             self.classifier = nn.Linear(channels * self.jig_sq, num_classes) | ||||
|         elif cfg.TASK == 'col': | ||||
|             self.classifier = nn.Conv2d(channels, num_classes, kernel_size=1, stride=1) | ||||
|         elif cfg.TASK == 'seg': | ||||
|             self.classifier = ASPP(channels, cfg.MODEL.ASPP_CHANNELS, num_classes, cfg.MODEL.ASPP_RATES) | ||||
|         else: | ||||
|             self.pooling = nn.AdaptiveAvgPool2d(1) | ||||
|             self.classifier = nn.Linear(channels, num_classes) | ||||
|  | ||||
|     def forward(self, x, shape): | ||||
|         if cfg.TASK == 'jig': | ||||
|             x = self.pooling(x) | ||||
|             x = x.view([x.shape[0] // self.jig_sq, x.shape[1] * self.jig_sq, x.shape[2], x.shape[3]]) | ||||
|             x = self.classifier(x.view(x.size(0), -1)) | ||||
|         elif cfg.TASK in ['col', 'seg']: | ||||
|             x = self.classifier(x) | ||||
|             x = nn.Upsample(shape, mode='bilinear', align_corners=True)(x) | ||||
|         else: | ||||
|             x = self.pooling(x) | ||||
|             x = self.classifier(x.view(x.size(0), -1)) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class ASPP(nn.Module): | ||||
|     def __init__(self, in_channels, out_channels, num_classes, rates): | ||||
|         super(ASPP, self).__init__() | ||||
|         assert len(rates) in [1, 3] | ||||
|         self.rates = rates | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.aspp1 = nn.Sequential( | ||||
|             nn.Conv2d(in_channels, out_channels, 1, bias=False), | ||||
|             nn.BatchNorm2d(out_channels), | ||||
|             nn.ReLU(inplace=True) | ||||
|         ) | ||||
|         self.aspp2 = nn.Sequential( | ||||
|             nn.Conv2d(in_channels, out_channels, 3, dilation=rates[0], | ||||
|                 padding=rates[0], bias=False), | ||||
|             nn.BatchNorm2d(out_channels), | ||||
|             nn.ReLU(inplace=True) | ||||
|         ) | ||||
|         if len(self.rates) == 3: | ||||
|             self.aspp3 = nn.Sequential( | ||||
|                 nn.Conv2d(in_channels, out_channels, 3, dilation=rates[1], | ||||
|                     padding=rates[1], bias=False), | ||||
|                 nn.BatchNorm2d(out_channels), | ||||
|                 nn.ReLU(inplace=True) | ||||
|             ) | ||||
|             self.aspp4 = nn.Sequential( | ||||
|                 nn.Conv2d(in_channels, out_channels, 3, dilation=rates[2], | ||||
|                     padding=rates[2], bias=False), | ||||
|                 nn.BatchNorm2d(out_channels), | ||||
|                 nn.ReLU(inplace=True) | ||||
|             ) | ||||
|         self.aspp5 = nn.Sequential( | ||||
|             nn.Conv2d(in_channels, out_channels, 1, bias=False), | ||||
|             nn.BatchNorm2d(out_channels), | ||||
|             nn.ReLU(inplace=True) | ||||
|         ) | ||||
|         self.classifier = nn.Sequential( | ||||
|             nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1, | ||||
|                 bias=False), | ||||
|             nn.BatchNorm2d(out_channels), | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.Conv2d(out_channels, num_classes, 1) | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x1 = self.aspp1(x) | ||||
|         x2 = self.aspp2(x) | ||||
|         x5 = self.global_pooling(x) | ||||
|         x5 = self.aspp5(x5) | ||||
|         x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear', | ||||
|                 align_corners=True)(x5) | ||||
|         if len(self.rates) == 3: | ||||
|             x3 = self.aspp3(x) | ||||
|             x4 = self.aspp4(x) | ||||
|             x = torch.cat((x1, x2, x3, x4, x5), 1) | ||||
|         else: | ||||
|             x = torch.cat((x1, x2, x5), 1) | ||||
|         x = self.classifier(x) | ||||
|         return x | ||||
							
								
								
									
										634
									
								
								pycls/models/nas/genotypes.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										634
									
								
								pycls/models/nas/genotypes.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,634 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """NAS genotypes (adopted from DARTS).""" | ||||
|  | ||||
| from collections import namedtuple | ||||
|  | ||||
|  | ||||
| Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') | ||||
|  | ||||
|  | ||||
| # NASNet ops | ||||
| NASNET_OPS = [ | ||||
|     'skip_connect', | ||||
|     'conv_3x1_1x3', | ||||
|     'conv_7x1_1x7', | ||||
|     'dil_conv_3x3', | ||||
|     'avg_pool_3x3', | ||||
|     'max_pool_3x3', | ||||
|     'max_pool_5x5', | ||||
|     'max_pool_7x7', | ||||
|     'conv_1x1', | ||||
|     'conv_3x3', | ||||
|     'sep_conv_3x3', | ||||
|     'sep_conv_5x5', | ||||
|     'sep_conv_7x7', | ||||
| ] | ||||
|  | ||||
| # ENAS ops | ||||
| ENAS_OPS = [ | ||||
|     'skip_connect', | ||||
|     'sep_conv_3x3', | ||||
|     'sep_conv_5x5', | ||||
|     'avg_pool_3x3', | ||||
|     'max_pool_3x3', | ||||
| ] | ||||
|  | ||||
| # AmoebaNet ops | ||||
| AMOEBA_OPS = [ | ||||
|     'skip_connect', | ||||
|     'sep_conv_3x3', | ||||
|     'sep_conv_5x5', | ||||
|     'sep_conv_7x7', | ||||
|     'avg_pool_3x3', | ||||
|     'max_pool_3x3', | ||||
|     'dil_sep_conv_3x3', | ||||
|     'conv_7x1_1x7', | ||||
| ] | ||||
|  | ||||
| # NAO ops | ||||
| NAO_OPS = [ | ||||
|     'skip_connect', | ||||
|     'conv_1x1', | ||||
|     'conv_3x3', | ||||
|     'conv_3x1_1x3', | ||||
|     'conv_7x1_1x7', | ||||
|     'max_pool_2x2', | ||||
|     'max_pool_3x3', | ||||
|     'max_pool_5x5', | ||||
|     'avg_pool_2x2', | ||||
|     'avg_pool_3x3', | ||||
|     'avg_pool_5x5', | ||||
| ] | ||||
|  | ||||
| # PNAS ops | ||||
| PNAS_OPS = [ | ||||
|     'sep_conv_3x3', | ||||
|     'sep_conv_5x5', | ||||
|     'sep_conv_7x7', | ||||
|     'conv_7x1_1x7', | ||||
|     'skip_connect', | ||||
|     'avg_pool_3x3', | ||||
|     'max_pool_3x3', | ||||
|     'dil_conv_3x3', | ||||
| ] | ||||
|  | ||||
| # DARTS ops | ||||
| DARTS_OPS = [ | ||||
|     'none', | ||||
|     'max_pool_3x3', | ||||
|     'avg_pool_3x3', | ||||
|     'skip_connect', | ||||
|     'sep_conv_3x3', | ||||
|     'sep_conv_5x5', | ||||
|     'dil_conv_3x3', | ||||
|     'dil_conv_5x5', | ||||
| ] | ||||
|  | ||||
|  | ||||
| NASNet = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('avg_pool_3x3', 1), | ||||
|         ('skip_connect', 0), | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('skip_connect', 1), | ||||
|     ], | ||||
|     normal_concat=[2, 3, 4, 5, 6], | ||||
|     reduce=[ | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_7x7', 0), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('sep_conv_7x7', 0), | ||||
|         ('avg_pool_3x3', 1), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('skip_connect', 3), | ||||
|         ('avg_pool_3x3', 2), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('max_pool_3x3', 1), | ||||
|     ], | ||||
|     reduce_concat=[4, 5, 6], | ||||
| ) | ||||
|  | ||||
|  | ||||
| PNASNet = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_7x7', 1), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 4), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('skip_connect', 1), | ||||
|     ], | ||||
|     normal_concat=[2, 3, 4, 5, 6], | ||||
|     reduce=[ | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_7x7', 1), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 4), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('skip_connect', 1), | ||||
|     ], | ||||
|     reduce_concat=[2, 3, 4, 5, 6], | ||||
| ) | ||||
|  | ||||
|  | ||||
| AmoebaNet = Genotype( | ||||
|     normal=[ | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_5x5', 2), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('avg_pool_3x3', 3), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('skip_connect', 1), | ||||
|         ('skip_connect', 0), | ||||
|         ('avg_pool_3x3', 1), | ||||
|     ], | ||||
|     normal_concat=[4, 5, 6], | ||||
|     reduce=[ | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_7x7', 2), | ||||
|         ('sep_conv_7x7', 0), | ||||
|         ('avg_pool_3x3', 1), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('conv_7x1_1x7', 0), | ||||
|         ('sep_conv_3x3', 5), | ||||
|     ], | ||||
|     reduce_concat=[3, 4, 6] | ||||
| ) | ||||
|  | ||||
|  | ||||
| DARTS_V1 = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('skip_connect', 2) | ||||
|     ], | ||||
|     normal_concat=[2, 3, 4, 5], | ||||
|     reduce=[ | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('skip_connect', 2), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('skip_connect', 2), | ||||
|         ('skip_connect', 2), | ||||
|         ('avg_pool_3x3', 0) | ||||
|     ], | ||||
|     reduce_concat=[2, 3, 4, 5] | ||||
| ) | ||||
|  | ||||
|  | ||||
| DARTS_V2 = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('skip_connect', 0), | ||||
|         ('skip_connect', 0), | ||||
|         ('dil_conv_3x3', 2) | ||||
|     ], | ||||
|     normal_concat=[2, 3, 4, 5], | ||||
|     reduce=[ | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('skip_connect', 2), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('skip_connect', 2), | ||||
|         ('skip_connect', 2), | ||||
|         ('max_pool_3x3', 1) | ||||
|     ], | ||||
|     reduce_concat=[2, 3, 4, 5] | ||||
| ) | ||||
|  | ||||
| PDARTS = Genotype( | ||||
|     normal=[ | ||||
|         ('skip_connect', 0), | ||||
|         ('dil_conv_3x3', 1), | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 3), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('dil_conv_5x5', 4) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('dil_conv_3x3', 1), | ||||
|         ('dil_conv_3x3', 1), | ||||
|         ('dil_conv_5x5', 3) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| PCDARTS_C10 = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('dil_conv_3x3', 1), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('dil_conv_3x3', 1) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_5x5', 2), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 3), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 2) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| PCDARTS_IN1K = Genotype( | ||||
|     normal=[ | ||||
|         ('skip_connect', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('skip_connect', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 3), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('dil_conv_5x5', 4) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('skip_connect', 1), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_3x3', 3) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_IMAGENET_CLS = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_3x3', 0) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('skip_connect', 1), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 4), | ||||
|         ('dil_conv_5x5', 3) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_IMAGENET_ROT = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 4), | ||||
|         ('sep_conv_5x5', 2) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_IMAGENET_COL = Genotype( | ||||
|     normal=[ | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 3), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 2) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_5x5', 3), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_3x3', 4) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_IMAGENET_JIG = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 3), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_5x5', 0) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_3x3', 1) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_IMAGENET22K_CLS = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('dil_conv_5x5', 3), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('dil_conv_5x5', 4), | ||||
|         ('dil_conv_5x5', 3) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_IMAGENET22K_ROT = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('dil_conv_5x5', 3), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 4), | ||||
|         ('sep_conv_3x3', 3) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_IMAGENET22K_COL = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 3), | ||||
|         ('sep_conv_3x3', 0) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('skip_connect', 1), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 3), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 4), | ||||
|         ('sep_conv_5x5', 1) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_IMAGENET22K_JIG = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 4) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('skip_connect', 1), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_5x5', 3), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_5x5', 4) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_CITYSCAPES_SEG = Genotype( | ||||
|     normal=[ | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('avg_pool_3x3', 1), | ||||
|         ('avg_pool_3x3', 1), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_3x3', 4), | ||||
|         ('sep_conv_5x5', 2) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_CITYSCAPES_ROT = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 3), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_5x5', 2), | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_5x5', 3), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('sep_conv_5x5', 2), | ||||
|         ('sep_conv_5x5', 0) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_CITYSCAPES_COL = Genotype( | ||||
|     normal=[ | ||||
|         ('dil_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_5x5', 2), | ||||
|         ('dil_conv_3x3', 3), | ||||
|         ('skip_connect', 0), | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_3x3', 1) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('avg_pool_3x3', 1), | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('avg_pool_3x3', 1), | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('avg_pool_3x3', 1), | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('avg_pool_3x3', 1), | ||||
|         ('skip_connect', 4) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_CITYSCAPES_JIG = Genotype( | ||||
|     normal=[ | ||||
|         ('dil_conv_5x5', 1), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('dil_conv_5x5', 1) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('skip_connect', 1), | ||||
|         ('dil_conv_5x5', 1), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('dil_conv_5x5', 0), | ||||
|         ('dil_conv_5x5', 3), | ||||
|         ('dil_conv_5x5', 2) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
|  | ||||
| # Supported genotypes | ||||
| GENOTYPES = { | ||||
|     'nas': NASNet, | ||||
|     'pnas': PNASNet, | ||||
|     'amoeba': AmoebaNet, | ||||
|     'darts_v1': DARTS_V1, | ||||
|     'darts_v2': DARTS_V2, | ||||
|     'pdarts': PDARTS, | ||||
|     'pcdarts_c10': PCDARTS_C10, | ||||
|     'pcdarts_in1k': PCDARTS_IN1K, | ||||
|     'unnas_imagenet_cls': UNNAS_IMAGENET_CLS, | ||||
|     'unnas_imagenet_rot': UNNAS_IMAGENET_ROT, | ||||
|     'unnas_imagenet_col': UNNAS_IMAGENET_COL, | ||||
|     'unnas_imagenet_jig': UNNAS_IMAGENET_JIG, | ||||
|     'unnas_imagenet22k_cls': UNNAS_IMAGENET22K_CLS, | ||||
|     'unnas_imagenet22k_rot': UNNAS_IMAGENET22K_ROT, | ||||
|     'unnas_imagenet22k_col': UNNAS_IMAGENET22K_COL, | ||||
|     'unnas_imagenet22k_jig': UNNAS_IMAGENET22K_JIG, | ||||
|     'unnas_cityscapes_seg': UNNAS_CITYSCAPES_SEG, | ||||
|     'unnas_cityscapes_rot': UNNAS_CITYSCAPES_ROT, | ||||
|     'unnas_cityscapes_col': UNNAS_CITYSCAPES_COL, | ||||
|     'unnas_cityscapes_jig': UNNAS_CITYSCAPES_JIG, | ||||
|     'custom': None, | ||||
| } | ||||
							
								
								
									
										337
									
								
								pycls/models/nas/nas.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										337
									
								
								pycls/models/nas/nas.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,337 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """NAS network (adopted from DARTS).""" | ||||
|  | ||||
| from torch.autograd import Variable | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| import pycls.core.logging as logging | ||||
|  | ||||
| from pycls.core.config import cfg | ||||
| from pycls.models.common import Preprocess | ||||
| from pycls.models.common import Classifier | ||||
| from pycls.models.nas.genotypes import GENOTYPES | ||||
| from pycls.models.nas.genotypes import Genotype | ||||
| from pycls.models.nas.operations import FactorizedReduce | ||||
| from pycls.models.nas.operations import OPS | ||||
| from pycls.models.nas.operations import ReLUConvBN | ||||
| from pycls.models.nas.operations import Identity | ||||
|  | ||||
|  | ||||
| logger = logging.get_logger(__name__) | ||||
|  | ||||
|  | ||||
| def drop_path(x, drop_prob): | ||||
|     """Drop path (ported from DARTS).""" | ||||
|     if drop_prob > 0.: | ||||
|         keep_prob = 1.-drop_prob | ||||
|         mask = Variable( | ||||
|             torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob) | ||||
|         ) | ||||
|         x.div_(keep_prob) | ||||
|         x.mul_(mask) | ||||
|     return x | ||||
|  | ||||
|  | ||||
| class Cell(nn.Module): | ||||
|     """NAS cell (ported from DARTS).""" | ||||
|  | ||||
|     def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): | ||||
|         super(Cell, self).__init__() | ||||
|         logger.info('{}, {}, {}'.format(C_prev_prev, C_prev, C)) | ||||
|  | ||||
|         if reduction_prev: | ||||
|             self.preprocess0 = FactorizedReduce(C_prev_prev, C) | ||||
|         else: | ||||
|             self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) | ||||
|         self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0) | ||||
|  | ||||
|         if reduction: | ||||
|             op_names, indices = zip(*genotype.reduce) | ||||
|             concat = genotype.reduce_concat | ||||
|         else: | ||||
|             op_names, indices = zip(*genotype.normal) | ||||
|             concat = genotype.normal_concat | ||||
|         self._compile(C, op_names, indices, concat, reduction) | ||||
|  | ||||
|     def _compile(self, C, op_names, indices, concat, reduction): | ||||
|         assert len(op_names) == len(indices) | ||||
|         self._steps = len(op_names) // 2 | ||||
|         self._concat = concat | ||||
|         self.multiplier = len(concat) | ||||
|  | ||||
|         self._ops = nn.ModuleList() | ||||
|         for name, index in zip(op_names, indices): | ||||
|             stride = 2 if reduction and index < 2 else 1 | ||||
|             op = OPS[name](C, stride, True) | ||||
|             self._ops += [op] | ||||
|         self._indices = indices | ||||
|  | ||||
|     def forward(self, s0, s1, drop_prob): | ||||
|         s0 = self.preprocess0(s0) | ||||
|         s1 = self.preprocess1(s1) | ||||
|  | ||||
|         states = [s0, s1] | ||||
|         for i in range(self._steps): | ||||
|             h1 = states[self._indices[2*i]] | ||||
|             h2 = states[self._indices[2*i+1]] | ||||
|             op1 = self._ops[2*i] | ||||
|             op2 = self._ops[2*i+1] | ||||
|             h1 = op1(h1) | ||||
|             h2 = op2(h2) | ||||
|             if self.training and drop_prob > 0.: | ||||
|                 if not isinstance(op1, Identity): | ||||
|                     h1 = drop_path(h1, drop_prob) | ||||
|                 if not isinstance(op2, Identity): | ||||
|                     h2 = drop_path(h2, drop_prob) | ||||
|             s = h1 + h2 | ||||
|             states += [s] | ||||
|         return torch.cat([states[i] for i in self._concat], dim=1) | ||||
|  | ||||
|  | ||||
| class AuxiliaryHeadCIFAR(nn.Module): | ||||
|  | ||||
|     def __init__(self, C, num_classes): | ||||
|         """assuming input size 8x8""" | ||||
|         super(AuxiliaryHeadCIFAR, self).__init__() | ||||
|         self.features = nn.Sequential( | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 | ||||
|             nn.Conv2d(C, 128, 1, bias=False), | ||||
|             nn.BatchNorm2d(128), | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.Conv2d(128, 768, 2, bias=False), | ||||
|             nn.BatchNorm2d(768), | ||||
|             nn.ReLU(inplace=True) | ||||
|         ) | ||||
|         self.classifier = nn.Linear(768, num_classes) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.features(x) | ||||
|         x = self.classifier(x.view(x.size(0),-1)) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class AuxiliaryHeadImageNet(nn.Module): | ||||
|  | ||||
|     def __init__(self, C, num_classes): | ||||
|         """assuming input size 14x14""" | ||||
|         super(AuxiliaryHeadImageNet, self).__init__() | ||||
|         self.features = nn.Sequential( | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False), | ||||
|             nn.Conv2d(C, 128, 1, bias=False), | ||||
|             nn.BatchNorm2d(128), | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.Conv2d(128, 768, 2, bias=False), | ||||
|             # NOTE: This batchnorm was omitted in my earlier implementation due to a typo. | ||||
|             # Commenting it out for consistency with the experiments in the paper. | ||||
|             # nn.BatchNorm2d(768), | ||||
|             nn.ReLU(inplace=True) | ||||
|         ) | ||||
|         self.classifier = nn.Linear(768, num_classes) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.features(x) | ||||
|         x = self.classifier(x.view(x.size(0),-1)) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class NetworkCIFAR(nn.Module): | ||||
|     """CIFAR network (ported from DARTS).""" | ||||
|  | ||||
|     def __init__(self, C, num_classes, layers, auxiliary, genotype): | ||||
|         super(NetworkCIFAR, self).__init__() | ||||
|         self._layers = layers | ||||
|         self._auxiliary = auxiliary | ||||
|  | ||||
|         stem_multiplier = 3 | ||||
|         C_curr = stem_multiplier*C | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C_curr, 3, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C_curr) | ||||
|         ) | ||||
|  | ||||
|         C_prev_prev, C_prev, C_curr = C_curr, C_curr, C | ||||
|         self.cells = nn.ModuleList() | ||||
|         reduction_prev = False | ||||
|         for i in range(layers): | ||||
|             if i in [layers//3, 2*layers//3]: | ||||
|                 C_curr *= 2 | ||||
|                 reduction = True | ||||
|             else: | ||||
|                 reduction = False | ||||
|             cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) | ||||
|             reduction_prev = reduction | ||||
|             self.cells += [cell] | ||||
|             C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr | ||||
|             if i == 2*layers//3: | ||||
|                 C_to_auxiliary = C_prev | ||||
|  | ||||
|         if auxiliary: | ||||
|             self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes) | ||||
|         self.classifier = Classifier(C_prev, num_classes) | ||||
|  | ||||
|     def forward(self, input): | ||||
|         input = Preprocess(input) | ||||
|         logits_aux = None | ||||
|         s0 = s1 = self.stem(input) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             s0, s1 = s1, cell(s0, s1, self.drop_path_prob) | ||||
|             if i == 2*self._layers//3: | ||||
|                 if self._auxiliary and self.training: | ||||
|                     logits_aux = self.auxiliary_head(s1) | ||||
|         logits = self.classifier(s1, input.shape[2:]) | ||||
|         if self._auxiliary and self.training: | ||||
|             return logits, logits_aux | ||||
|         return logits | ||||
|  | ||||
|     def _loss(self, input, target, return_logits=False): | ||||
|         logits = self(input) | ||||
|         loss = self._criterion(logits, target) | ||||
|          | ||||
|         return (loss, logits) if return_logits else loss | ||||
|  | ||||
|     def step(self, input, target, args, shared=None, return_grad=False): | ||||
|         Lt, logit_t = self._loss(input, target, return_logits=True) | ||||
|         Lt.backward() | ||||
|         if args.grad_clip != 0:  | ||||
|             nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip) | ||||
|         self.optimizer.step() | ||||
|  | ||||
|         if return_grad: | ||||
|             grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()]) | ||||
|             return logit_t, Lt, grad | ||||
|         else: | ||||
|             return logit_t, Lt | ||||
|  | ||||
|  | ||||
| class NetworkImageNet(nn.Module): | ||||
|     """ImageNet network (ported from DARTS).""" | ||||
|  | ||||
|     def __init__(self, C, num_classes, layers, auxiliary, genotype): | ||||
|         super(NetworkImageNet, self).__init__() | ||||
|         self._layers = layers | ||||
|         self._auxiliary = auxiliary | ||||
|  | ||||
|         self.stem0 = nn.Sequential( | ||||
|             nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C // 2, kernel_size=3, stride=2, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C // 2), | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C), | ||||
|         ) | ||||
|  | ||||
|         self.stem1 = nn.Sequential( | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C), | ||||
|         ) | ||||
|  | ||||
|         C_prev_prev, C_prev, C_curr = C, C, C | ||||
|  | ||||
|         self.cells = nn.ModuleList() | ||||
|         reduction_prev = True | ||||
|         reduction_layers = [layers//3] if cfg.TASK == 'seg' else [layers//3, 2*layers//3] | ||||
|         for i in range(layers): | ||||
|             if i in reduction_layers: | ||||
|                 C_curr *= 2 | ||||
|                 reduction = True | ||||
|             else: | ||||
|                 reduction = False | ||||
|             cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) | ||||
|             reduction_prev = reduction | ||||
|             self.cells += [cell] | ||||
|             C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr | ||||
|             if i == 2 * layers // 3: | ||||
|                 C_to_auxiliary = C_prev | ||||
|  | ||||
|         if auxiliary: | ||||
|             self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes) | ||||
|         self.classifier = Classifier(C_prev, num_classes) | ||||
|  | ||||
|     def forward(self, input): | ||||
|         input = Preprocess(input) | ||||
|         logits_aux = None | ||||
|         s0 = self.stem0(input) | ||||
|         s1 = self.stem1(s0) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             s0, s1 = s1, cell(s0, s1, self.drop_path_prob) | ||||
|             if i == 2 * self._layers // 3: | ||||
|                 if self._auxiliary and self.training: | ||||
|                     logits_aux = self.auxiliary_head(s1) | ||||
|         logits = self.classifier(s1, input.shape[2:]) | ||||
|         if self._auxiliary and self.training: | ||||
|             return logits, logits_aux | ||||
|         return logits | ||||
|      | ||||
|     def _loss(self, input, target, return_logits=False): | ||||
|         logits = self(input) | ||||
|         loss = self._criterion(logits, target) | ||||
|          | ||||
|         return (loss, logits) if return_logits else loss | ||||
|  | ||||
|     def step(self, input, target, args, shared=None, return_grad=False): | ||||
|         Lt, logit_t = self._loss(input, target, return_logits=True) | ||||
|         Lt.backward() | ||||
|         if args.grad_clip != 0:  | ||||
|             nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip) | ||||
|         self.optimizer.step() | ||||
|  | ||||
|         if return_grad: | ||||
|             grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()]) | ||||
|             return logit_t, Lt, grad | ||||
|         else: | ||||
|             return logit_t, Lt | ||||
|  | ||||
|  | ||||
| class NAS(nn.Module): | ||||
|     """NAS net wrapper (delegates to nets from DARTS).""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         assert cfg.TRAIN.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \ | ||||
|             'Training on {} is not supported'.format(cfg.TRAIN.DATASET) | ||||
|         assert cfg.TEST.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \ | ||||
|             'Testing on {} is not supported'.format(cfg.TEST.DATASET) | ||||
|         assert cfg.NAS.GENOTYPE in GENOTYPES, \ | ||||
|             'Genotype {} not supported'.format(cfg.NAS.GENOTYPE) | ||||
|         super(NAS, self).__init__() | ||||
|         logger.info('Constructing NAS: {}'.format(cfg.NAS)) | ||||
|         # Use a custom or predefined genotype | ||||
|         if cfg.NAS.GENOTYPE == 'custom': | ||||
|             genotype = Genotype( | ||||
|                 normal=cfg.NAS.CUSTOM_GENOTYPE[0], | ||||
|                 normal_concat=cfg.NAS.CUSTOM_GENOTYPE[1], | ||||
|                 reduce=cfg.NAS.CUSTOM_GENOTYPE[2], | ||||
|                 reduce_concat=cfg.NAS.CUSTOM_GENOTYPE[3], | ||||
|             ) | ||||
|         else: | ||||
|             genotype = GENOTYPES[cfg.NAS.GENOTYPE] | ||||
|         # Determine the network constructor for dataset | ||||
|         if 'cifar' in cfg.TRAIN.DATASET: | ||||
|             net_ctor = NetworkCIFAR | ||||
|         else: | ||||
|             net_ctor = NetworkImageNet | ||||
|         # Construct the network | ||||
|         self.net_ = net_ctor( | ||||
|             C=cfg.NAS.WIDTH, | ||||
|             num_classes=cfg.MODEL.NUM_CLASSES, | ||||
|             layers=cfg.NAS.DEPTH, | ||||
|             auxiliary=cfg.NAS.AUX, | ||||
|             genotype=genotype | ||||
|         ) | ||||
|         # Drop path probability (set / annealed based on epoch) | ||||
|         self.net_.drop_path_prob = 0.0 | ||||
|  | ||||
|     def set_drop_path_prob(self, drop_path_prob): | ||||
|         self.net_.drop_path_prob = drop_path_prob | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.net_.forward(x) | ||||
							
								
								
									
										219
									
								
								pycls/models/nas/operations.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										219
									
								
								pycls/models/nas/operations.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,219 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
|  | ||||
| """NAS ops (adopted from DARTS).""" | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from torch.autograd import Variable | ||||
|  | ||||
| OPS = { | ||||
|     'none': lambda C, stride, affine: | ||||
|         Zero(stride), | ||||
|     'noise': lambda C, stride, affine: NoiseOp(stride, 0., 1.), | ||||
|     'avg_pool_2x2': lambda C, stride, affine: | ||||
|         nn.AvgPool2d(2, stride=stride, padding=0, count_include_pad=False), | ||||
|     'avg_pool_3x3': lambda C, stride, affine: | ||||
|         nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), | ||||
|     'avg_pool_5x5': lambda C, stride, affine: | ||||
|         nn.AvgPool2d(5, stride=stride, padding=2, count_include_pad=False), | ||||
|     'max_pool_2x2': lambda C, stride, affine: | ||||
|         nn.MaxPool2d(2, stride=stride, padding=0), | ||||
|     'max_pool_3x3': lambda C, stride, affine: | ||||
|         nn.MaxPool2d(3, stride=stride, padding=1), | ||||
|     'max_pool_5x5': lambda C, stride, affine: | ||||
|         nn.MaxPool2d(5, stride=stride, padding=2), | ||||
|     'max_pool_7x7': lambda C, stride, affine: | ||||
|         nn.MaxPool2d(7, stride=stride, padding=3), | ||||
|     'skip_connect': lambda C, stride, affine: | ||||
|         Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), | ||||
|     'conv_1x1': lambda C, stride, affine: | ||||
|         nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d(C, C, 1, stride=stride, padding=0, bias=False), | ||||
|             nn.BatchNorm2d(C, affine=affine) | ||||
|         ), | ||||
|     'conv_3x3': lambda C, stride, affine: | ||||
|         nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d(C, C, 3, stride=stride, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C, affine=affine) | ||||
|         ), | ||||
|     'sep_conv_3x3': lambda C, stride, affine: | ||||
|         SepConv(C, C, 3, stride, 1, affine=affine), | ||||
|     'sep_conv_5x5': lambda C, stride, affine: | ||||
|         SepConv(C, C, 5, stride, 2, affine=affine), | ||||
|     'sep_conv_7x7': lambda C, stride, affine: | ||||
|         SepConv(C, C, 7, stride, 3, affine=affine), | ||||
|     'dil_conv_3x3': lambda C, stride, affine: | ||||
|         DilConv(C, C, 3, stride, 2, 2, affine=affine), | ||||
|     'dil_conv_5x5': lambda C, stride, affine: | ||||
|         DilConv(C, C, 5, stride, 4, 2, affine=affine), | ||||
|     'dil_sep_conv_3x3': lambda C, stride, affine: | ||||
|         DilSepConv(C, C, 3, stride, 2, 2, affine=affine), | ||||
|     'conv_3x1_1x3': lambda C, stride, affine: | ||||
|         nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d(C, C, (1,3), stride=(1, stride), padding=(0, 1), bias=False), | ||||
|             nn.Conv2d(C, C, (3,1), stride=(stride, 1), padding=(1, 0), bias=False), | ||||
|             nn.BatchNorm2d(C, affine=affine) | ||||
|         ), | ||||
|     'conv_7x1_1x7': lambda C, stride, affine: | ||||
|         nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False), | ||||
|             nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False), | ||||
|             nn.BatchNorm2d(C, affine=affine) | ||||
|         ), | ||||
| } | ||||
|  | ||||
| class NoiseOp(nn.Module): | ||||
|     def __init__(self, stride, mean, std): | ||||
|         super(NoiseOp, self).__init__() | ||||
|         self.stride = stride | ||||
|         self.mean = mean | ||||
|         self.std = std | ||||
|  | ||||
|     def forward(self, x, block_input=False): | ||||
|         if block_input: | ||||
|             x = x*0 | ||||
|         if self.stride != 1: | ||||
|             x_new = x[:,:,::self.stride,::self.stride] | ||||
|         else: | ||||
|             x_new = x | ||||
|         noise = Variable(x_new.data.new(x_new.size()).normal_(self.mean, self.std)) | ||||
|  | ||||
|         return noise | ||||
|  | ||||
| class ReLUConvBN(nn.Module): | ||||
|  | ||||
|     def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): | ||||
|         super(ReLUConvBN, self).__init__() | ||||
|         self.op = nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d( | ||||
|                 C_in, C_out, kernel_size, stride=stride, | ||||
|                 padding=padding, bias=False | ||||
|             ), | ||||
|             nn.BatchNorm2d(C_out, affine=affine) | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.op(x) | ||||
|  | ||||
|  | ||||
| class DilConv(nn.Module): | ||||
|  | ||||
|     def __init__( | ||||
|         self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True | ||||
|     ): | ||||
|         super(DilConv, self).__init__() | ||||
|         self.op = nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d( | ||||
|                 C_in, C_in, kernel_size=kernel_size, stride=stride, | ||||
|                 padding=padding, dilation=dilation, groups=C_in, bias=False | ||||
|             ), | ||||
|             nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||
|             nn.BatchNorm2d(C_out, affine=affine), | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.op(x) | ||||
|  | ||||
|  | ||||
| class SepConv(nn.Module): | ||||
|  | ||||
|     def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): | ||||
|         super(SepConv, self).__init__() | ||||
|         self.op = nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d( | ||||
|                 C_in, C_in, kernel_size=kernel_size, stride=stride, | ||||
|                 padding=padding, groups=C_in, bias=False | ||||
|             ), | ||||
|             nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), | ||||
|             nn.BatchNorm2d(C_in, affine=affine), | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d( | ||||
|                 C_in, C_in, kernel_size=kernel_size, stride=1, | ||||
|                 padding=padding, groups=C_in, bias=False | ||||
|             ), | ||||
|             nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||
|             nn.BatchNorm2d(C_out, affine=affine), | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.op(x) | ||||
|  | ||||
|  | ||||
| class DilSepConv(nn.Module): | ||||
|  | ||||
|     def __init__( | ||||
|         self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True | ||||
|     ): | ||||
|         super(DilSepConv, self).__init__() | ||||
|         self.op = nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d( | ||||
|                 C_in, C_in, kernel_size=kernel_size, stride=stride, | ||||
|                 padding=padding, dilation=dilation, groups=C_in, bias=False | ||||
|             ), | ||||
|             nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), | ||||
|             nn.BatchNorm2d(C_in, affine=affine), | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d( | ||||
|                 C_in, C_in, kernel_size=kernel_size, stride=1, | ||||
|                 padding=padding, dilation=dilation, groups=C_in, bias=False | ||||
|             ), | ||||
|             nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||
|             nn.BatchNorm2d(C_out, affine=affine), | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.op(x) | ||||
|  | ||||
|  | ||||
| class Identity(nn.Module): | ||||
|  | ||||
|     def __init__(self): | ||||
|         super(Identity, self).__init__() | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class Zero(nn.Module): | ||||
|  | ||||
|     def __init__(self, stride): | ||||
|         super(Zero, self).__init__() | ||||
|         self.stride = stride | ||||
|  | ||||
|     def forward(self, x): | ||||
|         if self.stride == 1: | ||||
|             return x.mul(0.) | ||||
|         return x[:,:,::self.stride,::self.stride].mul(0.) | ||||
|  | ||||
|  | ||||
| class FactorizedReduce(nn.Module): | ||||
|  | ||||
|     def __init__(self, C_in, C_out, affine=True): | ||||
|         super(FactorizedReduce, self).__init__() | ||||
|         assert C_out % 2 == 0 | ||||
|         self.relu = nn.ReLU(inplace=False) | ||||
|         self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) | ||||
|         self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) | ||||
|         self.bn = nn.BatchNorm2d(C_out, affine=affine) | ||||
|         self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.relu(x) | ||||
|         y = self.pad(x) | ||||
|         out = torch.cat([self.conv_1(x), self.conv_2(y[:,:,1:,1:])], dim=1) | ||||
|         out = self.bn(out) | ||||
|         return out | ||||
		Reference in New Issue
	
	Block a user