Skip to content

Instantly share code, notes, and snippets.

@stellaraccident
Created January 26, 2026 04:10
Show Gist options
  • Select an option

  • Save stellaraccident/eb9ce7ba069332b48e93a9de91c1a284 to your computer and use it in GitHub Desktop.

Select an option

Save stellaraccident/eb9ce7ba069332b48e93a9de91c1a284 to your computer and use it in GitHub Desktop.
MoE components plan for frank-models (iree-link composition)

Plan: Add MoE Components (mul_mat_id, moe_ffn_block)

Overview

Extract MoE primitives from /develop/ai-no-fluff/kb/ben/moe_f32_parameterized.mlir:

  • mul_mat_id - Expert-selected matrix multiplication (gather + batch_matmul)
  • moe_ffn_block - Full MoE FFN block composing routing, expert compute, weighted sum

Key challenge: moe_ffn_block depends on mul_mat_id and swiglu. Need systematic composition without manual inlining.

Composition Strategy: iree-link

Use iree-link (IREE's compile-time MLIR linker) to merge modules before compilation.

How iree-link works:

  • Resolves external function declarations against provided modules
  • Uses dotted naming: @module_name.func → searches module_name.mlir
  • Produces merged MLIR output for compilation

Workflow:

  1. Components declare external deps with dotted names (e.g., util.func private @activation_components.swiglu(...))
  2. Test infrastructure runs iree-link to merge before iree-compile

Add link_and_compile() to tests/utils.py:

def link_and_compile(main_path: str, library_paths: List[str], rt) -> IREEModule:
    """Link MLIR modules with iree-link, then compile."""
    import subprocess
    import tempfile

    main_full = COMPONENTS_DIR / main_path
    lib_args = []
    for lib in library_paths:
        lib_args.extend(["--link-module", str(COMPONENTS_DIR / lib)])

    with tempfile.NamedTemporaryFile(suffix=".mlir", delete=False) as f:
        linked_path = f.name

    subprocess.run([
        "iree-link", str(main_full), *lib_args, "-o", linked_path
    ], check=True)

    linked_source = Path(linked_path).read_text()
    return compile_mlir(linked_source, rt)

Test usage:

def moe_module(rt):
    return link_and_compile(
        "moe/moe_ffn_block.mlir",
        ["moe/mul_mat_id.mlir", "activation/swiglu.mlir"],
        rt
    )

Directory Structure

components/
  moe/
    mul_mat_id.mlir        # NEW - standalone primitive
    moe_ffn_block.mlir     # NEW - calls mul_mat_id + swiglu

oracles/
  moe.py                   # NEW - mul_mat_id, moe_ffn_block

tests/
  utils.py                 # MODIFY - add link_and_compile()
  test_mul_mat_id.py       # NEW
  test_moe_ffn_block.py    # NEW

Implementation Order

  1. link_and_compile() - Test infrastructure using iree-link
  2. mul_mat_id - Leaf primitive (no deps), can test standalone
  3. moe_ffn_block - Uses external declarations, linked at compile time

Component Details

1. mul_mat_id

Purpose: Matrix multiply where each row uses a different expert's weights based on indices.

Module naming (for iree-link resolution):

module @moe_components {
  util.func public @mul_mat_id(...) -> ...
}

When linked, callable as @moe_components.mul_mat_id.

// components/moe/mul_mat_id.mlir
util.func public @mul_mat_id(
    %expert_weights: tensor<?x?x?xf32>,  // [n_out, n_in, n_expert]
    %input: tensor<?x?x?xf32>,           // [n_in, n_expert_used, n_tokens]
    %expert_ids: tensor<?x?xi32>         // [n_expert_used, n_tokens]
) -> tensor<?x?x?xf32>                   // [n_out, n_expert_used, n_tokens]

Algorithm:

  1. Transpose weights: [n_out, n_in, n_expert][n_expert, n_out, n_in]
  2. Flatten indices: [n_expert_used, n_tokens][batch]
  3. Gather expert matrices: iree_linalg_ext.gather[batch, n_out, n_in]
  4. Reshape input for batched matmul: → [batch, n_in, 1]
  5. Batched matmul: linalg.batch_matmul[batch, n_out, 1]
  6. Reshape output: → [n_out, n_expert_used, n_tokens]

IREE ops used: iree_linalg_ext.gather, linalg.batch_matmul

2. moe_ffn_block

Purpose: Complete MoE FFN layer (replaces dense FFN in transformer).

External declarations (resolved by iree-link):

// Declare dependencies with dotted module names
util.func private @moe_components.mul_mat_id(
    tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?xi32>
) -> tensor<?x?x?xf32>

util.func private @activation_components.swiglu(
    tensor<?x?x?xf32>, tensor<?x?x?xf32>
) -> tensor<?x?x?xf32>

Public signature:

// components/moe/moe_ffn_block.mlir
util.func public @moe_ffn_block(
    %input: tensor<?x?xf32>,              // [n_tokens, n_embd]
    %router_weights: tensor<?x?xf32>,     // [n_expert, n_embd]
    %up_weights: tensor<?x?x?xf32>,       // [n_ff, n_embd, n_expert]
    %gate_weights: tensor<?x?x?xf32>,     // [n_ff, n_embd, n_expert]
    %down_weights: tensor<?x?x?xf32>,     // [n_embd, n_ff, n_expert]
    %n_expert_used: index                  // top-k (typically 2)
) -> tensor<?x?xf32>                      // [n_tokens, n_embd]

Algorithm:

  1. Router: router_weights @ input → logits [n_expert, n_tokens]
  2. Softmax: probabilities over experts
  3. Top-k: iree_linalg_ext.topk → weights + indices [n_expert_used, n_tokens]
  4. Replicate input: broadcast to [n_embd, n_expert_used, n_tokens]
  5. UP projection: mul_mat_id(up_weights, input, ids)[n_ff, k, n]
  6. GATE projection: mul_mat_id(gate_weights, input, ids)[n_ff, k, n]
  7. Activation: swiglu(gate, up)[n_ff, k, n]
  8. DOWN projection: mul_mat_id(down_weights, activated, ids)[n_embd, k, n]
  9. Weight & sum: multiply by routing weights, sum over expert dim
  10. Transpose:[n_tokens, n_embd]

Dependencies: @mul_mat_id, @swiglu IREE ops used: linalg.matmul, linalg.softmax, iree_linalg_ext.topk

NumPy Oracles

oracles/moe.py

def mul_mat_id(expert_weights, input, expert_ids):
    """Expert-selected matrix multiply.

    Args:
        expert_weights: [n_out, n_in, n_expert]
        input: [n_in, n_expert_used, n_tokens]
        expert_ids: [n_expert_used, n_tokens] (int32)
    Returns:
        [n_out, n_expert_used, n_tokens]
    """
    n_out, n_in, n_expert = expert_weights.shape
    n_expert_used, n_tokens = expert_ids.shape

    # Transpose weights for easier indexing
    # [n_expert, n_out, n_in]
    weights_t = expert_weights.transpose(2, 0, 1)

    # Gather: select expert matrices
    # [n_expert_used, n_tokens, n_out, n_in]
    gathered = weights_t[expert_ids]

    # Reshape input: [n_expert_used, n_tokens, n_in]
    input_t = input.transpose(1, 2, 0)

    # Batched matmul per (expert_used, token)
    output = np.einsum('etoi,eti->eto', gathered, input_t)

    # Transpose to [n_out, n_expert_used, n_tokens]
    return output.transpose(2, 0, 1)


def moe_ffn_block(input, router_weights, up_weights, gate_weights,
                  down_weights, n_expert_used):
    """Full MoE FFN block.

    Args:
        input: [n_tokens, n_embd]
        router_weights: [n_expert, n_embd]
        up_weights: [n_ff, n_embd, n_expert]
        gate_weights: [n_ff, n_embd, n_expert]
        down_weights: [n_embd, n_ff, n_expert]
        n_expert_used: int (top-k)
    Returns:
        [n_tokens, n_embd]
    """
    from oracles.activation import swiglu

    n_tokens, n_embd = input.shape
    n_expert = router_weights.shape[0]

    # 1. Router logits: [n_expert, n_tokens]
    logits = router_weights @ input.T

    # 2. Softmax over experts
    probs = np.exp(logits - logits.max(axis=0, keepdims=True))
    probs = probs / probs.sum(axis=0, keepdims=True)

    # 3. Top-k selection
    top_indices = np.argsort(-probs, axis=0)[:n_expert_used]  # [k, n_tokens]
    top_weights = np.take_along_axis(probs, top_indices, axis=0)  # [k, n_tokens]

    # 4. Replicate input: [n_embd, n_expert_used, n_tokens]
    input_rep = np.broadcast_to(
        input.T[:, None, :],
        (n_embd, n_expert_used, n_tokens)
    ).copy()

    # 5-6. Expert projections
    up = mul_mat_id(up_weights, input_rep, top_indices.astype(np.int32))
    gate = mul_mat_id(gate_weights, input_rep, top_indices.astype(np.int32))

    # 7. Activation
    activated = swiglu(gate, up)

    # 8. Down projection
    down = mul_mat_id(down_weights, activated, top_indices.astype(np.int32))

    # 9. Weight and sum: [n_embd, n_tokens]
    weighted = down * top_weights[None, :, :]  # broadcast weights
    summed = weighted.sum(axis=1)  # sum over experts

    # 10. Transpose: [n_tokens, n_embd]
    return summed.T

Test Strategy

test_mul_mat_id.py

  • Basic correctness with known expert selection
  • Single expert (n_expert_used=1)
  • Multiple experts (n_expert_used=2, 4)
  • Edge cases: all same expert, sequential experts

test_moe_ffn_block.py

  • Full forward pass with small dimensions
  • Verify routing (check which experts selected)
  • Compare against oracle end-to-end
  • Use link_and_compile() to compose modules

link_and_compile() Implementation

def link_and_compile(
    main_path: str,
    library_paths: List[str],
    rt,
    library_dir: Path = None
) -> IREEModule:
    """Link MLIR modules with iree-link, then compile.

    Args:
        main_path: Primary module (relative to COMPONENTS_DIR)
        library_paths: Dependency modules to link
        rt: IREERuntime instance
        library_dir: Optional library search path for auto-discovery
    """
    import subprocess
    import tempfile

    main_full = COMPONENTS_DIR / main_path

    # Build iree-link command
    cmd = ["iree-link", str(main_full)]
    for lib in library_paths:
        cmd.extend(["--link-module", str(COMPONENTS_DIR / lib)])
    if library_dir:
        cmd.extend(["--library-path", str(library_dir)])

    with tempfile.NamedTemporaryFile(suffix=".mlir", delete=False) as f:
        linked_path = f.name
    cmd.extend(["-o", linked_path])

    subprocess.run(cmd, check=True)

    linked_source = Path(linked_path).read_text()
    Path(linked_path).unlink()  # cleanup
    return compile_mlir(linked_source, rt)

Known Issues

  • attention_gqa is xfailed due to iree-org/iree#23277 (dynamic shapes + attention op)
  • Same issue may affect moe_ffn_block if using dynamic expert selection
  • May need to xfail or use bound hints

Module Naming Convention

For iree-link to resolve cross-module calls:

  • Each MLIR file uses module @<name>_components { ... }
  • External calls use @<name>_components.<function>
File Module Name
activation/swiglu.mlir @activation_components (already exists)
moe/mul_mat_id.mlir @moe_components
moe/moe_ffn_block.mlir @moe_ffn_components

Files to Create/Modify

File Action
tests/utils.py Add link_and_compile()
components/moe/mul_mat_id.mlir Create (module @moe_components)
components/moe/moe_ffn_block.mlir Create (module @moe_ffn_components)
oracles/moe.py Create
tests/test_mul_mat_id.py Create
tests/test_moe_ffn_block.py Create

Verification

cd /develop/ai-no-fluff/frank-models
pytest tests/test_mul_mat_id.py -v
pytest tests/test_moe_ffn_block.py -v

Source Reference

/develop/ai-no-fluff/kb/ben/moe_f32_parameterized.mlir

  • mul_mat_id: lines 362-467
  • moe_ffn_block: lines 536-688
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment