Tensor Manipulation in PyTorch: Splitting, Expanding, and Modifying Operations
Introduction
This article covers tensor manipulation operations in PyTorch, including splitting (split, unbind, chunk), expanding (repeat, cat, stack), and modifying (using indexing and slicing, gather, scatter).
Experimental Enviroment
This series of experiments uses the following environment setup:
conda create -n DL python==3.11
conda activate DL
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
PyTorch Data Structures
Tensor
A Tensor is the primary data structure in PyTorch for representing multi-dimensional data, similar to multi-dimensional arrays, used for storing and manipulating numerical data.
Dimensions
Tensor dimensions refer to the number of axes or rank. In PyTorch, use the size() method to get dimension information and dim() to get the number of axes.

Data Types
PyTorch tensors support various data types:
torch.float32ortorch.float: 32-bit floating-point tensor.torch.float64ortorch.double: 64-bit floating-point tensor.torch.float16ortorch.half: 16-bit floating-point tensor.torch.int8: 8-bit integer tensor.torch.int16ortorch.short: 16-bit integer tensor.torch.int32ortorch.int: 32-bit integer tensor.torch.int64ortorch.long: 64-bit integer tensor.torch.bool: Boolean tensor storing True or False.
GPU Acceleration
Tensors can leverage GPU acceleration for parallel computation, speeding up model training.
Tensor Mathematical Operations
PyTorch provides extensive functions for tensor operations, including mathematical computations, statistical calculations, reshaping, indexing, and slicing. These operations efficiently utilize GPU parallelism.
Vector Operations
Covers addition, subtraction, multiplication, division, scalar multiplication, dot product, outer product, norms, and broadcasting.
Matrix Operations
Includes basic operations, transpose, determinant, trace, adjugate matrix, inverse, eigenvalues, and eigenvectors.
Vector Norms, Matrix Norms, and Spectral Radius
Explains various norms (0, 1, 2, p, infinity) and spectral radius.
1D Convolution Operations
Discusses stride, padding, wide/narrow/equal-width convolutions, and correlation vs. convolution.
2D Convolution Operations
Covers the mathematical principles of 2D convolution.
High-Dimensional Tensors
Explores multiplication and convolution operations for 4D and 5D tensors.
Tensor Statistical Calculations
Detailed explanations of statistical computations on tensors.
Tensor Operations
Tensor Reshaping
Operations for changing tensor shape.
Indexing and Slicing
Methods to access specific elements or subsets of tensors.
Tensor Modification
Splitting Tensors
split
Splits a tensor along a specified dimension into multiple tensors.
import torch
matrix = torch.tensor([[1, 2, 3], [4, 5, 6]])
part1, part2 = matrix.split(2, dim=1)
print(part1) # Output: tensor([[1, 2], [4, 5]])
print(part2) # Output: tensor([[3], [6]])
unbind
Splits a tensor along a specified dimension and returns a list of tensors.
import torch
matrix = torch.tensor([[1, 2, 3], [4, 5, 6]])
row1, row2 = matrix.unbind(dim=0)
print(row1) # Output: tensor([1, 2, 3])
print(row2) # Output: tensor([4, 5, 6])

chunk
Evenly splits a tensor along a specified dimension into multiple tensors.
import torch
vector = torch.tensor([[1, 2, 3, 4, 5, 6]])
segments = vector.chunk(3, dim=1)
for segment in segments:
print(segment) # Output: tensor([[1, 2]]), tensor([[3, 4]]), tensor([[5, 6]])

Expanding Tensors
repeat
Replicates tensor elements by repeating them.
import torch
matrix = torch.tensor([[1, 2, 3], [4, 5, 6]])
repeated_horizontally = matrix.repeat(1, 2)
print(repeated_horizontally) # Output: tensor([[1, 2, 3, 1, 2, 3], [4, 5, 6, 4, 5, 6]])
repeated_both = matrix.repeat(2, 2)
print(repeated_both) # Output: tensor([[1, 2, 3, 1, 2, 3], [4, 5, 6, 4, 5, 6], [1, 2, 3, 1, 2, 3], [4, 5, 6, 4, 5, 6]])

cat
Concatenates multiple tensors along a specified dimension.
import torch
matrix_a = torch.tensor([[1, 2, 3], [4, 5, 6]])
matrix_b = torch.tensor([[7, 8, 9], [10, 11, 12]])
concatenated = torch.cat((matrix_a, matrix_b), dim=0)
print(concatenated) # Output: tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
stack
Stacks multiple tensors along a new dimension.
import torch
matrix_a = torch.tensor([[1, 2, 3], [4, 5, 6]])
matrix_b = torch.tensor([[7, 8, 9], [10, 11, 12]])
stacked = torch.stack((matrix_a, matrix_b), dim=0)
print(stacked) # Output: tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])

Modifying Tensors
Using Indexing and Slicing for Modification
Modify specific elements or subsets of a tensor using indexing and slicing.
import torch
matrix = torch.tensor([[1, 2, 3], [4, 5, 6]])
matrix[0, 1] = 9 # Modify element at row 0, column 1 to 9
print(matrix) # Output: tensor([[1, 9, 3], [4, 5, 6]])
gather
Collects values from an input tensor along a specified dimension based on indices.
import torch
matrix = torch.tensor([[1, 2, 3], [4, 5, 6]])
index_tensor = torch.tensor([[0, 0, 1], [1, 0, 0]])
collected = torch.gather(matrix, 1, index_tensor)
print(collected) # Output: tensor([[1, 1, 2], [5, 4, 4]])
scatter
Scatters values into a new tensor based on specified indices.
import torch
empty_matrix = torch.zeros(2, 4)
index_tensor = torch.tensor([[0, 1], [2, 3]])
value_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.float)
scattered = empty_matrix.scatter(1, index_tensor, value_tensor)
print(scattered) # Output: tensor([[1., 2., 0., 0.], [0., 0., 3., 4.]])