Last active
October 11, 2025 18:35
-
-
Save padmalcom/87132218d90f161d822b71ce13985bed to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import cv2 | |
| import os | |
| from PIL import Image | |
| import tempfile | |
| from transformers import AutoProcessor, LlavaForConditionalGeneration | |
| import glob | |
| from tqdm import tqdm | |
| import math | |
| import torch | |
| import hashlib | |
| import shutil | |
| import json | |
| import imagehash | |
| # pip install pillow transformers opencv-python tqdm imagehash | |
| # install pytroch depending on your device setup | |
| ROOT_DIR = "/datasets/video/" | |
| VID_EXTENSIONS = ["*.mp4", "*.avi", "*.mpg", "*.mpeg", "*.gif"] | |
| TARGET_DIR = "/datasets/videos_labeled/" | |
| QUESTION_PROMPT = "Generate a caption for this image." | |
| SUMMARIZE_PROMPT = "Summarize all frame descriptions beginning with 'Frame' and the number to one text as if you describe a video. Don't simply list the frames, create a story. This is the text to summarize: " | |
| # Use an uncensored model so as not to limit the conversation text vocabulary, and to avoid questionable rejection messages. | |
| # Apply Detoxify or RLHF/RLAIF before/during training. | |
| MODEL_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava" | |
| def extract_keyframes_from_video(video_path, output_dir="frames", min_frames=1, max_frames=5, step=30): | |
| frames = [] | |
| os.makedirs(output_dir, exist_ok=True) | |
| if (video_path.endswith(".gif")): | |
| with Image.open(video_path) as im: | |
| total_frames = im.n_frames | |
| num_frames = min(total_frames, max_frames) | |
| step = total_frames / num_frames | |
| chosen_indices = [math.floor(i * step) for i in range(num_frames)] | |
| for i, frame_index in enumerate(chosen_indices): | |
| im.seek(frame_index) | |
| frame = im.copy() | |
| frames.append(frame) | |
| else: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise IOError(f"Error opening video: {video_path}") | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| num_frames = min(total_frames, max_frames) | |
| step = total_frames / num_frames | |
| chosen_indices = [math.floor(i * step) for i in range(num_frames)] | |
| for i, frame_index in enumerate(chosen_indices): | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) | |
| ret, frame = cap.read() | |
| if not ret: | |
| continue | |
| frames.append(Image.fromarray(frame)) | |
| cap.release() | |
| if len(frames) == 0: | |
| print("No frames extracted.") | |
| return [] | |
| return frames | |
| def list_videos(): | |
| video_files = [] | |
| for ext in VID_EXTENSIONS: | |
| search_path = ROOT_DIR +"**/" + ext | |
| video_files.extend([f for f in glob.glob(search_path, recursive=True)]) | |
| return video_files | |
| def generate_caption(frame, processor, model): | |
| with torch.no_grad(): | |
| convo = [ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful image captioner." | |
| }, | |
| { | |
| "role": "user", | |
| "content": QUESTION_PROMPT | |
| } | |
| ] | |
| convo_string = processor.apply_chat_template(convo, tokenize = False, add_generation_prompt = True) | |
| inputs = processor(text=[convo_string], images=frame, return_tensors="pt").to('cuda') | |
| inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16) | |
| generate_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| suppress_tokens=None, | |
| use_cache=True, | |
| temperature=0.6, | |
| top_k=None, | |
| top_p=0.9, | |
| pad_token_id=processor.tokenizer.eos_token_id | |
| )[0] | |
| generate_ids = generate_ids[inputs['input_ids'].shape[1]:] | |
| caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| return caption.strip() | |
| def remove_similar_frames(frames, hash_func=imagehash.phash, hash_size=8, threshold=15): | |
| unique_frames = [] | |
| seen_hashes = [] | |
| for i, frame in enumerate(frames): | |
| h = hash_func(frame, hash_size=hash_size) | |
| if all(abs(h - prev_hash) > threshold for prev_hash in seen_hashes): | |
| unique_frames.append(frame) | |
| seen_hashes.append(h) | |
| return unique_frames | |
| def summarize(text, processor, model): | |
| with torch.no_grad(): | |
| convo = [ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful image captioner." | |
| }, | |
| { | |
| "role": "user", | |
| "content": SUMMARIZE_PROMPT + text | |
| } | |
| ] | |
| convo_string = processor.apply_chat_template(convo, tokenize = False, add_generation_prompt = True) | |
| inputs = processor(text=[convo_string], images=None, return_tensors="pt").to('cuda') | |
| generate_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| suppress_tokens=None, | |
| use_cache=True, | |
| temperature=0.6, | |
| top_k=None, | |
| top_p=0.9, | |
| pad_token_id=processor.tokenizer.eos_token_id | |
| )[0] | |
| generate_ids = generate_ids[inputs['input_ids'].shape[1]:] | |
| caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| return caption.strip() | |
| if __name__ == "__main__": | |
| processor = AutoProcessor.from_pretrained(MODEL_NAME) | |
| llava_model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype="bfloat16", device_map=0) | |
| llava_model.eval() | |
| videos = list_videos() | |
| print("There are", len(videos), "videos.") | |
| if not os.path.exists(TARGET_DIR): | |
| os.makedirs(TARGET_DIR, exist_ok=True) | |
| for v in tqdm(videos): | |
| with open(v, "rb") as f: | |
| digest = hashlib.file_digest(f, "sha256").hexdigest() | |
| if (digest is None): | |
| print("Could not calculate hash for file", video_path) | |
| continue | |
| if os.path.exists(os.path.join(TARGET_DIR, str(digest) + '.json')): | |
| continue | |
| with tempfile.TemporaryDirectory() as tmpdirname: | |
| frames = extract_keyframes_from_video(v, output_dir=tmpdirname) | |
| frames = remove_similar_frames(frames) | |
| text = "" | |
| if (len(frames) > 0): | |
| # get description for single frames and summarize | |
| for i, frame in enumerate(frames): | |
| text +=f"Frame{i}:" + generate_caption(frame, digest, v, processor=processor, model=llava_model) + " " | |
| summary = summarize(text, processor, model) | |
| # copy video | |
| _, ext = os.path.splitext(v) | |
| target_file_name = os.path.join(TARGET_DIR, str(digest) + ext) | |
| shutil.copy(v, target_file_name) | |
| # save label | |
| label = '{\"videos\":"' + str(digest) + ext + '", "messages": [' | |
| label += '{"role":"user", "content":"<video> Describe this video."},' | |
| label += '{"role":"assistant", "content":' + json.dumps(summary) + '}' | |
| label += ']}' | |
| pretty_label = json.dumps(json.loads(label), indent = 4) | |
| with open(os.path.join(TARGET_DIR, str(digest) + '.json'), "w", encoding="utf-8") as f: | |
| f.write(pretty_label) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment