Skip to content

Instantly share code, notes, and snippets.

@ofou
Created January 23, 2026 23:37
Show Gist options
  • Select an option

  • Save ofou/7f5105650296f96611d6b39de86bffda to your computer and use it in GitHub Desktop.

Select an option

Save ofou/7f5105650296f96611d6b39de86bffda to your computer and use it in GitHub Desktop.
Run Audio Flamingo 3 in MacOS
#!/usr/bin/env python3
"""
Audio Flamingo 3 transcription script with chunking at silences.
Supports short/long audio files, various prompts, and MPS/CUDA/CPU.
https://huggingface.co/nvidia/audio-flamingo-3-hf
"""
import os
import tempfile
import numpy as np
import soundfile as sf
import torch
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
from typing import Dict, List, Tuple, Optional
# Configuration constants
CHUNK_DURATION: float = 10.0
SILENCE_SEARCH_WINDOW: float = 1.0
MIN_SILENCE_DURATION: float = 0.1
PROMPTS: Dict[str, str] = {
"transcribe": "Transcribe the input speech.",
"transcribe_detailed": "Transcribe the input speech with proper capitalization and punctuation.",
"caption_detailed": (
"Generate a detailed caption for the input audio, describing all notable speech, "
"sound, and musical events comprehensively. In the caption, transcribe all spoken "
"content by all speakers in the audio precisely."
),
"caption": "Describe the audio in detail, including any speech, sounds, or music you hear.",
"reasoning": "Please think and reason about the input audio before you respond.",
"qa_emotion": "How does the tone of speech change throughout the audio?",
"qa_sounds": "What sounds are present in this audio?",
"qa_music": "What elements make this music feel the way it does?",
}
def get_device() -> Tuple[torch.device, torch.dtype]:
"""Detect optimal device and dtype."""
if torch.cuda.is_available():
return torch.device("cuda"), torch.float16
if torch.backends.mps.is_available():
return torch.device("mps"), torch.float16
return torch.device("cpu"), torch.float16
def move_to_device(inputs: Dict, device: torch.device, dtype: torch.dtype) -> Dict:
"""Move inputs to device with appropriate dtype handling."""
processed = {}
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
if k == "input_ids":
processed[k] = v.to(device)
else:
processed[k] = v.to(device=device, dtype=dtype)
else:
processed[k] = v
return processed
def transcribe_audio(
processor: AutoProcessor,
model: AudioFlamingo3ForConditionalGeneration,
audio_path: str,
use_deterministic: bool = True,
) -> str:
"""Transcribe single audio file using transcription request."""
device, dtype = model.device, next(model.parameters()).dtype
inputs = processor.apply_transcription_request(audio=audio_path)
processed_inputs = move_to_device(inputs, device, dtype)
generate_kwargs = {
"max_new_tokens": 256,
"do_sample": False,
}
if not use_deterministic:
generate_kwargs.update(
do_sample=True,
temperature=0.7,
top_p=0.9,
)
outputs = model.generate(**processed_inputs, **generate_kwargs)
decoded = processor.batch_decode(
outputs[:, processed_inputs["input_ids"].shape[1] :],
skip_special_tokens=True,
strip_prefix=True,
)
return decoded[0].strip()
def process_audio_with_prompt(
processor: AutoProcessor,
model: AudioFlamingo3ForConditionalGeneration,
audio_path: str,
prompt: str,
strip_prefix: Optional[bool] = None,
) -> str:
"""Process audio with custom chat prompt."""
if strip_prefix is None:
prompt_lower = prompt.lower()
strip_prefix = "transcribe" in prompt_lower or prompt in [
PROMPTS["transcribe"],
PROMPTS["transcribe_detailed"],
]
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "audio", "path": audio_path},
],
}
]
device, dtype = model.device, next(model.parameters()).dtype
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
)
processed_inputs = move_to_device(inputs, device, dtype)
generate_kwargs = {
"max_new_tokens": 256,
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
}
outputs = model.generate(**processed_inputs, **generate_kwargs)
decoded = processor.batch_decode(
outputs[:, processed_inputs["input_ids"].shape[1] :],
skip_special_tokens=True,
strip_prefix=strip_prefix,
)
return decoded[0].strip()
def find_silence_point(
audio_data: np.ndarray,
sample_rate: int,
target_sample: int,
search_window_samples: int,
) -> int:
"""Find nearest silence point to target_sample within search window."""
search_start = max(0, target_sample - search_window_samples)
search_end = min(len(audio_data), target_sample + search_window_samples)
if search_start >= search_end:
return target_sample
segment = audio_data[search_start:search_end]
window_size = max(1, int(MIN_SILENCE_DURATION * sample_rate))
min_energy = float("inf")
best_offset = search_window_samples
step = window_size // 4
for i in range(0, len(segment) - window_size + 1, step):
window = segment[i : i + window_size]
energy = np.sqrt(np.mean(window**2))
if energy < min_energy:
min_energy = energy
best_offset = i + window_size // 2
return search_start + best_offset
def find_chunk_boundaries(
audio_data: np.ndarray,
sample_rate: int,
) -> List[Tuple[int, int]]:
"""Split audio into chunks at silence points."""
total_samples = len(audio_data)
chunk_samples = int(CHUNK_DURATION * sample_rate)
search_window_samples = int(SILENCE_SEARCH_WINDOW * sample_rate)
boundaries = []
current_start = 0
while current_start < total_samples:
target_end = current_start + chunk_samples
if target_end >= total_samples:
boundaries.append((current_start, total_samples))
break
actual_end = find_silence_point(
audio_data, sample_rate, target_end, search_window_samples
)
min_end = current_start + chunk_samples // 2
actual_end = max(actual_end, min_end)
boundaries.append((current_start, actual_end))
current_start = actual_end
return boundaries
def transcribe_long_audio(
processor: AutoProcessor,
model: AudioFlamingo3ForConditionalGeneration,
audio_file: str,
) -> str:
"""Transcribe long audio by chunking at silences."""
audio_data, sample_rate = sf.read(audio_file)
if audio_data.ndim > 1:
audio_data = audio_data.mean(axis=1)
total_duration = len(audio_data) / sample_rate
print(f"Audio duration: {total_duration:.1f}s")
if total_duration <= CHUNK_DURATION:
return transcribe_audio(processor, model, audio_file)
boundaries = find_chunk_boundaries(audio_data, sample_rate)
num_chunks = len(boundaries)
print(
f"Processing {num_chunks} chunks (cutting at silences to avoid word breaks)..."
)
transcriptions = []
temp_files = []
try:
for i, (start_sample, end_sample) in enumerate(boundaries):
chunk_data = audio_data[start_sample:end_sample]
chunk_duration = len(chunk_data) / sample_rate
start_time = start_sample / sample_rate
end_time = end_sample / sample_rate
print(
f"\n[Chunk {i+1}/{num_chunks}] {start_time:.1f}s - {end_time:.1f}s ({chunk_duration:.1f}s)"
)
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
sf.write(temp_file.name, chunk_data, sample_rate)
temp_files.append(temp_file.name)
transcription = transcribe_audio(processor, model, temp_file.name)
transcriptions.append(transcription)
print(transcription)
finally:
for temp_file in temp_files:
if os.path.exists(temp_file):
os.unlink(temp_file)
return " ".join(transcriptions)
def main() -> None:
"""Run demo with model loading and examples."""
model_id = "nvidia/audio-flamingo-3-hf"
device, dtype = get_device()
print(f"Loading Audio Flamingo 3 from HuggingFace...")
print(f"Using device: {device} (dtype: {dtype})")
processor = AutoProcessor.from_pretrained(model_id)
if device.type == "mps":
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(
model_id, torch_dtype=dtype, low_cpu_mem_usage=True
)
model = model.to(device)
else:
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=dtype,
low_cpu_mem_usage=True,
)
print("✓ Model loaded!")
print(f"Model device: {next(model.parameters()).device}")
# Example 1: URL transcription
print("\n" + "=" * 80)
print("Example 1: Transcribe from URL")
print("=" * 80)
url = "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/WhDJDIviAOg_120_10.mp3"
result = transcribe_audio(processor, model, url)
print(result)
# Example 2: Local long file
print("\n" + "=" * 80)
print("Example 2: Transcribe local file")
print("=" * 80)
local_wav_files = [f for f in os.listdir(".") if f.endswith(".wav")]
if local_wav_files:
audio_file = local_wav_files[0]
print(f"Using: {audio_file}")
result = transcribe_long_audio(processor, model, audio_file)
print("\n" + "=" * 80)
print("Full Transcription:")
print("=" * 80)
print(result)
else:
print("No .wav files found in current directory.")
# Example 3: Prompt variations
print("\n" + "=" * 80)
print("Example 3: Audio understanding with prompts")
print("=" * 80)
if local_wav_files:
audio_file = local_wav_files[0]
prompts_to_try = [
("Detailed Caption", PROMPTS["caption_detailed"]),
("Simple Caption", PROMPTS["caption"]),
("Reasoning Mode", PROMPTS["reasoning"] + " " + PROMPTS["caption"]),
]
for name, text in prompts_to_try:
print(f"\n--- {name} ---")
print(f"Prompt: {text[:80]}...")
try:
result = process_audio_with_prompt(processor, model, audio_file, text)
preview = result[:200] + "..." if len(result) > 200 else result
print(f"Result: {preview}")
except Exception as e:
print(f"Error: {e}")
break
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment