Source code for fusionlab.trainers.trainer

import torch
from tqdm.auto import tqdm
import numpy as np

# ref: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback

[docs] class Trainer: def __init__(self, device): self.device = device
[docs] def train_step(self, data): data = self._data_to_device(data) inputs, target = data pred = self.model(inputs) loss = self.loss_fn(pred, target) loss.backward() self.optimizer.step() self.optimizer.zero_grad() return loss.item()
[docs] def val_step(self, data): data = self._data_to_device(data) inputs, target = data with torch.no_grad(): pred = self.model(inputs) loss = self.loss_fn(pred, target) return loss.item()
[docs] def train_epoch(self): self.model.train() epoch_loss = [] for _, data in enumerate(tqdm(self.train_dataloader, leave=False)): batch_loss = self.train_step(data) epoch_loss.append(batch_loss) return np.mean(epoch_loss)
[docs] def val_epoch(self): self.model.eval() epoch_loss = [] for _, data in enumerate(tqdm(self.val_dataloader, leave=False)): batch_loss = self.val_step(data) epoch_loss.append(batch_loss) return np.mean(epoch_loss)
[docs] def on_fit_begin(self): pass
[docs] def on_fit_end(self): pass
[docs] def on_epoch_begin(self): pass
[docs] def on_epoch_end(self): pass
def _data_to_device(self, data): if isinstance(data, torch.Tensor): return data.to(self.device) elif isinstance(data, dict): return {k: v.to(self.device) for k, v in data.items()} elif isinstance(data, list): return [v.to(self.device) for v in data] else: raise NotImplementedError
[docs] def fit(self, model, train_dataloader, val_dataloader, epochs, optimizer, loss_fn): self.model = model.to(self.device) self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader self.epochs = epochs self.optimizer = optimizer self.loss_fn = loss_fn self.train_log = {'loss': []} self.val_log = {'loss': []} self.on_fit_begin() for epoch in tqdm(range(epochs)): self.on_epoch_begin() train_epoch_loss = self.train_epoch() self.train_log['loss'].append(train_epoch_loss) if self.val_dataloader: val_epoch_loss = self.val_epoch() self.val_log['loss'].append(val_epoch_loss) print(f'''[{epoch}/{epochs}] train_loss: {self.train_log['loss'][-1]:.4f} \ val_loss: {self.val_log['loss'][-1]:.4f}''') self.on_epoch_end() self.on_fit_end() return
if __name__ == "__main__": class FakeModel(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(1, 3, 3) self.pool = torch.nn.Sequential( torch.nn.AdaptiveAvgPool2d(1), torch.nn.Flatten(), ) self.cls = torch.nn.Linear(3, 10) def forward(self, x): x = self.conv(x) x = self.pool(x) x = self.cls(x) return x from abc import ABC, abstractmethod class Metric(ABC): def __init__(self): pass @abstractmethod def reset(): raise NotImplementedError("reset method is not implemented!") @abstractmethod def update(): raise NotImplementedError("update method is not implemented!") @abstractmethod def compute(): raise NotImplementedError("compute method is not implemented!") # class Accuracy(Metric): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(device) # mnist from torchvision.datasets import MNIST from torchvision.transforms import ToTensor from torch.utils.data import DataLoader train_dataset = MNIST(root='data/', train=True, transform=ToTensor(), download=True) val_dataset = MNIST(root='data/', train=False, transform=ToTensor(), download=True) train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False) model = FakeModel() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) loss_fn = torch.nn.CrossEntropyLoss() trainer = Trainer(device) trainer.fit(model, train_dataloader, val_dataloader, 10, optimizer, loss_fn)