Skip to content

Instantly share code, notes, and snippets.

@Jokeren
Last active July 30, 2025 21:57
Show Gist options
  • Select an option

  • Save Jokeren/be3dfcadaaa3710584697a26db19de10 to your computer and use it in GitHub Desktop.

Select an option

Save Jokeren/be3dfcadaaa3710584697a26db19de10 to your computer and use it in GitHub Desktop.
Inline Info Reproducer

Build triton

git clone https://github.com/triton-lang/triton.git && cd triton
pip install -e.

Run test.py

CUDA_LAUNCH_BLOCKING=1 cuda-gdb --args python test.py

In cuda-gdb

CTRL+C

(cuda-gdb) r
then when it stops, interrupt with CTRL+C and you will see

0x0000000628f96150 in repro<<<(1,1,1),(128,1,1)>>> () at /root/code/triton/repro.py:8
8	    mbarrier.wait(b, 0)
(cuda-gdb) up
Initial frame selected; you cannot go up.
(cuda-gdb) bt
#0  0x0000000628f96150 in repro<<<(1,1,1),(128,1,1)>>> () at /root/code/triton/repro.py:8

Apparently there's only a single frame.

Run test.cu

nvcc -O3 -opt-info inline -lineinfo -arch sm_90a ./test.cu -o test.out
cuda-gdb ./test.out

In cuda-gdb

(cuda-gdb) r
then when it stops, interrupt with CTRL+C and you will see

(cuda-gdb) bt
#0  repro<<<(1,1,1),(32,1,1)>>> () at /mnt/data/keren/./test.cu:9 in _Z4waitPm inlined from test.cu:24

It is much better than our line info because it shows inlined from test.cu:24 even when optimizations are turned on.

#include <cuda/ptx>
#include <cuda/std/cstdint>
#include <cstdio>
using u64 = cuda::std::uint64_t;
using u32 = cuda::std::uint32_t;
__device__ void wait(u64 *barrier_generic) {
while (!cuda::ptx::mbarrier_test_wait(barrier_generic,
static_cast<u64>(0))) {
// busy‐loop
}
}
__global__ void repro() {
extern __shared__ u64 smem[];
if (threadIdx.x == 0) {
u32 count = 1;
cuda::ptx::mbarrier_init(smem, count);
}
__syncthreads();
wait(smem);
if (threadIdx.x == 0) {
printf("Barrier released\n");
}
}
int main() {
// one block of 32 threads, plus 8 bytes of shared storage
repro<<<1, 32, sizeof(cuda::std::uint64_t)>>>();
cudaDeviceSynchronize();
return 0;
}
~
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.language.nvidia.hopper import mbarrier
@gluon.jit
def foo(b):
mbarrier.wait(b, 0)
@gluon.jit
def bar(b):
foo(b)
@gluon.jit
def repro():
b = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout())
mbarrier.init(b, count=1)
bar(b)
mbarrier.invalidate(b)
repro[(1, )]()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment