Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Binary Classification Loss Functions in PyTorch: BCEWithLogitsLoss vs. CrossEntropyLoss

Tech 1

When constructing a binary classification model in PyTorch, three primary configurations exist for the final layer, activation, and loss function: torch.nn.Linear + torch.sigmoid + torch.nn.BCELoss; torch.nn.Linear + torch.nn.BCEWithLogitsLoss; and torch.nn.Linear (with output dimension of 2) + torch.nn.CrossEntropyLoss. The key distinction is that BCEWithLogitsLoss incorporates a sigmoid operation internally, whereas CrossEntropyLoss incorporates a softmax operation.

Critical differences are detailed below:

  • Output and Target Shape Requirements: For a batch of size N and C classes, BCELoss and BCEWithLogitsLoss require the model output (logits) and target tensors to have the same shape (N, C). In contrast, CrossEntropyLoss expects logits of shape (N, C) but a target tensor of shape (N,) containing integer class indices.
  • Target Data Type: CrossEntropyLoss targets must be of type torch.LongTensor. For the BCE family (BCELoss, BCEWithLogitsLoss), targets are torch.FloatTensor with values between 0 and 1, representing probabilities.
  • Mathematical Formulation: The BCE losses compute target * log(sigmoid(logits)) + (1 - target) * log(1 - sigmoid(logits)) per element and sum/average. CrossEntropyLoss is essentially Softmax followed by NLLLoss. Its calculation for a sample is -logits[target_class] + log(sum(exp(logits[i]) for i in classes)), which originates from information theory concepts like cross-entropy and KL divergence.

Code Examples and Implementations

Understanding CrossEntropyLoss and NLLLoss:

import torch
import torch.nn as nn
import math

celoss = nn.CrossEntropyLoss(reduction='none')
logits = torch.randn(2, 3)  # 2 samples, 3 classes
targets = torch.LongTensor([0, 2])  # Class labels for the two samples
loss_val = celoss(logits, targets)
print(f'CrossEntropyLoss per sample: {loss_val}')

# Manual calculation
for idx in range(logits.shape[0]):
    sample_logits = logits[idx]
    class_idx = targets[idx].item()
    x_class = sample_logits[class_idx]
    log_sum_exp = math.log(sum(math.exp(val) for val in sample_logits))
    manual_loss = -x_class + log_sum_exp
    print(f'Sample {idx} manual loss: {manual_loss}')

# Demonstrating the equivalence to NLLLoss + LogSoftmax
nll = nn.NLLLoss(reduction='none')
log_softmax = nn.LogSoftmax(dim=-1)
log_probs = log_softmax(logits)
nll_loss_val = nll(log_probs, targets)
print(f'NLLLoss after LogSoftmax: {nll_loss_val}')

Custom Loss Function Implementations:

import torch
import torch.nn as nn
import torch.nn.functional as F

# A simplified Symmetric Cross Entropy (SCE) Loss example
class SymmetricCELoss(nn.Module):
    def __init__(self, num_classes, alpha=1.0, beta=1.0):
        super().__init__()
        self.num_classes = num_classes
        self.alpha = alpha
        self.beta = beta
        self.standard_cross_entropy = nn.CrossEntropyLoss()
        self.eps = 1e-8

    def forward(self, predictions, labels):
        # Standard Cross-Entropy (CE)
        ce_loss = self.standard_cross_entropy(predictions, labels)
        
        # Reverse Cross-Entropy (RCE)
        predicted_probs = F.softmax(predictions, dim=1)
        predicted_probs = torch.clamp(predicted_probs, min=self.eps, max=1.0)
        
        encoded_labels = F.one_hot(labels, self.num_classes).float()
        encoded_labels = torch.clamp(encoded_labels, min=self.eps, max=1.0)
        
        # RCE = - Σ p_predicted * log(p_target)
        rce_term = -torch.sum(predicted_probs * torch.log(encoded_labels), dim=1)
        rce_loss = rce_term.mean()
        
        # Combined loss: SCE = α * CE + β * RCE
        total_loss = self.alpha * ce_loss + self.beta * rce_loss
        return total_loss

# Usage
pred = torch.tensor([[10.0, 5.0, -6.0], [8.0, 8.0, 8.0]])
true_labels = torch.tensor([0, 2])
sceloss = SymmetricCELoss(num_classes=3)
loss_value = sceloss(pred, true_labels)
print(f'Symmetric CE Loss: {loss_value}')

Equivalent Implementations of CrossEntropyLoss:

import torch
import torch.nn as nn
import torch.nn.functional as F

# Implementation 1: Using LogSoftmax + NLLLoss
class CustomCE_v1(nn.Module):
    def __init__(self):
        super().__init__()
        self.log_softmax = nn.LogSoftmax(dim=-1)
        self.nll_loss = nn.NLLLoss(reduction='none')
    
    def forward(self, model_output, target_indices):
        log_probabilities = self.log_softmax(model_output)
        return self.nll_loss(log_probabilities, target_indices)

# Implementation 2: Manual calculation using one-hot encoding
class CustomCE_v2(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, model_output, target_indices):
        batch_size, num_classes = model_output.shape
        # Create one-hot encoding of targets
        target_one_hot = F.one_hot(target_indices, num_classes=num_classes).float()
        # Compute softmax probabilities and their log
        log_softmax_probs = torch.log_softmax(model_output, dim=-1)
        # Sum over classes: - Σ (target_one_hot * log_softmax_probs)
        per_sample_loss = -torch.sum(target_one_hot * log_softmax_probs, dim=1)
        return per_sample_loss

# Verification
logits_tensor = torch.rand(4, 3)
target_tensor = torch.LongTensor([1, 2, 1, 0])

ce1 = CustomCE_v1()
ce2 = CustomCE_v2()
ref_loss = nn.CrossEntropyLoss(reduction='none')

print('Custom v1:', ce1(logits_tensor, target_tensor))
print('Custom v2:', ce2(logits_tensor, target_tensor))
print('Reference:', ref_loss(logits_tensor, target_tensor))

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.