Created
March 13, 2026 21:12
-
-
Save lucasnewman/00ac4dfa3eee4abeb088793cfe5be946 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/Sources/MLXAudioTTS/Models/Chatterbox/ChatterboxConfig.swift b/Sources/MLXAudioTTS/Models/Chatterbox/ChatterboxConfig.swift | |
| index 38c6dad..35d8c88 100644 | |
| --- a/Sources/MLXAudioTTS/Models/Chatterbox/ChatterboxConfig.swift | |
| +++ b/Sources/MLXAudioTTS/Models/Chatterbox/ChatterboxConfig.swift | |
| @@ -476,7 +476,7 @@ public struct ChatterboxConfiguration: Codable, Sendable { | |
| let altGlobalQuant = try container.decodeIfPresent( | |
| BaseConfiguration.Quantization.self, forKey: .quantizationConfig | |
| ) | |
| - self.quantization = globalQuant ?? altGlobalQuant ?? baseConfig?.quantization | |
| + self.quantization = globalQuant ?? altGlobalQuant ?? baseConfig?.perLayerQuantization?.quantization | |
| self.perLayerQuantization = baseConfig?.perLayerQuantization | |
| } | |
| diff --git a/Sources/MLXAudioTTS/Models/Chatterbox/ChatterboxModel.swift b/Sources/MLXAudioTTS/Models/Chatterbox/ChatterboxModel.swift | |
| index 0cc565c..9c75323 100644 | |
| --- a/Sources/MLXAudioTTS/Models/Chatterbox/ChatterboxModel.swift | |
| +++ b/Sources/MLXAudioTTS/Models/Chatterbox/ChatterboxModel.swift | |
| @@ -407,11 +407,9 @@ public final class ChatterboxModel: Module, SpeechGenerationModel, @unchecked Se | |
| /// - mel2wav.m_source.l_linear.{w,b} (Linear, not Conv1d) | |
| /// - mel2wav.m_source.l_sin_gen.* (no parameters) | |
| static func remapRegularMel2WavKey(_ key: String) -> String { | |
| - var k = key | |
| - | |
| // Strip mel2wav. prefix, process, then re-add | |
| - guard k.hasPrefix("mel2wav.") else { return k } | |
| - let subKey = String(k.dropFirst("mel2wav.".count)) | |
| + guard key.hasPrefix("mel2wav.") else { return key } | |
| + let subKey = String(key.dropFirst("mel2wav.".count)) | |
| // Match all Conv1d/ConvTranspose1d terminal patterns: | |
| // conv_pre.{weight,bias} | |
| @@ -721,7 +719,7 @@ public final class ChatterboxModel: Module, SpeechGenerationModel, @unchecked Se | |
| print("[Chatterbox] Text tokenized: \(textTokens.shape)") | |
| let temperature = generationParameters.temperature | |
| - let topP = generationParameters.topP ?? 0.95 | |
| + let topP = generationParameters.topP | |
| // Cap max tokens: when using reference audio without prompt speech tokens, | |
| // the model may not generate EOS reliably, so use a smaller limit. | |
| diff --git a/Sources/MLXAudioTTS/Models/Chatterbox/S3Gen/CAMPPlus.swift b/Sources/MLXAudioTTS/Models/Chatterbox/S3Gen/CAMPPlus.swift | |
| index f02c9fd..466c455 100644 | |
| --- a/Sources/MLXAudioTTS/Models/Chatterbox/S3Gen/CAMPPlus.swift | |
| +++ b/Sources/MLXAudioTTS/Models/Chatterbox/S3Gen/CAMPPlus.swift | |
| @@ -144,7 +144,7 @@ private func statisticsPooling(_ x: MLXArray, axis: Int = -1) -> MLXArray { | |
| /// pools each segment, expands back to original length. | |
| private func segPooling(_ x: MLXArray, segLen: Int = 100) -> MLXArray { | |
| // x: (B, C, T) | |
| - let B = x.dim(0), C = x.dim(1), T = x.dim(2) | |
| + let T = x.dim(2) | |
| if T <= segLen { | |
| // Single segment — just mean, expand to match T | |
| @@ -304,7 +304,7 @@ class FCM: Module { | |
| // x: (B, F, T) in PyTorch NCHW convention | |
| // MLX Conv2d expects NHWC: (B, H, W, C) | |
| // Treat F as H, T as W, C=1 | |
| - let B = x.dim(0), F = x.dim(1), T = x.dim(2) | |
| + let B = x.dim(0) | |
| // (B, F, T) -> (B, F, T, 1) — NHWC with C=1 | |
| var out = x.expandedDimensions(axis: 3) | |
| diff --git a/Sources/MLXAudioTTS/Models/Chatterbox/S3Gen/ConformerEncoder.swift b/Sources/MLXAudioTTS/Models/Chatterbox/S3Gen/ConformerEncoder.swift | |
| index ad3309e..ef668b7 100644 | |
| --- a/Sources/MLXAudioTTS/Models/Chatterbox/S3Gen/ConformerEncoder.swift | |
| +++ b/Sources/MLXAudioTTS/Models/Chatterbox/S3Gen/ConformerEncoder.swift | |
| @@ -587,8 +587,6 @@ class S3GenUpsample1D: Module { | |
| -> (MLXArray, MLXArray) | |
| { | |
| // inputs: (B, C, T) PyTorch format | |
| - let B = inputs.dim(0), C = inputs.dim(1) | |
| - | |
| // Repeat each timestep stride times | |
| var outputs = MLX.repeated(inputs, count: strideVal, axis: 2) | |
| @@ -779,7 +777,7 @@ class UpsampleConformerEncoder: Module { | |
| // Create encoder layers | |
| var encoderLayers = [S3GenConformerEncoderLayer]() | |
| - for i in 0 ..< numBlocks { | |
| + for _ in 0 ..< numBlocks { | |
| let attnModule: Module | |
| if selfattentionLayerType == "rel_selfattn" { | |
| attnModule = S3GenRelPositionMultiHeadedAttention( | |
| @@ -819,7 +817,7 @@ class UpsampleConformerEncoder: Module { | |
| // Create up-encoder layers | |
| var upEncoderLayers = [S3GenConformerEncoderLayer]() | |
| - for i in 0 ..< numUpBlocks { | |
| + for _ in 0 ..< numUpBlocks { | |
| let attnModule: Module | |
| if selfattentionLayerType == "rel_selfattn" { | |
| attnModule = S3GenRelPositionMultiHeadedAttention( | |
| diff --git a/Sources/MLXAudioTTS/Models/Chatterbox/S3Gen/HiFTGenerator.swift b/Sources/MLXAudioTTS/Models/Chatterbox/S3Gen/HiFTGenerator.swift | |
| index ae67b07..8dab39e 100644 | |
| --- a/Sources/MLXAudioTTS/Models/Chatterbox/S3Gen/HiFTGenerator.swift | |
| +++ b/Sources/MLXAudioTTS/Models/Chatterbox/S3Gen/HiFTGenerator.swift | |
| @@ -223,7 +223,7 @@ class SineGen: Module { | |
| func callAsFunction(_ f0: MLXArray) -> (MLXArray, MLXArray, MLXArray) { | |
| // f0: (B, 1, T) | |
| - let B = f0.dim(0), T = f0.dim(2) | |
| + let B = f0.dim(0) | |
| // Create harmonics multiplier: [1, 2, ..., harmonicNum+1] | |
| let harmonicMult = (MLXArray(1 ... (harmonicNum + 1)).asType(.float32)) | |
| @@ -285,7 +285,7 @@ class SourceModuleHnNSF: Module { | |
| func callAsFunction(_ x: MLXArray) -> (MLXArray, MLXArray, MLXArray) { | |
| // x: (B, 1, T) F0 values | |
| // Generate sine harmonics — SineGen expects (B, 1, T) directly | |
| - var (sineWavs, uv, _) = lSinGen(x) | |
| + let (sineWavs, uv, _) = lSinGen(x) | |
| // sineWavs: (B, harmonics+1, T), uv: (B, 1, T) | |
| // Merge harmonics with linear layer | |
| @@ -310,7 +310,7 @@ private func reverseAlongAxis(_ x: MLXArray, axis: Int) -> MLXArray { | |
| /// Short-Time Fourier Transform for HiFi-GAN. | |
| func hifigan_stft(x: MLXArray, nFft: Int, hopLength: Int, window: MLXArray) -> (MLXArray, MLXArray) | |
| { | |
| - let B = x.dim(0), T = x.dim(1) | |
| + let B = x.dim(0) | |
| // Reflect padding | |
| let padLength = nFft / 2 | |
| @@ -358,7 +358,7 @@ func hifigan_istft( | |
| let real = clippedMag * MLX.cos(phase) | |
| let imag = clippedMag * MLX.sin(phase) | |
| - let B = real.dim(0), F = real.dim(1), numFrames = real.dim(2) | |
| + let B = real.dim(0), numFrames = real.dim(2) | |
| // Construct complex spectrum | |
| let imagUnit = MLXArray(real: Float(0), imaginary: Float(1)) | |
| @@ -500,7 +500,7 @@ class HiFTGenerator: Module { | |
| // For [8, 5, 3]: reversed = [3, 5, 8], [:-1] = [3, 5], so [1, 3, 5] | |
| // downsample_cum = cumprod([1, 3, 5]) = [1, 3, 15] | |
| // reversed = [15, 3, 1] | |
| - var downsampleRates = [1] + Array(upsampleRates.reversed().dropLast()) | |
| + let downsampleRates = [1] + Array(upsampleRates.reversed().dropLast()) | |
| var downsampleCum: [Int] = [] | |
| var cumProd = 1 | |
| for r in downsampleRates { | |
| @@ -646,7 +646,7 @@ class HiFTGenerator: Module { | |
| func callAsFunction(_ speechFeat: MLXArray, cacheSource: MLXArray? = nil) | |
| -> (MLXArray, MLXArray) | |
| { | |
| - var cache = cacheSource ?? MLXArray.zeros([1, 1, 0]) | |
| + let cache = cacheSource ?? MLXArray.zeros([1, 1, 0]) | |
| // Predict F0 | |
| let f0 = f0Predictor(speechFeat) | |
| diff --git a/Sources/MLXAudioTTS/Models/Chatterbox/S3Gen/S3GenMel.swift b/Sources/MLXAudioTTS/Models/Chatterbox/S3Gen/S3GenMel.swift | |
| index 28751e4..1515ed4 100644 | |
| --- a/Sources/MLXAudioTTS/Models/Chatterbox/S3Gen/S3GenMel.swift | |
| +++ b/Sources/MLXAudioTTS/Models/Chatterbox/S3Gen/S3GenMel.swift | |
| @@ -42,8 +42,7 @@ func s3genMelSpectrogram( | |
| y: MLXArray, | |
| nFft: Int = 1920, numMels: Int = 80, | |
| samplingRate: Int = 24000, hopSize: Int = 480, | |
| - winSize: Int = 1920, fmin: Int = 0, fmax: Int = 8000 | |
| -) -> MLXArray { | |
| + winSize: Int = 1920, fmin: Int = 0, fmax: Int = 8000) -> MLXArray { | |
| var input = y | |
| let was1D = input.ndim == 1 | |
| if was1D { | |
| @@ -62,12 +61,11 @@ func s3genMelSpectrogram( | |
| for i in 0 ..< B { | |
| let spec = stft( | |
| audio: input[i], window: window, | |
| - nFft: nFft, hopLength: hopSize, | |
| - center: false) | |
| + nFft: nFft, hopLength: hopSize) | |
| specs.append(spec) | |
| } | |
| // Stack: each spec is (T', F) -> (B, T', F) | |
| - var spec = MLX.stacked(specs, axis: 0) | |
| + let spec = MLX.stacked(specs, axis: 0) | |
| // Magnitude | |
| let magnitudes = abs(spec) // (B, T', F) | |
| diff --git a/Sources/MLXAudioTTS/Models/Chatterbox/T3/T3Model.swift b/Sources/MLXAudioTTS/Models/Chatterbox/T3/T3Model.swift | |
| index 17ddf74..22bcdae 100644 | |
| --- a/Sources/MLXAudioTTS/Models/Chatterbox/T3/T3Model.swift | |
| +++ b/Sources/MLXAudioTTS/Models/Chatterbox/T3/T3Model.swift | |
| @@ -409,7 +409,7 @@ public class T3Model: Module { | |
| let inputEmbeddings = MLX.concatenated([condEmbForInput, textEmbResult, bosEmbed], axis: 1) | |
| // Create KV cache | |
| - var cache = makeCache() | |
| + let cache = makeCache() | |
| // Initial forward pass to fill cache | |
| var hidden = tfmr(inputEmbeddings, cache: cache) | |
| diff --git a/Sources/MLXAudioTTS/Models/Chatterbox/VoiceEncoder/VoiceEncoder.swift b/Sources/MLXAudioTTS/Models/Chatterbox/VoiceEncoder/VoiceEncoder.swift | |
| index 332a91e..a68a0b8 100644 | |
| --- a/Sources/MLXAudioTTS/Models/Chatterbox/VoiceEncoder/VoiceEncoder.swift | |
| +++ b/Sources/MLXAudioTTS/Models/Chatterbox/VoiceEncoder/VoiceEncoder.swift | |
| @@ -122,7 +122,7 @@ public class VoiceEncoder: Module { | |
| var finalHiddenStates = [MLXArray]() | |
| for layer in lstmLayers { | |
| - let (allH, allC) = layer(output) | |
| + let (allH, _) = layer(output) | |
| output = allH | |
| // Extract final timestep hidden state | |
| diff --git a/Sources/MLXAudioTTS/Models/Chatterbox/VoiceEncoder/VoiceEncoderMelSpec.swift b/Sources/MLXAudioTTS/Models/Chatterbox/VoiceEncoder/VoiceEncoderMelSpec.swift | |
| index 945e970..efec5c4 100644 | |
| --- a/Sources/MLXAudioTTS/Models/Chatterbox/VoiceEncoder/VoiceEncoderMelSpec.swift | |
| +++ b/Sources/MLXAudioTTS/Models/Chatterbox/VoiceEncoder/VoiceEncoderMelSpec.swift | |
| @@ -50,7 +50,7 @@ public func voiceEncoderMelSpectrogram( | |
| } | |
| // Stack: (B, T', F) | |
| - var spec = MLX.stacked(specs, axis: 0) | |
| + let spec = MLX.stacked(specs, axis: 0) | |
| // Magnitudes | |
| var specMagnitudes = MLX.abs(spec) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment