Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Diffusion Models Explained and Implemented with PyTorch

Tech May 8 4

Generative modeling has evolved through several milestones before reaching diffusion-based approaches. Understanding variational autoencoders, generative adversarial networks, and their limitations clarifies why denoising diffusion probabilistic models (DDPMs) emerged.

Variational Autoencoders

VAEs encode inputs into a distribution over latent space characterized by mean and variance vectors. During training, samples drawn from this Gaussian guide a decoder to reconstruct the input. Loss combines reconstruction error with KL divergence to regularize the latent space. While VAEs support diverse output generation via random latent sampling, they often produce blurry results.

Generative Adversarial Networks

GANs train two networks jointly: a generator that maps random noise to data, and a discriminator that distinguishes real from synthetic samples. Their adversarial loop pushes the generator toward producing increasingly realistic outputs. Sampling involves feeding fresh noise into the trained generator. GANs yield sharp images but tend to lack diversity and may collapse to limited modes.

Motivation for Denoising Diffusion Probabilistic Models

Neither VAEs nor GANs simultaneously achieve high fidelity and diversity. DDPMs address this gap by iteratively transforming noise into data through a learned reverse process, blending strengths of both architectures.

Denoising Diffusion Probabilistic Model Mechanics

A DDPM defines a fixed forward noising process and a learnable reverse denoising process. Starting from clean data, the forward pass incrementally corrupts it with Gaussian noise across discrete steps governed by a Markov chain. The reverse pass learns to invert this corruption, starting from pure noise to generate new samples.

Forward Process

Let ( T ) denote total steps. At each step ( t ), noise is added per variance schedule ( \beta_t ), small enough to ensure gradual corruption. Given an image ( x_0 ), the noisy version at step ( t ) is obtained via re-parameterization:

def forward_noise(x0, t, beta_seq):
    alpha_bar_t = torch.cumprod(1 - beta_seq, dim=0)[t]
    eps = torch.randn_like(x0)
    return torch.sqrt(alpha_bar_t) * x0 + torch.sqrt(1 - alpha_bar_t) * eps, eps

The Markov property ensures that only the previous step influences the current state, simplifying computation and enabling direct sampling of intermediate noise levels.

Reverse Process

Reversal requires estimating the conditional distribution ( p_{\theta}(x_{t-1} | x_t) ). A neural network approximates the mean of the added noise at each step. The variance can remain fixed or be learned. The objective becomes predicting the noise component givan a noisy observation and timestep.

Training Procedure

For each training iteration:

  1. Randomly pick a timestep ( t ) for every image.
  2. Apply forward noise to obtain ( x_t ) and the true noise ( \epsilon ).
  3. Feed ( x_t ) and ( t ) into a U-Net that outputs estimated noise ( \hat{\epsilon} ).
  4. Optimize using MSE between ( \epsilon ) and ( \hat{\epsilon} ).

The U-Net incorporates timestep embeddings so the same weights handle varying noise levels.

Sampling

Starting from Gaussian noise ( x_T ), apply the learned reverse process iteratively until ( x_0 ) is reached. This generates a new data sample.

Noise Schedule Parameters

Parameters such as ( \beta ), ( \bar{\alpha} ), and derived terms control noise magnitude per step. Total steps ( T ) are typically large (e.g., 1000). Linear schedules may cause uneven noise; cosine schedules yield smoother transitions.

PyTorch Implementation

Time Embedding

Sinusoidal position embedding injects temporal context into the model:

class SinPositionEmbed(nn.Module):
    def __init__(self, dim, scale=10000):
        super().__init__()
        self.dim = dim
        self.scale = scale

    def forward(self, t):
        half = self.dim // 2
        freq = torch.exp(torch.arange(half, device=t.device) * 
                         (-math.log(self.scale) / (half - 1)))
        angles = t[:, None] * freq[None, :]
        return torch.cat([angles.sin(), angles.cos()], dim=-1)

Enhanced Convolution Block

A ConvNeXt-inspired block enriches feature extraction:

class ConvNeXtUnit(nn.Module):
    def __init__(self, inc, outc, factor=2, time_emb_dim=None, groups=8):
        super().__init__()
        self.time_proj = nn.Sequential(
            nn.GELU(), nn.Linear(time_emb_dim, inc)
        ) if time_emb_dim else None

        self.dw_conv = nn.Conv2d(inc, inc, 7, padding=3, groups=inc)
        self.core = nn.Sequential(
            nn.GroupNorm(1, inc),
            nn.Conv2d(inc, outc * factor, 3, padding=1),
            nn.GELU(),
            nn.GroupNorm(1, outc * factor),
            nn.Conv2d(outc * factor, outc, 3, padding=1)
        )
        self.skip = nn.Conv2d(inc, outc, 1) if inc != outc else nn.Identity()

    def forward(self, x, t_emb=None):
        y = self.dw_conv(x)
        if self.time_proj and t_emb is not None:
            y += self.time_proj(t_emb)[:, :, None, None]
        y = self.core(y)
        return y + self.skip(x)

Down and Up Sampling

Spatial resolution changes use reshaping and convolution:

class DownRes(nn.Module):
    def __init__(self, cin, cout=None):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(cin * 4, cout or cin, 1),
            Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1=2, p2=2)
        )

    def forward(self, x):
        return self.net(x)

