Skip to content

Instantly share code, notes, and snippets.

@bbrowning
Created November 21, 2025 20:56
Show Gist options
  • Select an option

  • Save bbrowning/bba5038df574895485b08afd88421c72 to your computer and use it in GitHub Desktop.

Select an option

Save bbrowning/bba5038df574895485b08afd88421c72 to your computer and use it in GitHub Desktop.
Changes required to get latest main of vLLM running Qwen3 MoE NVFP4 on DGX Spark
diff --git a/csrc/ops.h b/csrc/ops.h
index f8bdc61aa..933c64db0 100644
--- a/csrc/ops.h
+++ b/csrc/ops.h
@@ -218,6 +218,7 @@ bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
+bool cutlass_moe_mm_supports_fp4(int64_t cuda_device_capability);
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B, torch::Tensor const& A_sf,
diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu
index 9cba2828a..996d9d5d3 100644
--- a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu
+++ b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu
@@ -51,4 +51,10 @@ bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
int runtimeVersion;
cudaRuntimeGetVersion(&runtimeVersion);
return cuda_device_capability >= 100 && runtimeVersion >= 12080;
-}
\ No newline at end of file
+}
+
+bool cutlass_moe_mm_supports_fp4(int64_t cuda_device_capability) {
+ int runtimeVersion;
+ cudaRuntimeGetVersion(&runtimeVersion);
+ return cuda_device_capability >= 100 && cuda_device_capability < 120 && runtimeVersion >= 12080;
+}
diff --git a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
index 1001af05f..4e0c34ae0 100644
--- a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
+++ b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
@@ -68,8 +68,7 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
#endif
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
- defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 || \
- defined(ENABLE_SCALED_MM_SM120) && ENABLE_SCALED_MM_SM120
+ defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100
void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp
index 5af74c2c2..9c8e0fcb7 100644
--- a/csrc/torch_bindings.cpp
+++ b/csrc/torch_bindings.cpp
@@ -531,6 +531,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
ops.impl("cutlass_scaled_mm_supports_fp4", &cutlass_scaled_mm_supports_fp4);
+
+ // Check if cutlass_moe_mm_supports_fp4 is supported for CUDA devices
+ // of the given capability
+ ops.def("cutlass_moe_mm_supports_fp4(int cuda_device_capability) -> bool");
+ ops.impl("cutlass_moe_mm_supports_fp4", &cutlass_moe_mm_supports_fp4);
#endif
// Quantized GEMM for GPTQ
diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py
index 0f625a794..214ab6909 100644
--- a/vllm/_custom_ops.py
+++ b/vllm/_custom_ops.py
@@ -739,6 +739,9 @@ if hasattr(torch.ops._C, "ggml_moe_a8_vec"):
def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability)
+def cutlass_moe_mm_supports_fp4(cuda_device_capability: int) -> bool:
+ return torch.ops._C.cutlass_moe_mm_supports_fp4(cuda_device_capability)
+
def cutlass_blockwise_scaled_grouped_mm(
output: torch.Tensor,
diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py
index 44c5b027d..be2ae920b 100644
--- a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py
+++ b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py
@@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
is_fp4_marlin_supported,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ cutlass_fp4_moe_supported,
cutlass_fp4_supported,
)
@@ -31,7 +32,7 @@ class NvFp4Support:
def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support:
"""Detect platform support for NV-FP4 fused-MoE path"""
- cutlass_supported = cutlass_fp4_supported()
+ cutlass_supported = cutlass_fp4_moe_supported()
allow_flashinfer = cutlass_supported and (
is_flashinfer_fp4_cutlass_moe_available()
diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py
index d056d3404..8f6717570 100644
--- a/vllm/model_executor/layers/quantization/utils/quant_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py
@@ -11,7 +11,7 @@ import numpy
import torch
from torch import fx
-from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
+from vllm._custom_ops import cutlass_moe_mm_supports_fp4, cutlass_scaled_mm_supports_fp4
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
@@ -685,3 +685,11 @@ def cutlass_fp4_supported() -> bool:
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability)
+
+
+def cutlass_fp4_moe_supported() -> bool:
+ if not current_platform.is_cuda():
+ return False
+ capability_tuple = current_platform.get_device_capability()
+ capability = -1 if capability_tuple is None else capability_tuple.to_int()
+ return cutlass_moe_mm_supports_fp4(capability)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment