Skip to content

Instantly share code, notes, and snippets.

@ulassbin
Last active August 24, 2025 16:54
Show Gist options
  • Select an option

  • Save ulassbin/5303807fba27b1fbb8ccd5bea03ef99e to your computer and use it in GitHub Desktop.

Select an option

Save ulassbin/5303807fba27b1fbb8ccd5bea03ef99e to your computer and use it in GitHub Desktop.
DinoV3 Feature Extraction and Saving Script
#!/usr/bin/env python3
# Tool to get DinoV3 features from videos, and save them in a compressed manner
# Compression is done with PCA and IncrementalPCA for larger inputs
import os, sys
from pathlib import Path
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import cv2
import torch.nn.functional as F
import torchvision.transforms.v2 as T
import torch.nn as nn
from sklearn.decomposition import PCA, IncrementalPCA
# ---------------- config ----------------
REPO_DIR = '/abyss/home/dinov3/dinov3-main'
BACKBONE_PATH = '/abyss/home/dinov3/weights/dinov3_vits16_pretrain_lvd1689m-08c60483.pth'
IMG_PATH = '/abyss/home/THUMOS_RAW/raw_videos/TH14_test_set_mp4/0' # change to your image
OUT_DIR_SINGLE = './feat_channels' # per-channel PNGs
OUT_DIR_TILE = './feat_tiles'
FULL_FEATURE_PATH = './npz_features_lite_final/test/'
NPZ_FEATURE_PATH = './npz_features_lite_final/test/'
# collage
MODEL_NAME = 'dinov3_vits16'
PATCH_SIZE = 16
RESIZE_TO = 448 #224 # square input (divisible by 16)
GRID_HW = (14, 14) # tile up to 2500 channels
CELL_PX = 16
BATCH_SIZE = 200
SAVE_FULL_FEATURES = False
TEMPORAL_POOLING = 4
PCA_KEEP_RATIO = 0.96
ITERATIVE_THRESHOLD = 4000 # above this amount of frames, use IncrementalPCA
# ----------------------------------------
sys.path.append(REPO_DIR)
os.makedirs(OUT_DIR_SINGLE, exist_ok=True)
os.makedirs(OUT_DIR_TILE, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'[info] using device: {device}')
# ---------- utils ----------
class FeaturesModule(nn.Module):
def __init__(self, _repo_dir, _model_name, _source, _pretrained, _weights, _check_hash):
super().__init__()
self.backbone = torch.hub.load(_repo_dir, _model_name, source=_source, pretrained=_pretrained, weights=_weights, check_hash=_check_hash)
def forward(self, x):
#print("running on", x.device) # you'll see cuda:0, cuda:1, ...
out = self.backbone.forward_features(x)
return out["x_norm_patchtokens"]
def minmax01(arr, eps=1e-6):
mn, mx = float(arr.min()), float(arr.max())
return (arr - mn) / max(mx - mn, eps)
def save_channel_images(feat_chw, out_dir, prefix='feat'):
C, H, W = feat_chw.shape
for c in range(C):
f = feat_chw[c].numpy()
f = minmax01(f)
plt.imsave(str(Path(out_dir) / f"{prefix}_c{c:04d}.png"), f, cmap='gray')
def save_tiled_collage(feat_chw, out_png, grid_hw=(50,50), cell_px=48):
C, H, W = feat_chw.shape
gh, gw = grid_hw
n = min(gh*gw, C)
tile = np.zeros((gh*cell_px, gw*cell_px), dtype=np.float32)
for idx in range(n):
r, c = divmod(idx, gw)
f = feat_chw[idx].numpy()
f = minmax01(f)
im = Image.fromarray((f*255).astype(np.uint8)).resize((cell_px, cell_px), Image.BILINEAR)
tile[r*cell_px:(r+1)*cell_px, c*cell_px:(c+1)*cell_px] = np.asarray(im, dtype=np.uint8)/255.0
plt.imsave(out_png, tile, cmap='viridis')
def coerce_tokens(obj):
"""
Return patch+extra tokens as tensor [B, N_total, C] from various return types:
- tensor [B, N, C]
- (tokens, cls)
- [ (tokens, cls), ... ] or [tokens, ...]
- or feature map [B,C,H',W'] -> flatten to [B,N,C]
"""
if isinstance(obj, list):
obj = obj[-1]
if isinstance(obj, (tuple, list)):
obj = obj[0]
assert isinstance(obj, torch.Tensor), f"Unexpected output type: {type(obj)}"
if obj.ndim == 4: # [B,C,H',W'] -> [B,N,C]
B, C, H, W = obj.shape
obj = obj.permute(0,2,3,1).reshape(B, H*W, C)
return obj
# ---------- load backbone ----------
backbone = FeaturesModule(REPO_DIR, MODEL_NAME, 'local', True, BACKBONE_PATH, False).to(device).eval()
if torch.cuda.device_count() > 1:
print(f'Lets use {torch.cuda.device_count()} GPUS!')
backbone = nn.DataParallel(backbone).to(device)
def try_get_tokens_api(x):
try:
# out = backbone.get_intermediate_layers(x, n=1, reshape=False, return_class_token=True)
# return coerce_tokens(out) # [B, N_total, C]
last_idx = (len(backbone.blocks) - 1) if hasattr(backbone, "blocks") else (backbone.num_blocks - 1)
out = backbone.get_intermediate_layers(x, [last_idx], reshape=False, return_class_token=True)
return coerce_tokens(out) # -> [B, N_total, C]
except Exception:
print(f'Exception in get_intermediate_layers: {sys.exc_info()[0]}')
return None
def get_tokens_hook(x):
grabbed = {}
last_blocks = [n for n,_ in backbone.named_modules() if 'blocks' in n]
assert last_blocks, "Could not find 'blocks' module."
last_block_name = sorted(last_blocks, key=len)[-1]
def hook(_, __, out):
t = out[0] if isinstance(out, (tuple, list)) else out
grabbed['t'] = t.detach()
handle = dict(backbone.named_modules())[last_block_name].register_forward_hook(hook)
_ = backbone(x)
handle.remove()
return coerce_tokens(grabbed['t'])
import torch
def temporal_avg_pool_with_tail(x: torch.Tensor, stride: int, idx_offset: int=0):
"""
Temporal average pool with tail handling for 2D inputs.
Args:
x: [T, F] tensor (T = time, F = feature dim e.g. D*P)
stride: window size (int > 0)
Returns:
pooled: [T_out, F] tensor
idxs: [T_out] LongTensor of representative frame indices (on x.device)
"""
assert x.dim() == 2, "x must be [T, F]"
assert stride > 0, "stride must be > 0"
T, F = x.shape
num_full = T // stride
rem = T % stride
parts = []
idxs = []
# full windows
if num_full > 0:
main = x[:num_full * stride] # [num_full*stride, F]
main = main.view(num_full, stride, F).mean(1) # [num_full, F]
parts.append(main)
idxs.append(torch.arange(num_full, device=x.device) * stride + (stride // 2))
# tail (if any)
if rem > 0:
tail = x[num_full * stride:] # [rem, F]
tail_mean = tail.mean(0, keepdim=True) # [1, F]
parts.append(tail_mean)
idxs.append(torch.tensor([num_full * stride + rem // 2], device=x.device))
if parts:
pooled = torch.cat(parts, dim=0) # [T_out, F]
idxs = torch.cat(idxs, dim=0) # [T_out]
else:
pooled = x.new_zeros((0, F))
idxs = torch.empty(0, dtype=torch.long, device=x.device)
idxs += idx_offset # adjust indices if needed
return pooled, idxs
def process_mini_batch(batch):
# stack and move to GPU
#batch = torch.stack(batch, dim=0).to(device) # [B, 3, H, W]
with torch.inference_mode():
out = backbone(batch) #_features(batch)['x_norm_patchtokens']
print(f'Input {batch.shape}, Feats {out.shape}')
return out # shape [B, N, C]
def load_video_with_minibatches(
path,
pool_stride,
size=224,
max_frames=None,
batch_size=BATCH_SIZE,
device="cuda"):
tfm = transforms.Compose([
transforms.Resize((size, size), antialias=True),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)),
])
#cap = cv2.VideoCapture(path)
#frames = []
results = []
count = 0
#while cap.isOpened():
# ret, frame = cap.read()
# if not ret:
# break
# frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# img = Image.fromarray(frame_rgb)
# x = tfm(img)
# frames.append(x)
#cap.release()
import torchvision.io as io
vframes, _, _ = io.read_video(path, output_format="TCHW") # [T,C,H,W]
#vframes = vframes.to(device).float() / 255
#vframes = T.Resize((size,size))(vframes)
mean=(0.485, 0.456, 0.406)
std=(0.229, 0.224, 0.225)
#vframes = T.Normalize(mean, std)(vframes)
num_runs = len(vframes) // batch_size + (1 if len(vframes) % batch_size > 0 else 0)
features = None
len_data = int(np.ceil(len(vframes)/pool_stride))
w=0
full_idx = []
for i in range(num_runs):
end = min((i + 1) * batch_size, len(vframes))
batch = vframes[i * batch_size: end].to(device).float()/255
# Pre process batch
batch = T.Resize((size,size))(batch) # This can also be moved in to feature processor if fails...
batch = T.Normalize(mean,std)(batch)
tokens = process_mini_batch(batch) # shape [T, N, C]
Tb, hw, C = tokens.shape
tokens, idx = temporal_avg_pool_with_tail(tokens.reshape(Tb, -1), pool_stride, w) # [T/stride, N, C], [T/stride]
# This should return [T', hw*C]
tokens = tokens.reshape(tokens.shape[0], hw, C) # [T', hw, C]
if features is None:
features = torch.empty((len_data, tokens.shape[1], tokens.shape[2]), dtype=torch.float32, device='cpu')
features[w:w+tokens.shape[0]] = tokens.detach().cpu()
full_idx.extend(idx.detach().cpu().numpy())
w += tokens.shape[0]
#results.append(tokens.detach().to("cpu"))
# concat all minibatch results
print(f'Got all video frames with len {features.shape}')
#features = torch.cat(results, dim=0) # runs out of memory in following parts
print(f'Catted results to a single feature')
features = features.to("cpu")
print(f'Moved features to cpu {features.shape}')
# Form idx
idx = torch.arange(features.shape[0], dtype=torch.long, device=features.device)
return features, idx # features are T/stride, hw, C size
# -- Saving pipeline
def save_pca_channels(
tokens: torch.Tensor, # [T, C, HW]
idx, # frame indices (len=T)
save_path: str,
keep_ratio: float = 0.99, # variance to keep (0<r<=1) or integer n_components
compress: bool = True,
save: bool = True
):
"""
Fit PCA across all (T * HW) spatial samples and reduce channel dim C -> K.
Saves compact representation as [HW, T, K] along with components/mean.
"""
assert tokens.dim() == 3, "tokens must be [T, C, HW]"
T, C, HW = tokens.shape
# (T, C, HW) -> (T*HW, C) **contiguous to preserve memory order**
X = (tokens.permute(0, 2, 1)
.contiguous()
.view(T * HW, C)
.cpu()
.numpy()
.astype(np.float32))
pca = PCA(n_components=keep_ratio, svd_solver="full", random_state=0)
Z = pca.fit_transform(X) # [T*HW, K]
K = Z.shape[1]
# Store as [HW, T, K] (spatial-major on disk)
Z_disk = Z.reshape(T, HW, K).transpose(1, 0, 2) # [HW, T, K]
if save:
data = {
"tokens_pca": Z_disk, # [HW, T, K]
"idx": np.asarray(idx),
"pca_components": pca.components_, # [K, C]
"pca_mean": pca.mean_, # [C]
"HW": np.int64(HW),
"explained_var_ratio": pca.explained_variance_ratio_.astype(np.float32),
}
(np.savez_compressed if compress else np.savez)(save_path, **data)
return K, pca.explained_variance_ratio_
@torch.no_grad()
def save_pca_channels_ipca(
tokens_TCHW: torch.Tensor, # [T, C, HW]
idx, # frame indices (len=T)
save_path: str,
keep_ratio: float = 0.99
):
"""
1) Estimates integer n_components via a PCA on a random sample (to match keep_ratio).
2) Runs IncrementalPCA(n_components=K) streaming over all T*HW rows.
3) Saves:
tokens_pca: [HW, T, K], pca_components: [K, C], pca_mean: [C],
HW (int), idx, explained_var_ratio (from IncrementalPCA).
"""
assert tokens_TCHW.ndim == 3
T, C, HW = tokens_TCHW.shape
# ---- step 1: estimate K from a sample
#rng = np.random.default_rng(rng)
num_samples = 1000
sample_indices = np.random.choice(T, size=num_samples, replace=False)
K, estimate_var_ratio = save_pca_channels(
tokens_TCHW[sample_indices], idx[sample_indices], save_path, keep_ratio=keep_ratio, save=False
)
print(f'Estimated K: {K}, Explained Variance Ratio: {estimate_var_ratio} for data of length {tokens_TCHW.shape}')
# ---- step 2: fit IncrementalPCA with that K over full data
ipca = IncrementalPCA(n_components=min(K, C), batch_size=None)
# pass 1: partial_fit over all frames (stream rows [HW, C] per frame)
# Tokens are of T,C,HW shape. we want to fit
num_batch = 1000
num_iters = (T + num_batch - 1) // num_batch
for i in range(num_iters):
start = i * num_batch
end = min((i + 1) * num_batch, T)
batch = tokens_TCHW[start:end]
batch = (batch.permute(0, 2, 1)
.contiguous()
.view(-1, C)
.cpu()
.numpy()
.astype(np.float32))
print(f'Partial Iter {100.0*i/num_iters}')
ipca.partial_fit(batch)
# pass 2: transform all rows (HW, C) -> (HW, K)
reduced = np.zeros((T, HW, ipca.n_components_), dtype=np.float32)
for i in range(num_iters):
start = i * num_batch
end = min((i + 1) * num_batch, T)
batch = tokens_TCHW[start:end]
batch = (batch.permute(0, 2, 1)
.contiguous()
.view(-1, C)
.cpu()
.numpy()
.astype(np.float32))
print(f'Fitting {100.0*i/num_iters}')
Z = ipca.transform(batch)
Z = Z.reshape(end - start, HW, ipca.n_components_) # [num_batch, HW, K]
reduced[start:end] = Z # [T, HW, K]
# ---- step 3: save npz
np.savez_compressed(
save_path,
tokens_pca=reduced, # [T, HW, K]
idx=np.asarray(idx),
pca_components=ipca.components_.astype(np.float32), # [K, C]
pca_mean=ipca.mean_.astype(np.float32), # [C]
HW=np.int64(HW),
explained_var_ratio=ipca.explained_variance_ratio_.astype(np.float32),
estimate_var_ratio=estimate_var_ratio.astype(np.float32),
K=np.int64(ipca.n_components_),
C=np.int64(C),
T=np.int64(T),
)
def reconstruct_data(data_path):
"""
Reconstruct channels from PCA data."""
data = np.load(data_path)
tokens_pca = data['tokens_pca'] # [H*W, num_frames, num_components]
pca_components = data['pca_components'] # [num_components, num_channels]
pca_mean = data['pca_mean'] # [H*W*num_frames, num_channels]
hw = data['HW']
# Reconstruct the original tokens
reconstructed_tokens = np.dot(tokens_pca, pca_components) + pca_mean # [H*W, num_frames, num_channels]
reconstructed_tokens = reconstructed_tokens.transpose(1, 2, 0) # [num_frames, num_channels, H*W]
return reconstructed_tokens
# ---------- run ----------
# for mp4 files in IMG_PATH
import glob
import time
last_duration = None
average_duration = None
IMG_PATH = Path(IMG_PATH)
if not IMG_PATH.exists():
raise FileNotFoundError(f"Path {IMG_PATH} does not exist.")
for i, video_path in enumerate(glob.glob(os.path.join(IMG_PATH, "*.mp4"))):
start = time.time()
print(f"[info] processing video: {video_path} number {i} out of {len(glob.glob(os.path.join(IMG_PATH, '*.mp4')))}")
if last_duration is not None:
print(f"Remaining time: {average_duration * (len(glob.glob(os.path.join(IMG_PATH, '*.mp4'))) - i) / 60:.2f} minutes")
if os.path.exists(os.path.join(NPZ_FEATURE_PATH, f"{Path(video_path).stem}.npz")):
print(f"[info] features already exist for {video_path}, skipping.")
continue
tokens_all, idx = load_video_with_minibatches(
video_path,
pool_stride=TEMPORAL_POOLING,
size=RESIZE_TO,
max_frames=None,
batch_size=BATCH_SIZE,
device=device
) # .to(device) # [T,HW,C]
video_name = Path(video_path).stem
#tokens_all = tokens_all.reshape(tokens_all.shape[1], channels, -1) # [frames', Channels, H*W]
idx_len = len(idx)
print(f'Last elements of idx: {idx[idx_len-5:idx_len-1]}')
tokens_all = tokens_all.transpose(1,2)
print(f'Tokens after correctly reshaping {tokens_all.shape}')
if len(tokens_all) < ITERATIVE_THRESHOLD:
save_pca_channels(tokens_all, idx, os.path.join(NPZ_FEATURE_PATH, f"{video_name}.npz"), PCA_KEEP_RATIO) # This needs [T,C,HW]
else: # Incremental PCA
print(f'Using Incremental PCA for {video_name} with {len(tokens_all)} frames')
save_pca_channels_ipca(tokens_all, idx, os.path.join(NPZ_FEATURE_PATH, f"{video_name}.npz"), PCA_KEEP_RATIO)
end = time.time()
last_duration = end - start
average_duration = last_duration if average_duration is None else (last_duration + average_duration * i) / (i + 1)
#reconstructed = reconstruct_data('pca_features.npz')
#print(f'Reconstructed shape {reconstructed.shape}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment