Binary Classification Loss Functions in PyTorch: BCEWithLogitsLoss vs. CrossEntropyLoss
When constructing a binary classification model in PyTorch, three primary configurations exist for the final layer, activation, and loss function: torch.nn.Linear + torch.sigmoid + torch.nn.BCELoss; torch.nn.Linear + torch.nn.BCEWithLogitsLoss; and torch.nn.Linear (with output dimension of 2) + torch.nn.CrossEntropyLoss. The key distinction is that BCEWithLogitsLoss incorporates a sigmoid operation internally, whereas CrossEntropyLoss incorporates a softmax operation.
Critical differences are detailed below:
- Output and Target Shape Requirements: For a batch of size
NandCclasses,BCELossandBCEWithLogitsLossrequire the model output (logits) and target tensors to have the same shape(N, C). In contrast,CrossEntropyLossexpects logits of shape(N, C)but a target tensor of shape(N,)containing integer class indices. - Target Data Type:
CrossEntropyLosstargets must be of typetorch.LongTensor. For the BCE family (BCELoss,BCEWithLogitsLoss), targets aretorch.FloatTensorwith values between 0 and 1, representing probabilities. - Mathematical Formulation: The BCE losses compute
target * log(sigmoid(logits)) + (1 - target) * log(1 - sigmoid(logits))per element and sum/average.CrossEntropyLossis essentiallySoftmaxfollowed byNLLLoss. Its calculation for a sample is-logits[target_class] + log(sum(exp(logits[i]) for i in classes)), which originates from information theory concepts like cross-entropy and KL divergence.
Code Examples and Implementations
Understanding CrossEntropyLoss and NLLLoss:
import torch
import torch.nn as nn
import math
celoss = nn.CrossEntropyLoss(reduction='none')
logits = torch.randn(2, 3) # 2 samples, 3 classes
targets = torch.LongTensor([0, 2]) # Class labels for the two samples
loss_val = celoss(logits, targets)
print(f'CrossEntropyLoss per sample: {loss_val}')
# Manual calculation
for idx in range(logits.shape[0]):
sample_logits = logits[idx]
class_idx = targets[idx].item()
x_class = sample_logits[class_idx]
log_sum_exp = math.log(sum(math.exp(val) for val in sample_logits))
manual_loss = -x_class + log_sum_exp
print(f'Sample {idx} manual loss: {manual_loss}')
# Demonstrating the equivalence to NLLLoss + LogSoftmax
nll = nn.NLLLoss(reduction='none')
log_softmax = nn.LogSoftmax(dim=-1)
log_probs = log_softmax(logits)
nll_loss_val = nll(log_probs, targets)
print(f'NLLLoss after LogSoftmax: {nll_loss_val}')
Custom Loss Function Implementations:
import torch
import torch.nn as nn
import torch.nn.functional as F
# A simplified Symmetric Cross Entropy (SCE) Loss example
class SymmetricCELoss(nn.Module):
def __init__(self, num_classes, alpha=1.0, beta=1.0):
super().__init__()
self.num_classes = num_classes
self.alpha = alpha
self.beta = beta
self.standard_cross_entropy = nn.CrossEntropyLoss()
self.eps = 1e-8
def forward(self, predictions, labels):
# Standard Cross-Entropy (CE)
ce_loss = self.standard_cross_entropy(predictions, labels)
# Reverse Cross-Entropy (RCE)
predicted_probs = F.softmax(predictions, dim=1)
predicted_probs = torch.clamp(predicted_probs, min=self.eps, max=1.0)
encoded_labels = F.one_hot(labels, self.num_classes).float()
encoded_labels = torch.clamp(encoded_labels, min=self.eps, max=1.0)
# RCE = - Σ p_predicted * log(p_target)
rce_term = -torch.sum(predicted_probs * torch.log(encoded_labels), dim=1)
rce_loss = rce_term.mean()
# Combined loss: SCE = α * CE + β * RCE
total_loss = self.alpha * ce_loss + self.beta * rce_loss
return total_loss
# Usage
pred = torch.tensor([[10.0, 5.0, -6.0], [8.0, 8.0, 8.0]])
true_labels = torch.tensor([0, 2])
sceloss = SymmetricCELoss(num_classes=3)
loss_value = sceloss(pred, true_labels)
print(f'Symmetric CE Loss: {loss_value}')
Equivalent Implementations of CrossEntropyLoss:
import torch
import torch.nn as nn
import torch.nn.functional as F
# Implementation 1: Using LogSoftmax + NLLLoss
class CustomCE_v1(nn.Module):
def __init__(self):
super().__init__()
self.log_softmax = nn.LogSoftmax(dim=-1)
self.nll_loss = nn.NLLLoss(reduction='none')
def forward(self, model_output, target_indices):
log_probabilities = self.log_softmax(model_output)
return self.nll_loss(log_probabilities, target_indices)
# Implementation 2: Manual calculation using one-hot encoding
class CustomCE_v2(nn.Module):
def __init__(self):
super().__init__()
def forward(self, model_output, target_indices):
batch_size, num_classes = model_output.shape
# Create one-hot encoding of targets
target_one_hot = F.one_hot(target_indices, num_classes=num_classes).float()
# Compute softmax probabilities and their log
log_softmax_probs = torch.log_softmax(model_output, dim=-1)
# Sum over classes: - Σ (target_one_hot * log_softmax_probs)
per_sample_loss = -torch.sum(target_one_hot * log_softmax_probs, dim=1)
return per_sample_loss
# Verification
logits_tensor = torch.rand(4, 3)
target_tensor = torch.LongTensor([1, 2, 1, 0])
ce1 = CustomCE_v1()
ce2 = CustomCE_v2()
ref_loss = nn.CrossEntropyLoss(reduction='none')
print('Custom v1:', ce1(logits_tensor, target_tensor))
print('Custom v2:', ce2(logits_tensor, target_tensor))
print('Reference:', ref_loss(logits_tensor, target_tensor))