Skip to content

Instantly share code, notes, and snippets.

@jsbeaudry
Created October 18, 2025 18:23
Show Gist options
  • Select an option

  • Save jsbeaudry/e0465bff190530862f884ac43872e33d to your computer and use it in GitHub Desktop.

Select an option

Save jsbeaudry/e0465bff190530862f884ac43872e33d to your computer and use it in GitHub Desktop.
import json
import warnings
import librosa
import soundfile as sf
import numpy as np
from datasets import load_dataset
import mlx.core as mx
from nanocodec_mlx.models.audio_codec import AudioCodecModel
warnings.filterwarnings("ignore", category=UserWarning)
# ---------------- Config ----------------
user = "hf_username"
repo = "hf_repo"
dataset_name = f"{user}/{repo}" # Hugging Face dataset name
split = "train" # Dataset split (train/test/validation)
output_json = f"{user}-{repo}-{split}-nemo-encoded.json"
# ---------------- Load Dataset ----------------
print(f"πŸ“¦ Loading dataset: {dataset_name}")
dataset = load_dataset(dataset_name, split=split)
print(f"πŸ“„ Loaded dataset with {len(dataset)} samples")
# ---------------- Load MLX Codec Model ----------------
print("πŸ”„ Loading MLX codec model...")
model = AudioCodecModel.from_pretrained(
"nineninesix/nemo-nano-codec-22khz-0.6kbps-12.5fps-MLX"
)
sample_rate = 22050 # fixed model sample rate
print(f"βœ… Model loaded (sample rate: {sample_rate}Hz)")
# ---------------- Helper Function ----------------
def load_audio_from_array(audio_dict):
"""Load and resample audio from HF dataset format."""
audio_array = audio_dict["array"]
sr = audio_dict["sampling_rate"]
if sr != sample_rate:
print(f"βš™οΈ Resampling from {sr}Hz β†’ {sample_rate}Hz")
audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=sample_rate)
return audio_array.astype(np.float32)
# ---------------- Encode Function ----------------
def encode_audio(audio_dict):
"""Encode a single audio sample into discrete nano layers using MLX."""
audio = load_audio_from_array(audio_dict)
# Convert to MLX array [B, C, T]
audio_mlx = mx.array(audio, dtype=mx.float32)[None, None, :]
audio_len = mx.array([audio_mlx.shape[-1]], dtype=mx.int32)
# Encode
tokens, tokens_len = model.encode(audio_mlx, audio_len)
# Convert MLX tensors β†’ numpy lists
tokens_np = np.array(tokens)
encoded_tokens = tokens_np.squeeze(0).astype(int).tolist()
data = {
"nano_layer_1": encoded_tokens[0],
"nano_layer_2": encoded_tokens[1],
"nano_layer_3": encoded_tokens[2],
"nano_layer_4": encoded_tokens[3],
"encoded_len": int(np.array(tokens_len)[0]),
}
return data
# ---------------- Encode All Samples ----------------
results = []
for idx, sample in enumerate(dataset):
try:
print(f"\n🎧 Encoding sample {idx + 1}/{len(dataset)}")
encoded = encode_audio(sample["audio"])
# Add metadata if available
encoded["text"] = sample.get("text", "")
encoded["speaker"] = sample.get("speaker_id", "anon") #Change the speaker id base on your dataset
results.append(encoded)
print(f"βœ… Encoded sample {idx + 1} ({encoded['encoded_len']} frames)")
except Exception as e:
print(f"❌ Error encoding sample {idx + 1}: {e}")
# ---------------- Save to JSON ----------------
with open(output_json, "w") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"\nπŸ’Ύ Encoded data saved to: {output_json}")
print(f"πŸ“Š Successfully encoded {len(results)}/{len(dataset)} samples")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment