Skip to content

Instantly share code, notes, and snippets.

@senstella
Created May 6, 2025 14:04
Show Gist options
  • Select an option

  • Save senstella/77178bb5d6ec67bf8c54705a5f490bed to your computer and use it in GitHub Desktop.

Select an option

Save senstella/77178bb5d6ec67bf8c54705a5f490bed to your computer and use it in GitHub Desktop.
A simple script to convert NeMo Parakeet weights to MLX.
import torch
from safetensors.torch import save_file
INPUT_NAME = "model_weights.ckpt"
OUTPUT_NAME = "model.safetensors"
state = torch.load(INPUT_NAME, map_location="cpu")
new_state = {}
for key, value in state.items():
if key.startswith("preprocessor"): continue
if 'num_batches_tracked' in key: continue
if 'conv' in key or 'ctc_decoder' in key or key == "decoder.decoder_layers.0.weight":
if len(value.shape) == 4:
value = value.permute((0, 2, 3, 1))
elif len(value.shape) == 3:
value = value.permute((0, 2, 1))
if 'weight_ih_l' in key:
key = key.replace('weight_ih_l', '') + '.Wx'
if 'weight_hh_l' in key:
key = key.replace('weight_hh_l', '') + '.Wh'
if 'bias_ih_l' in key or 'bias_hh_l' in key:
key = key.replace('bias_ih_l', '').replace('bias_hh_l', '') + '.bias'
new_state[key] = value if new_state.get(key) is None else value + new_state[key]
else:
new_state[key] = value
save_file(new_state, OUTPUT_NAME)
@senstella
Copy link
Author

Would this script work to convert canary?

Encoder part would work correctly, but decoder part is completely different (Canary is transformer decoder like Whisper) and won't work!

@itsklimov
Copy link

itsklimov commented Oct 8, 2025

I am new to conversions, how are you figuring out the conversion parameters?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment