Created
January 8, 2026 17:23
-
-
Save leegao/9b7e01d94126a49c42feac05b57dec88 to your computer and use it in GitHub Desktop.
Annotated libtpu vliw dump for the operation: max(x[64,32] @ w[32,64], axis=1), Input in VMEM. Weights in HBM (need DMA
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // Operation: max(x[64,32] @ w[32,64], axis=1), Input in VMEM. Weights in HBM (need DMA). | |
| 0x0 : { %v_const_neg_inf = vmov -inf // Init accumulator for Max reduction | |
| ;; %ptr_x_in = inlined_call_operand.vmem [shape: f32[64,32]] // operand 0: input x | |
| ;; %ptr_w_hbm = inlined_call_operand.hbm [shape: f32[32,64]] // operand 1: weights w (HBM) | |
| ;; %ptr_out_max = inlined_call_operand.vmem [shape: f32[64]] // operand 2: output | |
| ;; %ptr_out_full = inlined_call_operand.vmem [shape: f32[64,64]] // operand 3: scratch } | |
| 0x1 : { %6 = vst [vmem:[#allocation1] sm:$0xff] /*vst_source=*/%v_const_neg_inf } // for another kernel, #allocation1 is an addr of -\infty | |
| // First Phase: get w out of HBM into vmem (the mxu and vpu can only use vmem) | |
| // Prepare the load of w from HBM to vmem at the address #allocation2 | |
| 0x2 : { %7 = vsyncpa [#sync_flag], 0 | |
| ;; %offset_w_hbm = sshll.u32 %ptr_w_hbm, 4 | |
| ;; %offset_w_vmem = smov [#allocation2] | |
| ;; %ptr_src_hbm = int_to_ptr.hbm [resolvable:$true] %offset_w_hbm } | |
| 0x3 : { %offset_w_vmem_shifted = sshll.u32 %offset_w_vmem, 4 | |
| ;; %ptr_dst_vmem = int_to_ptr.vmem [resolvable:$true] %offset_w_vmem_shifted } | |
| // Start the actual load from hbm to vmem (512 granules x 16 bytes per granule == 32 x 64 fp32s) | |
| // This asynchronously loads data from hbm to vmem, atomically incrementing the pointee of the address [#sync_flag] for each successful load | |
| 0x4 : { %20 = dma.hbm_to_vmem [thread:$0] /*hbm=*/%ptr_src_hbm, /*size=*/512, /*vmem=*/%ptr_dst_vmem, /*sync=*/[#sync_flag] } | |
| // Wait for the load to be done (for the # of successful loads at [#sync_flag] to reach all 512) | |
| 0x5 : { %290 = dma.done.wait [#sync_flag], 512 } | |
| // Second Phase: load the chunks of weights w and inputs x from vmem into vregs | |
| // Note that in the systolic array, the inputs x will be stationary (fixed in the array), while | |
| // the weights w will be streamed in. | |
| // There's a specific layout of the streamed inputs (w), so the w_chunks must also be transposed | |
| // by the xlu (tranpose unit) into a layout compatible with the mxu. | |
| // This stage will just be alternating (on chunks of w and x): | |
| // Load w_{i+1} | |
| // Load x_{i+1} | |
| // Push x_i into the mxu's stationary slots (the vmatpush.msra) | |
| // Transpose w_i into an mxu-compatible layout using the xlu0 | |
| // for 4 MXUs (mxu0, mxu1, mxu2, mxu3) | |
| 0x6 : { %291 = vsyncadd [#sync_flag], -512 | |
| ;; %v_w_chunk0 = vld [vmem:[#allocation2] sm:$0xff] // Load w chunk 0 | |
| ;; %v_x_chunk0 = vld [vmem:[%ptr_x_in + $0x18] sm:$0xff] // Load x chunk 0 | |
| ;; %v_lane_id = vlaneseq } // Used later as a mask since we only feed a subset of the mxu | |
| 0x7 : { %61 = vxpose.xlu0.b32.start [1/4] /*vx=*/%v_w_chunk0 // Start Transpose W Chunk 0 | |
| ;; %v_x_chunk1 = vld [vmem:[%ptr_x_in + $0x10] sm:$0xff] // Load x chunk 1 | |
| ;; %86 = vmatpush.msra.mxu0 %v_x_chunk0 // Push x 0 to MXU 0 | |
| ;; %v_x_chunk2 = vld [vmem:[%ptr_x_in + $0x8] sm:$0xff] // Load x chunk 2 | |
| ;; %v_w_chunk1 = vld [vmem:[#allocation2 + $0x10] sm:$0xff]}// Load w chunk 1 | |
| 0x8 : { %252 = vmatpush.msra.mxu1 %v_x_chunk0 // Push x chunk 0 to MXU 1 | |
| ;; %253 = vmatpush.msra.mxu2 %v_x_chunk0 // Push x chunk 0 to MXU 2 | |
| ;; %v_w_chunk2 = vld [vmem:[#allocation2 + $0x8] sm:$0xff] // Load w chunk 2 | |
| ;; %v_w_chunk3 = vld [vmem:[#allocation2 + $0x18] sm:$0xff] // Load w chunk 3 | |
| ;; %v_lane_idx = vand.u32 127, %v_lane_id } // Mark all lanes active (reset it) | |
| 0x9 : { %254 = vmatpush.msra.mxu3 %v_x_chunk0 // Push x chunk 0 to MXU 3 | |
| ;; %92 = vmatpush.msra.mxu0 %v_x_chunk1 // Push x chunk 1 to MXU 0 | |
| ;; %v_x_chunk3 = vld [vmem:[%ptr_x_in] sm:$0xff] } // Load x chunk 3 (Final) | |
| 0xa : { %255 = vmatpush.msra.mxu1 %v_x_chunk1 // Push x chunk 1 to MXU 1 | |
| ;; %256 = vmatpush.msra.mxu2 %v_x_chunk1 // Push x chunk 1 to MXU 2 | |
| ;; %vm_is_valid = vcmp.lt.s32.totalorder %v_lane_idx, 64 // Mask: only turn on if lane < 64 (since this is a (64x32) x (32x64) matmul) | |
| ;; %v_global_acc = vld [vmem:[#allocation1] ss:$0 sm:$0xff] } // Load -inf accumulator (allocation1) | |
| // Push remaining inputs chunks to MXUs | |
| 0xb : { %257 = vmatpush.msra.mxu3 %v_x_chunk1 | |
| ;; %98 = vmatpush.msra.mxu0 %v_x_chunk2 } | |
| 0xc : { %258 = vmatpush.msra.mxu1 %v_x_chunk2 | |
| ;; %259 = vmatpush.msra.mxu2 %v_x_chunk2 } | |
| 0xd : { %260 = vmatpush.msra.mxu3 %v_x_chunk2 | |
| ;; %102 = vmatpush.msra.mxu0 %v_x_chunk3 } | |
| 0xe : { %261 = vmatpush.msra.mxu1 %v_x_chunk3 | |
| ;; %262 = vmatpush.msra.mxu2 %v_x_chunk3 } | |
| // Finish transposing weights w | |
| 0xf : { %62 = vxpose.xlu0.b32.cont [2/4] /*vx=*/%v_w_chunk2 | |
| ;; %263 = vmatpush.msra.mxu3 %v_x_chunk3 } | |
| 0x10 : { %63 = vxpose.xlu0.b32.cont [3/4] /*vx=*/%v_w_chunk1 } | |
| 0x11 : { %64 = vxpose.xlu0.b32.end [4/4] /*vx=*/%v_w_chunk3 } | |
| // Third Phase: the actual matmul using the systolic array | |
| // This will be of the form: | |
| // Fetch transposed weight chunk from xlu0 | |
| // Push it into the activation of the systolic array | |
| // MXU0 | |
| 0x12 : { %v_w_transposed_0 = vpop.trf.xlu0 } // Pop formatted W | |
| 0x13 : { %103 = vmatmul.f32.vlgmr.mxu0 %v_w_transposed_0 } // Matmul Part A (Load Acc) | |
| 0x14 : { %v_w_transposed_1 = vpop.trf.xlu0 } | |
| 0x15 : { %111 = vmatmul.f32.gmra.mxu0 %v_w_transposed_1 } // Matmul Part B (Accumulate) | |
| // MXU1 | |
| 0x16 : { %v_w_transposed_2 = vpop.trf.xlu0 } | |
| 0x17 : { %125 = vmatmul.f32.vlgmr.mxu1 %v_w_transposed_2 } | |
| 0x18 : { %v_w_transposed_3 = vpop.trf.xlu0 } | |
| 0x19 : { %139 = vmatmul.f32.gmra.mxu1 %v_w_transposed_3 } | |
| // MXU2 | |
| 0x1a : { %v_w_transposed_4 = vpop.trf.xlu0 } | |
| 0x1b : { %153 = vmatmul.f32.vlgmr.mxu2 %v_w_transposed_4 } | |
| 0x1c : { %v_w_transposed_5 = vpop.trf.xlu0 } | |
| 0x1d : { %167 = vmatmul.f32.gmra.mxu2 %v_w_transposed_5 } | |
| // MXU3 | |
| 0x1e : { %v_w_transposed_6 = vpop.trf.xlu0 } | |
| 0x1f : { %181 = vmatmul.f32.vlgmr.mxu3 %v_w_transposed_6 } | |
| 0x20 : { %v_w_transposed_7 = vpop.trf.xlu0 } | |
| 0x21 : { %195 = vmatmul.f32.gmra.mxu3 %v_w_transposed_7 } | |
| // Fourth Phase: fetch matmul results and do maximums accumulations | |
| // MXU0 | |
| 0x22 : { %v_res_raw_0 = vpop.f32.mrf.mxu0 } // Pop result chunk 0 of the matmul from MXU0 | |
| 0x23 : { %110 = vst [vmem:[%ptr_out_full] sm:$0xff] %v_res_raw_0 // Save to the scratch intermediate matrix H [64, 64] | |
| ;; %v_res_masked_0 = vsel %vm_is_valid, %v_res_raw_0, -inf } // Mask padding (64 lanes) | |
| 0x24 : { %v_res_raw_1 = vpop.f32.mrf.mxu0 } // Pop result chunk 1 of the matmul from MXU0 | |
| 0x25 : { %245 = vst [vmem:[%ptr_out_full + $0x8] sm:$0xff] %v_res_raw_1 // Save to the scratch intermediate matrix H (at offset 8) | |
| ;; %v_res_masked_1 = vsel %vm_is_valid, %v_res_raw_1, -inf } | |
| 0x26 : { %v_max_acc_0 = vmax.f32 %v_res_masked_0, %v_res_masked_1 } // Accumulate just the max01 = max(chunk_0, chunk_1) | |
| // MXU1 | |
| 0x27 : { %v_res_raw_2 = vpop.f32.mrf.mxu1 } | |
| 0x28 : { %246 = vst [vmem:[%ptr_out_full + $0x10] sm:$0xff] %v_res_raw_2 | |
| ;; %v_res_masked_2 = vsel %vm_is_valid, %v_res_raw_2, -inf } | |
| 0x29 : { %v_max_acc_1 = vmax.f32 %v_max_acc_0, %v_res_masked_2 } // Accumulate max0123 = max(max(chunk_0, chunk_1), chunk_2) | |
| 0x2a : { %v_res_raw_3 = vpop.f32.mrf.mxu1 } | |
| 0x2b : { %247 = vst [vmem:[%ptr_out_full + $0x18] sm:$0xff] %v_res_raw_3 | |
| ;; %v_res_masked_3 = vsel %vm_is_valid, %v_res_raw_3, -inf } | |
| 0x2c : { %v_max_acc_2 = vmax.f32 %v_max_acc_1, %v_res_masked_3 } // Accumulate max0123 = max(max012, chunk_3) | |
| // MXU2 | |
| 0x2d : { %v_res_raw_4 = vpop.f32.mrf.mxu2 } | |
| 0x2e : { %248 = vst [vmem:[%ptr_out_full + $0x20] sm:$0xff] %v_res_raw_4 | |
| ;; %v_res_masked_4 = vsel %vm_is_valid, %v_res_raw_4, -inf } | |
| 0x2f : { %v_max_acc_3 = vmax.f32 %v_max_acc_2, %v_res_masked_4 } // Accumulate max01234 = max(max0123, chunk_4) | |
| 0x30 : { %v_res_raw_5 = vpop.f32.mrf.mxu2 } | |
| 0x31 : { %249 = vst [vmem:[%ptr_out_full + $0x28] sm:$0xff] %v_res_raw_5 | |
| ;; %v_res_masked_5 = vsel %vm_is_valid, %v_res_raw_5, -inf } | |
| 0x32 : { %v_max_acc_4 = vmax.f32 %v_max_acc_3, %v_res_masked_5 } // Accumulate max012345 = max(max01234, chunk_5) | |
| // MXU3 | |
| 0x33 : { %v_res_raw_6 = vpop.f32.mrf.mxu3 } | |
| 0x34 : { %250 = vst [vmem:[%ptr_out_full + $0x30] sm:$0xff] %v_res_raw_6 | |
| ;; %v_res_masked_6 = vsel %vm_is_valid, %v_res_raw_6, -inf } | |
| 0x35 : { %v_max_acc_5 = vmax.f32 %v_max_acc_4, %v_res_masked_6 } // Accumulate max0123456 = max(max012345, chunk_6) | |
| 0x36 : { %v_res_raw_7 = vpop.f32.mrf.mxu3 } | |
| 0x37 : { %v_res_masked_7 = vsel %vm_is_valid, %v_res_raw_7, -inf | |
| ;; %251 = vst [vmem:[%ptr_out_full + $0x38] sm:$0xff] %v_res_raw_7 } | |
| 0x38 : { %v_vertical_max_final = vmax.f32 %v_max_acc_5, %v_res_masked_7 } // Final vmax8 = max(max0123456, chunk_7), is [64, 8] | |
| // Logarithmic horizontal reduction over v_vertical_max_final [64, 8] to just be [64, 1] | |
| // This is a standard prefix sum/reduction using rotations | |
| 0x39 : { %v_rot_4 = vrot.slane %v_vertical_max_final, 4 } // Rotate 4, so [1,2,3,4,5,6,7,8] -> [5,6,7,8,1,2,3,4] | |
| 0x3a : { %v_horiz_max_1 = vmax.f32 %v_vertical_max_final, %v_rot_4 }// [max(1,5), max(2,6), max(3,7), max(4,8), ...] | |
| 0x3b : { %v_rot_2 = vrot.slane %v_horiz_max_1, 2 } // Rotate 2, so [max15, max26, max37, max48] -> [max37, max48, max15, max26] | |
| 0x3c : { %v_horiz_max_2 = vmax.f32 %v_horiz_max_1, %v_rot_2 } // [max1357, max2468] | |
| 0x3d : { %v_rot_1 = vrot.slane %v_horiz_max_2, 1 } // Rotate 1, so [max1357, max2468] -> [max2468, max1357] | |
| 0x3e : { %v_horiz_max_final = vmax.f32 %v_horiz_max_2, %v_rot_1 } // This just reduces down to max12345678 in the first column | |
| // Final Phase: Global accumulate and write to output | |
| 0x3f : { %v_final_result = vmax.f32 %v_global_acc, %v_horiz_max_final } // Merge with global (tiling) | |
| // Note the sm:$0x1, we only take the first column (lane) of v_final_result | |
| 0x40 : { %230 = vst [vmem:[#allocation1] sm:$0x1] %v_final_result } // Reuse allocation1 (which was a vector of -inf) to be the final [64x1] max | |
| 0x41 : { %235 = vsyncpa [#sync_flag], 1 } | |
| 0x42 : { %v_out_load = vld [vmem:[#allocation1] sm:$0x1] } | |
| 0x43 : { %241 = vst [vmem:[%ptr_out_max] sm:$0x1] %v_out_load } // Write Scalar Output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment