Last active
August 24, 2025 16:54
-
-
Save ulassbin/5303807fba27b1fbb8ccd5bea03ef99e to your computer and use it in GitHub Desktop.
DinoV3 Feature Extraction and Saving Script
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # Tool to get DinoV3 features from videos, and save them in a compressed manner | |
| # Compression is done with PCA and IncrementalPCA for larger inputs | |
| import os, sys | |
| from pathlib import Path | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from torchvision import transforms | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| import torch.nn.functional as F | |
| import torchvision.transforms.v2 as T | |
| import torch.nn as nn | |
| from sklearn.decomposition import PCA, IncrementalPCA | |
| # ---------------- config ---------------- | |
| REPO_DIR = '/abyss/home/dinov3/dinov3-main' | |
| BACKBONE_PATH = '/abyss/home/dinov3/weights/dinov3_vits16_pretrain_lvd1689m-08c60483.pth' | |
| IMG_PATH = '/abyss/home/THUMOS_RAW/raw_videos/TH14_test_set_mp4/0' # change to your image | |
| OUT_DIR_SINGLE = './feat_channels' # per-channel PNGs | |
| OUT_DIR_TILE = './feat_tiles' | |
| FULL_FEATURE_PATH = './npz_features_lite_final/test/' | |
| NPZ_FEATURE_PATH = './npz_features_lite_final/test/' | |
| # collage | |
| MODEL_NAME = 'dinov3_vits16' | |
| PATCH_SIZE = 16 | |
| RESIZE_TO = 448 #224 # square input (divisible by 16) | |
| GRID_HW = (14, 14) # tile up to 2500 channels | |
| CELL_PX = 16 | |
| BATCH_SIZE = 200 | |
| SAVE_FULL_FEATURES = False | |
| TEMPORAL_POOLING = 4 | |
| PCA_KEEP_RATIO = 0.96 | |
| ITERATIVE_THRESHOLD = 4000 # above this amount of frames, use IncrementalPCA | |
| # ---------------------------------------- | |
| sys.path.append(REPO_DIR) | |
| os.makedirs(OUT_DIR_SINGLE, exist_ok=True) | |
| os.makedirs(OUT_DIR_TILE, exist_ok=True) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f'[info] using device: {device}') | |
| # ---------- utils ---------- | |
| class FeaturesModule(nn.Module): | |
| def __init__(self, _repo_dir, _model_name, _source, _pretrained, _weights, _check_hash): | |
| super().__init__() | |
| self.backbone = torch.hub.load(_repo_dir, _model_name, source=_source, pretrained=_pretrained, weights=_weights, check_hash=_check_hash) | |
| def forward(self, x): | |
| #print("running on", x.device) # you'll see cuda:0, cuda:1, ... | |
| out = self.backbone.forward_features(x) | |
| return out["x_norm_patchtokens"] | |
| def minmax01(arr, eps=1e-6): | |
| mn, mx = float(arr.min()), float(arr.max()) | |
| return (arr - mn) / max(mx - mn, eps) | |
| def save_channel_images(feat_chw, out_dir, prefix='feat'): | |
| C, H, W = feat_chw.shape | |
| for c in range(C): | |
| f = feat_chw[c].numpy() | |
| f = minmax01(f) | |
| plt.imsave(str(Path(out_dir) / f"{prefix}_c{c:04d}.png"), f, cmap='gray') | |
| def save_tiled_collage(feat_chw, out_png, grid_hw=(50,50), cell_px=48): | |
| C, H, W = feat_chw.shape | |
| gh, gw = grid_hw | |
| n = min(gh*gw, C) | |
| tile = np.zeros((gh*cell_px, gw*cell_px), dtype=np.float32) | |
| for idx in range(n): | |
| r, c = divmod(idx, gw) | |
| f = feat_chw[idx].numpy() | |
| f = minmax01(f) | |
| im = Image.fromarray((f*255).astype(np.uint8)).resize((cell_px, cell_px), Image.BILINEAR) | |
| tile[r*cell_px:(r+1)*cell_px, c*cell_px:(c+1)*cell_px] = np.asarray(im, dtype=np.uint8)/255.0 | |
| plt.imsave(out_png, tile, cmap='viridis') | |
| def coerce_tokens(obj): | |
| """ | |
| Return patch+extra tokens as tensor [B, N_total, C] from various return types: | |
| - tensor [B, N, C] | |
| - (tokens, cls) | |
| - [ (tokens, cls), ... ] or [tokens, ...] | |
| - or feature map [B,C,H',W'] -> flatten to [B,N,C] | |
| """ | |
| if isinstance(obj, list): | |
| obj = obj[-1] | |
| if isinstance(obj, (tuple, list)): | |
| obj = obj[0] | |
| assert isinstance(obj, torch.Tensor), f"Unexpected output type: {type(obj)}" | |
| if obj.ndim == 4: # [B,C,H',W'] -> [B,N,C] | |
| B, C, H, W = obj.shape | |
| obj = obj.permute(0,2,3,1).reshape(B, H*W, C) | |
| return obj | |
| # ---------- load backbone ---------- | |
| backbone = FeaturesModule(REPO_DIR, MODEL_NAME, 'local', True, BACKBONE_PATH, False).to(device).eval() | |
| if torch.cuda.device_count() > 1: | |
| print(f'Lets use {torch.cuda.device_count()} GPUS!') | |
| backbone = nn.DataParallel(backbone).to(device) | |
| def try_get_tokens_api(x): | |
| try: | |
| # out = backbone.get_intermediate_layers(x, n=1, reshape=False, return_class_token=True) | |
| # return coerce_tokens(out) # [B, N_total, C] | |
| last_idx = (len(backbone.blocks) - 1) if hasattr(backbone, "blocks") else (backbone.num_blocks - 1) | |
| out = backbone.get_intermediate_layers(x, [last_idx], reshape=False, return_class_token=True) | |
| return coerce_tokens(out) # -> [B, N_total, C] | |
| except Exception: | |
| print(f'Exception in get_intermediate_layers: {sys.exc_info()[0]}') | |
| return None | |
| def get_tokens_hook(x): | |
| grabbed = {} | |
| last_blocks = [n for n,_ in backbone.named_modules() if 'blocks' in n] | |
| assert last_blocks, "Could not find 'blocks' module." | |
| last_block_name = sorted(last_blocks, key=len)[-1] | |
| def hook(_, __, out): | |
| t = out[0] if isinstance(out, (tuple, list)) else out | |
| grabbed['t'] = t.detach() | |
| handle = dict(backbone.named_modules())[last_block_name].register_forward_hook(hook) | |
| _ = backbone(x) | |
| handle.remove() | |
| return coerce_tokens(grabbed['t']) | |
| import torch | |
| def temporal_avg_pool_with_tail(x: torch.Tensor, stride: int, idx_offset: int=0): | |
| """ | |
| Temporal average pool with tail handling for 2D inputs. | |
| Args: | |
| x: [T, F] tensor (T = time, F = feature dim e.g. D*P) | |
| stride: window size (int > 0) | |
| Returns: | |
| pooled: [T_out, F] tensor | |
| idxs: [T_out] LongTensor of representative frame indices (on x.device) | |
| """ | |
| assert x.dim() == 2, "x must be [T, F]" | |
| assert stride > 0, "stride must be > 0" | |
| T, F = x.shape | |
| num_full = T // stride | |
| rem = T % stride | |
| parts = [] | |
| idxs = [] | |
| # full windows | |
| if num_full > 0: | |
| main = x[:num_full * stride] # [num_full*stride, F] | |
| main = main.view(num_full, stride, F).mean(1) # [num_full, F] | |
| parts.append(main) | |
| idxs.append(torch.arange(num_full, device=x.device) * stride + (stride // 2)) | |
| # tail (if any) | |
| if rem > 0: | |
| tail = x[num_full * stride:] # [rem, F] | |
| tail_mean = tail.mean(0, keepdim=True) # [1, F] | |
| parts.append(tail_mean) | |
| idxs.append(torch.tensor([num_full * stride + rem // 2], device=x.device)) | |
| if parts: | |
| pooled = torch.cat(parts, dim=0) # [T_out, F] | |
| idxs = torch.cat(idxs, dim=0) # [T_out] | |
| else: | |
| pooled = x.new_zeros((0, F)) | |
| idxs = torch.empty(0, dtype=torch.long, device=x.device) | |
| idxs += idx_offset # adjust indices if needed | |
| return pooled, idxs | |
| def process_mini_batch(batch): | |
| # stack and move to GPU | |
| #batch = torch.stack(batch, dim=0).to(device) # [B, 3, H, W] | |
| with torch.inference_mode(): | |
| out = backbone(batch) #_features(batch)['x_norm_patchtokens'] | |
| print(f'Input {batch.shape}, Feats {out.shape}') | |
| return out # shape [B, N, C] | |
| def load_video_with_minibatches( | |
| path, | |
| pool_stride, | |
| size=224, | |
| max_frames=None, | |
| batch_size=BATCH_SIZE, | |
| device="cuda"): | |
| tfm = transforms.Compose([ | |
| transforms.Resize((size, size), antialias=True), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=(0.485, 0.456, 0.406), | |
| std=(0.229, 0.224, 0.225)), | |
| ]) | |
| #cap = cv2.VideoCapture(path) | |
| #frames = [] | |
| results = [] | |
| count = 0 | |
| #while cap.isOpened(): | |
| # ret, frame = cap.read() | |
| # if not ret: | |
| # break | |
| # frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| # img = Image.fromarray(frame_rgb) | |
| # x = tfm(img) | |
| # frames.append(x) | |
| #cap.release() | |
| import torchvision.io as io | |
| vframes, _, _ = io.read_video(path, output_format="TCHW") # [T,C,H,W] | |
| #vframes = vframes.to(device).float() / 255 | |
| #vframes = T.Resize((size,size))(vframes) | |
| mean=(0.485, 0.456, 0.406) | |
| std=(0.229, 0.224, 0.225) | |
| #vframes = T.Normalize(mean, std)(vframes) | |
| num_runs = len(vframes) // batch_size + (1 if len(vframes) % batch_size > 0 else 0) | |
| features = None | |
| len_data = int(np.ceil(len(vframes)/pool_stride)) | |
| w=0 | |
| full_idx = [] | |
| for i in range(num_runs): | |
| end = min((i + 1) * batch_size, len(vframes)) | |
| batch = vframes[i * batch_size: end].to(device).float()/255 | |
| # Pre process batch | |
| batch = T.Resize((size,size))(batch) # This can also be moved in to feature processor if fails... | |
| batch = T.Normalize(mean,std)(batch) | |
| tokens = process_mini_batch(batch) # shape [T, N, C] | |
| Tb, hw, C = tokens.shape | |
| tokens, idx = temporal_avg_pool_with_tail(tokens.reshape(Tb, -1), pool_stride, w) # [T/stride, N, C], [T/stride] | |
| # This should return [T', hw*C] | |
| tokens = tokens.reshape(tokens.shape[0], hw, C) # [T', hw, C] | |
| if features is None: | |
| features = torch.empty((len_data, tokens.shape[1], tokens.shape[2]), dtype=torch.float32, device='cpu') | |
| features[w:w+tokens.shape[0]] = tokens.detach().cpu() | |
| full_idx.extend(idx.detach().cpu().numpy()) | |
| w += tokens.shape[0] | |
| #results.append(tokens.detach().to("cpu")) | |
| # concat all minibatch results | |
| print(f'Got all video frames with len {features.shape}') | |
| #features = torch.cat(results, dim=0) # runs out of memory in following parts | |
| print(f'Catted results to a single feature') | |
| features = features.to("cpu") | |
| print(f'Moved features to cpu {features.shape}') | |
| # Form idx | |
| idx = torch.arange(features.shape[0], dtype=torch.long, device=features.device) | |
| return features, idx # features are T/stride, hw, C size | |
| # -- Saving pipeline | |
| def save_pca_channels( | |
| tokens: torch.Tensor, # [T, C, HW] | |
| idx, # frame indices (len=T) | |
| save_path: str, | |
| keep_ratio: float = 0.99, # variance to keep (0<r<=1) or integer n_components | |
| compress: bool = True, | |
| save: bool = True | |
| ): | |
| """ | |
| Fit PCA across all (T * HW) spatial samples and reduce channel dim C -> K. | |
| Saves compact representation as [HW, T, K] along with components/mean. | |
| """ | |
| assert tokens.dim() == 3, "tokens must be [T, C, HW]" | |
| T, C, HW = tokens.shape | |
| # (T, C, HW) -> (T*HW, C) **contiguous to preserve memory order** | |
| X = (tokens.permute(0, 2, 1) | |
| .contiguous() | |
| .view(T * HW, C) | |
| .cpu() | |
| .numpy() | |
| .astype(np.float32)) | |
| pca = PCA(n_components=keep_ratio, svd_solver="full", random_state=0) | |
| Z = pca.fit_transform(X) # [T*HW, K] | |
| K = Z.shape[1] | |
| # Store as [HW, T, K] (spatial-major on disk) | |
| Z_disk = Z.reshape(T, HW, K).transpose(1, 0, 2) # [HW, T, K] | |
| if save: | |
| data = { | |
| "tokens_pca": Z_disk, # [HW, T, K] | |
| "idx": np.asarray(idx), | |
| "pca_components": pca.components_, # [K, C] | |
| "pca_mean": pca.mean_, # [C] | |
| "HW": np.int64(HW), | |
| "explained_var_ratio": pca.explained_variance_ratio_.astype(np.float32), | |
| } | |
| (np.savez_compressed if compress else np.savez)(save_path, **data) | |
| return K, pca.explained_variance_ratio_ | |
| @torch.no_grad() | |
| def save_pca_channels_ipca( | |
| tokens_TCHW: torch.Tensor, # [T, C, HW] | |
| idx, # frame indices (len=T) | |
| save_path: str, | |
| keep_ratio: float = 0.99 | |
| ): | |
| """ | |
| 1) Estimates integer n_components via a PCA on a random sample (to match keep_ratio). | |
| 2) Runs IncrementalPCA(n_components=K) streaming over all T*HW rows. | |
| 3) Saves: | |
| tokens_pca: [HW, T, K], pca_components: [K, C], pca_mean: [C], | |
| HW (int), idx, explained_var_ratio (from IncrementalPCA). | |
| """ | |
| assert tokens_TCHW.ndim == 3 | |
| T, C, HW = tokens_TCHW.shape | |
| # ---- step 1: estimate K from a sample | |
| #rng = np.random.default_rng(rng) | |
| num_samples = 1000 | |
| sample_indices = np.random.choice(T, size=num_samples, replace=False) | |
| K, estimate_var_ratio = save_pca_channels( | |
| tokens_TCHW[sample_indices], idx[sample_indices], save_path, keep_ratio=keep_ratio, save=False | |
| ) | |
| print(f'Estimated K: {K}, Explained Variance Ratio: {estimate_var_ratio} for data of length {tokens_TCHW.shape}') | |
| # ---- step 2: fit IncrementalPCA with that K over full data | |
| ipca = IncrementalPCA(n_components=min(K, C), batch_size=None) | |
| # pass 1: partial_fit over all frames (stream rows [HW, C] per frame) | |
| # Tokens are of T,C,HW shape. we want to fit | |
| num_batch = 1000 | |
| num_iters = (T + num_batch - 1) // num_batch | |
| for i in range(num_iters): | |
| start = i * num_batch | |
| end = min((i + 1) * num_batch, T) | |
| batch = tokens_TCHW[start:end] | |
| batch = (batch.permute(0, 2, 1) | |
| .contiguous() | |
| .view(-1, C) | |
| .cpu() | |
| .numpy() | |
| .astype(np.float32)) | |
| print(f'Partial Iter {100.0*i/num_iters}') | |
| ipca.partial_fit(batch) | |
| # pass 2: transform all rows (HW, C) -> (HW, K) | |
| reduced = np.zeros((T, HW, ipca.n_components_), dtype=np.float32) | |
| for i in range(num_iters): | |
| start = i * num_batch | |
| end = min((i + 1) * num_batch, T) | |
| batch = tokens_TCHW[start:end] | |
| batch = (batch.permute(0, 2, 1) | |
| .contiguous() | |
| .view(-1, C) | |
| .cpu() | |
| .numpy() | |
| .astype(np.float32)) | |
| print(f'Fitting {100.0*i/num_iters}') | |
| Z = ipca.transform(batch) | |
| Z = Z.reshape(end - start, HW, ipca.n_components_) # [num_batch, HW, K] | |
| reduced[start:end] = Z # [T, HW, K] | |
| # ---- step 3: save npz | |
| np.savez_compressed( | |
| save_path, | |
| tokens_pca=reduced, # [T, HW, K] | |
| idx=np.asarray(idx), | |
| pca_components=ipca.components_.astype(np.float32), # [K, C] | |
| pca_mean=ipca.mean_.astype(np.float32), # [C] | |
| HW=np.int64(HW), | |
| explained_var_ratio=ipca.explained_variance_ratio_.astype(np.float32), | |
| estimate_var_ratio=estimate_var_ratio.astype(np.float32), | |
| K=np.int64(ipca.n_components_), | |
| C=np.int64(C), | |
| T=np.int64(T), | |
| ) | |
| def reconstruct_data(data_path): | |
| """ | |
| Reconstruct channels from PCA data.""" | |
| data = np.load(data_path) | |
| tokens_pca = data['tokens_pca'] # [H*W, num_frames, num_components] | |
| pca_components = data['pca_components'] # [num_components, num_channels] | |
| pca_mean = data['pca_mean'] # [H*W*num_frames, num_channels] | |
| hw = data['HW'] | |
| # Reconstruct the original tokens | |
| reconstructed_tokens = np.dot(tokens_pca, pca_components) + pca_mean # [H*W, num_frames, num_channels] | |
| reconstructed_tokens = reconstructed_tokens.transpose(1, 2, 0) # [num_frames, num_channels, H*W] | |
| return reconstructed_tokens | |
| # ---------- run ---------- | |
| # for mp4 files in IMG_PATH | |
| import glob | |
| import time | |
| last_duration = None | |
| average_duration = None | |
| IMG_PATH = Path(IMG_PATH) | |
| if not IMG_PATH.exists(): | |
| raise FileNotFoundError(f"Path {IMG_PATH} does not exist.") | |
| for i, video_path in enumerate(glob.glob(os.path.join(IMG_PATH, "*.mp4"))): | |
| start = time.time() | |
| print(f"[info] processing video: {video_path} number {i} out of {len(glob.glob(os.path.join(IMG_PATH, '*.mp4')))}") | |
| if last_duration is not None: | |
| print(f"Remaining time: {average_duration * (len(glob.glob(os.path.join(IMG_PATH, '*.mp4'))) - i) / 60:.2f} minutes") | |
| if os.path.exists(os.path.join(NPZ_FEATURE_PATH, f"{Path(video_path).stem}.npz")): | |
| print(f"[info] features already exist for {video_path}, skipping.") | |
| continue | |
| tokens_all, idx = load_video_with_minibatches( | |
| video_path, | |
| pool_stride=TEMPORAL_POOLING, | |
| size=RESIZE_TO, | |
| max_frames=None, | |
| batch_size=BATCH_SIZE, | |
| device=device | |
| ) # .to(device) # [T,HW,C] | |
| video_name = Path(video_path).stem | |
| #tokens_all = tokens_all.reshape(tokens_all.shape[1], channels, -1) # [frames', Channels, H*W] | |
| idx_len = len(idx) | |
| print(f'Last elements of idx: {idx[idx_len-5:idx_len-1]}') | |
| tokens_all = tokens_all.transpose(1,2) | |
| print(f'Tokens after correctly reshaping {tokens_all.shape}') | |
| if len(tokens_all) < ITERATIVE_THRESHOLD: | |
| save_pca_channels(tokens_all, idx, os.path.join(NPZ_FEATURE_PATH, f"{video_name}.npz"), PCA_KEEP_RATIO) # This needs [T,C,HW] | |
| else: # Incremental PCA | |
| print(f'Using Incremental PCA for {video_name} with {len(tokens_all)} frames') | |
| save_pca_channels_ipca(tokens_all, idx, os.path.join(NPZ_FEATURE_PATH, f"{video_name}.npz"), PCA_KEEP_RATIO) | |
| end = time.time() | |
| last_duration = end - start | |
| average_duration = last_duration if average_duration is None else (last_duration + average_duration * i) / (i + 1) | |
| #reconstructed = reconstruct_data('pca_features.npz') | |
| #print(f'Reconstructed shape {reconstructed.shape}') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment