xautodl/lib/tf_models/cell_operations.py
2020-01-05 22:19:38 +11:00

121 lines
4.2 KiB
Python

##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import tensorflow as tf
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
OPS = {
'none' : lambda C_in, C_out, stride, affine: Zero(C_in, C_out, stride),
'avg_pool_3x3': lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg', affine),
'nor_conv_1x1': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 1, stride, affine),
'nor_conv_3x3': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 3, stride, affine),
'nor_conv_5x5': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 5, stride, affine),
'skip_connect': lambda C_in, C_out, stride, affine: Identity(C_in, C_out, stride)
}
NAS_BENCH_102 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
SearchSpaceNames = {
'nas-bench-102': NAS_BENCH_102,
}
class POOLING(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, stride, mode, affine):
super(POOLING, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, affine)
if mode == 'avg' : self.op = tf.keras.layers.AvgPool2D((3,3), strides=stride, padding='same')
elif mode == 'max': self.op = tf.keras.layers.MaxPool2D((3,3), strides=stride, padding='same')
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
def call(self, inputs, training):
if self.preprocess: x = self.preprocess(inputs)
else : x = inputs
return self.op(x)
class Identity(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, stride):
super(Identity, self).__init__()
if C_in != C_out or stride != 1:
self.layer = tf.keras.layers.Conv2D(C_out, 3, stride, padding='same', use_bias=False)
else:
self.layer = None
def call(self, inputs, training):
x = inputs
if self.layer is not None:
x = self.layer(x)
return x
class Zero(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, stride):
super(Zero, self).__init__()
if C_in != C_out:
self.layer = tf.keras.layers.Conv2D(C_out, 1, stride, padding='same', use_bias=False)
elif stride != 1:
self.layer = tf.keras.layers.AvgPool2D((stride,stride), None, padding="same")
else:
self.layer = None
def call(self, inputs, training):
x = tf.zeros_like(inputs)
if self.layer is not None:
x = self.layer(x)
return x
class ReLUConvBN(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, kernel_size, strides, affine):
super(ReLUConvBN, self).__init__()
self.C_in = C_in
self.relu = tf.keras.activations.relu
self.conv = tf.keras.layers.Conv2D(C_out, kernel_size, strides, padding='same', use_bias=False)
self.bn = tf.keras.layers.BatchNormalization(center=affine, scale=affine)
def call(self, inputs, training):
x = self.relu(inputs)
x = self.conv(x)
x = self.bn(x, training)
return x
class ResNetBasicblock(tf.keras.layers.Layer):
def __init__(self, inplanes, planes, stride, affine=True):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, affine)
self.conv_b = ReLUConvBN( planes, planes, 3, 1, affine)
if stride == 2:
self.downsample = tf.keras.Sequential([
tf.keras.layers.AvgPool2D((stride,stride), None, padding="same"),
tf.keras.layers.Conv2D(planes, 1, 1, padding='same', use_bias=False)])
elif inplanes != planes:
self.downsample = ReLUConvBN(inplanes, planes, 1, stride, affine)
else:
self.downsample = None
self.addition = tf.keras.layers.Add()
self.in_dim = inplanes
self.out_dim = planes
self.stride = stride
self.num_conv = 2
def call(self, inputs, training):
basicblock = self.conv_a(inputs, training)
basicblock = self.conv_b(basicblock, training)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
return self.addition([residual, basicblock])