Skip to content

Instantly share code, notes, and snippets.

@ariG23498
Last active July 19, 2025 14:33
Show Gist options
  • Select an option

  • Save ariG23498/28bccc737c11d1692f6d0ad2a0d7cddb to your computer and use it in GitHub Desktop.

Select an option

Save ariG23498/28bccc737c11d1692f6d0ad2a0d7cddb to your computer and use it in GitHub Desktop.
fine tune
from huggingface_hub import hf_hub_download
import tarfile
dataset_id = "sayakpaul/ucf101-subset"
fname = "UCF101_subset.tar.gz"
fpath = hf_hub_download(repo_id=dataset_id, filename=fname, repo_type="dataset")
with tarfile.open(fpath) as t:
t.extractall(".")
import os
import pathlib
import random
import wandb
from functools import partial
from typing import List
import torch
from torch.utils.data import Dataset, DataLoader
from torchcodec.decoders import VideoDecoder
from torchcodec.samplers import clips_at_random_indices
from torchvision.transforms import v2
from transformers import VJEPA2VideoProcessor, VJEPA2ForVideoClassification
def set_gpu_device(gpu_id: int):
"""Set the CUDA device to use."""
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
class VideoDataset(Dataset):
"""Custom dataset for loading video files and labels."""
def __init__(self, file_paths: List[pathlib.Path], label_map: dict):
self.file_paths = file_paths
self.label_map = label_map
def __len__(self) -> int:
return len(self.file_paths)
def __getitem__(self, idx: int):
path = self.file_paths[idx]
decoder = VideoDecoder(path)
label = self.label_map[path.parts[2]]
return decoder, label
def video_collate_fn(
samples: List[tuple], frames_per_clip: int, transforms: v2.Compose
):
"""Sample clips and apply transforms to a batch."""
clips, labels = [], []
for decoder, lbl in samples:
clip = clips_at_random_indices(
decoder,
num_clips=1,
num_frames_per_clip=frames_per_clip,
num_indices_between_frames=3,
).data
clips.append(clip)
labels.append(lbl)
videos = torch.cat(clips, dim=0)
videos = transforms(videos)
return videos, torch.tensor(labels)
def split_dataset(paths: List[pathlib.Path]):
"""Split paths into train/val/test based on directory keyword."""
train, val, test = [], [], []
for p in paths:
if "train" in p.parts:
train.append(p)
elif "val" in p.parts:
val.append(p)
elif "test" in p.parts:
test.append(p)
else:
raise ValueError(f"Unrecognized split for path: {p}")
return train, val, test
def evaluate(
loader: DataLoader,
model: VJEPA2ForVideoClassification,
processor: VJEPA2VideoProcessor,
device: torch.device,
) -> float:
"""Compute accuracy over a dataset."""
model.eval()
correct, total = 0, 0
with torch.no_grad():
for vids, labels in loader:
inputs = processor(vids, return_tensors="pt").to(device)
labels = labels.to(device)
logits = model(**inputs).logits
preds = logits.argmax(-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
return correct / total if total > 0 else 0.0
def main():
# Configuration
gpu_id = 1
batch_size = 2
num_workers = 8
lr = 1e-5
num_epochs = 5
accumulation_steps = 4
set_gpu_device(gpu_id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load and shuffle data paths
root = pathlib.Path("UCF101_subset")
all_paths = list(root.rglob("*.avi"))
random.shuffle(all_paths)
train_paths, val_paths, test_paths = split_dataset(all_paths)
print(f"Splits -> train: {len(train_paths)}, val: {len(val_paths)}, test: {len(test_paths)}")
# Label mappings
classes = sorted({p.parts[2] for p in all_paths})
label2id = {c: i for i, c in enumerate(classes)}
id2label = {i: c for c, i in label2id.items()}
# Datasets
train_ds = VideoDataset(train_paths, label2id)
val_ds = VideoDataset(val_paths, label2id)
test_ds = VideoDataset(test_paths, label2id)
# Model & processor
model_name = "qubvel-hf/vjepa2-vitl-fpc16-256-ssv2"
processor = VJEPA2VideoProcessor.from_pretrained(model_name)
model = VJEPA2ForVideoClassification.from_pretrained(
model_name,
torch_dtype=torch.float32,
device_map=device,
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=True,
)
# Transforms
train_transforms = v2.Compose([
v2.RandomResizedCrop((processor.crop_size["height"], processor.crop_size["width"])),
v2.RandomHorizontalFlip(),
])
eval_transforms = v2.Compose([
v2.CenterCrop((processor.crop_size["height"], processor.crop_size["width"]))
])
# DataLoaders
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
collate_fn=partial(video_collate_fn, frames_per_clip=model.config.frames_per_clip, transforms=train_transforms),
num_workers=num_workers,
pin_memory=True,
)
val_loader = DataLoader(
val_ds,
batch_size=batch_size,
shuffle=False,
collate_fn=partial(video_collate_fn, frames_per_clip=model.config.frames_per_clip, transforms=eval_transforms),
num_workers=num_workers,
pin_memory=True,
)
test_loader = DataLoader(
test_ds,
batch_size=batch_size,
shuffle=False,
collate_fn=partial(video_collate_fn, frames_per_clip=model.config.frames_per_clip, transforms=eval_transforms),
num_workers=num_workers,
pin_memory=True,
)
# Freeze base encoder
for param in model.vjepa2.parameters():
param.requires_grad = False
# Optimizer and loss
trainable = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(trainable, lr=lr)
run = wandb.init(
entity="ariG23498",
project="fine tune vjep2",
)
# Training loop with gradient accumulation and evaluation
for epoch in range(1, num_epochs + 1):
model.train()
running_loss = 0.0
optimizer.zero_grad()
for step, (vids, labels) in enumerate(train_loader, start=1):
inputs = processor(vids, return_tensors="pt").to(device)
labels = labels.to(device)
outputs = model(**inputs, labels=labels)
loss = outputs.loss / accumulation_steps
loss.backward()
running_loss += loss.item()
if step % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
print(f"Epoch {epoch} Step {step}: Accumulated Loss = {running_loss:.4f}")
run.log({"train/loss": loss})
running_loss = 0.0
# End of epoch evaluation
val_acc = evaluate(val_loader, model, processor, device)
print(f"Epoch {epoch} Validation Accuracy: {val_acc:.4f}")
run.log({"val/accuracy": val_acc})
# Final test evaluation
test_acc = evaluate(test_loader, model, processor, device)
print(f"Final Test Accuracy: {test_acc:.4f}")
run.log({"test/accuracy": test_acc})
run.finish()
# Upload the model to Hub
model.push_to_hub("ariG23498/vjepa2-vitl-fpc16-256-ssv2-uvf101")
processor.push_to_hub("ariG23498/vjepa2-vitl-fpc16-256-ssv2-uvf101")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment