Source code for fusionlab.encoders.alexnet.tfalexnet
import tensorflow as tf
[docs]
class TFAlexNet(tf.keras.Model):
def __init__(self):
super().__init__()
self.features = tf.keras.Sequential([
tf.keras.layers.ZeroPadding2D(2),
tf.keras.layers.Conv2D(64, kernel_size=11, strides=4),
tf.keras.layers.ReLU(),
tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
tf.keras.layers.Conv2D(192, kernel_size=5, padding='same'),
tf.keras.layers.ReLU(),
tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
tf.keras.layers.Conv2D(384, kernel_size=3, padding='same'),
tf.keras.layers.ReLU(),
tf.keras.layers.Conv2D(256, kernel_size=3, padding='same'),
tf.keras.layers.ReLU(),
tf.keras.layers.Conv2D(256, kernel_size=3, padding='same'),
tf.keras.layers.ReLU(),
tf.keras.layers.MaxPool2D(pool_size=3, strides=2)
])
[docs]
def call(self, inputs):
return self.features(inputs)
if __name__ == '__main__':
inputs = tf.random.normal((1, 224, 224, 3))
output = TFAlexNet()(inputs)
shape = output.shape
print(shape)
assert shape[1:3] == [6, 6]