Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Implementing JPEG Compression for Image Processing in Python

Tech May 18 4

JPEG Compression Fundamentals

JPEG (Joint Photographic Experts Group) compression is a lossy algorithm that reduces image file sizes by exploiting human visual perception characteristics. The human eye is more sensitive to brightness variations than color changes, allowing the compression to preserve luminance information while discarding some chrominance data.

How JPEG Compression Works

The compression process divides images into 8x8 pixel blocks, applies discrete cosine transform (DCT), and quantizes the resulting coefficients. Higher compression ratios produce smaller files but introduce visible artifacts such as blocking effects and color bleeding.

A quality setting between 70-80% typically provides a good balance between file size reduction and visual quality.

Integrating JPEG Compression into Image Degradation Simulation

Adding JPEG compression to image degradation pipelines offers several advantages for research and development.

Benefits for Simulation Accuracy

Real-world images suffer from various degradations including lens aberrations, atmospheric scattering, and compression artifacts. Simulating JPEG compression alongside other degradation factors produces more realistic synthetic training data for image restoration models.

JPEG compression also reduces memory footprint and accelerates processing when handling large datasets containing thousands of images.

Implementation Trade-offs

Combining JPEG compression with other degradation factors may compound quality loss, potentially reducing the effectiveness of training data. Additionally, the added artifact patterns make model interpretation more challenging.

OpenCV Implementation

The OpenCV library provides straightforward functions for encoding and decoding JPEG images.

Encoding and Decoding Functions

import cv2
import numpy as np

def apply_opencv_jpeg_compression(image_path, quality):
    """
    Apply JPEG compression using OpenCV.
    
    Args:
        image_path: Path to input image
        quality: JPEG quality (1-100)
    
    Returns:
        Compressed and decompressed image as float32
    """
    img = cv2.imread(image_path)
    encode_params = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
    
    _, encoded_buffer = cv2.imencode('.jpg', img, encode_params)
    decoded_img = cv2.imdecode(encoded_buffer, cv2.IMREAD_COLOR)
    
    return np.float32(decoded_img)

# Example usage
compressed_20 = apply_opencv_jpeg_compression('input.png', 20)
compressed_80 = apply_opencv_jpeg_compression('input.png', 80)

cv2.imwrite('output_20.png', compressed_20)
cv2.imwrite('output_80.png', compressed_80)

PyTorch-Based Differentiable JPEG Compression

The DiffJPEG implementation provides a differentiable version suitable for deep learning pipelines, supporting batch processing with slightly different results compared to OpenCV.

Core Components

import torch
import torch.nn as nn
import numpy as np
import itertools
from torch.nn import functional as F

# Quantization tables
luminance_table = np.array(
    [[16, 11, 10, 16, 24, 40, 51, 61],
     [12, 12, 14, 19, 26, 58, 60, 55],
     [14, 13, 16, 24, 40, 57, 69, 56],
     [14, 17, 22, 29, 51, 87, 80, 62],
     [18, 22, 37, 56, 68, 109, 103, 77],
     [24, 35, 55, 64, 81, 104, 113, 92],
     [49, 64, 78, 87, 103, 121, 120, 101],
     [72, 92, 95, 98, 112, 100, 103, 99]],
    dtype=np.float32).T

chrominance_table = np.empty((8, 8), dtype=np.float32)
chrominance_table.fill(99)
chrominance_table[:4, :4] = np.array(
    [[17, 18, 24, 47], [18, 21, 26, 66], 
     [24, 26, 56, 99], [47, 66, 99, 99]]).T

luminance_table = nn.Parameter(torch.from_numpy(luminance_table))
chrominance_table = nn.Parameter(torch.from_numpy(chrominance_table))


def soft_round(x):
    """Differentiable rounding approximation"""
    return torch.round(x) + (x - torch.round(x)) ** 3


def quality_to_compression_factor(quality):
    """Convert quality setting to compression factor"""
    if quality < 50:
        return 5000.0 / quality
    return (200.0 - quality * 2) / 100.0

Color Space Conversion and Downsampling

class RGBToYCbCr(nn.Module):
    """Transform RGB to YCbCr color space"""
    
    def __init__(self):
        super().__init__()
        matrix = np.array(
            [[0.299, 0.587, 0.114],
             [-0.168736, -0.331264, 0.5],
             [0.5, -0.418688, -0.081312]],
            dtype=np.float32).T
        self.bias = nn.Parameter(torch.tensor([0., 128., 128.]))
        self.weight = nn.Parameter(torch.from_numpy(matrix))
    
    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        result = torch.tensordot(x, self.weight, dims=1) + self.bias
        return result.view(x.shape)


class ChromaDownsample(nn.Module):
    """Reduce chroma channel resolution by half"""
    
    def __init__(self):
        super().__init__()
    
    def forward(self, image):
        """
        Args:
            image: batch x height x width x 3
        Returns:
            y, cb, cr channels with cb/cr at half resolution
        """
        img_perm = image.permute(0, 3, 1, 2).clone()
        cb = F.avg_pool2d(img_perm[:, 1, :, :].unsqueeze(1),
                          kernel_size=2, stride=2,
                          count_include_pad=False)
        cr = F.avg_pool2d(img_perm[:, 2, :, :].unsqueeze(1),
                          kernel_size=2, stride=2,
                          count_include_pad=False)
        cb = cb.permute(0, 2, 3, 1).squeeze(3)
        cr = cr.permute(0, 2, 3, 1).squeeze(3)
        return image[:, :, :, 0], cb, cr

Block Processing and DCT

class BlockSplitter(nn.Module):
    """Split image into 8x8 blocks"""
    
    def __init__(self):
        super().__init__()
        self.block_size = 8
    
    def forward(self, image):
        """
        Args:
            image: batch x height x width
        Returns:
            batch x (h*w/64) x 8 x 8
        """
        b, h, w = image.shape[0], image.shape[1], image.shape[2]
        blocks = image.view(b, h // 8, 8, w // 8, 8)
        blocks = blocks.permute(0, 1, 3, 2, 4)
        return blocks.contiguous().view(b, -1, 8, 8)


class ForwardDCT(nn.Module):
    """8x8 Discrete Cosine Transform"""
    
    def __init__(self):
        super().__init__()
        basis = np.zeros((8, 8, 8, 8), dtype=np.float32)
        for i, j, u, v in itertools.product(range(8), repeat=4):
            basis[i, j, u, v] = (np.cos((2 * i + 1) * u * np.pi / 16) *
                                 np.cos((2 * j + 1) * v * np.pi / 16))
        alpha = np.array([1 / np.sqrt(2)] + [1] * 7)
        self.basis = nn.Parameter(torch.from_numpy(basis).float())
        self.scale = nn.Parameter(
            torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())
    
    def forward(self, image):
        """Apply DCT to image blocks"""
        centered = image - 128
        return self.scale * torch.tensordot(centered, self.basis, dims=2)

Quantization Layers

class LuminanceQuantizer(nn.Module):
    """Quantize luminance component"""
    
    def __init__(self, rounding_fn):
        super().__init__()
        self.round = rounding_fn
        self.table = luminance_table
    
    def forward(self, coefficients, factor=1):
        """Apply quantization with given factor"""
        if isinstance(factor, (int, float)):
            scaled = coefficients.float() / (self.table * factor)
        else:
            b = factor.size(0)
            table = self.table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
            scaled = coefficients.float() / table
        return self.round(scaled)


class ChrominanceQuantizer(nn.Module):
    """Quantize chrominance components"""
    
    def __init__(self, rounding_fn):
        super().__init__()
        self.round = rounding_fn
        self.table = chrominance_table
    
    def forward(self, coefficients, factor=1):
        """Apply quantization with given factor"""
        if isinstance(factor, (int, float)):
            scaled = coefficients.float() / (self.table * factor)
        else:
            b = factor.size(0)
            table = self.table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
            scaled = coefficients.float() / table
        return self.round(scaled)

Decompression Components

class LuminanceDequantizer(nn.Module):
    """Reverse luminance quantization"""
    
    def __init__(self):
        super().__init__()
        self.table = luminance_table
    
    def forward(self, quantized, factor=1):
        if isinstance(factor, (int, float)):
            return quantized * (self.table * factor)
        b = factor.size(0)
        table = self.table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
        return quantized * table


class InverseDCT(nn.Module):
    """Inverse 8x8 DCT"""
    
    def __init__(self):
        super().__init__()
        alpha = np.array([1 / np.sqrt(2)] + [1] * 7)
        self.alpha = nn.Parameter(torch.from_numpy(
            np.outer(alpha, alpha)).float())
        basis = np.zeros((8, 8, 8, 8), dtype=np.float32)
        for i, j, u, v in itertools.product(range(8), repeat=4):
            basis[i, j, u, v] = (np.cos((2 * u + 1) * i * np.pi / 16) *
                                 np.cos((2 * v + 1) * j * np.pi / 16))
        self.basis = nn.Parameter(torch.from_numpy(basis).float())
    
    def forward(self, coeffs):
        """Apply inverse DCT"""
        scaled = coeffs * self.alpha
        return 0.25 * torch.tensordot(scaled, self.basis, dims=2) + 128


class BlockMerger(nn.Module):
    """Reassemble 8x8 blocks into full image"""
    
    def forward(self, patches, height, width):
        """Merge patches back into image"""
        k = 8
        b = patches.shape[0]
        merged = patches.view(b, height // k, width // k, k, k)
        merged = merged.permute(0, 1, 3, 2, 4)
        return merged.contiguous().view(b, height, width)


class ChromaUpsample(nn.Module):
    """Upsample chroma channels to full resolution"""
    
    def forward(self, y, cb, cr):
        """Upsample and concatenate channels"""
        def upsample(x):
            h, w = x.shape[1:3]
            x = x.unsqueeze(-1).repeat(1, 1, 2, 2)
            return x.view(-1, h * 2, w * 2)
        
        return torch.cat([y.unsqueeze(3),
                          upsample(cb).unsqueeze(3),
                          upsample(cr).unsqueeze(3)], dim=3)


class YCbCrToRGB(nn.Module):
    """Convert YCbCr back to RGB color space"""
    
    def __init__(self):
        super().__init__()
        matrix = np.array(
            [[1., 0., 1.402],
             [1, -0.344136, -0.714136],
             [1, 1.772, 0]],
            dtype=np.float32).T
        self.bias = nn.Parameter(torch.tensor([0, -128., -128.]))
        self.weight = nn.Parameter(torch.from_numpy(matrix))
    
    def forward(self, image):
        """Transform YCbCr to RGB"""
        result = torch.tensordot(image + self.bias, self.weight, dims=1)
        return result.view(image.shape).permute(0, 3, 1, 2)

Complete Differentiable JPEG Module

class DiffJPEG(nn.Module):
    """
    End-to-end differentiable JPEG compression.
    Slightly different results from OpenCV but supports backpropagation.
    """
    
    def __init__(self, differentiable=True):
        super().__init__()
        rounding = soft_round if differentiable else torch.round
        
        self.color_convert = RGBToYCbCr()
        self.chroma_down = ChromaDownsample()
        self.block_split = BlockSplitter()
        self.dct = ForwardDCT()
        self.y_quant = LuminanceQuantizer(rounding)
        self.c_quant = ChrominanceQuantizer(rounding)
        
        self.y_dequant = LuminanceDequantizer()
        self.c_dequant = LuminanceDequantizer()
        self.idct = InverseDCT()
        self.block_merge = BlockMerger()
        self.chroma_up = ChromaUpsample()
        self.rgb_convert = YCbCrToRGB()
    
    def compress(self, x, factor):
        """Forward compression pass"""
        y, cb, cr = self.chroma_down(self.color_convert(x * 255))
        
        components = {'y': y, 'cb': cb, 'cr': cr}
        for key in components:
            coeffs = self.block_split(components[key])
            coeffs = self.dct(coeffs)
            if key in ('cb', 'cr'):
                components[key] = self.c_quant(coeffs, factor)
            else:
                components[key] = self.y_quant(coeffs, factor)
        
        return components['y'], components['cb'], components['cr']
    
    def decompress(self, y, cb, cr, h, w, factor):
        """Forward decompression pass"""
        components = {'y': y, 'cb': cb, 'cr': cr}
        for key in components:
            if key in ('cb', 'cr'):
                coeffs = self.c_dequant(components[key], factor)
                ch, cw = int(h / 2), int(w / 2)
            else:
                coeffs = self.y_dequant(components[key], factor)
                ch, cw = h, w
            components[key] = self.block_merge(self.idct(coeffs), ch, cw)
        
        merged = self.chroma_up(components['y'],
                                components['cb'],
                                components['cr'])
        rgb = self.rgb_convert(merged)
        rgb = torch.clamp(rgb, 0, 255)
        return rgb / 255
    
    def forward(self, x, quality):
        """
        Args:
            x: Input image batch, shape (B, C, H, W), RGB, range [0, 1]
            quality: JPEG quality factor (1-100)
        """
        if isinstance(quality, (int, float)):
            factor = quality_to_compression_factor(quality)
        else:
            factor = torch.stack([
                quality_to_compression_factor(q)
                for q in quality
            ])
        
        c, h, w = x.shape[1:]
        h_pad = (16 - h % 16) % 16
        w_pad = (16 - w % 16) % 16
        x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant')
        
        y, cb, cr = self.compress(x, factor)
        result = self.decompress(y, cb, cr, h + h_pad, w + w_pad, factor)
        return result[:, :, :h, :w]

Usage Example

def img_to_tensor(img):
    """Convert numpy image to tensor"""
    if img.ndim == 2:
        img = np.expand_dims(img, 2)
    tensor = torch.from_numpy(np.ascontiguousarray(img))
    return tensor.permute(2, 0, 1).float()

def tensor_to_img(tensor):
    """Convert tensor back to numpy image"""
    img = tensor.permute(1, 2, 0).cpu().numpy()
    return np.clip(img * 255, 0, 255).astype(np.uint8)

# Load and process image
import cv2

img = cv2.imread('input.jpg') / 255.0
tensor_img = img_to_tensor(img).unsqueeze(0).cuda()

# Apply compression at different quality levels
compressor = DiffJPEG(differentiable=False).cuda()
qualities = torch.tensor([10, 50, 90]).cuda()
batch = tensor_img.repeat(len(qualities), 1, 1, 1)

compressed = compressor(batch, qualities)

for i, q in enumerate(qualities):
    cv2.imwrite(f'compressed_q{int(q)}.png', tensor_to_img(compressed[i]))

Results Comparison

Quality File Size Reduction Visual Quality
90-100 30-50% Excellent
70-80 50-70% Good
50-60 70-85% Moderate
20-30 85-95% Visible artifacts

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.