class UpRes(nn.Module):
    def __init__(self, cin, cout=None):
        super().__init__()
        self.net = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(cin, cout or cin, 3, padding=1)
        )

    def forward(self, x):
        return self.net(x)

U-Net Backbone with Dual Residual Paths

Integrating the above blocks yields a symmetrical encoder–decoder:

class DualResUNet(nn.Module):
    def __init__(self, base_dim, ch_mul=(1,2,4,8), img_ch=3):
        super().__init__()
        self.start_conv = nn.Conv2d(img_ch, base_dim, 7, padding=3)
        dims = [base_dim] + [base_dim * m for m in ch_mul]
        pairs = [(dims[i], dims[i+1]) for i in range(len(dims)-1)]

        pos_emb = SinPositionEmbed(base_dim)
        t_emb_dim = base_dim * 4
        self.t_mlp = nn.Sequential(
            pos_emb, nn.Linear(base_dim, t_emb_dim), nn.GELU(),
            nn.Linear(t_emb_dim, t_emb_dim)
        )

        self.down_path = nn.ModuleList()
        self.up_path = nn.ModuleList()

        for idx, (ci, co) in enumerate(pairs):
            last = idx == len(pairs)-1
            self.down_path.append(nn.ModuleList([
                ConvNeXtUnit(ci, ci, time_emb_dim=t_emb_dim),
                ConvNeXtUnit(ci, ci, time_emb_dim=t_emb_dim),
                DownRes(ci, co) if not last else nn.Conv2d(ci, co, 3, padding=1)
            ]))

        mid = dims[-1]
        self.mid1 = ConvNeXtUnit(mid, mid, time_emb_dim=t_emb_dim)
        self.mid2 = ConvNeXtUnit(mid, mid, time_emb_dim=t_emb_dim)

        for idx, (ci, co) in enumerate(reversed(pairs)):
            last = idx == len(pairs)-1
            self.up_path.append(nn.ModuleList([
                ConvNeXtUnit(co*2, co, time_emb_dim=t_emb_dim),
                ConvNeXtUnit(co*2, co, time_emb_dim=t_emb_dim),
                UpRes(co, ci) if not last else nn.Conv2d(co, ci, 3, padding=1)
            ]))

        self.exit_res = ConvNeXtUnit(base_dim*2, base_dim, time_emb_dim=t_emb_dim)
        self.exit_conv = nn.Conv2d(base_dim, img_ch, 1)

    def forward(self, x, t):
        x = self.start_conv(x)
        skip_store = [x]
        t_vec = self.t_mlp(t)

        for b1, b2, down in self.down_path:
            x = b1(x, t_vec); skip_store.append(x)
            x = b2(x, t_vec); skip_store.append(x)
            x = down(x)

        x = self.mid1(x, t_vec)
        x = self.mid2(x, t_vec)

        for b1, b2, up in self.up_path:
            x = torch.cat([x, skip_store.pop()], dim=1)
            x = b1(x, t_vec)
            x = torch.cat([x, skip_store.pop()], dim=1)
            x = b2(x, t_vec)
            x = up(x)

        x = torch.cat([x, skip_store[0]], dim=1)
        x = self.exit_res(x, t_vec)
        return self.exit_conv(x)

Diffusion Engine

Noise schedules and sampling logic encapsulate the full process:

class ImgDiffuser(nn.Module):
    def __init__(self, net, im_sz, steps=1000, sched='cosine'):
        super().__init__()
        self.net = net
        self.im_sz = im_sz
        self.steps = steps
        self.betas = self._make_betas(sched, steps)
        self.alphas = 1 - self.betas
        self.alpha_bar = torch.cumprod(self.alphas, 0)
        self.register_buffers()

    def _make_betas(self, kind, n):
        if kind=='cosine':
            s = torch.linspace(0,1,n+1)
            return torch.cos((s+0.008)/1.008*math.pi/2)**2
        return torch.linspace(1e-4, 0.02, n)

    def register_buffers(self):
        for name, buf in {
            'betas': self.betas,
            'alphas': self.alphas,
            'alpha_bar': self.alpha_bar
        }.items():
            self.register_buffer(name, buf.float())

    def corrupt(self, x0, t):
        e = torch.randn_like(x0)
        ab = extract(self.alpha_bar, t, x0.shape)
        return torch.sqrt(ab)*x0 + torch.sqrt(1-ab)*e, e

    def sample_step(self, x, t):
        pred_e = self.net(x, t)
        ab = extract(self.alpha_bar, t, x.shape)
        ab_prev = extract(self.alpha_bar, t-1, x.shape)
        beta = extract(self.betas, t, x.shape)
        mean = (1/torch.sqrt(self.alphas[t]))*(x - (beta*pred_e)/torch.sqrt(1-ab))
        if t==0:
            return mean
        var = (1-ab_prev)/(1-ab)*beta
        return mean + torch.sqrt(var)*torch.randn_like(x)

    def generate(self, bs=16):
        x = torch.randn(bs, self.net.channels, self.im_sz, self.im_sz)
        for ti in reversed(range(self.steps)):
            x = self.sample_step(x, ti)
        return x

Training iterates over random timesteps, applies corruption, and minimizes prediction error of injected noise. Sampling runs the reverse process to synthesize images from noise.

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.