Created
January 23, 2026 23:37
-
-
Save ofou/7f5105650296f96611d6b39de86bffda to your computer and use it in GitHub Desktop.
Run Audio Flamingo 3 in MacOS
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Audio Flamingo 3 transcription script with chunking at silences. | |
| Supports short/long audio files, various prompts, and MPS/CUDA/CPU. | |
| https://huggingface.co/nvidia/audio-flamingo-3-hf | |
| """ | |
| import os | |
| import tempfile | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor | |
| from typing import Dict, List, Tuple, Optional | |
| # Configuration constants | |
| CHUNK_DURATION: float = 10.0 | |
| SILENCE_SEARCH_WINDOW: float = 1.0 | |
| MIN_SILENCE_DURATION: float = 0.1 | |
| PROMPTS: Dict[str, str] = { | |
| "transcribe": "Transcribe the input speech.", | |
| "transcribe_detailed": "Transcribe the input speech with proper capitalization and punctuation.", | |
| "caption_detailed": ( | |
| "Generate a detailed caption for the input audio, describing all notable speech, " | |
| "sound, and musical events comprehensively. In the caption, transcribe all spoken " | |
| "content by all speakers in the audio precisely." | |
| ), | |
| "caption": "Describe the audio in detail, including any speech, sounds, or music you hear.", | |
| "reasoning": "Please think and reason about the input audio before you respond.", | |
| "qa_emotion": "How does the tone of speech change throughout the audio?", | |
| "qa_sounds": "What sounds are present in this audio?", | |
| "qa_music": "What elements make this music feel the way it does?", | |
| } | |
| def get_device() -> Tuple[torch.device, torch.dtype]: | |
| """Detect optimal device and dtype.""" | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda"), torch.float16 | |
| if torch.backends.mps.is_available(): | |
| return torch.device("mps"), torch.float16 | |
| return torch.device("cpu"), torch.float16 | |
| def move_to_device(inputs: Dict, device: torch.device, dtype: torch.dtype) -> Dict: | |
| """Move inputs to device with appropriate dtype handling.""" | |
| processed = {} | |
| for k, v in inputs.items(): | |
| if isinstance(v, torch.Tensor): | |
| if k == "input_ids": | |
| processed[k] = v.to(device) | |
| else: | |
| processed[k] = v.to(device=device, dtype=dtype) | |
| else: | |
| processed[k] = v | |
| return processed | |
| def transcribe_audio( | |
| processor: AutoProcessor, | |
| model: AudioFlamingo3ForConditionalGeneration, | |
| audio_path: str, | |
| use_deterministic: bool = True, | |
| ) -> str: | |
| """Transcribe single audio file using transcription request.""" | |
| device, dtype = model.device, next(model.parameters()).dtype | |
| inputs = processor.apply_transcription_request(audio=audio_path) | |
| processed_inputs = move_to_device(inputs, device, dtype) | |
| generate_kwargs = { | |
| "max_new_tokens": 256, | |
| "do_sample": False, | |
| } | |
| if not use_deterministic: | |
| generate_kwargs.update( | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| ) | |
| outputs = model.generate(**processed_inputs, **generate_kwargs) | |
| decoded = processor.batch_decode( | |
| outputs[:, processed_inputs["input_ids"].shape[1] :], | |
| skip_special_tokens=True, | |
| strip_prefix=True, | |
| ) | |
| return decoded[0].strip() | |
| def process_audio_with_prompt( | |
| processor: AutoProcessor, | |
| model: AudioFlamingo3ForConditionalGeneration, | |
| audio_path: str, | |
| prompt: str, | |
| strip_prefix: Optional[bool] = None, | |
| ) -> str: | |
| """Process audio with custom chat prompt.""" | |
| if strip_prefix is None: | |
| prompt_lower = prompt.lower() | |
| strip_prefix = "transcribe" in prompt_lower or prompt in [ | |
| PROMPTS["transcribe"], | |
| PROMPTS["transcribe_detailed"], | |
| ] | |
| conversation = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": prompt}, | |
| {"type": "audio", "path": audio_path}, | |
| ], | |
| } | |
| ] | |
| device, dtype = model.device, next(model.parameters()).dtype | |
| inputs = processor.apply_chat_template( | |
| conversation, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| ) | |
| processed_inputs = move_to_device(inputs, device, dtype) | |
| generate_kwargs = { | |
| "max_new_tokens": 256, | |
| "do_sample": True, | |
| "temperature": 0.7, | |
| "top_p": 0.9, | |
| } | |
| outputs = model.generate(**processed_inputs, **generate_kwargs) | |
| decoded = processor.batch_decode( | |
| outputs[:, processed_inputs["input_ids"].shape[1] :], | |
| skip_special_tokens=True, | |
| strip_prefix=strip_prefix, | |
| ) | |
| return decoded[0].strip() | |
| def find_silence_point( | |
| audio_data: np.ndarray, | |
| sample_rate: int, | |
| target_sample: int, | |
| search_window_samples: int, | |
| ) -> int: | |
| """Find nearest silence point to target_sample within search window.""" | |
| search_start = max(0, target_sample - search_window_samples) | |
| search_end = min(len(audio_data), target_sample + search_window_samples) | |
| if search_start >= search_end: | |
| return target_sample | |
| segment = audio_data[search_start:search_end] | |
| window_size = max(1, int(MIN_SILENCE_DURATION * sample_rate)) | |
| min_energy = float("inf") | |
| best_offset = search_window_samples | |
| step = window_size // 4 | |
| for i in range(0, len(segment) - window_size + 1, step): | |
| window = segment[i : i + window_size] | |
| energy = np.sqrt(np.mean(window**2)) | |
| if energy < min_energy: | |
| min_energy = energy | |
| best_offset = i + window_size // 2 | |
| return search_start + best_offset | |
| def find_chunk_boundaries( | |
| audio_data: np.ndarray, | |
| sample_rate: int, | |
| ) -> List[Tuple[int, int]]: | |
| """Split audio into chunks at silence points.""" | |
| total_samples = len(audio_data) | |
| chunk_samples = int(CHUNK_DURATION * sample_rate) | |
| search_window_samples = int(SILENCE_SEARCH_WINDOW * sample_rate) | |
| boundaries = [] | |
| current_start = 0 | |
| while current_start < total_samples: | |
| target_end = current_start + chunk_samples | |
| if target_end >= total_samples: | |
| boundaries.append((current_start, total_samples)) | |
| break | |
| actual_end = find_silence_point( | |
| audio_data, sample_rate, target_end, search_window_samples | |
| ) | |
| min_end = current_start + chunk_samples // 2 | |
| actual_end = max(actual_end, min_end) | |
| boundaries.append((current_start, actual_end)) | |
| current_start = actual_end | |
| return boundaries | |
| def transcribe_long_audio( | |
| processor: AutoProcessor, | |
| model: AudioFlamingo3ForConditionalGeneration, | |
| audio_file: str, | |
| ) -> str: | |
| """Transcribe long audio by chunking at silences.""" | |
| audio_data, sample_rate = sf.read(audio_file) | |
| if audio_data.ndim > 1: | |
| audio_data = audio_data.mean(axis=1) | |
| total_duration = len(audio_data) / sample_rate | |
| print(f"Audio duration: {total_duration:.1f}s") | |
| if total_duration <= CHUNK_DURATION: | |
| return transcribe_audio(processor, model, audio_file) | |
| boundaries = find_chunk_boundaries(audio_data, sample_rate) | |
| num_chunks = len(boundaries) | |
| print( | |
| f"Processing {num_chunks} chunks (cutting at silences to avoid word breaks)..." | |
| ) | |
| transcriptions = [] | |
| temp_files = [] | |
| try: | |
| for i, (start_sample, end_sample) in enumerate(boundaries): | |
| chunk_data = audio_data[start_sample:end_sample] | |
| chunk_duration = len(chunk_data) / sample_rate | |
| start_time = start_sample / sample_rate | |
| end_time = end_sample / sample_rate | |
| print( | |
| f"\n[Chunk {i+1}/{num_chunks}] {start_time:.1f}s - {end_time:.1f}s ({chunk_duration:.1f}s)" | |
| ) | |
| temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| sf.write(temp_file.name, chunk_data, sample_rate) | |
| temp_files.append(temp_file.name) | |
| transcription = transcribe_audio(processor, model, temp_file.name) | |
| transcriptions.append(transcription) | |
| print(transcription) | |
| finally: | |
| for temp_file in temp_files: | |
| if os.path.exists(temp_file): | |
| os.unlink(temp_file) | |
| return " ".join(transcriptions) | |
| def main() -> None: | |
| """Run demo with model loading and examples.""" | |
| model_id = "nvidia/audio-flamingo-3-hf" | |
| device, dtype = get_device() | |
| print(f"Loading Audio Flamingo 3 from HuggingFace...") | |
| print(f"Using device: {device} (dtype: {dtype})") | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| if device.type == "mps": | |
| model = AudioFlamingo3ForConditionalGeneration.from_pretrained( | |
| model_id, torch_dtype=dtype, low_cpu_mem_usage=True | |
| ) | |
| model = model.to(device) | |
| else: | |
| model = AudioFlamingo3ForConditionalGeneration.from_pretrained( | |
| model_id, | |
| device_map="auto", | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| ) | |
| print("✓ Model loaded!") | |
| print(f"Model device: {next(model.parameters()).device}") | |
| # Example 1: URL transcription | |
| print("\n" + "=" * 80) | |
| print("Example 1: Transcribe from URL") | |
| print("=" * 80) | |
| url = "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/WhDJDIviAOg_120_10.mp3" | |
| result = transcribe_audio(processor, model, url) | |
| print(result) | |
| # Example 2: Local long file | |
| print("\n" + "=" * 80) | |
| print("Example 2: Transcribe local file") | |
| print("=" * 80) | |
| local_wav_files = [f for f in os.listdir(".") if f.endswith(".wav")] | |
| if local_wav_files: | |
| audio_file = local_wav_files[0] | |
| print(f"Using: {audio_file}") | |
| result = transcribe_long_audio(processor, model, audio_file) | |
| print("\n" + "=" * 80) | |
| print("Full Transcription:") | |
| print("=" * 80) | |
| print(result) | |
| else: | |
| print("No .wav files found in current directory.") | |
| # Example 3: Prompt variations | |
| print("\n" + "=" * 80) | |
| print("Example 3: Audio understanding with prompts") | |
| print("=" * 80) | |
| if local_wav_files: | |
| audio_file = local_wav_files[0] | |
| prompts_to_try = [ | |
| ("Detailed Caption", PROMPTS["caption_detailed"]), | |
| ("Simple Caption", PROMPTS["caption"]), | |
| ("Reasoning Mode", PROMPTS["reasoning"] + " " + PROMPTS["caption"]), | |
| ] | |
| for name, text in prompts_to_try: | |
| print(f"\n--- {name} ---") | |
| print(f"Prompt: {text[:80]}...") | |
| try: | |
| result = process_audio_with_prompt(processor, model, audio_file, text) | |
| preview = result[:200] + "..." if len(result) > 200 else result | |
| print(f"Result: {preview}") | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| break | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment