Skip to content

Instantly share code, notes, and snippets.

@jeromeku
Forked from fengxie/pass_list_as_vector.py
Created September 14, 2025 17:12
Show Gist options
  • Select an option

  • Save jeromeku/9a19019b119a3a7dfc5d7ed15a58f985 to your computer and use it in GitHub Desktop.

Select an option

Save jeromeku/9a19019b119a3a7dfc5d7ed15a58f985 to your computer and use it in GitHub Desktop.
CuTe DSL pass list from python and convert to vector for kernel
from typing import List
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
@cute.kernel
def kernel_use_vec_as_arg(vec, res: cute.Tensor):
res.store(vec)
for i in cutlass.range(10):
cute.printf("vec[%d]: %d", i, vec[i])
# cute.print_tensor(vec)
@cute.jit
def pass_list_as_vector(xs: List[cutlass.Int32], res: cute.Tensor):
vec = cute.make_fragment(10, dtype=cutlass.Int32)
for i, x in enumerate(xs):
vec[i] = x
kernel_use_vec_as_arg(vec.load(), res).launch(grid=[1, 1, 1], block=[1, 1, 1])
import torch
res = torch.zeros(10, dtype=torch.int32, device="cuda")
pass_list_as_vector([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], from_dlpack(res))
torch.cuda.synchronize()
print(res)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment