Skip to content

Instantly share code, notes, and snippets.

@melonedo
Created August 9, 2025 17:04
Show Gist options
  • Select an option

  • Save melonedo/8edeade85962c6de8c3721f61885fdd1 to your computer and use it in GitHub Desktop.

Select an option

Save melonedo/8edeade85962c6de8c3721f61885fdd1 to your computer and use it in GitHub Desktop.
Generate generic intra warp shuffle for use with CUDA
# The algorithm proposed by https://github.com/triton-lang/triton/pull/7558
from dataclasses import dataclass
from enum import Enum
def inverse_permutation(P):
"""
Given a permutation P, return its inverse permutation.
"""
P_inv = [None] * len(P)
for i, p in enumerate(P):
P_inv[p] = i
assert all(p is not None for p in P_inv)
return P_inv
def compose_permutation(a, b):
return [a[p] for p in b]
def is_identity_permutation(P):
for i, p in enumerate(P):
if i != p:
return False
return True
class OpCode(Enum):
COND_SWAP = 0
WARP_SHUFFLE = 1
PERMUTE_LANEID = 2
@dataclass
class RegLaneLayout:
"""
A simplified version of generic linear layout, where only support register id and lane id are supported.
Id < reg_dim are register id, and Id >= reg_dim are lane id.
"""
reg_dim: int
def factorize(self, P):
"""
Given P, a permutation, which functions as an adjacency list,
will walk P following this adjacency list
then detect the pattern [r_begin, l_begin, l_begin+1, ..., l_end, r_end] in the loops,
where r means n-th element where n < k, and l means n-th element where n >= k,
and then output the result as a combination of in- and cross-group permutations.
"""
in_group = [None] * len(P)
visited = [False] * len(P)
cross_group = []
for begin in range(len(P)):
if visited[begin]:
continue
cur = begin
r_end, r_begin, l_begin, l_end = [None] * 4
while True:
assert not visited[cur]
visited[cur] = True
next = P[cur]
if cur < self.reg_dim:
if next < self.reg_dim:
in_group[cur] = next
else:
assert r_begin is None and l_begin is None
r_begin = cur
l_begin = next
else:
if next < self.reg_dim:
assert l_end is None and r_end is None
l_end = cur
r_end = next
else:
in_group[cur] = next - self.reg_dim
if r_end is not None:
# We iterate first registers then lanes
# so we won't see an end before a begin
assert (
r_begin is not None
and l_begin is not None
and l_end is not None
)
in_group[r_begin] = r_end
in_group[l_end] = l_begin - self.reg_dim
cross_group.append((r_end, l_begin - self.reg_dim))
r_end, r_begin, l_begin, l_end = [None] * 4
if begin == next:
break
cur = next
assert all(visited)
return in_group[: self.reg_dim], in_group[self.reg_dim :], cross_group
def emit_cond_swap(self, ops, reg, lane):
mask = 1 << lane
tile_size = 1 << (reg + 1)
for tile_begin in range(0, 1 << self.reg_dim, tile_size):
for i in range(tile_size // 2):
ops.append(
(
OpCode.COND_SWAP,
tile_begin + i,
tile_begin + i + tile_size // 2,
mask,
)
)
def perform_shuffle(self, p_reg, p_lane, p_cross):
assert is_identity_permutation(p_reg)
ops = []
p_lane_inv = inverse_permutation(p_lane)
for reg, lane in p_cross:
self.emit_cond_swap(ops, reg, p_lane_inv[lane])
p_lane_trivial = is_identity_permutation(p_lane)
if not p_lane_trivial:
ops.append((OpCode.PERMUTE_LANEID, p_lane_inv))
else:
ops.append((OpCode.PERMUTE_LANEID, None))
for r in range(1 << self.reg_dim):
mask = 0
for reg, lane in p_cross:
if r & (1 << reg):
mask |= 1 << p_lane_inv[lane]
if mask != 0 or not p_lane_trivial:
ops.append((OpCode.WARP_SHUFFLE, r, mask))
for reg, lane in p_cross:
self.emit_cond_swap(ops, reg, lane)
return ops
def emit_cuda(ops):
result = []
rename = dict()
counter = 0
for op in ops:
if op[0] == OpCode.COND_SWAP:
a, b, mask = op[1:]
# result.append(
# f"if (laneid & {mask}) {{ auto tmp = reg[{a}]; reg[{a}] = reg[{b}]; reg[{b}] = tmp; }}"
# )
src_a = rename.get(a, f"reg[{a}]")
src_b = rename.get(b, f"reg[{b}]")
dst_a = f"tmp{counter}"
dst_b = f"tmp{counter + 1}"
counter += 2
result.append(f"auto {dst_a} = laneid & {mask} ? {src_b} : {src_a};")
result.append(f"auto {dst_b} = laneid & {mask} ? {src_a} : {src_b};")
rename[a] = dst_a
rename[b] = dst_b
elif op[0] == OpCode.WARP_SHUFFLE:
reg, mask = op[1:]
src = rename.get(reg, f"reg[{reg}]")
dst = f"tmp{counter}"
counter += 1
result.append(
f"auto {dst} = __shfl_sync(0xffffffff, {src}, shuffle_laneid ^ {mask});"
)
rename[reg] = dst
elif op[0] == OpCode.PERMUTE_LANEID:
p = op[1]
if not p:
result.append("int shuffle_laneid = laneid;")
continue
# TODO: Optimize this
parts = [f"int shuffle_laneid = (laneid & ~{(1 << len(p)) - 1})"]
for i, l in enumerate(p):
if l > i:
parts.append(f"((laneid & {1 << i}) << {l - i})")
elif l < i:
parts.append(f"((laneid & {1 << i}) >> {i - l})")
else:
parts.append(f"(laneid & {1 << i})")
result.append(" | ".join(parts) + ";")
else:
raise ValueError(f"Unknown operator {op}")
for key, value in rename.items():
result.append(f"reg[{key}] = {value};")
return "\n".join(result)
def emulate(num_threads, num_regs, ops, regs=None):
if not regs:
regs = [[r + t * num_regs for r in range(num_regs)] for t in range(num_threads)]
laneids = list(range(num_threads))
for op in ops:
if op[0] == OpCode.COND_SWAP:
a, b, mask = op[1:]
for lane, reg in enumerate(regs):
if lane & mask:
reg[a], reg[b] = reg[b], reg[a]
elif op[0] == OpCode.WARP_SHUFFLE:
dst, mask = op[1:]
tmp = [None] * num_threads
for i, lane in enumerate(laneids):
tmp[i] = regs[lane ^ mask][dst]
for i, (reg, t) in enumerate(zip(regs, tmp)):
reg[dst] = t
elif op[0] == OpCode.PERMUTE_LANEID:
p = op[1]
if not p:
continue
for j in range(num_threads):
old_id = laneids[j]
new_id = old_id & ~((1 << len(p)) - 1)
for i, l in enumerate(p):
if l > i:
new_id |= (old_id & (1 << i)) << (l - i)
elif l < i:
new_id |= (old_id & (1 << i)) >> (i - l)
else:
new_id |= old_id & (1 << i)
laneids[j] = new_id
else:
raise ValueError(f"Unknown operator {op}")
return regs
layout = RegLaneLayout(1)
P = inverse_permutation([2, 0, 1])
factors = layout.factorize(P)
print(P, factors)
ops = layout.perform_shuffle(*factors)
print(ops)
print(emit_cuda(ops))
regs = emulate(4, 2, ops)
print(regs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment