# Ref: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from tqdm.auto import tqdm
import numpy as np
from fusionlab.layers import ConvND, BatchNorm, ConvT
[docs]
class Generator(nn.Module):
def __init__(self, c_in, c_out, dim_g, spatial_dims=2):
super().__init__()
self.main = nn.Sequential(
# input is Z, going into a convolution
ConvT(spatial_dims, c_in, dim_g * 8, 4, 1, 0, bias=False),
BatchNorm(spatial_dims,dim_g * 8),
nn.ReLU(True),
# state size. (dim_g*8) x 4 x 4
ConvT(spatial_dims, dim_g * 8, dim_g * 4, 4, 2, 1, bias=False),
BatchNorm(spatial_dims,dim_g * 4),
nn.ReLU(True),
# state size. (dim_g*4) x 8 x 8
ConvT(spatial_dims, dim_g * 4, dim_g * 2, 4, 2, 1, bias=False),
BatchNorm(spatial_dims, dim_g * 2),
nn.ReLU(True),
# state size. (dim_g*2) x 16 x 16
ConvT(spatial_dims, dim_g * 2, dim_g, 4, 2, 1, bias=False),
BatchNorm(spatial_dims, dim_g),
nn.ReLU(True),
# state size. (dim_g) x 32 x 32
ConvT(spatial_dims, dim_g, c_out, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
[docs]
def forward(self, x):
return self.main(x)
[docs]
class Discriminator(nn.Module):
def __init__(self, c_in, dim_d, spatial_dims=2):
super().__init__()
self.main = nn.Sequential(
# input is (nc) x 64 x 64
ConvND(spatial_dims, c_in, dim_d, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
ConvND(spatial_dims, dim_d, dim_d * 2, 4, 2, 1, bias=False),
BatchNorm(spatial_dims, dim_d * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
ConvND(spatial_dims, dim_d * 2, dim_d * 4, 4, 2, 1, bias=False),
BatchNorm(spatial_dims, dim_d * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
ConvND(spatial_dims, dim_d * 4, dim_d * 8, 4, 2, 1, bias=False),
BatchNorm(spatial_dims, dim_d * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
ConvND(spatial_dims, dim_d * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
[docs]
def forward(self, x):
return self.main(x)
[docs]
class DCGANTrainer:
def __init__(self, generator, discriminator, optim_g, optim_d, loss_fn, device, dim_z, spatial_dims=2):
self.g = generator.to(device)
self.d = discriminator.to(device)
self.optim_g = optim_g
self.optim_d = optim_d
self.loss_fn = loss_fn
self.dim_z = dim_z
self.device = device
self.fixed_noise = None
self.spatial_dims = spatial_dims
[docs]
def train_step(self, real_x):
bs = real_x.size(0)
label_real = torch.full((bs,), 1., dtype=torch.float, device=self.device)
label_fake = torch.full((bs,), 0., dtype=torch.float, device=self.device)
if self.spatial_dims==1:
noise_x = torch.rand(bs, self.dim_z, 1, device=self.device)
elif self.spatial_dims==2:
noise_x = torch.rand(bs, self.dim_z, 1, 1, device=self.device)
elif self.spatial_dims==3:
noise_x = torch.rand(bs, self.dim_z, 1, 1, 1, device=self.device)
# Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
self.d.zero_grad()
# real -> D
output_d_real = self.d(real_x).view(-1)
# noise -> G -> D
fake = self.g(noise_x)
output_g_d = self.d(fake.detach()).view(-1)
# Discriminator loss on real and fake batch
loss_d_real = self.loss_fn(output_d_real, label_real)
loss_d_fake = self.loss_fn(output_g_d, label_fake)
loss_d = loss_d_real.mean() + loss_d_fake.mean()
loss_d_real.backward()
loss_d_fake.backward()
self.optim_d.step()
# Train Generator: minimize log(1-D(G(z)))
self.g.zero_grad()
output_g = self.d(fake).view(-1)
loss_g = self.loss_fn(output_g, label_real)
loss_g.backward()
optimizerG.step()
log = {'loss_d': loss_d.item(), 'loss_g': loss_g.item()}
return log
[docs]
def fit(self, dataloader, epochs):
log = {'loss_d': [], 'loss_g': []}
if self.spatial_dims==1:
fixed_noise = torch.rand(32, self.dim_z, 1, device=self.device)
elif self.spatial_dims==2:
fixed_noise = torch.rand(32, self.dim_z, 1, 1, device=self.device)
elif self.spatial_dims==3:
fixed_noise = torch.rand(32, self.dim_z, 1, 1, 1, device=self.device)
for epoch in tqdm(range(epochs)):
epoch_log = {'loss_d': [], 'loss_g': []}
for i, (x, _) in enumerate(tqdm(dataloader)):
real_x = x.to(self.device)
step_log = self.train_step(real_x)
[epoch_log[k].append(v) for k, v in step_log.items()]
[log[k].append(np.mean(v)) for k, v in epoch_log.items()]
print(f'''[{epoch}/{epochs}] loss_d: {log['loss_d'][-1]} loss_g: {log['loss_g'][-1]}''')
return log
if __name__ == '__main__':
BS = 64
image_size = 64
img_channel = 1
dim_z = 100
dim_g = 64
dim_d = 64
lr = 0.0002
beta1 = 0.5
dataset = torchvision.datasets.MNIST(
root="",
transform=transforms.Compose([
transforms.Resize(64),
transforms.ToTensor(),
transforms.Normalize((0.5), (0.5)),
]),
download=True,
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BS,
shuffle=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_g = Generator(dim_z, img_channel, dim_g, spatial_dims=2)
model_d = Discriminator(img_channel, dim_d, spatial_dims=2)
criterion = nn.BCELoss()
optimizerD = torch.optim.Adam(model_d.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(model_g.parameters(), lr=lr, betas=(beta1, 0.999))
# DCGAN trainer
trainer = DCGANTrainer(model_g, model_d, optimizerG, optimizerD, criterion, device, dim_z, spatial_dims=2)
log = trainer.fit(dataloader, 3)