Source code for fusionlab.classification.lstm
import torch
from torch import nn
from fusionlab.classification.base import RNNClassificationModel
[docs]
class LSTMClassifier(RNNClassificationModel):
def __init__(self, cin, cout, hidden_size=512):
super().__init__()
self.encoder = nn.LSTM(input_size=cin, hidden_size=hidden_size, batch_first=True) # define LSTM layer
self.head = nn.Linear(hidden_size, cout) # define output head layer
if __name__ == '__main__':
inputs = torch.randn(1, 5, 3) # create random input tensor
model = LSTMClassifier(cin=3, hidden_size=4, cout=2) # create model instance
outputs = model(inputs) # pass input through model
assert list(outputs.shape) == [1, 2] # check output shape is correct