Skip to content

Instantly share code, notes, and snippets.

@KeAWang
Created May 12, 2025 01:00
Show Gist options
  • Select an option

  • Save KeAWang/719edff81d4fe1c11c125ef5b4933b38 to your computer and use it in GitHub Desktop.

Select an option

Save KeAWang/719edff81d4fe1c11c125ef5b4933b38 to your computer and use it in GitHub Desktop.
Indexing exercises
# %%
import numpy as np
"""Column-wise Sorting
Spec: sort every column of X using the pre-computed indices in idx
Desired result shape: (M, N)
"""
M, N = 6, 4
X = np.random.randn(M, N)
idx = np.argsort(X, axis=1) # shape (M, N)
# Answer: X_sorted[m, n] = X[m, idx[m, n]]
X_sorted = np.take_along_axis(X, idx, axis=-1)
assert X_sorted.shape == (M, N)
assert np.array_equal(X_sorted, X[np.arange(M)[:, None], idx])
# %%
""" Per batch point lookup
Spec: from each matrix in A, pick exactly one element using row_idx and col_idx
"""
B, N = 5, 7
A = np.random.randn(B, N, N)
row_idx = np.random.randint(0, N, size=B) # (B,)
col_idx = np.random.randint(0, N, size=B) # (B,)
# Answer: A_subset[b, i, j] = A[b, row_idx[i], col_idx[j]] = A[b, row_idx[b, 1, 1], col_idx[b, 1, 1]]
A_subset = A[np.arange(B)[:, None, None], row_idx[:, None, None], col_idx[:, None, None]]
assert A_subset.shape == (B, 1, 1)
# %%
""" Gather Arbitrary Timesteps
Spec: collect K timesteps per sequence
"""
B, T, D = 3, 10, 4
X = np.random.randn(B, T, D)
K = 6
timesteps = np.random.randint(0, T, size=(B, K))
# Answer: X_subset[b, t, d] = X[b, timesteps[b, t], d] = X[b, timesteps[b, t, None], d]
X_subset = np.take_along_axis(X, timesteps[:, :, None], axis=-2)
assert X_subset.shape == (B, K, D)
assert np.array_equal(X_subset, X[np.arange(B)[:, None, None], timesteps[:, :, None], np.arange(D)[None, None, :]])
# %%
""" Two-Level Table Lookup
# Spec: first choose rows with outer_idx, then choose columns with inner_idx
"""
N, D = 12, 8
X = np.arange(N * D).reshape(N, D)
B, M, K = 2, 4, 5
outer_idx = np.random.randint(0, N, size=(B, M)) # rows
inner_idx = np.random.randint(0, D, size=(B, M, K)) # columns
# Answer: outer[b, m, d] = X[outer_idx[b, m], d]
outer = X[outer_idx, :]
assert outer.shape == (B, M, D)
# Answer: inner[b, m, d] = outer[b, m, inner_idx[b, m, d]]
inner = np.take_along_axis(outer, inner_idx, axis=-1)
assert inner.shape == (B, M, K)
# %%
""" Batch diagonal
# Spec: get the diagonal of a batch of matrices
"""
B, N = 4, 6
X = np.random.randn(B, N, N)
# Answer X_diag[b, i] = X[b, i, i]
# i.e. X_diag[b, 1, i] = X[b, i, i]
X_diag = X[np.arange(B)[:, None, None], np.arange(N)[None, None, :], np.arange(N)[None, None, :]][:, 0, :]
assert X_diag.shape == (B, N)
# %%
""" Diagonal Replacement
# Spec: copy X into Y, but replace each matrix’s diagonal with v
"""
B, N = 3, 2
X = np.random.randn(B, N, N)
v = np.random.randn(B, N) # one value per diagonal position
# Answer: Y[b, i, i] = v[b, i];
# Y[b, I[b, None, i], I[b, None, i]] = v[b, None, i]
Y = np.copy(X)
Y[np.arange(B)[:, None, None], np.arange(N)[None, None, :], np.arange(N)[None, None, :]] = v[:, None, :]
for b in range(B):
assert np.array_equal(np.diagonal(Y[b]), v[b, :])
# %%
""" Reverse-and-tile sequence
Spec: reverse X along the time dimension and repeat the reversed block R times
"""
B, T, D = 2, 5, 3
R = 4
X = np.random.randn(B, T, D)
# Answer:
X_reversed = np.flip(X, axis=-2)
X_repeated = np.tile(X_reversed, (1, R, 1))
assert X_repeated.shape == (B, R * T, D)
# %%
""" Segment-wise Argmax Positions
Spec: for each batch, segments lists the start indices of contiguous segments in X
* **Segments encode boundaries.**
For each batch *i* (row in `segments`), the integers in `segments[i]` give the **starting indices** of *S* contiguous, non-overlapping time windows inside `X[i]`. They are sorted in ascending order and satisfy
`0 = segments[i,0] < segments[i,1] < … < segments[i,S-1] < T`.
* **Implicit end-points.**
A window runs from its start index up to—but **not including**—the next start index, except for the last window, which runs to `T` (the end of the sequence).
```
start = segments[i, s]
end = segments[i, s+1] # or T if s is the last segment
window = X[i, start:end]
```
* **Task.**
Within every such window, locate the position of the maximum value **in the original time coordinate system** (i.e. 0 … T-1). Collect those positions into `idx`, so that
```
idx.shape == (B, S)
# idx[i, s] is an integer in [0, T) giving the arg-max within window s of batch i
```
No need to return the max values themselves, just the positions.
"""
B, T = 3, 12
X = np.random.randn(B, T)
# each row i or segments contains the starting indices of each segment in sequence i
segments = np.array([[0, 4, 9],
[0, 5, 10],
[0, 3, 8]]) # shape (B, S) with S = 3
# Idea: argmax over each of the S segments, each represnted a masked window of the full length T segment
# Compute segment ends by appending T
ends = np.concatenate([segments[:, 1:], np.full((B, 1), T)], axis=1) # (B, S)
# Broadcast a time grid of shape (1, 1, T)
t = np.arange(T)[None, None, :] # (1, 1, T)
# start and end for each segment: (B, S, 1)
start = segments[:, :, None]
end = ends [:, :, None]
# Boolean mask that is True only inside a segment
mask = (t >= start) & (t < end) # (B, S, T)
# Expand X to (B, 1, T) then broadcast; outside the mask set to -inf
X_seg = np.where(mask, X[:, None, :], -np.inf) # (B, S, T)
# Arg-max over the time dimension; result is absolute positions
idx = np.argmax(X_seg, axis=-1) # (B, S)
""" One-Hot Scatter
# Spec: build a one-hot tensor Y from class_ids
"""
B, T = 3, 5
num_classes = 8
class_ids = np.random.randint(0, num_classes, size=(B, T))
# Answer: Y[b, t, class_ids[b, t, None]] = 1
Y = np.zeros((B, T, num_classes))
np.put_along_axis(Y, class_ids[:, :, None], values=1.,axis=-1)
# %%
""" Pairwise Distance Extraction
# Spec: gather distances for multiple (i, j) pairs
"""
N = 20
Dmat = np.random.rand(N, N) # symmetric matrix
B = 15
pairs = np.random.randint(0, N, size=(B, 2)) # (i, j) per pair
# Answer: dists[k] = Dmat[pairs[k, 0], pairs[k, 1]]
dists = Dmat[pairs[:, 0][None, :], pairs[:, 1]][0, :]
assert dists.shape == (B,)
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment