Implementing JPEG Compression for Image Processing in Python
JPEG Compression Fundamentals
JPEG (Joint Photographic Experts Group) compression is a lossy algorithm that reduces image file sizes by exploiting human visual perception characteristics. The human eye is more sensitive to brightness variations than color changes, allowing the compression to preserve luminance information while discarding some chrominance data.
How JPEG Compression Works
The compression process divides images into 8x8 pixel blocks, applies discrete cosine transform (DCT), and quantizes the resulting coefficients. Higher compression ratios produce smaller files but introduce visible artifacts such as blocking effects and color bleeding.
A quality setting between 70-80% typically provides a good balance between file size reduction and visual quality.
Integrating JPEG Compression into Image Degradation Simulation
Adding JPEG compression to image degradation pipelines offers several advantages for research and development.
Benefits for Simulation Accuracy
Real-world images suffer from various degradations including lens aberrations, atmospheric scattering, and compression artifacts. Simulating JPEG compression alongside other degradation factors produces more realistic synthetic training data for image restoration models.
JPEG compression also reduces memory footprint and accelerates processing when handling large datasets containing thousands of images.
Implementation Trade-offs
Combining JPEG compression with other degradation factors may compound quality loss, potentially reducing the effectiveness of training data. Additionally, the added artifact patterns make model interpretation more challenging.
OpenCV Implementation
The OpenCV library provides straightforward functions for encoding and decoding JPEG images.
Encoding and Decoding Functions
import cv2
import numpy as np
def apply_opencv_jpeg_compression(image_path, quality):
"""
Apply JPEG compression using OpenCV.
Args:
image_path: Path to input image
quality: JPEG quality (1-100)
Returns:
Compressed and decompressed image as float32
"""
img = cv2.imread(image_path)
encode_params = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
_, encoded_buffer = cv2.imencode('.jpg', img, encode_params)
decoded_img = cv2.imdecode(encoded_buffer, cv2.IMREAD_COLOR)
return np.float32(decoded_img)
# Example usage
compressed_20 = apply_opencv_jpeg_compression('input.png', 20)
compressed_80 = apply_opencv_jpeg_compression('input.png', 80)
cv2.imwrite('output_20.png', compressed_20)
cv2.imwrite('output_80.png', compressed_80)
PyTorch-Based Differentiable JPEG Compression
The DiffJPEG implementation provides a differentiable version suitable for deep learning pipelines, supporting batch processing with slightly different results compared to OpenCV.
Core Components
import torch
import torch.nn as nn
import numpy as np
import itertools
from torch.nn import functional as F
# Quantization tables
luminance_table = np.array(
[[16, 11, 10, 16, 24, 40, 51, 61],
[12, 12, 14, 19, 26, 58, 60, 55],
[14, 13, 16, 24, 40, 57, 69, 56],
[14, 17, 22, 29, 51, 87, 80, 62],
[18, 22, 37, 56, 68, 109, 103, 77],
[24, 35, 55, 64, 81, 104, 113, 92],
[49, 64, 78, 87, 103, 121, 120, 101],
[72, 92, 95, 98, 112, 100, 103, 99]],
dtype=np.float32).T
chrominance_table = np.empty((8, 8), dtype=np.float32)
chrominance_table.fill(99)
chrominance_table[:4, :4] = np.array(
[[17, 18, 24, 47], [18, 21, 26, 66],
[24, 26, 56, 99], [47, 66, 99, 99]]).T
luminance_table = nn.Parameter(torch.from_numpy(luminance_table))
chrominance_table = nn.Parameter(torch.from_numpy(chrominance_table))
def soft_round(x):
"""Differentiable rounding approximation"""
return torch.round(x) + (x - torch.round(x)) ** 3
def quality_to_compression_factor(quality):
"""Convert quality setting to compression factor"""
if quality < 50:
return 5000.0 / quality
return (200.0 - quality * 2) / 100.0
Color Space Conversion and Downsampling
class RGBToYCbCr(nn.Module):
"""Transform RGB to YCbCr color space"""
def __init__(self):
super().__init__()
matrix = np.array(
[[0.299, 0.587, 0.114],
[-0.168736, -0.331264, 0.5],
[0.5, -0.418688, -0.081312]],
dtype=np.float32).T
self.bias = nn.Parameter(torch.tensor([0., 128., 128.]))
self.weight = nn.Parameter(torch.from_numpy(matrix))
def forward(self, x):
x = x.permute(0, 2, 3, 1)
result = torch.tensordot(x, self.weight, dims=1) + self.bias
return result.view(x.shape)
class ChromaDownsample(nn.Module):
"""Reduce chroma channel resolution by half"""
def __init__(self):
super().__init__()
def forward(self, image):
"""
Args:
image: batch x height x width x 3
Returns:
y, cb, cr channels with cb/cr at half resolution
"""
img_perm = image.permute(0, 3, 1, 2).clone()
cb = F.avg_pool2d(img_perm[:, 1, :, :].unsqueeze(1),
kernel_size=2, stride=2,
count_include_pad=False)
cr = F.avg_pool2d(img_perm[:, 2, :, :].unsqueeze(1),
kernel_size=2, stride=2,
count_include_pad=False)
cb = cb.permute(0, 2, 3, 1).squeeze(3)
cr = cr.permute(0, 2, 3, 1).squeeze(3)
return image[:, :, :, 0], cb, cr
Block Processing and DCT
class BlockSplitter(nn.Module):
"""Split image into 8x8 blocks"""
def __init__(self):
super().__init__()
self.block_size = 8
def forward(self, image):
"""
Args:
image: batch x height x width
Returns:
batch x (h*w/64) x 8 x 8
"""
b, h, w = image.shape[0], image.shape[1], image.shape[2]
blocks = image.view(b, h // 8, 8, w // 8, 8)
blocks = blocks.permute(0, 1, 3, 2, 4)
return blocks.contiguous().view(b, -1, 8, 8)
class ForwardDCT(nn.Module):
"""8x8 Discrete Cosine Transform"""
def __init__(self):
super().__init__()
basis = np.zeros((8, 8, 8, 8), dtype=np.float32)
for i, j, u, v in itertools.product(range(8), repeat=4):
basis[i, j, u, v] = (np.cos((2 * i + 1) * u * np.pi / 16) *
np.cos((2 * j + 1) * v * np.pi / 16))
alpha = np.array([1 / np.sqrt(2)] + [1] * 7)
self.basis = nn.Parameter(torch.from_numpy(basis).float())
self.scale = nn.Parameter(
torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())
def forward(self, image):
"""Apply DCT to image blocks"""
centered = image - 128
return self.scale * torch.tensordot(centered, self.basis, dims=2)
Quantization Layers
class LuminanceQuantizer(nn.Module):
"""Quantize luminance component"""
def __init__(self, rounding_fn):
super().__init__()
self.round = rounding_fn
self.table = luminance_table
def forward(self, coefficients, factor=1):
"""Apply quantization with given factor"""
if isinstance(factor, (int, float)):
scaled = coefficients.float() / (self.table * factor)
else:
b = factor.size(0)
table = self.table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
scaled = coefficients.float() / table
return self.round(scaled)
class ChrominanceQuantizer(nn.Module):
"""Quantize chrominance components"""
def __init__(self, rounding_fn):
super().__init__()
self.round = rounding_fn
self.table = chrominance_table
def forward(self, coefficients, factor=1):
"""Apply quantization with given factor"""
if isinstance(factor, (int, float)):
scaled = coefficients.float() / (self.table * factor)
else:
b = factor.size(0)
table = self.table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
scaled = coefficients.float() / table
return self.round(scaled)
Decompression Components
class LuminanceDequantizer(nn.Module):
"""Reverse luminance quantization"""
def __init__(self):
super().__init__()
self.table = luminance_table
def forward(self, quantized, factor=1):
if isinstance(factor, (int, float)):
return quantized * (self.table * factor)
b = factor.size(0)
table = self.table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
return quantized * table
class InverseDCT(nn.Module):
"""Inverse 8x8 DCT"""
def __init__(self):
super().__init__()
alpha = np.array([1 / np.sqrt(2)] + [1] * 7)
self.alpha = nn.Parameter(torch.from_numpy(
np.outer(alpha, alpha)).float())
basis = np.zeros((8, 8, 8, 8), dtype=np.float32)
for i, j, u, v in itertools.product(range(8), repeat=4):
basis[i, j, u, v] = (np.cos((2 * u + 1) * i * np.pi / 16) *
np.cos((2 * v + 1) * j * np.pi / 16))
self.basis = nn.Parameter(torch.from_numpy(basis).float())
def forward(self, coeffs):
"""Apply inverse DCT"""
scaled = coeffs * self.alpha
return 0.25 * torch.tensordot(scaled, self.basis, dims=2) + 128
class BlockMerger(nn.Module):
"""Reassemble 8x8 blocks into full image"""
def forward(self, patches, height, width):
"""Merge patches back into image"""
k = 8
b = patches.shape[0]
merged = patches.view(b, height // k, width // k, k, k)
merged = merged.permute(0, 1, 3, 2, 4)
return merged.contiguous().view(b, height, width)
class ChromaUpsample(nn.Module):
"""Upsample chroma channels to full resolution"""
def forward(self, y, cb, cr):
"""Upsample and concatenate channels"""
def upsample(x):
h, w = x.shape[1:3]
x = x.unsqueeze(-1).repeat(1, 1, 2, 2)
return x.view(-1, h * 2, w * 2)
return torch.cat([y.unsqueeze(3),
upsample(cb).unsqueeze(3),
upsample(cr).unsqueeze(3)], dim=3)
class YCbCrToRGB(nn.Module):
"""Convert YCbCr back to RGB color space"""
def __init__(self):
super().__init__()
matrix = np.array(
[[1., 0., 1.402],
[1, -0.344136, -0.714136],
[1, 1.772, 0]],
dtype=np.float32).T
self.bias = nn.Parameter(torch.tensor([0, -128., -128.]))
self.weight = nn.Parameter(torch.from_numpy(matrix))
def forward(self, image):
"""Transform YCbCr to RGB"""
result = torch.tensordot(image + self.bias, self.weight, dims=1)
return result.view(image.shape).permute(0, 3, 1, 2)
Complete Differentiable JPEG Module
class DiffJPEG(nn.Module):
"""
End-to-end differentiable JPEG compression.
Slightly different results from OpenCV but supports backpropagation.
"""
def __init__(self, differentiable=True):
super().__init__()
rounding = soft_round if differentiable else torch.round
self.color_convert = RGBToYCbCr()
self.chroma_down = ChromaDownsample()
self.block_split = BlockSplitter()
self.dct = ForwardDCT()
self.y_quant = LuminanceQuantizer(rounding)
self.c_quant = ChrominanceQuantizer(rounding)
self.y_dequant = LuminanceDequantizer()
self.c_dequant = LuminanceDequantizer()
self.idct = InverseDCT()
self.block_merge = BlockMerger()
self.chroma_up = ChromaUpsample()
self.rgb_convert = YCbCrToRGB()
def compress(self, x, factor):
"""Forward compression pass"""
y, cb, cr = self.chroma_down(self.color_convert(x * 255))
components = {'y': y, 'cb': cb, 'cr': cr}
for key in components:
coeffs = self.block_split(components[key])
coeffs = self.dct(coeffs)
if key in ('cb', 'cr'):
components[key] = self.c_quant(coeffs, factor)
else:
components[key] = self.y_quant(coeffs, factor)
return components['y'], components['cb'], components['cr']
def decompress(self, y, cb, cr, h, w, factor):
"""Forward decompression pass"""
components = {'y': y, 'cb': cb, 'cr': cr}
for key in components:
if key in ('cb', 'cr'):
coeffs = self.c_dequant(components[key], factor)
ch, cw = int(h / 2), int(w / 2)
else:
coeffs = self.y_dequant(components[key], factor)
ch, cw = h, w
components[key] = self.block_merge(self.idct(coeffs), ch, cw)
merged = self.chroma_up(components['y'],
components['cb'],
components['cr'])
rgb = self.rgb_convert(merged)
rgb = torch.clamp(rgb, 0, 255)
return rgb / 255
def forward(self, x, quality):
"""
Args:
x: Input image batch, shape (B, C, H, W), RGB, range [0, 1]
quality: JPEG quality factor (1-100)
"""
if isinstance(quality, (int, float)):
factor = quality_to_compression_factor(quality)
else:
factor = torch.stack([
quality_to_compression_factor(q)
for q in quality
])
c, h, w = x.shape[1:]
h_pad = (16 - h % 16) % 16
w_pad = (16 - w % 16) % 16
x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant')
y, cb, cr = self.compress(x, factor)
result = self.decompress(y, cb, cr, h + h_pad, w + w_pad, factor)
return result[:, :, :h, :w]
Usage Example
def img_to_tensor(img):
"""Convert numpy image to tensor"""
if img.ndim == 2:
img = np.expand_dims(img, 2)
tensor = torch.from_numpy(np.ascontiguousarray(img))
return tensor.permute(2, 0, 1).float()
def tensor_to_img(tensor):
"""Convert tensor back to numpy image"""
img = tensor.permute(1, 2, 0).cpu().numpy()
return np.clip(img * 255, 0, 255).astype(np.uint8)
# Load and process image
import cv2
img = cv2.imread('input.jpg') / 255.0
tensor_img = img_to_tensor(img).unsqueeze(0).cuda()
# Apply compression at different quality levels
compressor = DiffJPEG(differentiable=False).cuda()
qualities = torch.tensor([10, 50, 90]).cuda()
batch = tensor_img.repeat(len(qualities), 1, 1, 1)
compressed = compressor(batch, qualities)
for i, q in enumerate(qualities):
cv2.imwrite(f'compressed_q{int(q)}.png', tensor_to_img(compressed[i]))
Results Comparison
| Quality | File Size Reduction | Visual Quality |
|---|---|---|
| 90-100 | 30-50% | Excellent |
| 70-80 | 50-70% | Good |
| 50-60 | 70-85% | Moderate |
| 20-30 | 85-95% | Visible artifacts |