Skip to content

Instantly share code, notes, and snippets.

@padmalcom
Last active October 11, 2025 18:35
Show Gist options
  • Select an option

  • Save padmalcom/87132218d90f161d822b71ce13985bed to your computer and use it in GitHub Desktop.

Select an option

Save padmalcom/87132218d90f161d822b71ce13985bed to your computer and use it in GitHub Desktop.
import cv2
import os
from PIL import Image
import tempfile
from transformers import AutoProcessor, LlavaForConditionalGeneration
import glob
from tqdm import tqdm
import math
import torch
import hashlib
import shutil
import json
import imagehash
# pip install pillow transformers opencv-python tqdm imagehash
# install pytroch depending on your device setup
ROOT_DIR = "/datasets/video/"
VID_EXTENSIONS = ["*.mp4", "*.avi", "*.mpg", "*.mpeg", "*.gif"]
TARGET_DIR = "/datasets/videos_labeled/"
QUESTION_PROMPT = "Generate a caption for this image."
SUMMARIZE_PROMPT = "Summarize all frame descriptions beginning with 'Frame' and the number to one text as if you describe a video. Don't simply list the frames, create a story. This is the text to summarize: "
# Use an uncensored model so as not to limit the conversation text vocabulary, and to avoid questionable rejection messages.
# Apply Detoxify or RLHF/RLAIF before/during training.
MODEL_NAME = "fancyfeast/llama-joycaption-beta-one-hf-llava"
def extract_keyframes_from_video(video_path, output_dir="frames", min_frames=1, max_frames=5, step=30):
frames = []
os.makedirs(output_dir, exist_ok=True)
if (video_path.endswith(".gif")):
with Image.open(video_path) as im:
total_frames = im.n_frames
num_frames = min(total_frames, max_frames)
step = total_frames / num_frames
chosen_indices = [math.floor(i * step) for i in range(num_frames)]
for i, frame_index in enumerate(chosen_indices):
im.seek(frame_index)
frame = im.copy()
frames.append(frame)
else:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Error opening video: {video_path}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
num_frames = min(total_frames, max_frames)
step = total_frames / num_frames
chosen_indices = [math.floor(i * step) for i in range(num_frames)]
for i, frame_index in enumerate(chosen_indices):
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
ret, frame = cap.read()
if not ret:
continue
frames.append(Image.fromarray(frame))
cap.release()
if len(frames) == 0:
print("No frames extracted.")
return []
return frames
def list_videos():
video_files = []
for ext in VID_EXTENSIONS:
search_path = ROOT_DIR +"**/" + ext
video_files.extend([f for f in glob.glob(search_path, recursive=True)])
return video_files
def generate_caption(frame, processor, model):
with torch.no_grad():
convo = [
{
"role": "system",
"content": "You are a helpful image captioner."
},
{
"role": "user",
"content": QUESTION_PROMPT
}
]
convo_string = processor.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
inputs = processor(text=[convo_string], images=frame, return_tensors="pt").to('cuda')
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
generate_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
suppress_tokens=None,
use_cache=True,
temperature=0.6,
top_k=None,
top_p=0.9,
pad_token_id=processor.tokenizer.eos_token_id
)[0]
generate_ids = generate_ids[inputs['input_ids'].shape[1]:]
caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return caption.strip()
def remove_similar_frames(frames, hash_func=imagehash.phash, hash_size=8, threshold=15):
unique_frames = []
seen_hashes = []
for i, frame in enumerate(frames):
h = hash_func(frame, hash_size=hash_size)
if all(abs(h - prev_hash) > threshold for prev_hash in seen_hashes):
unique_frames.append(frame)
seen_hashes.append(h)
return unique_frames
def summarize(text, processor, model):
with torch.no_grad():
convo = [
{
"role": "system",
"content": "You are a helpful image captioner."
},
{
"role": "user",
"content": SUMMARIZE_PROMPT + text
}
]
convo_string = processor.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
inputs = processor(text=[convo_string], images=None, return_tensors="pt").to('cuda')
generate_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
suppress_tokens=None,
use_cache=True,
temperature=0.6,
top_k=None,
top_p=0.9,
pad_token_id=processor.tokenizer.eos_token_id
)[0]
generate_ids = generate_ids[inputs['input_ids'].shape[1]:]
caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return caption.strip()
if __name__ == "__main__":
processor = AutoProcessor.from_pretrained(MODEL_NAME)
llava_model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype="bfloat16", device_map=0)
llava_model.eval()
videos = list_videos()
print("There are", len(videos), "videos.")
if not os.path.exists(TARGET_DIR):
os.makedirs(TARGET_DIR, exist_ok=True)
for v in tqdm(videos):
with open(v, "rb") as f:
digest = hashlib.file_digest(f, "sha256").hexdigest()
if (digest is None):
print("Could not calculate hash for file", video_path)
continue
if os.path.exists(os.path.join(TARGET_DIR, str(digest) + '.json')):
continue
with tempfile.TemporaryDirectory() as tmpdirname:
frames = extract_keyframes_from_video(v, output_dir=tmpdirname)
frames = remove_similar_frames(frames)
text = ""
if (len(frames) > 0):
# get description for single frames and summarize
for i, frame in enumerate(frames):
text +=f"Frame{i}:" + generate_caption(frame, digest, v, processor=processor, model=llava_model) + " "
summary = summarize(text, processor, model)
# copy video
_, ext = os.path.splitext(v)
target_file_name = os.path.join(TARGET_DIR, str(digest) + ext)
shutil.copy(v, target_file_name)
# save label
label = '{\"videos\":"' + str(digest) + ext + '", "messages": ['
label += '{"role":"user", "content":"<video> Describe this video."},'
label += '{"role":"assistant", "content":' + json.dumps(summary) + '}'
label += ']}'
pretty_label = json.dumps(json.loads(label), indent = 4)
with open(os.path.join(TARGET_DIR, str(digest) + '.json'), "w", encoding="utf-8") as f:
f.write(pretty_label)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment