Skip to content

Instantly share code, notes, and snippets.

@calebrob6
Last active September 22, 2025 19:09
Show Gist options
  • Select an option

  • Save calebrob6/78625123f1e6eb9d44f165b914b34b9e to your computer and use it in GitHub Desktop.

Select an option

Save calebrob6/78625123f1e6eb9d44f165b914b34b9e to your computer and use it in GitHub Desktop.
Create embeddings from an input GeoTIFF using DINOv3 and save results as an output GeoTIFF.
import argparse
import math
import os
import time
from typing import List, Optional, Sequence, Tuple
import numpy as np
import rasterio
import rasterio.windows
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.decomposition import PCA
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, Sampler
from torchvision import transforms
from tqdm import tqdm
TOKEN_STRIDE = 16
NUM_FEATS = 1024
def _list_dict_to_dict_list(samples):
"""Convert a list of dictionaries to a dictionary of lists.
Args:
samples: a list of dictionaries
Returns:
a dictionary of lists
"""
collated = dict()
for sample in samples:
for key, value in sample.items():
if key not in collated:
collated[key] = []
collated[key].append(value)
return collated
def stack_samples(samples):
"""Stack a list of samples along a new axis.
Useful for forming a mini-batch of samples to pass to
:class:`torch.utils.data.DataLoader`.
Args:
samples: list of samples
Returns:
a single sample
"""
collated = _list_dict_to_dict_list(samples)
for key, value in collated.items():
if isinstance(value[0], torch.Tensor):
collated[key] = torch.stack(value)
return collated
class GridGeoSampler(Sampler):
def __init__(
self,
image_fns: List[List[str]],
image_fn_indices: List[int],
patch_size: int = 256,
stride: int = 256,
):
self.image_fn_indices = image_fn_indices
self.patch_size = patch_size
# tuples of the form (i, y, x, patch_size) that index into a TileDataset
self.indices = []
for i in self.image_fn_indices:
with rasterio.open(image_fns[i][0]) as f:
height, width = f.height, f.width
if patch_size > height and patch_size > width:
self.indices.append((i, 0, 0, self.patch_size))
else:
for y in list(range(0, height - patch_size, stride)) + [
height - patch_size
]:
for x in list(range(0, width - patch_size, stride)) + [
width - patch_size
]:
self.indices.append((i, y, x, self.patch_size))
self.num_chips = len(self.indices)
def __iter__(self):
for index in self.indices:
yield index
def __len__(self):
return self.num_chips
class TileDataset(Dataset):
def __init__(
self,
image_fns: List[List[str]],
mask_fns: Optional[List[str]],
transforms=None,
sanity_check=True,
):
self.image_fns = image_fns
self.mask_fns = mask_fns
if mask_fns is not None:
assert len(image_fns) == len(mask_fns)
# Check to make sure that all the image and mask tile pairs are the same size
# as a sanity check
if sanity_check and mask_fns is not None:
print("Running sanity check on dataset...")
for image_fn, mask_fn in list(zip(image_fns, mask_fns)):
with rasterio.open(image_fn[0]) as f:
image_height, image_width = f.shape
with rasterio.open(mask_fn) as f:
mask_height, mask_width = f.shape
assert image_height == mask_height
assert image_width == mask_width
self.transforms = transforms
def __len__(self):
return len(self.image_fns)
def __getitem__(self, index: Tuple[int, int, int, int]):
i, y, x, patch_size = index
sample = {
"y": y,
"x": x,
}
window = rasterio.windows.Window(x, y, patch_size, patch_size)
# Load imagery
stack = []
for j in range(len(self.image_fns[i])):
image_fn = self.image_fns[i][j]
with rasterio.open(image_fn) as f:
image = f.read(window=window)
stack.append(image)
stack = np.concatenate(stack, axis=0)
sample["image"] = torch.from_numpy(stack).float()
# Load mask
if self.mask_fns is not None:
mask_fn = self.mask_fns[i]
with rasterio.open(mask_fn) as f:
mask = f.read(window=window)
sample["mask"] = torch.from_numpy(mask).long()
if self.transforms is not None:
sample["image"] = self.transforms(sample["image"])
return sample
class BatchedDinoWrapper(nn.Module):
def __init__(self, layers: Optional[Sequence[int]] = None):
super().__init__()
# choose which transformer blocks to read; default: last layer only
self.layers = list(range(24)) if layers is None else list(layers)
# keep the backbone in self.backbone; don't wrap it yet
self.backbone = torch.hub.load(
"dinov3",
"dinov3_vitl16",
source="local",
weights="dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth",
).eval()
@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
"""
x: (N, 3, H, W)
returns: (N, HW, C) features from the last selected layer
"""
if x.dim() == 3:
x = x.unsqueeze(0) # -> (1,3,H,W)
assert x.dim() == 4 and x.size(1) == 3, "Expected (N,3,H,W)"
feats_list = self.backbone.get_intermediate_layers(
x, n=self.layers, reshape=True, norm=True
)
feats = feats_list[-1] # take the last requested layer
assert feats.dim() == 4 and feats.size(0) == x.size(0), (
f"Unexpected feats shape: {feats.shape}"
)
N, C, h, w = feats.shape
feats = feats.view(N, C, h * w).transpose(1, 2).contiguous()
return feats
def parse_args() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Embed script arguments")
parser.add_argument(
"--input_fn",
type=str,
required=True,
help="Path to the input file",
)
parser.add_argument(
"--output_fn",
type=str,
required=True,
help="Path to the output file",
)
parser.add_argument(
"--batch_size",
type=int,
default=16,
)
parser.add_argument(
"--gpus",
nargs="+",
type=int,
default=0,
help="GPU ids to use",
)
parser.add_argument(
"--overwrite",
action="store_true",
help="Whether to overwrite the output file if it exists",
)
parser.add_argument(
"--patch_size",
type=int,
default=1024,
help="Patch size to use for inference (default: 1024)",
)
parser.add_argument(
"--padding",
type=int,
default=32,
help="Padding to use for inference (default: 32)",
)
parser.add_argument(
"--resize_factor",
type=int,
default=1,
help="Resize factor to use for inference (default: 1)",
)
parser.add_argument(
"--pca",
action="store_true",
help="Whether to run PCA on the features and save the first 3 components",
)
return parser
def main(args: argparse.Namespace):
input_image_fn = args.input_fn
if not os.path.exists(input_image_fn):
raise FileNotFoundError(f"Input file {input_image_fn} does not exist.")
if not (
input_image_fn.lower().endswith(".tif")
or input_image_fn.lower().endswith(".vrt")
):
raise ValueError("Input file must be a .tif or .vrt file")
output_fn = args.output_fn
if os.path.exists(output_fn) and not args.overwrite:
raise FileExistsError(
f"Output file {output_fn} already exists. Use --overwrite to overwrite it."
)
if os.path.exists(output_fn):
print(
f"WARNING: Output file {output_fn} already exists and will be overwritten."
)
patch_size = args.patch_size
padding = args.padding
assert int(math.log(patch_size, 2)) == math.log(patch_size, 2)
stride = patch_size - padding * 2
with rasterio.open(input_image_fn) as f:
input_height, input_width = f.shape
profile = f.profile
if patch_size > input_height or patch_size > input_width:
raise ValueError(
f"Patch size {patch_size} is larger than image dimensions {input_height}x{input_width}"
)
print(f"Input size: {input_height} x {input_width}")
print(f"Using patch size {patch_size} with padding {padding} and stride {stride}")
print(f"Starting inference with batch size {args.batch_size} on GPUs {args.gpus}")
augs = transforms.Compose(
[
# transforms.ToTensor(),
transforms.Lambda(lambda x: x / 255.0),
transforms.Resize(
(
args.patch_size * args.resize_factor,
args.patch_size * args.resize_factor,
)
),
transforms.Normalize(mean=(0.430, 0.411, 0.296), std=(0.213, 0.156, 0.143)),
]
)
dataset = TileDataset([[input_image_fn]], mask_fns=None, transforms=augs)
sampler = GridGeoSampler(
[[input_image_fn]], [0], patch_size=args.patch_size, stride=stride
)
dataloader = DataLoader(
dataset,
sampler=sampler,
batch_size=args.batch_size,
num_workers=16,
collate_fn=stack_samples,
)
# Load model
device = torch.device("cuda")
model = BatchedDinoWrapper().to(device)
model = nn.DataParallel(model, device_ids=args.gpus)
output_height = math.ceil(input_height / TOKEN_STRIDE) * args.resize_factor
output_width = math.ceil(input_width / TOKEN_STRIDE) * args.resize_factor
print(f"Output size: {output_height} x {output_width} x {NUM_FEATS}")
output = np.zeros((output_height, output_width, NUM_FEATS), dtype=np.float32)
downscaled_patch_size = patch_size // TOKEN_STRIDE * args.resize_factor
downsampled_padding_size = padding // TOKEN_STRIDE * args.resize_factor
tic = time.time()
for batch in tqdm(dataloader, desc="Running model"):
images = batch["image"].to(device)
x_coords = batch["x"]
y_coords = batch["y"]
batch_size = images.shape[0]
with torch.inference_mode(), torch.amp.autocast("cuda"):
features = model(images)
features = (
features.cpu()
.numpy()
.reshape(
batch_size, downscaled_patch_size, downscaled_patch_size, NUM_FEATS
)
)
for i in range(batch_size):
height, width, _ = features[i].shape
y = int(y_coords[i]) // TOKEN_STRIDE * args.resize_factor
x = int(x_coords[i]) // TOKEN_STRIDE * args.resize_factor
output[
y + downsampled_padding_size : y + height - downsampled_padding_size,
x + downsampled_padding_size : x + width - downsampled_padding_size,
] = features[i][
downsampled_padding_size:-downsampled_padding_size,
downsampled_padding_size:-downsampled_padding_size,
]
print(f"Finished running model in {time.time() - tic:0.2f} seconds")
new_profile = {
"driver": "GTiff",
"height": output_height,
"width": output_width,
"count": NUM_FEATS,
"dtype": "float32",
"crs": profile["crs"],
"transform": profile["transform"]
* rasterio.Affine.scale(
TOKEN_STRIDE // args.resize_factor, TOKEN_STRIDE // args.resize_factor
),
"compress": "lzw",
"predictor": 3,
"nodata": 0,
"blockxsize": 512,
"blockysize": 512,
"tiled": True,
"interleave": "pixel",
"BIGTIFF": "YES",
}
tic = time.time()
with rasterio.open(output_fn, "w", **new_profile) as f:
f.write(output.transpose(2, 0, 1))
print(f"Wrote output to {output_fn} in {time.time() - tic:0.2f} seconds")
if args.pca:
print("Running PCA on features and saving first 3 components")
tic = time.time()
pca = PCA(n_components=3, whiten=True)
x_pca = pca.fit_transform(output.reshape(-1, NUM_FEATS))
x_pca = x_pca.reshape(output_height, output_width, -1)
x_pca = torch.from_numpy(x_pca)
x_pca = F.sigmoid(x_pca * 2.0).numpy()
x_pca = (x_pca * 255.0).astype(np.uint8).transpose(2, 0, 1)
print(f"PCA summed explained variance: {sum(pca.explained_variance_ratio_)}")
pca_output_fn = output_fn.replace(".tif", "_pca.tif")
pca_profile = new_profile.copy()
pca_profile["count"] = 3
pca_profile["dtype"] = "uint8"
pca_profile["predictor"] = 2
with rasterio.open(pca_output_fn, "w", **pca_profile) as f:
f.write(x_pca)
print(
f"Wrote PCA output to {pca_output_fn} in {time.time() - tic:0.2f} seconds"
)
if __name__ == "__main__":
parser = parse_args()
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment