Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Implementation of Convolution Kernels for Convolutional Neural Networks

Tech 1

2D Cross-Correlation Calculatoin

import torch
from torch import nn

def compute_2d_cross_corr(input_tensor, kernel):
    """Execute 2D cross-correlation operation"""
    kernel_h, kernel_w = kernel.shape
    output_h = input_tensor.shape[0] - kernel_h + 1
    output_w = input_tensor.shape[1] - kernel_w + 1
    output = torch.zeros((output_h, output_w))
    for row in range(output_h):
        for col in range(output_w):
            output[row, col] = (input_tensor[row:row + kernel_h, col:col + kernel_w] * kernel).sum()
    return output

Example Usage

test_input = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])
test_kernel = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
compute_2d_cross_corr(test_input, test_kernel)

Output:

tensor([[19., 25.],
        [37., 43.]])

Custom Convolution Layer Implementation

class CustomConv2D(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        self.weight = nn.Parameter(torch.rand(kernel_size))
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        return compute_2d_cross_corr(x, self.weight) + self.bias

Image Edge Detection Demo

First, create a 6×8 grayscale test image:

edge_test_img = torch.ones((6, 8))
edge_test_img[:, 2:6] = 0
edge_test_img

Output:

tensor([[1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.]])

Define a simple vertical edge detection kernel:

edge_kernel = torch.tensor([[1.0, -1.0]])

Run the cross-correlation operation:

edge_detection_result = compute_2d_cross_corr(edge_test_img, edge_kernel)
edge_detection_result

Output:

tensor([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.]])

Positive values mark transitions from white (1) to black (0), while negative values mark transitions from black (0) to white (1).

Learning a Convolution Kernel

We can train a model to automatically learn the optimal convolution kernel from input and target output pairs.

# Initialize a 2D convolution layer with 1 input channel, 1 output channel, (1,2) kernel size, no bias
conv_model = nn.Conv2d(1, 1, kernel_size=(1, 2), bias=False)

# Reshape inputs and targets to PyTorch's 4D format: (batch, channels, height, width)
reshaped_input = edge_test_img.reshape((1, 1, 6, 8))
reshaped_target = edge_detection_result.reshape((1, 1, 6, 7))
learning_rate = 3e-2

# Training loop for 10 epochs
for epoch in range(10):
    predicted_output = conv_model(reshaped_input)
    loss = (predicted_output - reshaped_target) ** 2
    conv_model.zero_grad()
    loss.sum().backward()
    # Update kernel weights
    conv_model.weight.data[:] -= learning_rate * conv_model.weight.grad
    # Print loss every 2 epochs
    if (epoch + 1) % 2 == 0:
        print(f'epoch {epoch+1}, loss {loss.sum():.3f}')

Training output:

epoch 2, loss 11.296
epoch 4, loss 1.912
epoch 6, loss 0.328
epoch 8, loss 0.058
epoch 10, loss 0.011

Inspect the learned kernel weights:

conv_model.weight.data.reshape((1, 2))

Output:

tensor([[ 0.9871, -0.9780]])

The learned kernel is very close to the predefined edge_kernel.

Key Takeaways

  • The core computation of 2D convolusional layers is 2D cross-correlation, typically followed by adding a bias term.
  • Predefined convolution kernels can detect specific image features such as edges.
  • Convolution kernel parameters can be automatically learned from training data using standard backpropagation.
  • Deeper convolutional networks can be constructed to detect a broader range of input features.

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.