Last active
July 19, 2025 14:33
-
-
Save ariG23498/28bccc737c11d1692f6d0ad2a0d7cddb to your computer and use it in GitHub Desktop.
fine tune
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from huggingface_hub import hf_hub_download | |
| import tarfile | |
| dataset_id = "sayakpaul/ucf101-subset" | |
| fname = "UCF101_subset.tar.gz" | |
| fpath = hf_hub_download(repo_id=dataset_id, filename=fname, repo_type="dataset") | |
| with tarfile.open(fpath) as t: | |
| t.extractall(".") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import pathlib | |
| import random | |
| import wandb | |
| from functools import partial | |
| from typing import List | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchcodec.decoders import VideoDecoder | |
| from torchcodec.samplers import clips_at_random_indices | |
| from torchvision.transforms import v2 | |
| from transformers import VJEPA2VideoProcessor, VJEPA2ForVideoClassification | |
| def set_gpu_device(gpu_id: int): | |
| """Set the CUDA device to use.""" | |
| os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) | |
| class VideoDataset(Dataset): | |
| """Custom dataset for loading video files and labels.""" | |
| def __init__(self, file_paths: List[pathlib.Path], label_map: dict): | |
| self.file_paths = file_paths | |
| self.label_map = label_map | |
| def __len__(self) -> int: | |
| return len(self.file_paths) | |
| def __getitem__(self, idx: int): | |
| path = self.file_paths[idx] | |
| decoder = VideoDecoder(path) | |
| label = self.label_map[path.parts[2]] | |
| return decoder, label | |
| def video_collate_fn( | |
| samples: List[tuple], frames_per_clip: int, transforms: v2.Compose | |
| ): | |
| """Sample clips and apply transforms to a batch.""" | |
| clips, labels = [], [] | |
| for decoder, lbl in samples: | |
| clip = clips_at_random_indices( | |
| decoder, | |
| num_clips=1, | |
| num_frames_per_clip=frames_per_clip, | |
| num_indices_between_frames=3, | |
| ).data | |
| clips.append(clip) | |
| labels.append(lbl) | |
| videos = torch.cat(clips, dim=0) | |
| videos = transforms(videos) | |
| return videos, torch.tensor(labels) | |
| def split_dataset(paths: List[pathlib.Path]): | |
| """Split paths into train/val/test based on directory keyword.""" | |
| train, val, test = [], [], [] | |
| for p in paths: | |
| if "train" in p.parts: | |
| train.append(p) | |
| elif "val" in p.parts: | |
| val.append(p) | |
| elif "test" in p.parts: | |
| test.append(p) | |
| else: | |
| raise ValueError(f"Unrecognized split for path: {p}") | |
| return train, val, test | |
| def evaluate( | |
| loader: DataLoader, | |
| model: VJEPA2ForVideoClassification, | |
| processor: VJEPA2VideoProcessor, | |
| device: torch.device, | |
| ) -> float: | |
| """Compute accuracy over a dataset.""" | |
| model.eval() | |
| correct, total = 0, 0 | |
| with torch.no_grad(): | |
| for vids, labels in loader: | |
| inputs = processor(vids, return_tensors="pt").to(device) | |
| labels = labels.to(device) | |
| logits = model(**inputs).logits | |
| preds = logits.argmax(-1) | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| return correct / total if total > 0 else 0.0 | |
| def main(): | |
| # Configuration | |
| gpu_id = 1 | |
| batch_size = 2 | |
| num_workers = 8 | |
| lr = 1e-5 | |
| num_epochs = 5 | |
| accumulation_steps = 4 | |
| set_gpu_device(gpu_id) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load and shuffle data paths | |
| root = pathlib.Path("UCF101_subset") | |
| all_paths = list(root.rglob("*.avi")) | |
| random.shuffle(all_paths) | |
| train_paths, val_paths, test_paths = split_dataset(all_paths) | |
| print(f"Splits -> train: {len(train_paths)}, val: {len(val_paths)}, test: {len(test_paths)}") | |
| # Label mappings | |
| classes = sorted({p.parts[2] for p in all_paths}) | |
| label2id = {c: i for i, c in enumerate(classes)} | |
| id2label = {i: c for c, i in label2id.items()} | |
| # Datasets | |
| train_ds = VideoDataset(train_paths, label2id) | |
| val_ds = VideoDataset(val_paths, label2id) | |
| test_ds = VideoDataset(test_paths, label2id) | |
| # Model & processor | |
| model_name = "qubvel-hf/vjepa2-vitl-fpc16-256-ssv2" | |
| processor = VJEPA2VideoProcessor.from_pretrained(model_name) | |
| model = VJEPA2ForVideoClassification.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float32, | |
| device_map=device, | |
| label2id=label2id, | |
| id2label=id2label, | |
| ignore_mismatched_sizes=True, | |
| ) | |
| # Transforms | |
| train_transforms = v2.Compose([ | |
| v2.RandomResizedCrop((processor.crop_size["height"], processor.crop_size["width"])), | |
| v2.RandomHorizontalFlip(), | |
| ]) | |
| eval_transforms = v2.Compose([ | |
| v2.CenterCrop((processor.crop_size["height"], processor.crop_size["width"])) | |
| ]) | |
| # DataLoaders | |
| train_loader = DataLoader( | |
| train_ds, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| collate_fn=partial(video_collate_fn, frames_per_clip=model.config.frames_per_clip, transforms=train_transforms), | |
| num_workers=num_workers, | |
| pin_memory=True, | |
| ) | |
| val_loader = DataLoader( | |
| val_ds, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| collate_fn=partial(video_collate_fn, frames_per_clip=model.config.frames_per_clip, transforms=eval_transforms), | |
| num_workers=num_workers, | |
| pin_memory=True, | |
| ) | |
| test_loader = DataLoader( | |
| test_ds, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| collate_fn=partial(video_collate_fn, frames_per_clip=model.config.frames_per_clip, transforms=eval_transforms), | |
| num_workers=num_workers, | |
| pin_memory=True, | |
| ) | |
| # Freeze base encoder | |
| for param in model.vjepa2.parameters(): | |
| param.requires_grad = False | |
| # Optimizer and loss | |
| trainable = [p for p in model.parameters() if p.requires_grad] | |
| optimizer = torch.optim.Adam(trainable, lr=lr) | |
| run = wandb.init( | |
| entity="ariG23498", | |
| project="fine tune vjep2", | |
| ) | |
| # Training loop with gradient accumulation and evaluation | |
| for epoch in range(1, num_epochs + 1): | |
| model.train() | |
| running_loss = 0.0 | |
| optimizer.zero_grad() | |
| for step, (vids, labels) in enumerate(train_loader, start=1): | |
| inputs = processor(vids, return_tensors="pt").to(device) | |
| labels = labels.to(device) | |
| outputs = model(**inputs, labels=labels) | |
| loss = outputs.loss / accumulation_steps | |
| loss.backward() | |
| running_loss += loss.item() | |
| if step % accumulation_steps == 0: | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| print(f"Epoch {epoch} Step {step}: Accumulated Loss = {running_loss:.4f}") | |
| run.log({"train/loss": loss}) | |
| running_loss = 0.0 | |
| # End of epoch evaluation | |
| val_acc = evaluate(val_loader, model, processor, device) | |
| print(f"Epoch {epoch} Validation Accuracy: {val_acc:.4f}") | |
| run.log({"val/accuracy": val_acc}) | |
| # Final test evaluation | |
| test_acc = evaluate(test_loader, model, processor, device) | |
| print(f"Final Test Accuracy: {test_acc:.4f}") | |
| run.log({"test/accuracy": test_acc}) | |
| run.finish() | |
| # Upload the model to Hub | |
| model.push_to_hub("ariG23498/vjepa2-vitl-fpc16-256-ssv2-uvf101") | |
| processor.push_to_hub("ariG23498/vjepa2-vitl-fpc16-256-ssv2-uvf101") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment