Skip to content

Instantly share code, notes, and snippets.

@leegao
Created January 8, 2026 17:23
Show Gist options
  • Select an option

  • Save leegao/9b7e01d94126a49c42feac05b57dec88 to your computer and use it in GitHub Desktop.

Select an option

Save leegao/9b7e01d94126a49c42feac05b57dec88 to your computer and use it in GitHub Desktop.
Annotated libtpu vliw dump for the operation: max(x[64,32] @ w[32,64], axis=1), Input in VMEM. Weights in HBM (need DMA
// Operation: max(x[64,32] @ w[32,64], axis=1), Input in VMEM. Weights in HBM (need DMA).
0x0 : { %v_const_neg_inf = vmov -inf // Init accumulator for Max reduction
;; %ptr_x_in = inlined_call_operand.vmem [shape: f32[64,32]] // operand 0: input x
;; %ptr_w_hbm = inlined_call_operand.hbm [shape: f32[32,64]] // operand 1: weights w (HBM)
;; %ptr_out_max = inlined_call_operand.vmem [shape: f32[64]] // operand 2: output
;; %ptr_out_full = inlined_call_operand.vmem [shape: f32[64,64]] // operand 3: scratch }
0x1 : { %6 = vst [vmem:[#allocation1] sm:$0xff] /*vst_source=*/%v_const_neg_inf } // for another kernel, #allocation1 is an addr of -\infty
// First Phase: get w out of HBM into vmem (the mxu and vpu can only use vmem)
// Prepare the load of w from HBM to vmem at the address #allocation2
0x2 : { %7 = vsyncpa [#sync_flag], 0
;; %offset_w_hbm = sshll.u32 %ptr_w_hbm, 4
;; %offset_w_vmem = smov [#allocation2]
;; %ptr_src_hbm = int_to_ptr.hbm [resolvable:$true] %offset_w_hbm }
0x3 : { %offset_w_vmem_shifted = sshll.u32 %offset_w_vmem, 4
;; %ptr_dst_vmem = int_to_ptr.vmem [resolvable:$true] %offset_w_vmem_shifted }
// Start the actual load from hbm to vmem (512 granules x 16 bytes per granule == 32 x 64 fp32s)
// This asynchronously loads data from hbm to vmem, atomically incrementing the pointee of the address [#sync_flag] for each successful load
0x4 : { %20 = dma.hbm_to_vmem [thread:$0] /*hbm=*/%ptr_src_hbm, /*size=*/512, /*vmem=*/%ptr_dst_vmem, /*sync=*/[#sync_flag] }
// Wait for the load to be done (for the # of successful loads at [#sync_flag] to reach all 512)
0x5 : { %290 = dma.done.wait [#sync_flag], 512 }
// Second Phase: load the chunks of weights w and inputs x from vmem into vregs
// Note that in the systolic array, the inputs x will be stationary (fixed in the array), while
// the weights w will be streamed in.
// There's a specific layout of the streamed inputs (w), so the w_chunks must also be transposed
// by the xlu (tranpose unit) into a layout compatible with the mxu.
// This stage will just be alternating (on chunks of w and x):
// Load w_{i+1}
// Load x_{i+1}
// Push x_i into the mxu's stationary slots (the vmatpush.msra)
// Transpose w_i into an mxu-compatible layout using the xlu0
// for 4 MXUs (mxu0, mxu1, mxu2, mxu3)
0x6 : { %291 = vsyncadd [#sync_flag], -512
;; %v_w_chunk0 = vld [vmem:[#allocation2] sm:$0xff] // Load w chunk 0
;; %v_x_chunk0 = vld [vmem:[%ptr_x_in + $0x18] sm:$0xff] // Load x chunk 0
;; %v_lane_id = vlaneseq } // Used later as a mask since we only feed a subset of the mxu
0x7 : { %61 = vxpose.xlu0.b32.start [1/4] /*vx=*/%v_w_chunk0 // Start Transpose W Chunk 0
;; %v_x_chunk1 = vld [vmem:[%ptr_x_in + $0x10] sm:$0xff] // Load x chunk 1
;; %86 = vmatpush.msra.mxu0 %v_x_chunk0 // Push x 0 to MXU 0
;; %v_x_chunk2 = vld [vmem:[%ptr_x_in + $0x8] sm:$0xff] // Load x chunk 2
;; %v_w_chunk1 = vld [vmem:[#allocation2 + $0x10] sm:$0xff]}// Load w chunk 1
0x8 : { %252 = vmatpush.msra.mxu1 %v_x_chunk0 // Push x chunk 0 to MXU 1
;; %253 = vmatpush.msra.mxu2 %v_x_chunk0 // Push x chunk 0 to MXU 2
;; %v_w_chunk2 = vld [vmem:[#allocation2 + $0x8] sm:$0xff] // Load w chunk 2
;; %v_w_chunk3 = vld [vmem:[#allocation2 + $0x18] sm:$0xff] // Load w chunk 3
;; %v_lane_idx = vand.u32 127, %v_lane_id } // Mark all lanes active (reset it)
0x9 : { %254 = vmatpush.msra.mxu3 %v_x_chunk0 // Push x chunk 0 to MXU 3
;; %92 = vmatpush.msra.mxu0 %v_x_chunk1 // Push x chunk 1 to MXU 0
;; %v_x_chunk3 = vld [vmem:[%ptr_x_in] sm:$0xff] } // Load x chunk 3 (Final)
0xa : { %255 = vmatpush.msra.mxu1 %v_x_chunk1 // Push x chunk 1 to MXU 1
;; %256 = vmatpush.msra.mxu2 %v_x_chunk1 // Push x chunk 1 to MXU 2
;; %vm_is_valid = vcmp.lt.s32.totalorder %v_lane_idx, 64 // Mask: only turn on if lane < 64 (since this is a (64x32) x (32x64) matmul)
;; %v_global_acc = vld [vmem:[#allocation1] ss:$0 sm:$0xff] } // Load -inf accumulator (allocation1)
// Push remaining inputs chunks to MXUs
0xb : { %257 = vmatpush.msra.mxu3 %v_x_chunk1
;; %98 = vmatpush.msra.mxu0 %v_x_chunk2 }
0xc : { %258 = vmatpush.msra.mxu1 %v_x_chunk2
;; %259 = vmatpush.msra.mxu2 %v_x_chunk2 }
0xd : { %260 = vmatpush.msra.mxu3 %v_x_chunk2
;; %102 = vmatpush.msra.mxu0 %v_x_chunk3 }
0xe : { %261 = vmatpush.msra.mxu1 %v_x_chunk3
;; %262 = vmatpush.msra.mxu2 %v_x_chunk3 }
// Finish transposing weights w
0xf : { %62 = vxpose.xlu0.b32.cont [2/4] /*vx=*/%v_w_chunk2
;; %263 = vmatpush.msra.mxu3 %v_x_chunk3 }
0x10 : { %63 = vxpose.xlu0.b32.cont [3/4] /*vx=*/%v_w_chunk1 }
0x11 : { %64 = vxpose.xlu0.b32.end [4/4] /*vx=*/%v_w_chunk3 }
// Third Phase: the actual matmul using the systolic array
// This will be of the form:
// Fetch transposed weight chunk from xlu0
// Push it into the activation of the systolic array
// MXU0
0x12 : { %v_w_transposed_0 = vpop.trf.xlu0 } // Pop formatted W
0x13 : { %103 = vmatmul.f32.vlgmr.mxu0 %v_w_transposed_0 } // Matmul Part A (Load Acc)
0x14 : { %v_w_transposed_1 = vpop.trf.xlu0 }
0x15 : { %111 = vmatmul.f32.gmra.mxu0 %v_w_transposed_1 } // Matmul Part B (Accumulate)
// MXU1
0x16 : { %v_w_transposed_2 = vpop.trf.xlu0 }
0x17 : { %125 = vmatmul.f32.vlgmr.mxu1 %v_w_transposed_2 }
0x18 : { %v_w_transposed_3 = vpop.trf.xlu0 }
0x19 : { %139 = vmatmul.f32.gmra.mxu1 %v_w_transposed_3 }
// MXU2
0x1a : { %v_w_transposed_4 = vpop.trf.xlu0 }
0x1b : { %153 = vmatmul.f32.vlgmr.mxu2 %v_w_transposed_4 }
0x1c : { %v_w_transposed_5 = vpop.trf.xlu0 }
0x1d : { %167 = vmatmul.f32.gmra.mxu2 %v_w_transposed_5 }
// MXU3
0x1e : { %v_w_transposed_6 = vpop.trf.xlu0 }
0x1f : { %181 = vmatmul.f32.vlgmr.mxu3 %v_w_transposed_6 }
0x20 : { %v_w_transposed_7 = vpop.trf.xlu0 }
0x21 : { %195 = vmatmul.f32.gmra.mxu3 %v_w_transposed_7 }
// Fourth Phase: fetch matmul results and do maximums accumulations
// MXU0
0x22 : { %v_res_raw_0 = vpop.f32.mrf.mxu0 } // Pop result chunk 0 of the matmul from MXU0
0x23 : { %110 = vst [vmem:[%ptr_out_full] sm:$0xff] %v_res_raw_0 // Save to the scratch intermediate matrix H [64, 64]
;; %v_res_masked_0 = vsel %vm_is_valid, %v_res_raw_0, -inf } // Mask padding (64 lanes)
0x24 : { %v_res_raw_1 = vpop.f32.mrf.mxu0 } // Pop result chunk 1 of the matmul from MXU0
0x25 : { %245 = vst [vmem:[%ptr_out_full + $0x8] sm:$0xff] %v_res_raw_1 // Save to the scratch intermediate matrix H (at offset 8)
;; %v_res_masked_1 = vsel %vm_is_valid, %v_res_raw_1, -inf }
0x26 : { %v_max_acc_0 = vmax.f32 %v_res_masked_0, %v_res_masked_1 } // Accumulate just the max01 = max(chunk_0, chunk_1)
// MXU1
0x27 : { %v_res_raw_2 = vpop.f32.mrf.mxu1 }
0x28 : { %246 = vst [vmem:[%ptr_out_full + $0x10] sm:$0xff] %v_res_raw_2
;; %v_res_masked_2 = vsel %vm_is_valid, %v_res_raw_2, -inf }
0x29 : { %v_max_acc_1 = vmax.f32 %v_max_acc_0, %v_res_masked_2 } // Accumulate max0123 = max(max(chunk_0, chunk_1), chunk_2)
0x2a : { %v_res_raw_3 = vpop.f32.mrf.mxu1 }
0x2b : { %247 = vst [vmem:[%ptr_out_full + $0x18] sm:$0xff] %v_res_raw_3
;; %v_res_masked_3 = vsel %vm_is_valid, %v_res_raw_3, -inf }
0x2c : { %v_max_acc_2 = vmax.f32 %v_max_acc_1, %v_res_masked_3 } // Accumulate max0123 = max(max012, chunk_3)
// MXU2
0x2d : { %v_res_raw_4 = vpop.f32.mrf.mxu2 }
0x2e : { %248 = vst [vmem:[%ptr_out_full + $0x20] sm:$0xff] %v_res_raw_4
;; %v_res_masked_4 = vsel %vm_is_valid, %v_res_raw_4, -inf }
0x2f : { %v_max_acc_3 = vmax.f32 %v_max_acc_2, %v_res_masked_4 } // Accumulate max01234 = max(max0123, chunk_4)
0x30 : { %v_res_raw_5 = vpop.f32.mrf.mxu2 }
0x31 : { %249 = vst [vmem:[%ptr_out_full + $0x28] sm:$0xff] %v_res_raw_5
;; %v_res_masked_5 = vsel %vm_is_valid, %v_res_raw_5, -inf }
0x32 : { %v_max_acc_4 = vmax.f32 %v_max_acc_3, %v_res_masked_5 } // Accumulate max012345 = max(max01234, chunk_5)
// MXU3
0x33 : { %v_res_raw_6 = vpop.f32.mrf.mxu3 }
0x34 : { %250 = vst [vmem:[%ptr_out_full + $0x30] sm:$0xff] %v_res_raw_6
;; %v_res_masked_6 = vsel %vm_is_valid, %v_res_raw_6, -inf }
0x35 : { %v_max_acc_5 = vmax.f32 %v_max_acc_4, %v_res_masked_6 } // Accumulate max0123456 = max(max012345, chunk_6)
0x36 : { %v_res_raw_7 = vpop.f32.mrf.mxu3 }
0x37 : { %v_res_masked_7 = vsel %vm_is_valid, %v_res_raw_7, -inf
;; %251 = vst [vmem:[%ptr_out_full + $0x38] sm:$0xff] %v_res_raw_7 }
0x38 : { %v_vertical_max_final = vmax.f32 %v_max_acc_5, %v_res_masked_7 } // Final vmax8 = max(max0123456, chunk_7), is [64, 8]
// Logarithmic horizontal reduction over v_vertical_max_final [64, 8] to just be [64, 1]
// This is a standard prefix sum/reduction using rotations
0x39 : { %v_rot_4 = vrot.slane %v_vertical_max_final, 4 } // Rotate 4, so [1,2,3,4,5,6,7,8] -> [5,6,7,8,1,2,3,4]
0x3a : { %v_horiz_max_1 = vmax.f32 %v_vertical_max_final, %v_rot_4 }// [max(1,5), max(2,6), max(3,7), max(4,8), ...]
0x3b : { %v_rot_2 = vrot.slane %v_horiz_max_1, 2 } // Rotate 2, so [max15, max26, max37, max48] -> [max37, max48, max15, max26]
0x3c : { %v_horiz_max_2 = vmax.f32 %v_horiz_max_1, %v_rot_2 } // [max1357, max2468]
0x3d : { %v_rot_1 = vrot.slane %v_horiz_max_2, 1 } // Rotate 1, so [max1357, max2468] -> [max2468, max1357]
0x3e : { %v_horiz_max_final = vmax.f32 %v_horiz_max_2, %v_rot_1 } // This just reduces down to max12345678 in the first column
// Final Phase: Global accumulate and write to output
0x3f : { %v_final_result = vmax.f32 %v_global_acc, %v_horiz_max_final } // Merge with global (tiling)
// Note the sm:$0x1, we only take the first column (lane) of v_final_result
0x40 : { %230 = vst [vmem:[#allocation1] sm:$0x1] %v_final_result } // Reuse allocation1 (which was a vector of -inf) to be the final [64x1] max
0x41 : { %235 = vsyncpa [#sync_flag], 1 }
0x42 : { %v_out_load = vld [vmem:[#allocation1] sm:$0x1] }
0x43 : { %241 = vst [vmem:[%ptr_out_max] sm:$0x1] %v_out_load } // Write Scalar Output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment