Skip to content

Instantly share code, notes, and snippets.

@vovw
Created July 29, 2025 12:08
Show Gist options
  • Select an option

  • Save vovw/75cb3793db9d5fb787085953d72214c9 to your computer and use it in GitHub Desktop.

Select an option

Save vovw/75cb3793db9d5fb787085953d72214c9 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import argparse
import os
import glob
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.preprocessing import minmax_scale
import matplotlib.cm as cm
def parse_args():
parser = argparse.ArgumentParser(description="V-JEPA-2 PCA Visualizations")
parser.add_argument("--img_dir", type=str, default="images", help="Image directory")
parser.add_argument("--model_size", type=str, choices=['large', 'huge'], default='huge',
help="V-JEPA-2 model size (default: huge)")
parser.add_argument("--save_all_modes", action='store_true', help="Save individual images for each visualization mode")
parser.add_argument("--fg_threshold", type=float, default=0.45,
help="Foreground detection threshold (0.0-1.0). Default: 0.45 for V-JEPA")
parser.add_argument("--vjepa_resolution", type=int, default=256,
help="V-JEPA input resolution (default: 256, try 384 or 512 for more detail)")
parser.add_argument("--no_temporal_avg", action='store_true',
help="Don't average temporal patches - use all patches for more spatial detail")
parser.add_argument("--temporal_method", type=str, choices=['mean', 'max', 'median', 'last'], default='mean',
help="How to aggregate temporal patches (default: mean)")
return parser.parse_args()
def load_vjepa_model(model_size='huge'):
print(f"Loading V-JEPA-2 (size: {model_size})...")
options = {
'huge': {
'name': 'V-JEPA-2 ViT-Huge',
'hub_model': 'vjepa2_vit_huge',
'hf_repo': 'facebook/vjepa2-vith-fpc64-256',
'preprocessor': 'vjepa2_preprocessor'
},
'large': {
'name': 'V-JEPA-2 ViT-Large',
'hub_model': 'vjepa2_vit_large',
'hf_repo': 'facebook/vjepa2-vitl-fpc64-256',
'preprocessor': 'vjepa2_preprocessor'
}
}
option = options.get(model_size)
if not option:
raise ValueError(f"Invalid model size: {model_size}")
try:
processor = torch.hub.load('facebookresearch/vjepa2', option['preprocessor'])
model = torch.hub.load('facebookresearch/vjepa2', option['hub_model'])
if isinstance(model, tuple):
model = model[0]
model = model.cuda().eval()
print(f"✅ Loaded {option['name']} via torch.hub!")
return model, processor, option['name']
except Exception as e:
print(f"❌ torch.hub failed: {str(e)[:100]}...")
try:
from transformers import AutoModel, AutoImageProcessor
model = AutoModel.from_pretrained(option['hf_repo'])
processor = AutoImageProcessor.from_pretrained(option['hf_repo'])
model = model.cuda().eval()
print(f"✅ Loaded {option['name']} via transformers!")
return model, processor, option['name']
except Exception as e2:
print(f"❌ HuggingFace failed: {str(e2)[:100]}...")
raise RuntimeError("Could not load V-JEPA-2 model")
def get_vjepa_features(model, processor, image_pil, args):
resolution = args.vjepa_resolution
image_resized = image_pil.resize((resolution, resolution))
img_array = np.array(image_resized)
frames = 64
video = np.stack([img_array] * frames, axis=0) # (64, H, W, 3)
video_tensor = torch.from_numpy(video).permute(0, 3, 1, 2).float() # (64, 3, H, W)
video_batch = video_tensor.unsqueeze(0).cuda() # (1, 64, 3, H, W)
try:
video_normalized = (video_batch / 255.0 - 0.5) / 0.5 # Normalize to [-1, 1]
with torch.no_grad():
features = model(video_normalized)
if isinstance(features, tuple):
features = features[0]
features_np = features.cpu().numpy()[0] # (N, C)
except Exception as e:
print(f"⚠️ Direct processing failed: {e}")
try:
video_np = video_batch.squeeze(0).permute(0, 2, 3, 1).cpu().numpy() # (64, H, W, 3)
inputs = processor(video_np)
if isinstance(inputs, list):
inputs = torch.stack(inputs)
inputs = inputs.cuda()
with torch.no_grad():
features = model(inputs)
if isinstance(features, tuple):
features = features[0]
features_np = features.cpu().numpy()[0] # (N, C)
except Exception as e2:
print(f"❌ Preprocessor failed: {e2}")
return None
num_patches = features_np.shape[0]
spatial_patches_per_dim = resolution // 16
spatial_patches_per_frame = spatial_patches_per_dim ** 2
temporal_patches = num_patches // spatial_patches_per_frame
if num_patches == temporal_patches * spatial_patches_per_frame:
features_3d = features_np.reshape(temporal_patches, spatial_patches_per_frame, -1)
if args.no_temporal_avg:
return features_3d.reshape(-1, features_np.shape[1])
else:
if args.temporal_method == 'mean':
return np.mean(features_3d, axis=0)
elif args.temporal_method == 'max':
return np.max(features_3d, axis=0)
elif args.temporal_method == 'median':
return np.median(features_3d, axis=0)
elif args.temporal_method == 'last':
return features_3d[-1]
return features_np
def enhanced_pca_viz(features, original_img, model_name="V-JEPA-2", fg_threshold=0.45,
vjepa_resolution=256, no_temporal_avg=False, temporal_method='mean'):
if len(features.shape) == 3:
features = features[0]
num_patches = features.shape[0]
feature_dim = features.shape[1]
spatial_patches_per_dim = vjepa_resolution // 16
spatial_patches_per_frame = spatial_patches_per_dim ** 2
temporal_patches = 64 // 2 # Assuming tubelet_size=2
config_found = False
if num_patches == temporal_patches * spatial_patches_per_frame:
config_found = True
features_3d = features.reshape(temporal_patches, spatial_patches_per_frame, feature_dim)
if no_temporal_avg:
features = features_3d.reshape(-1, feature_dim)
patch_size_viz = int(np.ceil(np.sqrt(num_patches)))
else:
agg_func = {'mean': np.mean, 'max': np.max, 'median': np.median, 'last': lambda x: x[-1]}[temporal_method]
features = agg_func(features_3d, axis=0)
patch_size_viz = spatial_patches_per_dim
if not config_found:
patch_size_viz = int(np.ceil(np.sqrt(num_patches)))
# Pad to square grid
original_patch_count = features.shape[0]
if patch_size_viz ** 2 > original_patch_count:
padding = np.zeros((patch_size_viz ** 2 - original_patch_count, feature_dim))
features = np.vstack([features, padding])
x_norm_patches = features / (np.linalg.norm(features, axis=1, keepdims=True) + 1e-8)
# Foreground mask
fg_pca = PCA(n_components=1)
fg_pca_features = fg_pca.fit_transform(x_norm_patches)
fg_pca_scaled = minmax_scale(fg_pca_features)
mask = (fg_pca_scaled > fg_threshold).ravel()[:original_patch_count]
# Visualizations
visualizations = {}
pca_all = PCA(n_components=3)
pca_features_all = pca_all.fit_transform(x_norm_patches)
pca_rgb = minmax_scale(pca_features_all)
visualizations['standard'] = pca_rgb.reshape(patch_size_viz, patch_size_viz, 3)
# Add other modes (foreground_only, enhanced_fg, high_contrast, object_boundaries, fg_clusters, activation_heatmap, etc.)
# [Omitted for brevity; implement similar to original but simplified]
return visualizations, pca_all.explained_variance_ratio_
def process_image(img_path, model, processor, model_name, args):
image_pil = Image.open(img_path).convert('RGB')
img_name = os.path.splitext(os.path.basename(img_path))[0]
features = get_vjepa_features(model, processor, image_pil, args)
if features is None:
print(f"❌ Failed for {img_name}")
return
viz, variance = enhanced_pca_viz(features, image_pil, model_name, args.fg_threshold,
args.vjepa_resolution, args.no_temporal_avg, args.temporal_method)
# Plot and save (simplified to save standard viz)
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(image_pil)
plt.title('Original')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(viz['standard'])
plt.title(f'{model_name} PCA')
plt.axis('off')
filename = f"vjepa_pca_{img_name}.jpg"
plt.savefig(filename, dpi=150)
plt.close()
print(f"💾 Saved: {filename}")
def main():
args = parse_args()
image_paths = []
for ext in ['*.jpg', '*.jpeg', '*.png']:
image_paths.extend(glob.glob(os.path.join(args.img_dir, ext)))
if not image_paths:
print("No images found")
return
model, processor, model_name = load_vjepa_model(args.model_size)
for img_path in image_paths:
process_image(img_path, model, processor, model_name, args)
print("Done!")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment