Diffusion Models Explained and Implemented with PyTorch
Generative modeling has evolved through several milestones before reaching diffusion-based approaches. Understanding variational autoencoders, generative adversarial networks, and their limitations clarifies why denoising diffusion probabilistic models (DDPMs) emerged.
Variational Autoencoders
VAEs encode inputs into a distribution over latent space characterized by mean and variance vectors. During training, samples drawn from this Gaussian guide a decoder to reconstruct the input. Loss combines reconstruction error with KL divergence to regularize the latent space. While VAEs support diverse output generation via random latent sampling, they often produce blurry results.
Generative Adversarial Networks
GANs train two networks jointly: a generator that maps random noise to data, and a discriminator that distinguishes real from synthetic samples. Their adversarial loop pushes the generator toward producing increasingly realistic outputs. Sampling involves feeding fresh noise into the trained generator. GANs yield sharp images but tend to lack diversity and may collapse to limited modes.
Motivation for Denoising Diffusion Probabilistic Models
Neither VAEs nor GANs simultaneously achieve high fidelity and diversity. DDPMs address this gap by iteratively transforming noise into data through a learned reverse process, blending strengths of both architectures.
Denoising Diffusion Probabilistic Model Mechanics
A DDPM defines a fixed forward noising process and a learnable reverse denoising process. Starting from clean data, the forward pass incrementally corrupts it with Gaussian noise across discrete steps governed by a Markov chain. The reverse pass learns to invert this corruption, starting from pure noise to generate new samples.
Forward Process
Let ( T ) denote total steps. At each step ( t ), noise is added per variance schedule ( \beta_t ), small enough to ensure gradual corruption. Given an image ( x_0 ), the noisy version at step ( t ) is obtained via re-parameterization:
def forward_noise(x0, t, beta_seq):
alpha_bar_t = torch.cumprod(1 - beta_seq, dim=0)[t]
eps = torch.randn_like(x0)
return torch.sqrt(alpha_bar_t) * x0 + torch.sqrt(1 - alpha_bar_t) * eps, eps
The Markov property ensures that only the previous step influences the current state, simplifying computation and enabling direct sampling of intermediate noise levels.
Reverse Process
Reversal requires estimating the conditional distribution ( p_{\theta}(x_{t-1} | x_t) ). A neural network approximates the mean of the added noise at each step. The variance can remain fixed or be learned. The objective becomes predicting the noise component givan a noisy observation and timestep.
Training Procedure
For each training iteration:
- Randomly pick a timestep ( t ) for every image.
- Apply forward noise to obtain ( x_t ) and the true noise ( \epsilon ).
- Feed ( x_t ) and ( t ) into a U-Net that outputs estimated noise ( \hat{\epsilon} ).
- Optimize using MSE between ( \epsilon ) and ( \hat{\epsilon} ).
The U-Net incorporates timestep embeddings so the same weights handle varying noise levels.
Sampling
Starting from Gaussian noise ( x_T ), apply the learned reverse process iteratively until ( x_0 ) is reached. This generates a new data sample.
Noise Schedule Parameters
Parameters such as ( \beta ), ( \bar{\alpha} ), and derived terms control noise magnitude per step. Total steps ( T ) are typically large (e.g., 1000). Linear schedules may cause uneven noise; cosine schedules yield smoother transitions.
PyTorch Implementation
Time Embedding
Sinusoidal position embedding injects temporal context into the model:
class SinPositionEmbed(nn.Module):
def __init__(self, dim, scale=10000):
super().__init__()
self.dim = dim
self.scale = scale
def forward(self, t):
half = self.dim // 2
freq = torch.exp(torch.arange(half, device=t.device) *
(-math.log(self.scale) / (half - 1)))
angles = t[:, None] * freq[None, :]
return torch.cat([angles.sin(), angles.cos()], dim=-1)
Enhanced Convolution Block
A ConvNeXt-inspired block enriches feature extraction:
class ConvNeXtUnit(nn.Module):
def __init__(self, inc, outc, factor=2, time_emb_dim=None, groups=8):
super().__init__()
self.time_proj = nn.Sequential(
nn.GELU(), nn.Linear(time_emb_dim, inc)
) if time_emb_dim else None
self.dw_conv = nn.Conv2d(inc, inc, 7, padding=3, groups=inc)
self.core = nn.Sequential(
nn.GroupNorm(1, inc),
nn.Conv2d(inc, outc * factor, 3, padding=1),
nn.GELU(),
nn.GroupNorm(1, outc * factor),
nn.Conv2d(outc * factor, outc, 3, padding=1)
)
self.skip = nn.Conv2d(inc, outc, 1) if inc != outc else nn.Identity()
def forward(self, x, t_emb=None):
y = self.dw_conv(x)
if self.time_proj and t_emb is not None:
y += self.time_proj(t_emb)[:, :, None, None]
y = self.core(y)
return y + self.skip(x)
Down and Up Sampling
Spatial resolution changes use reshaping and convolution:
class DownRes(nn.Module):
def __init__(self, cin, cout=None):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(cin * 4, cout or cin, 1),
Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1=2, p2=2)
)
def forward(self, x):
return self.net(x)
class UpRes(nn.Module):
def __init__(self, cin, cout=None):
super().__init__()
self.net = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(cin, cout or cin, 3, padding=1)
)
def forward(self, x):
return self.net(x)
U-Net Backbone with Dual Residual Paths
Integrating the above blocks yields a symmetrical encoder–decoder:
class DualResUNet(nn.Module):
def __init__(self, base_dim, ch_mul=(1,2,4,8), img_ch=3):
super().__init__()
self.start_conv = nn.Conv2d(img_ch, base_dim, 7, padding=3)
dims = [base_dim] + [base_dim * m for m in ch_mul]
pairs = [(dims[i], dims[i+1]) for i in range(len(dims)-1)]
pos_emb = SinPositionEmbed(base_dim)
t_emb_dim = base_dim * 4
self.t_mlp = nn.Sequential(
pos_emb, nn.Linear(base_dim, t_emb_dim), nn.GELU(),
nn.Linear(t_emb_dim, t_emb_dim)
)
self.down_path = nn.ModuleList()
self.up_path = nn.ModuleList()
for idx, (ci, co) in enumerate(pairs):
last = idx == len(pairs)-1
self.down_path.append(nn.ModuleList([
ConvNeXtUnit(ci, ci, time_emb_dim=t_emb_dim),
ConvNeXtUnit(ci, ci, time_emb_dim=t_emb_dim),
DownRes(ci, co) if not last else nn.Conv2d(ci, co, 3, padding=1)
]))
mid = dims[-1]
self.mid1 = ConvNeXtUnit(mid, mid, time_emb_dim=t_emb_dim)
self.mid2 = ConvNeXtUnit(mid, mid, time_emb_dim=t_emb_dim)
for idx, (ci, co) in enumerate(reversed(pairs)):
last = idx == len(pairs)-1
self.up_path.append(nn.ModuleList([
ConvNeXtUnit(co*2, co, time_emb_dim=t_emb_dim),
ConvNeXtUnit(co*2, co, time_emb_dim=t_emb_dim),
UpRes(co, ci) if not last else nn.Conv2d(co, ci, 3, padding=1)
]))
self.exit_res = ConvNeXtUnit(base_dim*2, base_dim, time_emb_dim=t_emb_dim)
self.exit_conv = nn.Conv2d(base_dim, img_ch, 1)
def forward(self, x, t):
x = self.start_conv(x)
skip_store = [x]
t_vec = self.t_mlp(t)
for b1, b2, down in self.down_path:
x = b1(x, t_vec); skip_store.append(x)
x = b2(x, t_vec); skip_store.append(x)
x = down(x)
x = self.mid1(x, t_vec)
x = self.mid2(x, t_vec)
for b1, b2, up in self.up_path:
x = torch.cat([x, skip_store.pop()], dim=1)
x = b1(x, t_vec)
x = torch.cat([x, skip_store.pop()], dim=1)
x = b2(x, t_vec)
x = up(x)
x = torch.cat([x, skip_store[0]], dim=1)
x = self.exit_res(x, t_vec)
return self.exit_conv(x)
Diffusion Engine
Noise schedules and sampling logic encapsulate the full process:
class ImgDiffuser(nn.Module):
def __init__(self, net, im_sz, steps=1000, sched='cosine'):
super().__init__()
self.net = net
self.im_sz = im_sz
self.steps = steps
self.betas = self._make_betas(sched, steps)
self.alphas = 1 - self.betas
self.alpha_bar = torch.cumprod(self.alphas, 0)
self.register_buffers()
def _make_betas(self, kind, n):
if kind=='cosine':
s = torch.linspace(0,1,n+1)
return torch.cos((s+0.008)/1.008*math.pi/2)**2
return torch.linspace(1e-4, 0.02, n)
def register_buffers(self):
for name, buf in {
'betas': self.betas,
'alphas': self.alphas,
'alpha_bar': self.alpha_bar
}.items():
self.register_buffer(name, buf.float())
def corrupt(self, x0, t):
e = torch.randn_like(x0)
ab = extract(self.alpha_bar, t, x0.shape)
return torch.sqrt(ab)*x0 + torch.sqrt(1-ab)*e, e
def sample_step(self, x, t):
pred_e = self.net(x, t)
ab = extract(self.alpha_bar, t, x.shape)
ab_prev = extract(self.alpha_bar, t-1, x.shape)
beta = extract(self.betas, t, x.shape)
mean = (1/torch.sqrt(self.alphas[t]))*(x - (beta*pred_e)/torch.sqrt(1-ab))
if t==0:
return mean
var = (1-ab_prev)/(1-ab)*beta
return mean + torch.sqrt(var)*torch.randn_like(x)
def generate(self, bs=16):
x = torch.randn(bs, self.net.channels, self.im_sz, self.im_sz)
for ti in reversed(range(self.steps)):
x = self.sample_step(x, ti)
return x
Training iterates over random timesteps, applies corruption, and minimizes prediction error of injected noise. Sampling runs the reverse process to synthesize images from noise.