Deep Convolutional GAN for Anime Portrait Synthesis
DCGAN extends the foundational GAN architecture by replacing fully-connected layers with convolutional operations, enabling more stable training for high-resolution image generation. The generator utilizes transposed convolutions to upsample latent vectors into RGB images, while the discriminator employs strided convolutions to classify input authenticity.
This implementation utilizes the Anime Faces dataset containing 70,171 images scaled to 96×96 pixels. The objective involves training adversarial networks where the generator learns to map 100-dimansional Gaussian noise to realistic anime portraits, while the discriminator distinguishes between generated samples and authentic dataset images.
Dataset Preparation
Download and extract the dataset to the working directory:
from download import download
url = "https://download.mindspore.cn/dataset/Faces/faces.zip"
path = download(url, "./anime_data", kind="zip", replace=True)
The extracted directory structure contains individual JPEG files numbered sequentially from 0 to 70170. Initialize training hyperparameters and construct the data pipeline:
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
import numpy as np
batch_size = 128
img_resolution = 64
channel_count = 3
latent_dim = 100
gen_base_features = 64
disc_base_features = 64
max_epochs = 3
lr = 0.0002
beta1 = 0.5
def build_dataloader(root_path):
dataset = ds.ImageFolderDataset(root_path, num_parallel_workers=4,
shuffle=True, decode=True)
transform_pipeline = [
vision.Resize(img_resolution),
vision.CenterCrop(img_resolution),
vision.HWC2CHW(),
lambda x: (x / 255.0).astype(np.float32)
]
dataset = dataset.project('image')
dataset = dataset.map(transform_pipeline, 'image')
dataset = dataset.batch(batch_size)
return dataset
train_data = build_dataloader('./anime_data/faces')
Visualize a subset of training samples to verify preprocessing:
import matplotlib.pyplot as plt
def display_batch(data):
plt.figure(figsize=(10, 3), dpi=140)
for idx, img in enumerate(data[0][:30], 1):
plt.subplot(3, 10, idx)
plt.axis('off')
plt.imshow(img.transpose(1, 2, 0))
plt.show()
sample_batch = next(train_data.create_tuple_iterator(output_numpy=True))
display_batch(sample_batch)
Network Architecture
Initialize weights from a normal distribution (μ=0, σ=0.02) as specified in the original DCGAN paper. The generator architecture progressively upsamples the latent vector through five transposed convolutional layers, with batch normalization and ReLU activations applied throughout. The final layer uses Tanh activation to constrain output values between [-1, 1].
import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normal
weight_init = Normal(mean=0.0, sigma=0.02)
gamma_init = Normal(mean=1.0, sigma=0.02)
class Generator(nn.Cell):
def __init__(self):
super().__init__()
self.main = nn.SequentialCell([
nn.Conv2dTranspose(latent_dim, gen_base_features * 8, 4, 1,
pad_mode='valid', weight_init=weight_init),
nn.BatchNorm2d(gen_base_features * 8, gamma_init=gamma_init),
nn.ReLU(),
nn.Conv2dTranspose(gen_base_features * 8, gen_base_features * 4, 4, 2,
pad_mode='pad', padding=1, weight_init=weight_init),
nn.BatchNorm2d(gen_base_features * 4, gamma_init=gamma_init),
nn.ReLU(),
nn.Conv2dTranspose(gen_base_features * 4, gen_base_features * 2, 4, 2,
pad_mode='pad', padding=1, weight_init=weight_init),
nn.BatchNorm2d(gen_base_features * 2, gamma_init=gamma_init),
nn.ReLU(),
nn.Conv2dTranspose(gen_base_features * 2, gen_base_features, 4, 2,
pad_mode='pad', padding=1, weight_init=weight_init),
nn.BatchNorm2d(gen_base_features, gamma_init=gamma_init),
nn.ReLU(),
nn.Conv2dTranspose(gen_base_features, channel_count, 4, 2,
pad_mode='pad', padding=1, weight_init=weight_init),
nn.Tanh()
])
def construct(self, z):
return self.main(z)
generator = Generator()
The discriminator implements a mirrored architecture using standard convolutions with LeakyReLU (α=0.2) activations. Batch normalization stabilizes training across the four convolutional layers, culminating in a single scalar output through Sigmoid activation representing the probability of input authenticity.
class Discriminator(nn.Cell):
def __init__(self):
super().__init__()
self.features = nn.SequentialCell([
nn.Conv2d(channel_count, disc_base_features, 4, 2,
pad_mode='pad', padding=1, weight_init=weight_init),
nn.LeakyReLU(0.2),
nn.Conv2d(disc_base_features, disc_base_features * 2, 4, 2,
pad_mode='pad', padding=1, weight_init=weight_init),
nn.BatchNorm2d(disc_base_features * 2, gamma_init=gamma_init),
nn.LeakyReLU(0.2),
nn.Conv2d(disc_base_features * 2, disc_base_features * 4, 4, 2,
pad_mode='pad', padding=1, weight_init=weight_init),
nn.BatchNorm2d(disc_base_features * 4, gamma_init=gamma_init),
nn.LeakyReLU(0.2),
nn.Conv2d(disc_base_features * 4, disc_base_features * 8, 4, 2,
pad_mode='pad', padding=1, weight_init=weight_init),
nn.BatchNorm2d(disc_base_features * 8, gamma_init=gamma_init),
nn.LeakyReLU(0.2),
nn.Conv2d(disc_base_features * 8, 1, 4, 1,
pad_mode='valid', weight_init=weight_init),
])
self.classifier = nn.Sigmoid()
def construct(self, x):
out = self.features(x)
out = out.view(out.shape[0], -1)
return self.classifier(out)
discriminator = Discriminator()
Training Configuration
Binary cross-entropy loss measures the adversarial objective, with separate Adam optimizers for each network component:
bce_loss = nn.BCELoss(reduction='mean')
optimizer_g = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_d = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_g.update_parameters_name('opt_g.')
optimizer_d.update_parameters_name('opt_d.')
Define the adversarial training logic using MindSpore's automatic differentiation:
def forward_g(real_imgs, valid_labels):
noise = ops.standard_normal((real_imgs.shape[0], latent_dim, 1, 1))
synthetic = generator(noise)
predictions = discriminator(synthetic)
loss = bce_loss(predictions, valid_labels)
return loss, synthetic
def forward_d(real_imgs, synthetic_imgs, valid_labels, fake_labels):
real_loss = bce_loss(discriminator(real_imgs), valid_labels)
fake_loss = bce_loss(discriminator(synthetic_imgs), fake_labels)
avg_loss = (real_loss + fake_loss) / 2
return avg_loss
grad_g_fn = ms.value_and_grad(forward_g, None, optimizer_g.parameters, has_aux=True)
grad_d_fn = ms.value_and_grad(forward_d, None, optimizer_d.parameters)
@ms.jit
def train_batch(imgs):
valid = ops.ones((imgs.shape[0], 1), ms.float32)
fake = ops.zeros((imgs.shape[0], 1), ms.float32)
(loss_g, gen_output), grads_g = grad_g_fn(imgs, valid)
optimizer_g(grads_g)
loss_d, grads_d = grad_d_fn(imgs, gen_output, valid, fake)
optimizer_d(grads_d)
return loss_g, loss_d, gen_output
Execute the training loop, capturing metrics and periodic samples:
history_g = []
history_d = []
progress_snapshots = []
total_steps = train_data.get_dataset_size()
for epoch in range(max_epochs):
generator.set_train()
discriminator.set_train()
for step, (images,) in enumerate(train_data.create_tuple_iterator()):
loss_gen, loss_disc, generated = train_batch(images)
if step % 100 == 0 or step == total_steps - 1:
print(f'Epoch [{epoch+1}/{max_epochs}] Batch [{step+1}/{total_steps}] '
f'D_Loss: {loss_disc.asnumpy():.4f} G_Loss: {loss_gen.asnumpy():.4f}')
history_g.append(float(loss_gen.asnumpy()))
history_d.append(float(loss_disc.asnumpy()))
generator.set_train(False)
fixed_noise = ops.standard_normal((batch_size, latent_dim, 1, 1))
epoch_samples = generator(fixed_noise).transpose(0, 2, 3, 1).asnumpy()
progress_snapshots.append(epoch_samples)
ms.save_checkpoint(generator, f"./checkpoint_gen_{epoch+1}.ckpt")
ms.save_checkpoint(discriminator, f"./checkpoint_disc_{epoch+1}.ckpt")
Results and Visualization
Plot the convergence behavier of both networks:
plt.figure(figsize=(10, 5))
plt.plot(history_g, label='Generator', color='steelblue')
plt.plot(history_d, label='Discriminator', color='coral')
plt.xlabel('Iterations')
plt.ylabel('Adversarial Loss')
plt.legend()
plt.title('Training Dynamics')
plt.show()
Generate an animated progression showing the generator's improvement across epochs:
def animate_training(sequence):
fig = plt.figure(figsize=(8, 3), dpi=120)
frames = []
for snapshot in sequence:
rows = []
for row_idx in range(3):
start = row_idx * 8
end = start + 8
rows.append(np.concatenate(snapshot[start:end], axis=1))
composite = np.clip(np.concatenate(rows, axis=0), 0, 1)
plt.axis('off')
frames.append([plt.imshow(composite)])
anim = animation.ArtistAnimation(fig, frames, interval=1000, repeat_delay=1000, blit=True)
anim.save('dcgan_evolution.gif', writer='pillow', fps=1)
animate_training(progress_snapshots)
Load the trained generator for inference on new random seeds:
ms.load_checkpoint("./checkpoint_gen_3.ckpt", generator)
inference_noise = ops.standard_normal((batch_size, latent_dim, 1, 1))
final_output = generator(inference_noise).transpose(0, 2, 3, 1).asnumpy()
fig = plt.figure(figsize=(8, 3), dpi=120)
grid_rows = [np.concatenate(final_output[i*8:(i+1)*8], axis=1) for i in range(3)]
final_grid = np.clip(np.concatenate(grid_rows, axis=0), 0, 1)
plt.axis('off')
plt.imshow(final_grid)
plt.show()