Created
August 9, 2025 17:04
-
-
Save melonedo/8edeade85962c6de8c3721f61885fdd1 to your computer and use it in GitHub Desktop.
Generate generic intra warp shuffle for use with CUDA
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # The algorithm proposed by https://github.com/triton-lang/triton/pull/7558 | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| def inverse_permutation(P): | |
| """ | |
| Given a permutation P, return its inverse permutation. | |
| """ | |
| P_inv = [None] * len(P) | |
| for i, p in enumerate(P): | |
| P_inv[p] = i | |
| assert all(p is not None for p in P_inv) | |
| return P_inv | |
| def compose_permutation(a, b): | |
| return [a[p] for p in b] | |
| def is_identity_permutation(P): | |
| for i, p in enumerate(P): | |
| if i != p: | |
| return False | |
| return True | |
| class OpCode(Enum): | |
| COND_SWAP = 0 | |
| WARP_SHUFFLE = 1 | |
| PERMUTE_LANEID = 2 | |
| @dataclass | |
| class RegLaneLayout: | |
| """ | |
| A simplified version of generic linear layout, where only support register id and lane id are supported. | |
| Id < reg_dim are register id, and Id >= reg_dim are lane id. | |
| """ | |
| reg_dim: int | |
| def factorize(self, P): | |
| """ | |
| Given P, a permutation, which functions as an adjacency list, | |
| will walk P following this adjacency list | |
| then detect the pattern [r_begin, l_begin, l_begin+1, ..., l_end, r_end] in the loops, | |
| where r means n-th element where n < k, and l means n-th element where n >= k, | |
| and then output the result as a combination of in- and cross-group permutations. | |
| """ | |
| in_group = [None] * len(P) | |
| visited = [False] * len(P) | |
| cross_group = [] | |
| for begin in range(len(P)): | |
| if visited[begin]: | |
| continue | |
| cur = begin | |
| r_end, r_begin, l_begin, l_end = [None] * 4 | |
| while True: | |
| assert not visited[cur] | |
| visited[cur] = True | |
| next = P[cur] | |
| if cur < self.reg_dim: | |
| if next < self.reg_dim: | |
| in_group[cur] = next | |
| else: | |
| assert r_begin is None and l_begin is None | |
| r_begin = cur | |
| l_begin = next | |
| else: | |
| if next < self.reg_dim: | |
| assert l_end is None and r_end is None | |
| l_end = cur | |
| r_end = next | |
| else: | |
| in_group[cur] = next - self.reg_dim | |
| if r_end is not None: | |
| # We iterate first registers then lanes | |
| # so we won't see an end before a begin | |
| assert ( | |
| r_begin is not None | |
| and l_begin is not None | |
| and l_end is not None | |
| ) | |
| in_group[r_begin] = r_end | |
| in_group[l_end] = l_begin - self.reg_dim | |
| cross_group.append((r_end, l_begin - self.reg_dim)) | |
| r_end, r_begin, l_begin, l_end = [None] * 4 | |
| if begin == next: | |
| break | |
| cur = next | |
| assert all(visited) | |
| return in_group[: self.reg_dim], in_group[self.reg_dim :], cross_group | |
| def emit_cond_swap(self, ops, reg, lane): | |
| mask = 1 << lane | |
| tile_size = 1 << (reg + 1) | |
| for tile_begin in range(0, 1 << self.reg_dim, tile_size): | |
| for i in range(tile_size // 2): | |
| ops.append( | |
| ( | |
| OpCode.COND_SWAP, | |
| tile_begin + i, | |
| tile_begin + i + tile_size // 2, | |
| mask, | |
| ) | |
| ) | |
| def perform_shuffle(self, p_reg, p_lane, p_cross): | |
| assert is_identity_permutation(p_reg) | |
| ops = [] | |
| p_lane_inv = inverse_permutation(p_lane) | |
| for reg, lane in p_cross: | |
| self.emit_cond_swap(ops, reg, p_lane_inv[lane]) | |
| p_lane_trivial = is_identity_permutation(p_lane) | |
| if not p_lane_trivial: | |
| ops.append((OpCode.PERMUTE_LANEID, p_lane_inv)) | |
| else: | |
| ops.append((OpCode.PERMUTE_LANEID, None)) | |
| for r in range(1 << self.reg_dim): | |
| mask = 0 | |
| for reg, lane in p_cross: | |
| if r & (1 << reg): | |
| mask |= 1 << p_lane_inv[lane] | |
| if mask != 0 or not p_lane_trivial: | |
| ops.append((OpCode.WARP_SHUFFLE, r, mask)) | |
| for reg, lane in p_cross: | |
| self.emit_cond_swap(ops, reg, lane) | |
| return ops | |
| def emit_cuda(ops): | |
| result = [] | |
| rename = dict() | |
| counter = 0 | |
| for op in ops: | |
| if op[0] == OpCode.COND_SWAP: | |
| a, b, mask = op[1:] | |
| # result.append( | |
| # f"if (laneid & {mask}) {{ auto tmp = reg[{a}]; reg[{a}] = reg[{b}]; reg[{b}] = tmp; }}" | |
| # ) | |
| src_a = rename.get(a, f"reg[{a}]") | |
| src_b = rename.get(b, f"reg[{b}]") | |
| dst_a = f"tmp{counter}" | |
| dst_b = f"tmp{counter + 1}" | |
| counter += 2 | |
| result.append(f"auto {dst_a} = laneid & {mask} ? {src_b} : {src_a};") | |
| result.append(f"auto {dst_b} = laneid & {mask} ? {src_a} : {src_b};") | |
| rename[a] = dst_a | |
| rename[b] = dst_b | |
| elif op[0] == OpCode.WARP_SHUFFLE: | |
| reg, mask = op[1:] | |
| src = rename.get(reg, f"reg[{reg}]") | |
| dst = f"tmp{counter}" | |
| counter += 1 | |
| result.append( | |
| f"auto {dst} = __shfl_sync(0xffffffff, {src}, shuffle_laneid ^ {mask});" | |
| ) | |
| rename[reg] = dst | |
| elif op[0] == OpCode.PERMUTE_LANEID: | |
| p = op[1] | |
| if not p: | |
| result.append("int shuffle_laneid = laneid;") | |
| continue | |
| # TODO: Optimize this | |
| parts = [f"int shuffle_laneid = (laneid & ~{(1 << len(p)) - 1})"] | |
| for i, l in enumerate(p): | |
| if l > i: | |
| parts.append(f"((laneid & {1 << i}) << {l - i})") | |
| elif l < i: | |
| parts.append(f"((laneid & {1 << i}) >> {i - l})") | |
| else: | |
| parts.append(f"(laneid & {1 << i})") | |
| result.append(" | ".join(parts) + ";") | |
| else: | |
| raise ValueError(f"Unknown operator {op}") | |
| for key, value in rename.items(): | |
| result.append(f"reg[{key}] = {value};") | |
| return "\n".join(result) | |
| def emulate(num_threads, num_regs, ops, regs=None): | |
| if not regs: | |
| regs = [[r + t * num_regs for r in range(num_regs)] for t in range(num_threads)] | |
| laneids = list(range(num_threads)) | |
| for op in ops: | |
| if op[0] == OpCode.COND_SWAP: | |
| a, b, mask = op[1:] | |
| for lane, reg in enumerate(regs): | |
| if lane & mask: | |
| reg[a], reg[b] = reg[b], reg[a] | |
| elif op[0] == OpCode.WARP_SHUFFLE: | |
| dst, mask = op[1:] | |
| tmp = [None] * num_threads | |
| for i, lane in enumerate(laneids): | |
| tmp[i] = regs[lane ^ mask][dst] | |
| for i, (reg, t) in enumerate(zip(regs, tmp)): | |
| reg[dst] = t | |
| elif op[0] == OpCode.PERMUTE_LANEID: | |
| p = op[1] | |
| if not p: | |
| continue | |
| for j in range(num_threads): | |
| old_id = laneids[j] | |
| new_id = old_id & ~((1 << len(p)) - 1) | |
| for i, l in enumerate(p): | |
| if l > i: | |
| new_id |= (old_id & (1 << i)) << (l - i) | |
| elif l < i: | |
| new_id |= (old_id & (1 << i)) >> (i - l) | |
| else: | |
| new_id |= old_id & (1 << i) | |
| laneids[j] = new_id | |
| else: | |
| raise ValueError(f"Unknown operator {op}") | |
| return regs | |
| layout = RegLaneLayout(1) | |
| P = inverse_permutation([2, 0, 1]) | |
| factors = layout.factorize(P) | |
| print(P, factors) | |
| ops = layout.perform_shuffle(*factors) | |
| print(ops) | |
| print(emit_cuda(ops)) | |
| regs = emulate(4, 2, ops) | |
| print(regs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment