Skip to content

Instantly share code, notes, and snippets.

@nilsleh
Created February 6, 2026 01:28
Show Gist options
  • Select an option

  • Save nilsleh/a4803ca9a8e8c1675c2fc7dd7fecc48a to your computer and use it in GitHub Desktop.

Select an option

Save nilsleh/a4803ca9a8e8c1675c2fc7dd7fecc48a to your computer and use it in GitHub Desktop.
Sparse Satellite-Track interpolation
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
TARGET_SIZE = (128, 128)
RAW_SHAPE = (180, 150)
def sparse_resize(data, size, mode="bilinear"):
"""Resize sparse data preserving signal magnitude."""
mask = torch.isfinite(data).float()
data_zeroed = torch.nan_to_num(data, nan=0.0)
data_resized = F.interpolate(data_zeroed, size=size, mode=mode, align_corners=False)
mask_resized = F.interpolate(mask, size=size, mode=mode, align_corners=False)
output = data_resized / (mask_resized + 1e-8)
output[mask_resized < 1e-4] = float("nan")
return output
def adaptive_nan_pool(data, output_size):
"""Downsample by averaging valid pixels."""
mask = torch.isfinite(data).float()
data_zeroed = torch.nan_to_num(data, nan=0.0)
data_pooled = F.adaptive_avg_pool2d(data_zeroed, output_size)
mask_pooled = F.adaptive_avg_pool2d(mask, output_size)
output = data_pooled / (mask_pooled + 1e-8)
output[mask_pooled < 0.05] = float("nan")
return output
def make_dummy_track(shape=RAW_SHAPE, n_tracks=10, track_width=0.5, seed=7):
"""Generate dummy sparse satellite track data."""
rng = np.random.default_rng(seed)
h, w = shape
yy, xx = np.mgrid[0:h, 0:w]
# Background field
field = (0.6 * np.sin(2 * np.pi * xx / w) +
0.4 * np.cos(2 * np.pi * yy / h) +
0.2 * np.sin(2 * np.pi * (xx + yy) / (0.75 * w)))
# Create sparse tracks
data = np.full((h, w), np.nan, dtype=np.float32)
mask = np.zeros((h, w), dtype=bool)
x_positions = rng.uniform(0, w, size=n_tracks)
tilts = rng.uniform(-0.25, 0.25, size=n_tracks)
for x0, tilt in zip(x_positions, tilts):
x_line = x0 + tilt * (yy - h / 2.0)
mask |= np.abs(xx - x_line) <= track_width
data[mask] = field[mask] + rng.normal(0, 0.04, size=field.shape)[mask]
return torch.from_numpy(data).unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
def run_demo():
"""Generate comparison plot of interpolation methods."""
data = make_dummy_track()
coverage = torch.isfinite(data).float().mean()
methods = {
f"Original\n{RAW_SHAPE}": data,
f"Nearest\n→{TARGET_SIZE}": F.interpolate(data, TARGET_SIZE, mode="nearest"),
"Bilinear (NaNs)\nFailure": F.interpolate(data, TARGET_SIZE, mode="bilinear", align_corners=False),
"Sparse Resize\nMask-weighted": sparse_resize(data, TARGET_SIZE),
"Adaptive Pool\nValid avg": adaptive_nan_pool(data, TARGET_SIZE)
}
# Plot
fig, axes = plt.subplots(1, 5, figsize=(26, 10))
valid_data = data[torch.isfinite(data)]
vmin, vmax = (valid_data.mean() - 2.5 * valid_data.std()).item(), (valid_data.mean() + 2.5 * valid_data.std()).item()
cmap = plt.cm.RdBu_r
cmap.set_bad("gray", 0.5)
for ax, (title, result) in zip(axes, methods.items()):
img = result[0, 0].detach().cpu().numpy()
im = ax.imshow(img, origin="lower", cmap=cmap, vmin=vmin, vmax=vmax)
ax.set_title(title, fontsize=20)
ax.axis("off")
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cov = torch.isfinite(result).float().mean()
ax.text(0.5, -0.1, f"Coverage: {cov:.2%}", transform=ax.transAxes, ha="center", fontsize=12)
plt.suptitle("Sparse Altimetry Track Interpolation Methods", fontsize=16)
plt.tight_layout()
plt.savefig("dummy_interpolation.png", dpi=200, bbox_inches="tight")
print("Saved dummy_interpolation.png")
plt.close()
if __name__ == "__main__":
run_demo()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment