A comprehensive guide to common tensor manipulation patterns using einops and PyTorch.
- Einops Basics
- Reshaping & Rearranging
- Reduction Operations
- Broadcasting & Repeating
- Normalization Patterns
- Indexing & Gathering
- Classification & Metrics
- Probability & Sampling
- Advanced Operations
- Einsum Reference
import torch as t
import einops
from einops import rearrange, reduce, repeat- Axes: Named dimensions in your pattern (e.g.,
'batch height width') - Composition: Group axes with parentheses
(h w)to combine them - Decomposition: Split axes with values
(h1 h2)whereh = h1 * h2
# 2D → 1D
flat = rearrange(matrix, 'h w -> (h w)')
# Example: (3, 4) → (12,)
# 3D → 1D
flat = rearrange(tensor, 'b h w -> (b h w)')
# Example: (2, 3, 4) → (24,)# Equal dimensions
matrix = rearrange(t.arange(9), '(h w) -> h w', h=3, w=3)
# [0,1,2,3,4,5,6,7,8] → [[0,1,2],[3,4,5],[6,7,8]]
# Variable dimensions
matrix = rearrange(t.arange(12), '(h w) -> h w', h=3)
# Auto-calculates w=4# Simple transpose
transposed = rearrange(matrix, 'h w -> w h')
# Batch transpose
transposed = rearrange(batch, 'b h w -> b w h')# NCHW → NHWC (PyTorch to TensorFlow format)
nhwc = rearrange(images, 'n c h w -> n h w c')
# NHWC → NCHW
nchw = rearrange(images, 'n h w c -> n c h w')# (batch, channels, height, width) → (channels, height, batch*width)
stacked = rearrange(images, 'b c h w -> c h (b w)')
# Example: 6 images (28×28) → one row (168 pixels wide)# (b*b, c, h, w) → (c, b*h, b*w)
grid = rearrange(images, '(b1 b2) c h w -> c (b1 h) (b2 w)', b1=2)
# Example: 4 images in 2×2 grid# Extract non-overlapping patches from image
patches = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=16, p2=16)
# Example: (3, 224, 224) → (196, 768) for 16×16 patches# Sum rows (across columns)
row_sums = reduce(matrix, 'h w -> h', 'sum')
# Sum columns (across rows)
col_sums = reduce(matrix, 'h w -> w', 'sum')
# Sum everything
total = reduce(matrix, 'h w -> ', 'sum')# Average each row
row_means = reduce(matrix, 'h w -> h', 'mean')
# Batch-wise average
batch_avg = reduce(batch, 'b c h w -> b c', 'mean')
# Global average pooling
global_avg = reduce(features, 'b c h w -> b c', 'mean')# Max per row
row_max = reduce(matrix, 'h w -> h', 'max')
# Max pooling (2×2)
pooled = reduce(image, 'c (h h2) (w w2) -> c h w', 'max', h2=2, w2=2)# Use PyTorch function in reduce
std_devs = reduce(data, '(weeks days) -> weeks', t.std, days=7)# Weekly averages from daily data
weekly = reduce(temps, '(weeks days) -> weeks', 'mean', days=7)
# Hourly to daily
daily = reduce(hourly, '(days hours) -> days', 'mean', hours=24)
# Monthly from daily (30 days/month)
monthly = reduce(daily, '(months days) -> months', 'sum', days=30)# Repeat each element n times
repeated = repeat(vector, 'w -> (w n)', n=3)
# [1, 2, 3] → [1, 1, 1, 2, 2, 2, 3, 3, 3]
# Tile entire array
tiled = repeat(vector, 'w -> (n w)', n=3)
# [1, 2, 3] → [1, 2, 3, 1, 2, 3, 1, 2, 3]# Add batch dimension
batched = repeat(image, 'c h w -> b c h w', b=32)
# Copy across channels
multi_channel = repeat(grayscale, 'h w -> c h w', c=3)# Expand weekly stats to daily
daily_stats = repeat(weekly, 'weeks -> (weeks days)', days=7)
# Example: [21, 24] → [21,21,21,21,21,21,21, 24,24,24,24,24,24,24]# Prepare for element-wise operations
a = repeat(vector_a, 'n -> n 1') # (n, 1)
b = repeat(vector_b, 'm -> 1 m') # (1, m)
result = a * b # (n, m) - outer productdef normalize_rows(matrix):
"""Each row has L2 norm = 1"""
norms = matrix.norm(dim=1, keepdim=True) # (m, 1)
return matrix / norms
# Equivalent with einops
norms = reduce(matrix ** 2, 'm n -> m 1', 'sum') ** 0.5
normalized = matrix / normsdef normalize_cols(matrix):
"""Each column has L2 norm = 1"""
norms = matrix.norm(dim=0, keepdim=True) # (1, n)
return matrix / norms# Normalize daily temps by weekly stats
def normalize_by_week(temps):
"""temps: (weeks*7,)"""
avg = reduce(temps, '(w d) -> w', 'mean', d=7)
std = reduce(temps, '(w d) -> w', t.std, d=7)
# Expand to daily
avg_daily = repeat(avg, 'w -> (w d)', d=7)
std_daily = repeat(std, 'w -> (w d)', d=7)
return (temps - avg_daily) / std_daily# Normalize across batch dimension
def batch_norm_manual(x):
"""x: (batch, features)"""
mean = x.mean(dim=0, keepdim=True)
std = x.std(dim=0, keepdim=True)
return (x - mean) / (std + 1e-8)# Normalize across feature dimension
def layer_norm_manual(x):
"""x: (batch, features)"""
mean = x.mean(dim=1, keepdim=True)
std = x.std(dim=1, keepdim=True)
return (x - mean) / (std + 1e-8)# Select specific rows
rows = matrix[indices] # indices: (k,) → output: (k, n)
# Select specific elements per row
gathered = matrix[torch.arange(m), column_indices]
# column_indices: (m,) → output: (m,)# Gather from each row
# indices shape: (batch, k)
gathered = torch.gather(scores, dim=1, index=indices)
# Example: Get top-3 scores per sample
top3_indices = scores.topk(3, dim=1).indices # (batch, 3)
top3_values = torch.gather(scores, 1, top3_indices)# Get diagonal
diag = matrix.diagonal()
# Get offset diagonal
diag = matrix.diagonal(offset=1)# Create index grids
row_idx = torch.arange(m).unsqueeze(1) # (m, 1)
col_idx = torch.arange(n).unsqueeze(0) # (1, n)
# Index with conditions
mask = (row_idx + col_idx) % 2 == 0
selected = matrix[mask]# Select one element per batch item
batch_indices = torch.arange(batch_size)
class_indices = predictions.argmax(dim=1)
selected = logits[batch_indices, class_indices]
# Equivalent with gather
selected = torch.gather(logits, 1, class_indices.unsqueeze(1)).squeeze(1)# Get predicted class (highest score)
predictions = scores.argmax(dim=1) # (batch, n_classes) → (batch,)
# Get top-k predictions
topk_preds = scores.topk(k=5, dim=1)
# Returns: values and indicesdef accuracy(scores, targets):
"""
scores: (batch, n_classes)
targets: (batch,)
"""
predictions = scores.argmax(dim=1)
correct = (predictions == targets).float()
return correct.mean().item()def topk_accuracy(scores, targets, k=5):
"""True if target in top-k predictions"""
_, topk_preds = scores.topk(k, dim=1) # (batch, k)
targets_expanded = targets.unsqueeze(1).expand_as(topk_preds)
correct = (topk_preds == targets_expanded).any(dim=1).float()
return correct.mean().item()def confusion_matrix(preds, targets, num_classes):
"""
preds, targets: (batch,)
Returns: (num_classes, num_classes)
"""
cm = torch.zeros(num_classes, num_classes, dtype=torch.int64)
for t, p in zip(targets, preds):
cm[t, p] += 1
return cm
# Vectorized version
def confusion_matrix_vectorized(preds, targets, num_classes):
stacked = num_classes * targets + preds
return torch.bincount(stacked, minlength=num_classes**2).reshape(num_classes, num_classes)def per_class_accuracy(scores, targets):
"""Accuracy for each class separately"""
predictions = scores.argmax(dim=1)
num_classes = scores.shape[1]
accuracies = []
for c in range(num_classes):
mask = (targets == c)
if mask.sum() > 0:
class_acc = (predictions[mask] == c).float().mean()
accuracies.append(class_acc)
else:
accuracies.append(torch.tensor(float('nan')))
return torch.stack(accuracies)# Manual softmax
def softmax(logits, dim=-1):
exp_logits = torch.exp(logits - logits.max(dim=dim, keepdim=True).values)
return exp_logits / exp_logits.sum(dim=dim, keepdim=True)
# Built-in
probs = torch.softmax(logits, dim=-1)def sample_categorical(probs, n_samples):
"""
probs: (k,) - probability distribution
Returns: (n_samples,) - sampled indices
"""
cumsum = probs.cumsum(dim=0)
random_vals = torch.rand(n_samples, 1)
samples = (random_vals > cumsum).sum(dim=1)
return samplesdef gumbel_softmax(logits, temperature=1.0):
"""Differentiable sampling"""
gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-10) + 1e-10)
y = logits + gumbel_noise
return torch.softmax(y / temperature, dim=-1)# Sample with replacement
samples = torch.multinomial(probs, num_samples=10, replacement=True)
# Sample without replacement
samples = torch.multinomial(probs, num_samples=5, replacement=False)def cosine_similarity_matrix(matrix):
"""
matrix: (m, n)
Returns: (m, m) where entry (i,j) is cosine similarity
"""
# Normalize rows
norms = matrix.norm(dim=1, keepdim=True)
normalized = matrix / norms
# Compute all pairwise dot products
return normalized @ normalized.Tdef pairwise_distances(x):
"""
x: (n, d)
Returns: (n, n) matrix of Euclidean distances
"""
# ||x - y||^2 = ||x||^2 + ||y||^2 - 2*x·y
x_norm = (x ** 2).sum(dim=1, keepdim=True)
y_norm = x_norm.T
distances = x_norm + y_norm - 2 * x @ x.T
return distances.sqrt()# (batch, n, k) @ (batch, k, m) → (batch, n, m)
result = torch.bmm(A, B)
# With broadcasting: (batch, n, k) @ (k, m) → (batch, n, m)
result = A @ B# Matrix multiplication
C = torch.einsum('ij,jk->ik', A, B)
# Batch matrix multiplication
C = torch.einsum('bij,bjk->bik', A, B)
# Outer product
outer = torch.einsum('i,j->ij', a, b)
# Dot product
dot = torch.einsum('i,i->', a, b)
# Trace
trace = torch.einsum('ii->', A)def cross_entropy(logits, targets):
"""
logits: (batch, n_classes)
targets: (batch,) - class indices
"""
log_probs = torch.log_softmax(logits, dim=1)
return -log_probs[torch.arange(len(targets)), targets].mean()
# With label smoothing
def cross_entropy_smooth(logits, targets, smoothing=0.1):
n_classes = logits.shape[1]
log_probs = torch.log_softmax(logits, dim=1)
with torch.no_grad():
true_dist = torch.zeros_like(log_probs)
true_dist.fill_(smoothing / (n_classes - 1))
true_dist.scatter_(1, targets.unsqueeze(1), 1.0 - smoothing)
return (-true_dist * log_probs).sum(dim=1).mean()def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q, K, V: (batch, seq_len, d_model)
"""
d_k = Q.shape[-1]
scores = Q @ K.transpose(-2, -1) / (d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = torch.softmax(scores, dim=-1)
return attn_weights @ V, attn_weights# Matrix transpose
AT = einsum('ij->ji', A)
# Matrix-vector product
y = einsum('ij,j->i', A, x)
# Matrix-matrix product
C = einsum('ik,kj->ij', A, B)
# Hadamard (element-wise) product
C = einsum('ij,ij->ij', A, B)# Batch matrix multiply
C = einsum('bij,bjk->bik', A, B)
# Batch outer product
outer = einsum('bi,bj->bij', a, b)# Sum all elements
total = einsum('ij->', A)
# Sum rows
row_sums = einsum('ij->i', A)
# Sum columns
col_sums = einsum('ij->j', A)
# Trace
trace = einsum('ii->', A)# Bilinear form: x^T A y
result = einsum('i,ij,j->', x, A, y)
# Tensor contraction
result = einsum('ijk,jkl->il', A, B)
# Batch trace
traces = einsum('bii->b', batch_matrices)| Operation | Einops | PyTorch | Einsum |
|---|---|---|---|
| Flatten | rearrange(x, 'h w -> (h w)') |
x.flatten() |
N/A |
| Transpose | rearrange(x, 'h w -> w h') |
x.T |
einsum('ij->ji', x) |
| Sum rows | reduce(x, 'h w -> h', 'sum') |
x.sum(dim=1) |
einsum('ij->i', x) |
| Mean | reduce(x, 'h w -> h', 'mean') |
x.mean(dim=1) |
N/A |
| Matmul | N/A | A @ B |
einsum('ik,kj->ij', A, B) |
| Outer product | einsum('i,j->ij', a, b) |
a.unsqueeze(1) * b |
einsum('i,j->ij', a, b) |
| Broadcast | repeat(x, 'n -> n m', m=5) |
x.unsqueeze(1).expand(-1, 5) |
N/A |
# Always check shapes match for operations
assert matrix.shape[1] == vector.shape[0], "Incompatible dimensions"
# Use keepdim=True for broadcasting
mean = x.mean(dim=1, keepdim=True) # (n, 1) instead of (n,)
normalized = x / mean # Broadcasting works!# Bad: softmax can overflow
bad_probs = torch.exp(logits) / torch.exp(logits).sum()
# Good: subtract max first
logits_shifted = logits - logits.max(dim=-1, keepdim=True).values
probs = torch.exp(logits_shifted) / torch.exp(logits_shifted).sum(dim=-1, keepdim=True)
# Best: use built-in
probs = torch.softmax(logits, dim=-1)# Inefficient: creates intermediate tensors
result = (((x ** 2).sum() ** 0.5) / x.size(0)) * 2
# Better: chain operations
result = x.square().sum().sqrt().div(x.size(0)).mul(2)
# Best: use in-place when possible
x.mul_(2) # In-place multiplication# Print shapes at each step
print(f"Input shape: {x.shape}")
x = rearrange(x, 'b c h w -> b (c h w)')
print(f"After rearrange: {x.shape}")
# Check for NaN/Inf
assert not torch.isnan(x).any(), "NaN detected!"
assert not torch.isinf(x).any(), "Inf detected!"
# Validate probability distributions
assert torch.allclose(probs.sum(dim=-1), torch.ones(probs.shape[0]))def aggregate_data(data, group_size, reduction='mean'):
"""
data: 1D tensor of length n (divisible by group_size)
Returns: aggregated values per group
"""
return reduce(data, f'(groups {group_size}) -> groups', reduction)def normalize_by_group(data, group_size):
"""Subtract group mean, divide by group std"""
# Compute stats
means = reduce(data, f'(g {group_size}) -> g', 'mean')
stds = reduce(data, f'(g {group_size}) -> g', t.std)
# Expand back
means_expanded = repeat(means, f'g -> (g {group_size})')
stds_expanded = repeat(stds, f'g -> (g {group_size})')
# Normalize
return (data - means_expanded) / stds_expandeddef pairwise_operation(vectors, operation='dot'):
"""Compute pairwise operations between all vectors"""
if operation == 'dot':
return vectors @ vectors.T
elif operation == 'euclidean':
sq_norms = (vectors ** 2).sum(dim=1, keepdim=True)
distances = sq_norms + sq_norms.T - 2 * vectors @ vectors.T
return distances.sqrt()
elif operation == 'cosine':
normalized = vectors / vectors.norm(dim=1, keepdim=True)
return normalized @ normalized.T- Einops Documentation: https://einops.rocks/
- PyTorch Docs: https://pytorch.org/docs/
- Einsum Guide: https://ajcr.net/Basic-guide-to-einsum/
Created: October 2025 | Keep this handy for quick reference during tensor operations!