Generative AI 5: Generative Adversarial Networks (GANs)
Generative Adversarial Networks (GANs) are a type of generative model that consists of two neural networks: a generator and a discriminator. These two networks are trained simultaneously in a competitive process where the generator tries to create realistic data, and the discriminator tries to distinguish between real and fake data.
5.1 GAN Architecture
Key Components of GANs:
- Generator (G):
- The generator takes random noise as input and generates synthetic data (e.g., images).
- The goal of the generator is to produce data that is indistinguishable from real data to fool the discriminator.
- Discriminator (D):
- The discriminator is a binary classifier that takes in real data and generated (fake) data and tries to classify whether the input is real or fake.
- The goal of the discriminator is to correctly identify real data from fake data.
- Adversarial Process:
- The two networks are in competition: the generator is trained to produce realistic data to fool the discriminator, while the discriminator is trained to correctly classify real and fake data.
- This process is a minimax game where both networks improve over time.
GAN Loss Functions:
GANs use two loss functions:
- Generator Loss:
-
The generator aims to minimize the discriminator’s ability to correctly classify fake data as fake. The loss is designed to make the generator’s data look as real as possible.
\[\mathcal{L}_G = -\log(D(G(z)))\]Where \(G(z)\) is the generator’s output (fake data) from random noise \(z\), and \(D(G(z))\) is the discriminator’s output for the fake data.
-
- Discriminator Loss:
-
The discriminator aims to maximize the probability of correctly classifying real data and minimizing the probability of misclassifying fake data.
\[\mathcal{L}_D = -\left[ \log(D(x)) + \log(1 - D(G(z))) \right]\]Where \(x\) is the real data, and \(G(z)\) is the fake data generated by the generator.
-
5.2 GAN Training Process
The training process of GANs involves alternating between updating the generator and the discriminator:
- Train the Discriminator:
- The discriminator is trained to maximize its ability to distinguish between real and fake data.
- This involves feeding real data \(x\) to the discriminator and maximizing \(\log(D(x))\), while minimizing \(\log(1 - D(G(z)))\) for fake data \(G(z)\).
- Train the Generator:
- The generator is trained to fool the discriminator by minimizing \(\log(1 - D(G(z)))\), i.e., making the discriminator believe the fake data is real.
Implementing a GAN with PyTorch
Here’s how to implement a simple GAN in PyTorch to generate handwritten digits (MNIST dataset).
Python Code:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Generator model
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(True),
nn.Linear(256, 512),
nn.ReLU(True),
nn.Linear(512, 1024),
nn.ReLU(True),
nn.Linear(1024, 28*28),
nn.Tanh()
)
def forward(self, z):
return self.main(z).view(-1, 1, 28, 28)
# Discriminator model
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(28*28, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x.view(-1, 28*28))
# Hyperparameters
batch_size = 64
learning_rate = 0.0002
latent_dim = 100
epochs = 50
# Load MNIST dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # Scale to [-1, 1]
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
# Initialize generator and discriminator
G = Generator()
D = Discriminator()
# Loss and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(D.parameters(), lr=learning_rate)
# Training the GAN
for epoch in range(epochs):
for batch_idx, (real_images, _) in enumerate(train_loader):
batch_size = real_images.size(0)
# Real images and labels
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# Train Discriminator
optimizer_D.zero_grad()
outputs = D(real_images)
d_loss_real = criterion(outputs, real_labels)
z = torch.randn(batch_size, latent_dim)
fake_images = G(z)
outputs = D(fake_images.detach())
d_loss_fake = criterion(outputs, fake_labels)
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_D.step()
# Train Generator
optimizer_G.zero_grad()
z = torch.randn(batch_size, latent_dim)
fake_images = G(z)
outputs = D(fake_images)
g_loss = criterion(outputs, real_labels)
g_loss.backward()
optimizer_G.step()
print(f'Epoch [{epoch+1}/{epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')
# Generate and visualize fake images
import matplotlib.pyplot as plt
import torchvision.utils as vutils
with torch.no_grad():
z = torch.randn(64, latent_dim)
fake_images = G(z)
grid = vutils.make_grid(fake_images, normalize=True)
plt.imshow(grid.permute(1, 2, 0))
plt.show()
Output:
- Training Output: As the GAN trains, you will see both the discriminator loss (d_loss) and generator loss (g_loss) printed for each epoch:
Epoch [1/50], d_loss: 0.5432, g_loss: 0.7856 Epoch [2/50], d_loss: 0.4643, g_loss: 0.9123 ... Epoch [50/50], d_loss: 0.3456, g_loss: 1.2234
- Generated Samples: After training, the GAN will generate new images of handwritten digits. These will look similar to the MNIST dataset.
Summary:
- GANs consist of two neural networks: a generator that creates data and a discriminator that distinguishes between real and fake data.
- The generator is trained to fool the discriminator, while the discriminator is trained to correctly classify real vs. fake data.
- The training process is a minimax game, where the two networks improve together. GANs are widely used for generating realistic images, video, and text.
Comments