Extract MoE primitives from /develop/ai-no-fluff/kb/ben/moe_f32_parameterized.mlir:
mul_mat_id- Expert-selected matrix multiplication (gather + batch_matmul)moe_ffn_block- Full MoE FFN block composing routing, expert compute, weighted sum
Key challenge: moe_ffn_block depends on mul_mat_id and swiglu. Need systematic composition without manual inlining.
Use iree-link (IREE's compile-time MLIR linker) to merge modules before compilation.
How iree-link works:
- Resolves external function declarations against provided modules
- Uses dotted naming:
@module_name.func→ searchesmodule_name.mlir - Produces merged MLIR output for compilation
Workflow:
- Components declare external deps with dotted names (e.g.,
util.func private @activation_components.swiglu(...)) - Test infrastructure runs
iree-linkto merge beforeiree-compile
Add link_and_compile() to tests/utils.py:
def link_and_compile(main_path: str, library_paths: List[str], rt) -> IREEModule:
"""Link MLIR modules with iree-link, then compile."""
import subprocess
import tempfile
main_full = COMPONENTS_DIR / main_path
lib_args = []
for lib in library_paths:
lib_args.extend(["--link-module", str(COMPONENTS_DIR / lib)])
with tempfile.NamedTemporaryFile(suffix=".mlir", delete=False) as f:
linked_path = f.name
subprocess.run([
"iree-link", str(main_full), *lib_args, "-o", linked_path
], check=True)
linked_source = Path(linked_path).read_text()
return compile_mlir(linked_source, rt)Test usage:
def moe_module(rt):
return link_and_compile(
"moe/moe_ffn_block.mlir",
["moe/mul_mat_id.mlir", "activation/swiglu.mlir"],
rt
)components/
moe/
mul_mat_id.mlir # NEW - standalone primitive
moe_ffn_block.mlir # NEW - calls mul_mat_id + swiglu
oracles/
moe.py # NEW - mul_mat_id, moe_ffn_block
tests/
utils.py # MODIFY - add link_and_compile()
test_mul_mat_id.py # NEW
test_moe_ffn_block.py # NEW
- link_and_compile() - Test infrastructure using iree-link
- mul_mat_id - Leaf primitive (no deps), can test standalone
- moe_ffn_block - Uses external declarations, linked at compile time
Purpose: Matrix multiply where each row uses a different expert's weights based on indices.
Module naming (for iree-link resolution):
module @moe_components {
util.func public @mul_mat_id(...) -> ...
}When linked, callable as @moe_components.mul_mat_id.
// components/moe/mul_mat_id.mlir
util.func public @mul_mat_id(
%expert_weights: tensor<?x?x?xf32>, // [n_out, n_in, n_expert]
%input: tensor<?x?x?xf32>, // [n_in, n_expert_used, n_tokens]
%expert_ids: tensor<?x?xi32> // [n_expert_used, n_tokens]
) -> tensor<?x?x?xf32> // [n_out, n_expert_used, n_tokens]Algorithm:
- Transpose weights:
[n_out, n_in, n_expert]→[n_expert, n_out, n_in] - Flatten indices:
[n_expert_used, n_tokens]→[batch] - Gather expert matrices:
iree_linalg_ext.gather→[batch, n_out, n_in] - Reshape input for batched matmul: →
[batch, n_in, 1] - Batched matmul:
linalg.batch_matmul→[batch, n_out, 1] - Reshape output: →
[n_out, n_expert_used, n_tokens]
IREE ops used: iree_linalg_ext.gather, linalg.batch_matmul
Purpose: Complete MoE FFN layer (replaces dense FFN in transformer).
External declarations (resolved by iree-link):
// Declare dependencies with dotted module names
util.func private @moe_components.mul_mat_id(
tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?xi32>
) -> tensor<?x?x?xf32>
util.func private @activation_components.swiglu(
tensor<?x?x?xf32>, tensor<?x?x?xf32>
) -> tensor<?x?x?xf32>Public signature:
// components/moe/moe_ffn_block.mlir
util.func public @moe_ffn_block(
%input: tensor<?x?xf32>, // [n_tokens, n_embd]
%router_weights: tensor<?x?xf32>, // [n_expert, n_embd]
%up_weights: tensor<?x?x?xf32>, // [n_ff, n_embd, n_expert]
%gate_weights: tensor<?x?x?xf32>, // [n_ff, n_embd, n_expert]
%down_weights: tensor<?x?x?xf32>, // [n_embd, n_ff, n_expert]
%n_expert_used: index // top-k (typically 2)
) -> tensor<?x?xf32> // [n_tokens, n_embd]Algorithm:
- Router:
router_weights @ input→ logits[n_expert, n_tokens] - Softmax: probabilities over experts
- Top-k:
iree_linalg_ext.topk→ weights + indices[n_expert_used, n_tokens] - Replicate input: broadcast to
[n_embd, n_expert_used, n_tokens] - UP projection:
mul_mat_id(up_weights, input, ids)→[n_ff, k, n] - GATE projection:
mul_mat_id(gate_weights, input, ids)→[n_ff, k, n] - Activation:
swiglu(gate, up)→[n_ff, k, n] - DOWN projection:
mul_mat_id(down_weights, activated, ids)→[n_embd, k, n] - Weight & sum: multiply by routing weights, sum over expert dim
- Transpose: →
[n_tokens, n_embd]
Dependencies: @mul_mat_id, @swiglu
IREE ops used: linalg.matmul, linalg.softmax, iree_linalg_ext.topk
def mul_mat_id(expert_weights, input, expert_ids):
"""Expert-selected matrix multiply.
Args:
expert_weights: [n_out, n_in, n_expert]
input: [n_in, n_expert_used, n_tokens]
expert_ids: [n_expert_used, n_tokens] (int32)
Returns:
[n_out, n_expert_used, n_tokens]
"""
n_out, n_in, n_expert = expert_weights.shape
n_expert_used, n_tokens = expert_ids.shape
# Transpose weights for easier indexing
# [n_expert, n_out, n_in]
weights_t = expert_weights.transpose(2, 0, 1)
# Gather: select expert matrices
# [n_expert_used, n_tokens, n_out, n_in]
gathered = weights_t[expert_ids]
# Reshape input: [n_expert_used, n_tokens, n_in]
input_t = input.transpose(1, 2, 0)
# Batched matmul per (expert_used, token)
output = np.einsum('etoi,eti->eto', gathered, input_t)
# Transpose to [n_out, n_expert_used, n_tokens]
return output.transpose(2, 0, 1)
def moe_ffn_block(input, router_weights, up_weights, gate_weights,
down_weights, n_expert_used):
"""Full MoE FFN block.
Args:
input: [n_tokens, n_embd]
router_weights: [n_expert, n_embd]
up_weights: [n_ff, n_embd, n_expert]
gate_weights: [n_ff, n_embd, n_expert]
down_weights: [n_embd, n_ff, n_expert]
n_expert_used: int (top-k)
Returns:
[n_tokens, n_embd]
"""
from oracles.activation import swiglu
n_tokens, n_embd = input.shape
n_expert = router_weights.shape[0]
# 1. Router logits: [n_expert, n_tokens]
logits = router_weights @ input.T
# 2. Softmax over experts
probs = np.exp(logits - logits.max(axis=0, keepdims=True))
probs = probs / probs.sum(axis=0, keepdims=True)
# 3. Top-k selection
top_indices = np.argsort(-probs, axis=0)[:n_expert_used] # [k, n_tokens]
top_weights = np.take_along_axis(probs, top_indices, axis=0) # [k, n_tokens]
# 4. Replicate input: [n_embd, n_expert_used, n_tokens]
input_rep = np.broadcast_to(
input.T[:, None, :],
(n_embd, n_expert_used, n_tokens)
).copy()
# 5-6. Expert projections
up = mul_mat_id(up_weights, input_rep, top_indices.astype(np.int32))
gate = mul_mat_id(gate_weights, input_rep, top_indices.astype(np.int32))
# 7. Activation
activated = swiglu(gate, up)
# 8. Down projection
down = mul_mat_id(down_weights, activated, top_indices.astype(np.int32))
# 9. Weight and sum: [n_embd, n_tokens]
weighted = down * top_weights[None, :, :] # broadcast weights
summed = weighted.sum(axis=1) # sum over experts
# 10. Transpose: [n_tokens, n_embd]
return summed.T- Basic correctness with known expert selection
- Single expert (n_expert_used=1)
- Multiple experts (n_expert_used=2, 4)
- Edge cases: all same expert, sequential experts
- Full forward pass with small dimensions
- Verify routing (check which experts selected)
- Compare against oracle end-to-end
- Use
link_and_compile()to compose modules
def link_and_compile(
main_path: str,
library_paths: List[str],
rt,
library_dir: Path = None
) -> IREEModule:
"""Link MLIR modules with iree-link, then compile.
Args:
main_path: Primary module (relative to COMPONENTS_DIR)
library_paths: Dependency modules to link
rt: IREERuntime instance
library_dir: Optional library search path for auto-discovery
"""
import subprocess
import tempfile
main_full = COMPONENTS_DIR / main_path
# Build iree-link command
cmd = ["iree-link", str(main_full)]
for lib in library_paths:
cmd.extend(["--link-module", str(COMPONENTS_DIR / lib)])
if library_dir:
cmd.extend(["--library-path", str(library_dir)])
with tempfile.NamedTemporaryFile(suffix=".mlir", delete=False) as f:
linked_path = f.name
cmd.extend(["-o", linked_path])
subprocess.run(cmd, check=True)
linked_source = Path(linked_path).read_text()
Path(linked_path).unlink() # cleanup
return compile_mlir(linked_source, rt)- attention_gqa is xfailed due to iree-org/iree#23277 (dynamic shapes + attention op)
- Same issue may affect
moe_ffn_blockif using dynamic expert selection - May need to xfail or use bound hints
For iree-link to resolve cross-module calls:
- Each MLIR file uses
module @<name>_components { ... } - External calls use
@<name>_components.<function>
| File | Module Name |
|---|---|
activation/swiglu.mlir |
@activation_components (already exists) |
moe/mul_mat_id.mlir |
@moe_components |
moe/moe_ffn_block.mlir |
@moe_ffn_components |
| File | Action |
|---|---|
tests/utils.py |
Add link_and_compile() |
components/moe/mul_mat_id.mlir |
Create (module @moe_components) |
components/moe/moe_ffn_block.mlir |
Create (module @moe_ffn_components) |
oracles/moe.py |
Create |
tests/test_mul_mat_id.py |
Create |
tests/test_moe_ffn_block.py |
Create |
cd /develop/ai-no-fluff/frank-models
pytest tests/test_mul_mat_id.py -v
pytest tests/test_moe_ffn_block.py -v/develop/ai-no-fluff/kb/ben/moe_f32_parameterized.mlir
mul_mat_id: lines 362-467moe_ffn_block: lines 536-688