Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Implementing Deep Convolutional GANs for Face Image Generation

Tech 1

Environment Setup and Dependencies

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import random

# Set random seed for reproducibility
SEED = 999
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

Configuration Parameters

DATA_PATH = 'data/faces'
BATCH_SIZE = 128
IMG_SIZE = 64
LATENT_DIM = 100
GEN_FEATURES = 64
DISC_FEATURES = 64
EPOCHS = 5
LEARNING_RATE = 0.0002
BETA1 = 0.5

Data Preparasion

# Define image transformations
transform_pipeline = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load dataset
face_dataset = dset.ImageFolder(root=DATA_PATH, transform=transform_pipeline)

# Create data loader
data_loader = torch.utils.data.DataLoader(
    face_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4
)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Model Architecture

Weight Iintialization

def initialize_weights(model):
    class_name = model.__class__.__name__
    if 'Conv' in class_name:
        nn.init.normal_(model.weight.data, 0.0, 0.02)
    elif 'BatchNorm' in class_name:
        nn.init.normal_(model.weight.data, 1.0, 0.02)
        nn.init.constant_(model.bias.data, 0)

Generator Network

class FaceGenerator(nn.Module):
    def __init__(self):
        super(FaceGenerator, self).__init__()
        self.layers = nn.Sequential(
            nn.ConvTranspose2d(LATENT_DIM, GEN_FEATURES * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(GEN_FEATURES * 8),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(GEN_FEATURES * 8, GEN_FEATURES * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(GEN_FEATURES * 4),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(GEN_FEATURES * 4, GEN_FEATURES * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(GEN_FEATURES * 2),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(GEN_FEATURES * 2, GEN_FEATURES, 4, 2, 1, bias=False),
            nn.BatchNorm2d(GEN_FEATURES),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(GEN_FEATURES, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, input_tensor):
        return self.layers(input_tensor)

# Initialize generator
generator = FaceGenerator().to(device)
generator.apply(initialize_weights)

Discriminator Network

class FaceDiscriminator(nn.Module):
    def __init__(self):
        super(FaceDiscriminator, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3, DISC_FEATURES, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(DISC_FEATURES, DISC_FEATURES * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(DISC_FEATURES * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(DISC_FEATURES * 2, DISC_FEATURES * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(DISC_FEATURES * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(DISC_FEATURES * 4, DISC_FEATURES * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(DISC_FEATURES * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(DISC_FEATURES * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, input_tensor):
        return self.layers(input_tensor)

# Initialize discriminator
discriminator = FaceDiscriminator().to(device)
discriminator.apply(initialize_weights)

Training Loop

loss_function = nn.BCELoss()

# Fixed noise for visualization
static_noise = torch.randn(64, LATENT_DIM, 1, 1, device=device)

# Optimizers
disc_optimizer = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))
gen_optimizer = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))

# Training statistics
generator_losses = []
discriminator_losses = []
generated_images = []

for epoch in range(EPOCHS):
    for batch_idx, (real_images, _) in enumerate(data_loader):
        batch_size = real_images.size(0)
        real_images = real_images.to(device)
        
        # Train discriminator with real images
        discriminator.zero_grad()
        real_labels = torch.full((batch_size,), 1.0, device=device)
        real_output = discriminator(real_images).view(-1)
        real_loss = loss_function(real_output, real_labels)
        real_loss.backward()
        
        # Train discriminator with fake images
        noise_vector = torch.randn(batch_size, LATENT_DIM, 1, 1, device=device)
        fake_images = generator(noise_vector)
        fake_labels = torch.full((batch_size,), 0.0, device=device)
        fake_output = discriminator(fake_images.detach()).view(-1)
        fake_loss = loss_function(fake_output, fake_labels)
        fake_loss.backward()
        
        disc_loss = real_loss + fake_loss
        disc_optimizer.step()
        
        # Train generator
        generator.zero_grad()
        gen_labels = torch.full((batch_size,), 1.0, device=device)
        gen_output = discriminator(fake_images).view(-1)
        gen_loss = loss_function(gen_output, gen_labels)
        gen_loss.backward()
        gen_optimizer.step()
        
        # Save losses
        generator_losses.append(gen_loss.item())
        discriminator_losses.append(disc_loss.item())
        
        # Generate sample images
        if batch_idx % 400 == 0:
            with torch.no_grad():
                sample_images = generator(static_noise).detach().cpu()
            generated_images.append(vutils.make_grid(sample_images, padding=2, normalize=True))

Results Visualization

# Plot training losses
plt.figure(figsize=(10, 5))
plt.plot(generator_losses, label='Generator')
plt.plot(discriminator_losses, label='Discriminator')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Display generated images
plt.figure(figsize=(15, 15))
plt.subplot(1, 2, 1)
plt.axis('off')
plt.title('Real Images')
real_display = vutils.make_grid(real_images[:64], padding=5, normalize=True)
plt.imshow(np.transpose(real_display.cpu(), (1, 2, 0)))

plt.subplot(1, 2, 2)
plt.axis('off')
plt.title('Generated Images')
plt.imshow(np.transpose(generated_images[-1], (1, 2, 0)))
plt.show()

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.