Skip to content

Instantly share code, notes, and snippets.

class GraphModule(torch.nn.Module):
def forward(self, primals_1: "bf16[32, 1024, 768]", primals_2: "bf16[50257, 768]", primals_3: "bf16[50257]", primals_4: "i64[32, 1024]", tangents_1: "bf16[]"):
# File: /home/shunting/ws/pytorch/test/inductor/test_auto_chunker.py:135 in f, code: x = x * 2
mul: "bf16[32, 1024, 768]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
# File: /home/shunting/ws/pytorch/test/inductor/test_auto_chunker.py:126 in forward, code: return self.ce(self.linear(x).view(B * T, -1), y.view(-1))
view: "bf16[32768, 768]" = torch.ops.aten.view.default(mul, [32768, 768]); mul = None
permute: "bf16[768, 50257]" = torch.ops.aten.permute.default(primals_2, [1, 0]); primals_2 = None
view_3: "i64[32768]" = torch.ops.aten.view.default(primals_4, [-1]); primals_4 = None
ne: "b8[32768]" = torch.ops.aten.ne.Scalar(view_3, -100)
class inner_f(torch.nn.Module):
def forward(self, primals, tangents):
primals_1: "bf16[32, 1024, 768]"; primals_2: "bf16[50257, 768]"; primals_3: "bf16[50257]"; primals_4: "i64[32, 1024]"; tangents_1: "bf16[]";
primals_1, primals_2, primals_3, primals_4, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
# File: /home/shunting/ws/pytorch/test/inductor/test_auto_chunker.py:135 in f, code: x = x * 2
mul: "bf16[32, 1024, 768]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
# File: /home/shunting/ws/pytorch/test/inductor/test_auto_chunker.py:126 in forward, code: return self.ce(self.linear(x).view(B * T, -1), y.view(-1))
view: "bf16[32768, 768]" = torch.ops.aten.view.default(mul, [32768, 768]); mul = None
"""
NOTE: the script right now only benchmark the latency of the attention
kernel itself. The following things are excluded
- add new key/value to the cache
- setup BlockMask for flex-attention
- etc.
"""
import math
import torch
from torch import nn
from torch import distributed
import contextlib
import os
from vllm import LLM, SamplingParams
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
os.environ["VLLM_ATTENTION_BACKEND"] = os.getenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
from torch._dynamo.testing import rand_strided
diff --git a/src/liger_kernel/ops/rms_norm.py b/src/liger_kernel/ops/rms_norm.py
index af867f4..5f7f2ac 100644
--- a/src/liger_kernel/ops/rms_norm.py
+++ b/src/liger_kernel/ops/rms_norm.py
@@ -450,6 +450,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
elif X.device.type == "xpu":
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
+ sm_count = sm_count * 32
# fp32 for numerical stability especially.
def triton_per_fused__to_copy_add_div_expand_mul_pow_squeeze_sum_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ws_ptr, xnumel, r0_numel, XBLOCK : tl.constexpr, RSPLIT_SIZE : tl.constexpr, NUM_STAGES : tl.constexpr):
xnumel = 32768
r0_numel = 768
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * RSPLIT_SIZE
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_index = tl.arange(0, R0_BLOCK)[None, :]
def triton_per_fused__to_copy_add_div_expand_mul_pow_squeeze_sum_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ws_ptr, xnumel, r0_numel, XBLOCK : tl.constexpr, RSPLIT_SIZE : tl.constexpr, NUM_STAGES : tl.constexpr):
xnumel = 32768
r0_numel = 768
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * RSPLIT_SIZE
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_index = tl.arange(0, R0_BLOCK)[None, :]
diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py
index b4f175183..b539fd1f8 100644
--- a/vllm/benchmarks/latency.py
+++ b/vllm/benchmarks/latency.py
@@ -101,7 +101,7 @@ def main(args: argparse.Namespace):
sampling_params = SamplingParams(
n=args.n,
- temperature=1.0,
+ temperature=0.0,
import torch
torch._inductor.config.combo_kernels = True
torch._inductor.config.fx_graph_cache = False
@torch.compile
def f(x, y):
return x + 1, y * 2
# x = torch.randn(1024, device="cuda")