import tensorflow as tf
from tensorflow.keras import Model, Sequential, layers
# ResNet50
# Ref:
# https://github.com/keras-team/keras-applications/blob/master/keras_applications/resnet50.py
# https://github.com/raghakot/keras-resnet/blob/master/README.md
[docs]
class Identity(layers.Layer):
def __init__(self):
super(Identity, self).__init__()
[docs]
def call(self, inputs, training=None):
return inputs
[docs]
class ConvBlock(Model):
def __init__(self, cout, kernel_size=3, stride=1, activation=True, padding=True):
super().__init__()
self.conv = layers.Conv2D(cout, kernel_size, stride,
padding='same' if padding else 'valid')
self.bn = layers.BatchNormalization()
self.act = layers.ReLU() if activation else Identity()
[docs]
def call(self, inputs, training=None):
x = self.conv(inputs)
x = self.bn(x, training)
x = self.act(x)
return x
[docs]
class Bottleneck(Model):
def __init__(self, dims, kernel_size=3, stride=None):
super().__init__()
dim1, dim2, dim3 = dims
self.conv1 = ConvBlock(dim1, kernel_size=1)
self.conv2 = ConvBlock(dim2, kernel_size=kernel_size,
stride=stride if stride else 1)
self.conv3 = ConvBlock(dim3, kernel_size=1, activation=False)
self.act = layers.ReLU()
self.skip = Identity() if not stride else ConvBlock(dim3,
kernel_size=1,
stride=stride,
activation=False)
[docs]
def call(self, inputs, training=None):
identity = self.skip(inputs, training)
x = self.conv1(inputs, training)
x = self.conv2(x, training)
x = self.conv3(x, training)
x += identity
x = self.act(x)
return x
[docs]
class TFResNet50V1(Model):
def __init__(self):
super(TFResNet50V1, self).__init__()
self.conv1 = Sequential([
ConvBlock(64, 7, stride=2),
layers.MaxPool2D(3, strides=2, padding='same'),
])
self.conv2 = Sequential([
Bottleneck([64, 64, 256], 3, stride=1),
Bottleneck([64, 64, 256], 3),
Bottleneck([64, 64, 256], 3),
])
self.conv3 = Sequential([
Bottleneck([128, 128, 512], 3, stride=2),
Bottleneck([128, 128, 512], 3),
Bottleneck([128, 128, 512], 3),
Bottleneck([128, 128, 512], 3),
])
self.conv4 = Sequential([
Bottleneck([256, 256, 1024], 3, stride=2),
Bottleneck([256, 256, 1024], 3),
Bottleneck([256, 256, 1024], 3),
Bottleneck([256, 256, 1024], 3),
Bottleneck([256, 256, 1024], 3),
Bottleneck([256, 256, 1024], 3),
])
self.conv5 = Sequential([
Bottleneck([512, 512, 2048], 3, stride=2),
Bottleneck([512, 512, 2048], 3),
Bottleneck([512, 512, 2048], 3),
])
[docs]
def call(self, inputs, training=None):
x = self.conv1(inputs, training)
x = self.conv2(x, training)
x = self.conv3(x, training)
x = self.conv4(x, training)
x = self.conv5(x, training)
return x
if __name__ == '__main__':
inputs = tf.random.normal((1, 224, 224, 128))
output = Bottleneck([64, 64, 128])(inputs)
shape = output.shape
print("Bottleneck", shape)
assert shape == (1, 224, 224, 128)
output = Bottleneck([128, 128, 256], stride=1)(inputs)
shape = output.shape
print("Bottleneck first conv for aligh dims", shape)
assert shape == (1, 224, 224, 256)
output = Bottleneck([64, 64, 128], stride=2)(inputs)
shape = output.shape
print("Bottleneck downsample", shape)
assert shape == (1, 112, 112, 128)
output = Identity()(inputs)
shape = output.shape
print("Identity", shape)
assert shape == (1, 224, 224, 128)
output = TFResNet50V1()(inputs)
shape = output.shape
print("TFResNet50V1", shape)
assert shape == (1, 7, 7, 2048)