3 minute read

Generative Adversarial Networks (GANs) are a type of generative model that consists of two neural networks: a generator and a discriminator. These two networks are trained simultaneously in a competitive process where the generator tries to create realistic data, and the discriminator tries to distinguish between real and fake data.

5.1 GAN Architecture

Key Components of GANs:

  1. Generator (G):
    • The generator takes random noise as input and generates synthetic data (e.g., images).
    • The goal of the generator is to produce data that is indistinguishable from real data to fool the discriminator.
  2. Discriminator (D):
    • The discriminator is a binary classifier that takes in real data and generated (fake) data and tries to classify whether the input is real or fake.
    • The goal of the discriminator is to correctly identify real data from fake data.
  3. Adversarial Process:
    • The two networks are in competition: the generator is trained to produce realistic data to fool the discriminator, while the discriminator is trained to correctly classify real and fake data.
    • This process is a minimax game where both networks improve over time.

GAN Loss Functions:

GANs use two loss functions:

  1. Generator Loss:
    • The generator aims to minimize the discriminator’s ability to correctly classify fake data as fake. The loss is designed to make the generator’s data look as real as possible.

      \[\mathcal{L}_G = -\log(D(G(z)))\]

      Where \(G(z)\) is the generator’s output (fake data) from random noise \(z\), and \(D(G(z))\) is the discriminator’s output for the fake data.

  2. Discriminator Loss:
    • The discriminator aims to maximize the probability of correctly classifying real data and minimizing the probability of misclassifying fake data.

      \[\mathcal{L}_D = -\left[ \log(D(x)) + \log(1 - D(G(z))) \right]\]

      Where \(x\) is the real data, and \(G(z)\) is the fake data generated by the generator.

5.2 GAN Training Process

The training process of GANs involves alternating between updating the generator and the discriminator:

  1. Train the Discriminator:
    • The discriminator is trained to maximize its ability to distinguish between real and fake data.
    • This involves feeding real data \(x\) to the discriminator and maximizing \(\log(D(x))\), while minimizing \(\log(1 - D(G(z)))\) for fake data \(G(z)\).
  2. Train the Generator:
    • The generator is trained to fool the discriminator by minimizing \(\log(1 - D(G(z)))\), i.e., making the discriminator believe the fake data is real.

Implementing a GAN with PyTorch

Here’s how to implement a simple GAN in PyTorch to generate handwritten digits (MNIST dataset).

Python Code:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Generator model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 28*28),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.main(z).view(-1, 1, 28, 28)

# Discriminator model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(28*28, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.main(x.view(-1, 28*28))

# Hyperparameters
batch_size = 64
learning_rate = 0.0002
latent_dim = 100
epochs = 50

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Scale to [-1, 1]
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# Initialize generator and discriminator
G = Generator()
D = Discriminator()

# Loss and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(D.parameters(), lr=learning_rate)

# Training the GAN
for epoch in range(epochs):
    for batch_idx, (real_images, _) in enumerate(train_loader):
        batch_size = real_images.size(0)
        
        # Real images and labels
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)
        
        # Train Discriminator
        optimizer_D.zero_grad()
        outputs = D(real_images)
        d_loss_real = criterion(outputs, real_labels)
        
        z = torch.randn(batch_size, latent_dim)
        fake_images = G(z)
        outputs = D(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)
        
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()
        
        # Train Generator
        optimizer_G.zero_grad()
        z = torch.randn(batch_size, latent_dim)
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels)
        
        g_loss.backward()
        optimizer_G.step()
    
    print(f'Epoch [{epoch+1}/{epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

# Generate and visualize fake images
import matplotlib.pyplot as plt
import torchvision.utils as vutils

with torch.no_grad():
    z = torch.randn(64, latent_dim)
    fake_images = G(z)
    grid = vutils.make_grid(fake_images, normalize=True)
    plt.imshow(grid.permute(1, 2, 0))
    plt.show()

Output:

  • Training Output: As the GAN trains, you will see both the discriminator loss (d_loss) and generator loss (g_loss) printed for each epoch:
    Epoch [1/50], d_loss: 0.5432, g_loss: 0.7856
    Epoch [2/50], d_loss: 0.4643, g_loss: 0.9123
    ...
    Epoch [50/50], d_loss: 0.3456, g_loss: 1.2234
    
  • Generated Samples: After training, the GAN will generate new images of handwritten digits. These will look similar to the MNIST dataset. GAN Output Images

Summary:

  • GANs consist of two neural networks: a generator that creates data and a discriminator that distinguishes between real and fake data.
  • The generator is trained to fool the discriminator, while the discriminator is trained to correctly classify real vs. fake data.
  • The training process is a minimax game, where the two networks improve together. GANs are widely used for generating realistic images, video, and text.

Comments