Source code for fusionlab.layers.squeeze_excitation.tfse
import tensorflow as tf
from tensorflow.keras import layers, Sequential
[docs]
class TFSEModule(layers.Layer):
def __init__(self, cin, ratio=16):
super().__init__()
cout = int(cin / ratio)
self.gate = Sequential([
layers.Conv2D(cout, kernel_size=1),
layers.ReLU(),
layers.Conv2D(cin, kernel_size=1),
layers.Activation(tf.nn.sigmoid),
])
[docs]
def call(self, inputs):
x = tf.reduce_mean(inputs, (1, 2), keepdims=True)
x = self.gate(x)
return inputs * x
if __name__ == '__main__':
inputs = tf.random.normal((1, 224, 224, 256), 0, 1)
layer = TFSEModule(256)
outputs = layer(inputs)
assert list(outputs.shape) == [1, 224, 224, 256]