Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save jeromeku/aa07cf4e5f907bf66d62ad5b6a9e0eb4 to your computer and use it in GitHub Desktop.

Select an option

Save jeromeku/aa07cf4e5f907bf66d62ad5b6a9e0eb4 to your computer and use it in GitHub Desktop.
Mirage starting example
from typing import Optional, Callable, Sequence, Any
import torch
from torch import nn, fx
from torch.library import Library
import torch.nn.functional as F
import torch._inductor
import torch._inductor.compile_fx
mirage_lib = Library("mirage", "FRAGMENT") # noqa
def direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: list[str] = [],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
tags: tuple[torch.Tag, ...] = (),
):
"""
`torch.library.custom_op` can have significant overhead because it
needs to consider complicated dispatching logic. This function
directly registers a custom op and dispatches it to the CUDA backend.
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
for more details.
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
library object. If you want to bind the operator to a different library,
make sure the library object is alive when the operator is used.
"""
import torch.library
schema_str = torch.library.infer_schema(op_func,
mutates_args=mutates_args)
my_lib = target_lib or mirage_lib
my_lib.define(op_name + schema_str, tags=tags)
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)
# ============================================================
# Mirage placeholder op registration
# ============================================================
def rms_norm(input: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor],
epsilon: float = 1e-5) -> tuple[torch.Tensor, torch.Tensor]:
# Never actually called
print("rms_norm")
if residual is None:
residual = input
return torch.zeros_like(input), residual
def rms_norm_fake(input: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor],
epsilon: float = 1e-5) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(input), torch.empty_like(input)
direct_register_custom_op("rms_norm", rms_norm, fake_impl=rms_norm_fake)
def silu_mul(input: torch.Tensor) -> torch.Tensor:
# Never actually called
print("silu_mul")
return torch.zeros_like(input[..., 0:input.shape[1] // 2])
def silu_mul_fake(input: torch.Tensor) -> torch.Tensor:
return torch.empty_like(input[..., 0:input.shape[1] // 2])
direct_register_custom_op("silu_mul", silu_mul, fake_impl=silu_mul_fake)
def rope(q: torch.Tensor,
k: torch.Tensor,
positions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# Never actually called
print("rope")
return torch.zeros_like(q), torch.zeros_like(k)
def rope_fake(q: torch.Tensor,
k: torch.Tensor,
positions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(q), torch.empty_like(k)
direct_register_custom_op("rope", rope, fake_impl=rope_fake)
def quantize(input: torch.Tensor,
scale: Optional[torch.Tensor],
dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
# Never actually called
print("quantize")
return torch.zeros_like(input, dtype=dtype), scale
def quantize_fake(input: torch.Tensor,
scale: Optional[torch.Tensor],
dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(input, dtype=dtype), scale
direct_register_custom_op("quantize", quantize, fake_impl=quantize_fake)
def attention(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor) -> torch.Tensor:
# Never actually called
print("attention")
return torch.zeros_like(q)
def attention_fake(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor) -> torch.Tensor:
return torch.empty_like(q)
direct_register_custom_op("attention", attention, fake_impl=attention_fake)
# ============================================================
# Example PyTorch-model
# ============================================================
class SimpleLlamaLayer(nn.Module):
def __init__(self,
hidden_dim: int = 4096,
num_heads: int = 32,
num_kv_heads: int = 8,
head_size: int = 128,
dtype: torch.dtype = torch.float16,
qdtype: Optional[torch.dtype] = None,
):
super().__init__()
if qdtype is None:
qdtype = dtype
self.hidden_dim = hidden_dim
self.head_size = head_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.dtype = dtype
self.qdtype = qdtype
self.quantized = qdtype != dtype
rand_w = lambda *dims, **kwargs: torch.randn(*dims, **kwargs, dtype=dtype, device="cuda")
rand_wq = lambda *dims, **kwargs: rand_w(*dims, **kwargs).to(dtype=qdtype).t().contiguous().t() # column-major for scaled-mm
self.weights = {
"qkv_proj": rand_wq(hidden_dim, (num_heads + num_kv_heads * 2) * head_size),
"o_proj": rand_wq(hidden_dim, hidden_dim),
"gate_up_proj": rand_wq(hidden_dim, 2 * hidden_dim),
"down_proj": rand_wq(hidden_dim, hidden_dim),
"input_norm": rand_w(hidden_dim),
"post_attn_norm": rand_w(hidden_dim),
}
if self.quantized:
self.scales = {k: torch.ones(1, 1, dtype=torch.float32) for k in self.weights}
self.wscales = {k: torch.ones(1, 1, dtype=torch.float32) for k in self.weights}
def _linear(self, input: torch.Tensor, name: str) -> torch.Tensor:
weight = self.weights[name]
if not self.quantized:
return input @ weight
scale_a, scale_b = self.scales[name], self.wscales[name]
qinput, scale_a = torch.ops.mirage.quantize(input, scale_a, dtype=self.qdtype)
return torch._scaled_mm(qinput, weight, scale_a=scale_a, scale_b=scale_b)
def forward(self, input: torch.Tensor, residual: torch.Tensor, positions: torch.Tensor) \
-> tuple[torch.Tensor, torch.Tensor]:
input_norm, residual = torch.ops.mirage.rms_norm(input, self.weights["input_norm"], residual)
qkv = self._linear(input_norm, "qkv_proj")
q, k, v = qkv.split_with_sizes([
self.num_heads * self.head_size,
self.num_kv_heads * self.head_size,
self.num_kv_heads * self.head_size
], dim=-1)
q, k = torch.ops.mirage.rope(q, k, positions)
out = torch.ops.mirage.attention(q, k, v)
out2 = self._linear(out, "o_proj")
out_norm, residual = torch.ops.mirage.rms_norm(out2, self.weights["post_attn_norm"], residual)
# mlp
up_gate = self._linear(out_norm, "gate_up_proj")
silu = torch.ops.mirage.silu_mul(up_gate)
down = self._linear(silu, "down_proj")
return down, residual
class SimpleLlama(nn.Module):
def __init__(self,
num_layers: int = 32,
vocab_size: int = 128256,
hidden_dim: int = 4096,
num_heads: int = 32,
num_kv_heads: int = 8,
head_size: int = 128,
dtype: torch.dtype = torch.float16,
qdtype: Optional[torch.dtype] = None,
):
super().__init__()
rand_w = lambda *dims, **kwargs: torch.randn(*dims, **kwargs, dtype=dtype, device="cuda")
self.weights = {
"embedding": rand_w(vocab_size, hidden_dim),
"out_norm": rand_w(hidden_dim),
}
self.layers = nn.ModuleList([SimpleLlamaLayer(
hidden_dim=hidden_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
qdtype=qdtype,
) for _ in range(num_layers)])
def forward(self, input: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
x_emb = F.embedding(input, self.weights["embedding"])
x, residual = x_emb, None
for layer in self.layers:
x, residual = layer(x, residual, positions)
x, _ = torch.ops.mirage.rms_norm(x, self.weights["out_norm"], residual)
return x
# ============================================================
# Backends
# ============================================================
class AotBackend:
"""Boilerplace to get Mirage backend to """
def __init__(self, compile_fn: Callable[[fx.GraphModule, Sequence], Callable[[Sequence], Any]]):
self.compile_fn = compile_fn
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence):
from torch._dynamo.backends.common import aot_autograd
return aot_autograd(
fw_compiler=self.compile_fn,
decompositions=torch._inductor.compile_fx.select_decomp_table(),
)(graph, example_inputs)
# ============================================================
# Skeleton for the actual Mirage backend that takes a Mirage-friendly fx graph and compiles it.
# ============================================================
class MirageBackend:
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence):
"""
Receives normalized (post-grad) IR.
"""
print(graph.graph.python_code(root_module="self").src)
return self.run
def run(self, *args, **kwargs):
print(f"Forward called with {len(args)=} args and {len(kwargs)=} kwargs.")
return torch.empty_like(args[0], dtype=torch.float16)
torch.set_default_device("cuda")
model = SimpleLlama()
inputs = [torch.randint(0, 4096, (5,)), torch.arange(0, 4096)]
model(*inputs)
compiled_model = torch.compile(model, backend=AotBackend(MirageBackend()), fullgraph=True)
compiled_model(*inputs)
qmodel = SimpleLlama(qdtype=torch.float8_e4m3fn)
qmodel(*inputs)
compiled_qmodel = torch.compile(qmodel, backend=AotBackend(MirageBackend()), fullgraph=True)
compiled_qmodel(*inputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment