Skip to content

Instantly share code, notes, and snippets.

@avinab-neogy
Created October 29, 2025 14:23
Show Gist options
  • Select an option

  • Save avinab-neogy/70737b144a134f7d26fe30f14133add2 to your computer and use it in GitHub Desktop.

Select an option

Save avinab-neogy/70737b144a134f7d26fe30f14133add2 to your computer and use it in GitHub Desktop.
einops cheat sheet

Einops & PyTorch Tensor Operations Cheatsheet

A comprehensive guide to common tensor manipulation patterns using einops and PyTorch.


Table of Contents

  1. Einops Basics
  2. Reshaping & Rearranging
  3. Reduction Operations
  4. Broadcasting & Repeating
  5. Normalization Patterns
  6. Indexing & Gathering
  7. Classification & Metrics
  8. Probability & Sampling
  9. Advanced Operations
  10. Einsum Reference

Einops Basics

Import Statement

import torch as t
import einops
from einops import rearrange, reduce, repeat

Key Concepts

  • Axes: Named dimensions in your pattern (e.g., 'batch height width')
  • Composition: Group axes with parentheses (h w) to combine them
  • Decomposition: Split axes with values (h1 h2) where h = h1 * h2

Reshaping & Rearranging

Pattern: Flatten a Tensor

# 2D → 1D
flat = rearrange(matrix, 'h w -> (h w)')
# Example: (3, 4) → (12,)

# 3D → 1D
flat = rearrange(tensor, 'b h w -> (b h w)')
# Example: (2, 3, 4) → (24,)

Pattern: 1D to 2D Matrix

# Equal dimensions
matrix = rearrange(t.arange(9), '(h w) -> h w', h=3, w=3)
# [0,1,2,3,4,5,6,7,8] → [[0,1,2],[3,4,5],[6,7,8]]

# Variable dimensions
matrix = rearrange(t.arange(12), '(h w) -> h w', h=3)
# Auto-calculates w=4

Pattern: Transpose

# Simple transpose
transposed = rearrange(matrix, 'h w -> w h')

# Batch transpose
transposed = rearrange(batch, 'b h w -> b w h')

Pattern: Channel Reordering

# NCHW → NHWC (PyTorch to TensorFlow format)
nhwc = rearrange(images, 'n c h w -> n h w c')

# NHWC → NCHW
nchw = rearrange(images, 'n h w c -> n c h w')

Pattern: Stack Images in a Row

# (batch, channels, height, width) → (channels, height, batch*width)
stacked = rearrange(images, 'b c h w -> c h (b w)')

# Example: 6 images (28×28) → one row (168 pixels wide)

Pattern: Stack Images in a Grid

# (b*b, c, h, w) → (c, b*h, b*w)
grid = rearrange(images, '(b1 b2) c h w -> c (b1 h) (b2 w)', b1=2)

# Example: 4 images in 2×2 grid

Pattern: Patch Extraction

# Extract non-overlapping patches from image
patches = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=16, p2=16)

# Example: (3, 224, 224) → (196, 768) for 16×16 patches

Reduction Operations

Pattern: Sum Along Dimension

# Sum rows (across columns)
row_sums = reduce(matrix, 'h w -> h', 'sum')

# Sum columns (across rows)
col_sums = reduce(matrix, 'h w -> w', 'sum')

# Sum everything
total = reduce(matrix, 'h w -> ', 'sum')

Pattern: Mean/Average

# Average each row
row_means = reduce(matrix, 'h w -> h', 'mean')

# Batch-wise average
batch_avg = reduce(batch, 'b c h w -> b c', 'mean')

# Global average pooling
global_avg = reduce(features, 'b c h w -> b c', 'mean')

Pattern: Max/Min Operations

# Max per row
row_max = reduce(matrix, 'h w -> h', 'max')

# Max pooling (2×2)
pooled = reduce(image, 'c (h h2) (w w2) -> c h w', 'max', h2=2, w2=2)

Pattern: Standard Deviation

# Use PyTorch function in reduce
std_devs = reduce(data, '(weeks days) -> weeks', t.std, days=7)

Pattern: Group and Aggregate

# Weekly averages from daily data
weekly = reduce(temps, '(weeks days) -> weeks', 'mean', days=7)

# Hourly to daily
daily = reduce(hourly, '(days hours) -> days', 'mean', hours=24)

# Monthly from daily (30 days/month)
monthly = reduce(daily, '(months days) -> months', 'sum', days=30)

Broadcasting & Repeating

Pattern: Repeat Elements

# Repeat each element n times
repeated = repeat(vector, 'w -> (w n)', n=3)
# [1, 2, 3] → [1, 1, 1, 2, 2, 2, 3, 3, 3]

# Tile entire array
tiled = repeat(vector, 'w -> (n w)', n=3)
# [1, 2, 3] → [1, 2, 3, 1, 2, 3, 1, 2, 3]

Pattern: Broadcast to Higher Dimensions

# Add batch dimension
batched = repeat(image, 'c h w -> b c h w', b=32)

# Copy across channels
multi_channel = repeat(grayscale, 'h w -> c h w', c=3)

Pattern: Expand for Broadcasting

# Expand weekly stats to daily
daily_stats = repeat(weekly, 'weeks -> (weeks days)', days=7)

# Example: [21, 24] → [21,21,21,21,21,21,21, 24,24,24,24,24,24,24]

Pattern: Manual Broadcasting

# Prepare for element-wise operations
a = repeat(vector_a, 'n -> n 1')      # (n, 1)
b = repeat(vector_b, 'm -> 1 m')      # (1, m)
result = a * b                         # (n, m) - outer product

Normalization Patterns

Pattern: Row Normalization (L2)

def normalize_rows(matrix):
    """Each row has L2 norm = 1"""
    norms = matrix.norm(dim=1, keepdim=True)  # (m, 1)
    return matrix / norms

# Equivalent with einops
norms = reduce(matrix ** 2, 'm n -> m 1', 'sum') ** 0.5
normalized = matrix / norms

Pattern: Column Normalization

def normalize_cols(matrix):
    """Each column has L2 norm = 1"""
    norms = matrix.norm(dim=0, keepdim=True)  # (1, n)
    return matrix / norms

Pattern: Z-Score Normalization (Per Group)

# Normalize daily temps by weekly stats
def normalize_by_week(temps):
    """temps: (weeks*7,)"""
    avg = reduce(temps, '(w d) -> w', 'mean', d=7)
    std = reduce(temps, '(w d) -> w', t.std, d=7)
    
    # Expand to daily
    avg_daily = repeat(avg, 'w -> (w d)', d=7)
    std_daily = repeat(std, 'w -> (w d)', d=7)
    
    return (temps - avg_daily) / std_daily

Pattern: Batch Normalization

# Normalize across batch dimension
def batch_norm_manual(x):
    """x: (batch, features)"""
    mean = x.mean(dim=0, keepdim=True)
    std = x.std(dim=0, keepdim=True)
    return (x - mean) / (std + 1e-8)

Pattern: Layer Normalization

# Normalize across feature dimension
def layer_norm_manual(x):
    """x: (batch, features)"""
    mean = x.mean(dim=1, keepdim=True)
    std = x.std(dim=1, keepdim=True)
    return (x - mean) / (std + 1e-8)

Indexing & Gathering

Pattern: Basic Indexing

# Select specific rows
rows = matrix[indices]  # indices: (k,) → output: (k, n)

# Select specific elements per row
gathered = matrix[torch.arange(m), column_indices]
# column_indices: (m,) → output: (m,)

Pattern: torch.gather

# Gather from each row
# indices shape: (batch, k)
gathered = torch.gather(scores, dim=1, index=indices)

# Example: Get top-3 scores per sample
top3_indices = scores.topk(3, dim=1).indices  # (batch, 3)
top3_values = torch.gather(scores, 1, top3_indices)

Pattern: Select Diagonal Elements

# Get diagonal
diag = matrix.diagonal()

# Get offset diagonal
diag = matrix.diagonal(offset=1)

Pattern: Fancy Indexing with Broadcasting

# Create index grids
row_idx = torch.arange(m).unsqueeze(1)  # (m, 1)
col_idx = torch.arange(n).unsqueeze(0)  # (1, n)

# Index with conditions
mask = (row_idx + col_idx) % 2 == 0
selected = matrix[mask]

Pattern: Batch Indexing

# Select one element per batch item
batch_indices = torch.arange(batch_size)
class_indices = predictions.argmax(dim=1)
selected = logits[batch_indices, class_indices]

# Equivalent with gather
selected = torch.gather(logits, 1, class_indices.unsqueeze(1)).squeeze(1)

Classification & Metrics

Pattern: Argmax - Get Predictions

# Get predicted class (highest score)
predictions = scores.argmax(dim=1)  # (batch, n_classes) → (batch,)

# Get top-k predictions
topk_preds = scores.topk(k=5, dim=1)
# Returns: values and indices

Pattern: Accuracy

def accuracy(scores, targets):
    """
    scores: (batch, n_classes)
    targets: (batch,)
    """
    predictions = scores.argmax(dim=1)
    correct = (predictions == targets).float()
    return correct.mean().item()

Pattern: Top-K Accuracy

def topk_accuracy(scores, targets, k=5):
    """True if target in top-k predictions"""
    _, topk_preds = scores.topk(k, dim=1)  # (batch, k)
    targets_expanded = targets.unsqueeze(1).expand_as(topk_preds)
    correct = (topk_preds == targets_expanded).any(dim=1).float()
    return correct.mean().item()

Pattern: Confusion Matrix

def confusion_matrix(preds, targets, num_classes):
    """
    preds, targets: (batch,)
    Returns: (num_classes, num_classes)
    """
    cm = torch.zeros(num_classes, num_classes, dtype=torch.int64)
    for t, p in zip(targets, preds):
        cm[t, p] += 1
    return cm

# Vectorized version
def confusion_matrix_vectorized(preds, targets, num_classes):
    stacked = num_classes * targets + preds
    return torch.bincount(stacked, minlength=num_classes**2).reshape(num_classes, num_classes)

Pattern: Per-Class Accuracy

def per_class_accuracy(scores, targets):
    """Accuracy for each class separately"""
    predictions = scores.argmax(dim=1)
    num_classes = scores.shape[1]
    
    accuracies = []
    for c in range(num_classes):
        mask = (targets == c)
        if mask.sum() > 0:
            class_acc = (predictions[mask] == c).float().mean()
            accuracies.append(class_acc)
        else:
            accuracies.append(torch.tensor(float('nan')))
    
    return torch.stack(accuracies)

Probability & Sampling

Pattern: Softmax

# Manual softmax
def softmax(logits, dim=-1):
    exp_logits = torch.exp(logits - logits.max(dim=dim, keepdim=True).values)
    return exp_logits / exp_logits.sum(dim=dim, keepdim=True)

# Built-in
probs = torch.softmax(logits, dim=-1)

Pattern: Categorical Sampling

def sample_categorical(probs, n_samples):
    """
    probs: (k,) - probability distribution
    Returns: (n_samples,) - sampled indices
    """
    cumsum = probs.cumsum(dim=0)
    random_vals = torch.rand(n_samples, 1)
    samples = (random_vals > cumsum).sum(dim=1)
    return samples

Pattern: Gumbel-Softmax (Differentiable Sampling)

def gumbel_softmax(logits, temperature=1.0):
    """Differentiable sampling"""
    gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-10) + 1e-10)
    y = logits + gumbel_noise
    return torch.softmax(y / temperature, dim=-1)

Pattern: Multinomial Sampling

# Sample with replacement
samples = torch.multinomial(probs, num_samples=10, replacement=True)

# Sample without replacement
samples = torch.multinomial(probs, num_samples=5, replacement=False)

Advanced Operations

Pattern: Cosine Similarity Matrix

def cosine_similarity_matrix(matrix):
    """
    matrix: (m, n)
    Returns: (m, m) where entry (i,j) is cosine similarity
    """
    # Normalize rows
    norms = matrix.norm(dim=1, keepdim=True)
    normalized = matrix / norms
    
    # Compute all pairwise dot products
    return normalized @ normalized.T

Pattern: Pairwise Distance Matrix

def pairwise_distances(x):
    """
    x: (n, d)
    Returns: (n, n) matrix of Euclidean distances
    """
    # ||x - y||^2 = ||x||^2 + ||y||^2 - 2*x·y
    x_norm = (x ** 2).sum(dim=1, keepdim=True)
    y_norm = x_norm.T
    distances = x_norm + y_norm - 2 * x @ x.T
    return distances.sqrt()

Pattern: Batched Matrix Multiplication

# (batch, n, k) @ (batch, k, m) → (batch, n, m)
result = torch.bmm(A, B)

# With broadcasting: (batch, n, k) @ (k, m) → (batch, n, m)
result = A @ B

Pattern: Einstein Summation (einsum)

# Matrix multiplication
C = torch.einsum('ij,jk->ik', A, B)

# Batch matrix multiplication
C = torch.einsum('bij,bjk->bik', A, B)

# Outer product
outer = torch.einsum('i,j->ij', a, b)

# Dot product
dot = torch.einsum('i,i->', a, b)

# Trace
trace = torch.einsum('ii->', A)

Pattern: Cross Entropy Loss

def cross_entropy(logits, targets):
    """
    logits: (batch, n_classes)
    targets: (batch,) - class indices
    """
    log_probs = torch.log_softmax(logits, dim=1)
    return -log_probs[torch.arange(len(targets)), targets].mean()

# With label smoothing
def cross_entropy_smooth(logits, targets, smoothing=0.1):
    n_classes = logits.shape[1]
    log_probs = torch.log_softmax(logits, dim=1)
    
    with torch.no_grad():
        true_dist = torch.zeros_like(log_probs)
        true_dist.fill_(smoothing / (n_classes - 1))
        true_dist.scatter_(1, targets.unsqueeze(1), 1.0 - smoothing)
    
    return (-true_dist * log_probs).sum(dim=1).mean()

Pattern: Attention Mechanism

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q, K, V: (batch, seq_len, d_model)
    """
    d_k = Q.shape[-1]
    scores = Q @ K.transpose(-2, -1) / (d_k ** 0.5)
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    attn_weights = torch.softmax(scores, dim=-1)
    return attn_weights @ V, attn_weights

Einsum Reference

Common Patterns

Matrix Operations

# Matrix transpose
AT = einsum('ij->ji', A)

# Matrix-vector product
y = einsum('ij,j->i', A, x)

# Matrix-matrix product
C = einsum('ik,kj->ij', A, B)

# Hadamard (element-wise) product
C = einsum('ij,ij->ij', A, B)

Batch Operations

# Batch matrix multiply
C = einsum('bij,bjk->bik', A, B)

# Batch outer product
outer = einsum('bi,bj->bij', a, b)

Reductions

# Sum all elements
total = einsum('ij->', A)

# Sum rows
row_sums = einsum('ij->i', A)

# Sum columns
col_sums = einsum('ij->j', A)

# Trace
trace = einsum('ii->', A)

Advanced

# Bilinear form: x^T A y
result = einsum('i,ij,j->', x, A, y)

# Tensor contraction
result = einsum('ijk,jkl->il', A, B)

# Batch trace
traces = einsum('bii->b', batch_matrices)

Quick Reference Table

Operation Einops PyTorch Einsum
Flatten rearrange(x, 'h w -> (h w)') x.flatten() N/A
Transpose rearrange(x, 'h w -> w h') x.T einsum('ij->ji', x)
Sum rows reduce(x, 'h w -> h', 'sum') x.sum(dim=1) einsum('ij->i', x)
Mean reduce(x, 'h w -> h', 'mean') x.mean(dim=1) N/A
Matmul N/A A @ B einsum('ik,kj->ij', A, B)
Outer product einsum('i,j->ij', a, b) a.unsqueeze(1) * b einsum('i,j->ij', a, b)
Broadcast repeat(x, 'n -> n m', m=5) x.unsqueeze(1).expand(-1, 5) N/A

Common Tips

Dimension Matching

# Always check shapes match for operations
assert matrix.shape[1] == vector.shape[0], "Incompatible dimensions"

# Use keepdim=True for broadcasting
mean = x.mean(dim=1, keepdim=True)  # (n, 1) instead of (n,)
normalized = x / mean  # Broadcasting works!

Numerical Stability

# Bad: softmax can overflow
bad_probs = torch.exp(logits) / torch.exp(logits).sum()

# Good: subtract max first
logits_shifted = logits - logits.max(dim=-1, keepdim=True).values
probs = torch.exp(logits_shifted) / torch.exp(logits_shifted).sum(dim=-1, keepdim=True)

# Best: use built-in
probs = torch.softmax(logits, dim=-1)

Memory Efficiency

# Inefficient: creates intermediate tensors
result = (((x ** 2).sum() ** 0.5) / x.size(0)) * 2

# Better: chain operations
result = x.square().sum().sqrt().div(x.size(0)).mul(2)

# Best: use in-place when possible
x.mul_(2)  # In-place multiplication

Debugging Tips

# Print shapes at each step
print(f"Input shape: {x.shape}")
x = rearrange(x, 'b c h w -> b (c h w)')
print(f"After rearrange: {x.shape}")

# Check for NaN/Inf
assert not torch.isnan(x).any(), "NaN detected!"
assert not torch.isinf(x).any(), "Inf detected!"

# Validate probability distributions
assert torch.allclose(probs.sum(dim=-1), torch.ones(probs.shape[0]))

Practice Templates

Template: Aggregation Pattern

def aggregate_data(data, group_size, reduction='mean'):
    """
    data: 1D tensor of length n (divisible by group_size)
    Returns: aggregated values per group
    """
    return reduce(data, f'(groups {group_size}) -> groups', reduction)

Template: Normalization Pattern

def normalize_by_group(data, group_size):
    """Subtract group mean, divide by group std"""
    # Compute stats
    means = reduce(data, f'(g {group_size}) -> g', 'mean')
    stds = reduce(data, f'(g {group_size}) -> g', t.std)
    
    # Expand back
    means_expanded = repeat(means, f'g -> (g {group_size})')
    stds_expanded = repeat(stds, f'g -> (g {group_size})')
    
    # Normalize
    return (data - means_expanded) / stds_expanded

Template: Pairwise Operation

def pairwise_operation(vectors, operation='dot'):
    """Compute pairwise operations between all vectors"""
    if operation == 'dot':
        return vectors @ vectors.T
    elif operation == 'euclidean':
        sq_norms = (vectors ** 2).sum(dim=1, keepdim=True)
        distances = sq_norms + sq_norms.T - 2 * vectors @ vectors.T
        return distances.sqrt()
    elif operation == 'cosine':
        normalized = vectors / vectors.norm(dim=1, keepdim=True)
        return normalized @ normalized.T

Additional Resources


Created: October 2025 | Keep this handy for quick reference during tensor operations!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment