Created
August 9, 2024 02:39
-
-
Save Jokeren/483687e5bb4968f61a0564d35b06d724 to your computer and use it in GitHub Desktop.
mlirs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> | |
| #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> | |
| #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> | |
| #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> | |
| #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> | |
| #shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> | |
| module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { | |
| tt.func public @hoist_convert_above_extf_and_remat(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f16>) attributes {noinline = false} { | |
| %cst = arith.constant dense<0.000000e+00> : tensor<32x256xf32, #mma> | |
| %c32_i32 = arith.constant 32 : i32 | |
| %cst_0 = arith.constant dense<256> : tensor<32x1xi32, #blocked> | |
| %cst_1 = arith.constant dense<256> : tensor<32x1xi32, #blocked1> | |
| %cst_2 = arith.constant dense<256> : tensor<256x1xi32, #blocked> | |
| %c64_i32 = arith.constant 64 : i32 | |
| %c256_i32 = arith.constant 256 : i32 | |
| %c0_i32 = arith.constant 0 : i32 | |
| %cst_3 = arith.constant dense<1.000000e-03> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> | |
| %cst_4 = arith.constant dense<2.560000e+02> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> | |
| %0 = tt.get_program_id x : i32 | |
| %1 = arith.muli %0, %c32_i32 : i32 | |
| %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> | |
| %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> | |
| %4 = tt.splat %1 : i32 -> tensor<32x1xi32, #blocked> | |
| %5 = arith.addi %4, %3 : tensor<32x1xi32, #blocked> | |
| %6 = arith.muli %5, %cst_0 : tensor<32x1xi32, #blocked> | |
| %7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> | |
| %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> | |
| %9 = tt.expand_dims %7 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> | |
| %10 = tt.expand_dims %8 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> | |
| %11 = tt.broadcast %9 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> | |
| %12 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> | |
| %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> | |
| %14 = arith.muli %13, %cst_2 : tensor<256x1xi32, #blocked> | |
| %15 = tt.broadcast %10 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked> | |
| %16 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked> | |
| %17 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<256x64x!tt.ptr<f16>, #blocked> | |
| %18 = scf.for %arg7 = %c0_i32 to %c256_i32 step %c64_i32 iter_args(%arg8 = %cst) -> (tensor<32x256xf32, #mma>) : i32 { | |
| %60 = tt.splat %arg7 : i32 -> tensor<32x1xi32, #blocked> | |
| %61 = arith.addi %6, %60 : tensor<32x1xi32, #blocked> | |
| %62 = tt.broadcast %61 : tensor<32x1xi32, #blocked> -> tensor<32x64xi32, #blocked> | |
| %63 = arith.addi %62, %11 : tensor<32x64xi32, #blocked> | |
| %64 = tt.splat %arg7 : i32 -> tensor<256x1xi32, #blocked> | |
| %65 = arith.addi %14, %64 : tensor<256x1xi32, #blocked> | |
| %66 = tt.broadcast %65 : tensor<256x1xi32, #blocked> -> tensor<256x64xi32, #blocked> | |
| %67 = arith.addi %66, %15 : tensor<256x64xi32, #blocked> | |
| %68 = tt.addptr %16, %63 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked> | |
| %69 = tt.load %68 : tensor<32x64x!tt.ptr<f16>, #blocked> | |
| %70 = tt.addptr %17, %67 : tensor<256x64x!tt.ptr<f16>, #blocked>, tensor<256x64xi32, #blocked> | |
| %71 = tt.load %70 : tensor<256x64x!tt.ptr<f16>, #blocked> | |
| %72 = triton_gpu.local_alloc %71 : (tensor<256x64xf16, #blocked>) -> !tt.memdesc<256x64xf16, #shared> | |
| %73 = tt.trans %72 {order = array<i32: 1, 0>} : !tt.memdesc<256x64xf16, #shared> -> !tt.memdesc<64x256xf16, #shared1> | |
| %74 = triton_gpu.local_load %73 : !tt.memdesc<64x256xf16, #shared1> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> | |
| %75 = triton_gpu.convert_layout %69 : tensor<32x64xf16, #blocked> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> | |
| %76 = tt.dot %75, %74, %arg8 : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x256xf32, #mma> | |
| scf.yield %76 : tensor<32x256xf32, #mma> | |
| } | |
| %19 = arith.truncf %18 : tensor<32x256xf32, #mma> to tensor<32x256xf16, #mma> | |
| %20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> | |
| %21 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> | |
| %22 = tt.expand_dims %20 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> | |
| %23 = tt.expand_dims %21 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> | |
| %24 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<1x256x!tt.ptr<f16>, #blocked2> | |
| %25 = tt.addptr %24, %22 : tensor<1x256x!tt.ptr<f16>, #blocked2>, tensor<1x256xi32, #blocked2> | |
| %26 = tt.load %25 : tensor<1x256x!tt.ptr<f16>, #blocked2> | |
| %27 = triton_gpu.convert_layout %26 : tensor<1x256xf16, #blocked2> -> tensor<1x256xf16, #mma> | |
| %28 = tt.broadcast %27 : tensor<1x256xf16, #mma> -> tensor<32x256xf16, #mma> | |
| %29 = arith.addf %19, %28 : tensor<32x256xf16, #mma> | |
| %30 = arith.extf %29 : tensor<32x256xf16, #mma> to tensor<32x256xf32, #mma> | |
| %31 = arith.extf %29 : tensor<32x256xf16, #mma> to tensor<32x256xf32, #mma> | |
| %32 = arith.extf %29 : tensor<32x256xf16, #mma> to tensor<32x256xf32, #mma> | |
| %33 = "tt.reduce"(%30) <{axis = 1 : i32}> ({ | |
| ^bb0(%arg7: f32, %arg8: f32): | |
| %60 = arith.addf %arg7, %arg8 : f32 | |
| tt.reduce.return %60 : f32 | |
| }) : (tensor<32x256xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> | |
| %34 = arith.divf %33, %cst_4 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> | |
| %35 = arith.mulf %31, %31 : tensor<32x256xf32, #mma> | |
| %36 = "tt.reduce"(%35) <{axis = 1 : i32}> ({ | |
| ^bb0(%arg7: f32, %arg8: f32): | |
| %60 = arith.addf %arg7, %arg8 : f32 | |
| tt.reduce.return %60 : f32 | |
| }) : (tensor<32x256xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> | |
| %37 = arith.divf %36, %cst_4 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> | |
| %38 = arith.mulf %34, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> | |
| %39 = arith.subf %37, %38 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> | |
| %40 = math.sqrt %39 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> | |
| %41 = arith.addf %40, %cst_3 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> | |
| %42 = tt.expand_dims %34 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xf32, #mma> | |
| %43 = tt.expand_dims %41 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xf32, #mma> | |
| %44 = tt.broadcast %42 : tensor<32x1xf32, #mma> -> tensor<32x256xf32, #mma> | |
| %45 = arith.subf %32, %44 : tensor<32x256xf32, #mma> | |
| %46 = tt.broadcast %43 : tensor<32x1xf32, #mma> -> tensor<32x256xf32, #mma> | |
| %47 = arith.divf %45, %46 : tensor<32x256xf32, #mma> | |
| %48 = arith.truncf %47 : tensor<32x256xf32, #mma> to tensor<32x256xf16, #mma> | |
| %49 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> | |
| %50 = tt.expand_dims %49 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> | |
| %51 = arith.muli %50, %cst_1 : tensor<32x1xi32, #blocked1> | |
| %52 = tt.splat %1 : i32 -> tensor<32x1xi32, #blocked1> | |
| %53 = arith.addi %52, %51 : tensor<32x1xi32, #blocked1> | |
| %54 = tt.broadcast %53 : tensor<32x1xi32, #blocked1> -> tensor<32x256xi32, #blocked1> | |
| %55 = tt.broadcast %23 : tensor<1x256xi32, #blocked1> -> tensor<32x256xi32, #blocked1> | |
| %56 = arith.addi %54, %55 : tensor<32x256xi32, #blocked1> | |
| %57 = tt.splat %arg5 : !tt.ptr<f16> -> tensor<32x256x!tt.ptr<f16>, #blocked1> | |
| %58 = tt.addptr %57, %56 : tensor<32x256x!tt.ptr<f16>, #blocked1>, tensor<32x256xi32, #blocked1> | |
| %59 = triton_gpu.convert_layout %48 : tensor<32x256xf16, #mma> -> tensor<32x256xf16, #blocked1> | |
| tt.store %58, %59 : tensor<32x256x!tt.ptr<f16>, #blocked1> | |
| tt.return | |
| } | |
| } | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> | |
| #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> | |
| #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> | |
| #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> | |
| #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> | |
| #shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> | |
| module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { | |
| tt.func public @hoist_convert_above_extf_and_remat(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f16>) attributes {noinline = false} { | |
| %cst = arith.constant dense<0.000000e+00> : tensor<32x256xf32, #mma> | |
| %c32_i32 = arith.constant 32 : i32 | |
| %cst_0 = arith.constant dense<256> : tensor<32x1xi32, #blocked> | |
| %cst_1 = arith.constant dense<256> : tensor<32x1xi32, #blocked1> | |
| %cst_2 = arith.constant dense<256> : tensor<256x1xi32, #blocked> | |
| %c64_i32 = arith.constant 64 : i32 | |
| %c256_i32 = arith.constant 256 : i32 | |
| %c0_i32 = arith.constant 0 : i32 | |
| %cst_3 = arith.constant dense<1.000000e-03> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> | |
| %cst_4 = arith.constant dense<2.560000e+02> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> | |
| %0 = tt.get_program_id x : i32 | |
| %1 = arith.muli %0, %c32_i32 : i32 | |
| %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> | |
| %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> | |
| %4 = tt.splat %1 : i32 -> tensor<32x1xi32, #blocked> | |
| %5 = arith.addi %4, %3 : tensor<32x1xi32, #blocked> | |
| %6 = arith.muli %5, %cst_0 : tensor<32x1xi32, #blocked> | |
| %7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> | |
| %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> | |
| %9 = tt.expand_dims %7 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> | |
| %10 = tt.expand_dims %8 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> | |
| %11 = tt.broadcast %9 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> | |
| %12 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> | |
| %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> | |
| %14 = arith.muli %13, %cst_2 : tensor<256x1xi32, #blocked> | |
| %15 = tt.broadcast %10 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked> | |
| %16 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked> | |
| %17 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<256x64x!tt.ptr<f16>, #blocked> | |
| %18 = scf.for %arg7 = %c0_i32 to %c256_i32 step %c64_i32 iter_args(%arg8 = %cst) -> (tensor<32x256xf32, #mma>) : i32 { | |
| %62 = tt.splat %arg7 : i32 -> tensor<32x1xi32, #blocked> | |
| %63 = arith.addi %6, %62 : tensor<32x1xi32, #blocked> | |
| %64 = tt.broadcast %63 : tensor<32x1xi32, #blocked> -> tensor<32x64xi32, #blocked> | |
| %65 = arith.addi %64, %11 : tensor<32x64xi32, #blocked> | |
| %66 = tt.splat %arg7 : i32 -> tensor<256x1xi32, #blocked> | |
| %67 = arith.addi %14, %66 : tensor<256x1xi32, #blocked> | |
| %68 = tt.broadcast %67 : tensor<256x1xi32, #blocked> -> tensor<256x64xi32, #blocked> | |
| %69 = arith.addi %68, %15 : tensor<256x64xi32, #blocked> | |
| %70 = tt.addptr %16, %65 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked> | |
| %71 = tt.load %70 : tensor<32x64x!tt.ptr<f16>, #blocked> | |
| %72 = tt.addptr %17, %69 : tensor<256x64x!tt.ptr<f16>, #blocked>, tensor<256x64xi32, #blocked> | |
| %73 = tt.load %72 : tensor<256x64x!tt.ptr<f16>, #blocked> | |
| %74 = triton_gpu.local_alloc %73 : (tensor<256x64xf16, #blocked>) -> !tt.memdesc<256x64xf16, #shared> | |
| %75 = tt.trans %74 {order = array<i32: 1, 0>} : !tt.memdesc<256x64xf16, #shared> -> !tt.memdesc<64x256xf16, #shared1> | |
| %76 = triton_gpu.local_load %75 : !tt.memdesc<64x256xf16, #shared1> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> | |
| %77 = triton_gpu.convert_layout %71 : tensor<32x64xf16, #blocked> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> | |
| %78 = tt.dot %77, %76, %arg8 : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x256xf32, #mma> | |
| scf.yield %78 : tensor<32x256xf32, #mma> | |
| } | |
| %19 = arith.truncf %18 : tensor<32x256xf32, #mma> to tensor<32x256xf16, #mma> | |
| %20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> | |
| %21 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> | |
| %22 = tt.expand_dims %20 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> | |
| %23 = tt.expand_dims %21 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> | |
| %24 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<1x256x!tt.ptr<f16>, #blocked2> | |
| %25 = tt.addptr %24, %22 : tensor<1x256x!tt.ptr<f16>, #blocked2>, tensor<1x256xi32, #blocked2> | |
| %26 = tt.load %25 : tensor<1x256x!tt.ptr<f16>, #blocked2> | |
| %27 = triton_gpu.convert_layout %26 : tensor<1x256xf16, #blocked2> -> tensor<1x256xf16, #mma> | |
| %28 = tt.broadcast %27 : tensor<1x256xf16, #mma> -> tensor<32x256xf16, #mma> | |
| %29 = arith.addf %19, %28 : tensor<32x256xf16, #mma> | |
| %30 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #mma> -> tensor<32x256xf16, #blocked1> | |
| %31 = arith.extf %30 : tensor<32x256xf16, #blocked1> to tensor<32x256xf32, #blocked1> | |
| %32 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #mma> -> tensor<32x256xf16, #blocked1> | |
| %33 = arith.extf %32 : tensor<32x256xf16, #blocked1> to tensor<32x256xf32, #blocked1> | |
| %34 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #mma> -> tensor<32x256xf16, #blocked1> | |
| %35 = arith.extf %34 : tensor<32x256xf16, #blocked1> to tensor<32x256xf32, #blocked1> | |
| %36 = "tt.reduce"(%31) <{axis = 1 : i32}> ({ | |
| ^bb0(%arg7: f32, %arg8: f32): | |
| %62 = arith.addf %arg7, %arg8 : f32 | |
| tt.reduce.return %62 : f32 | |
| }) : (tensor<32x256xf32, #blocked1>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> | |
| %37 = arith.divf %36, %cst_4 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> | |
| %38 = arith.mulf %33, %33 : tensor<32x256xf32, #blocked1> | |
| %39 = "tt.reduce"(%38) <{axis = 1 : i32}> ({ | |
| ^bb0(%arg7: f32, %arg8: f32): | |
| %62 = arith.addf %arg7, %arg8 : f32 | |
| tt.reduce.return %62 : f32 | |
| }) : (tensor<32x256xf32, #blocked1>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> | |
| %40 = arith.divf %39, %cst_4 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> | |
| %41 = arith.mulf %37, %37 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> | |
| %42 = arith.subf %40, %41 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> | |
| %43 = math.sqrt %42 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> | |
| %44 = arith.addf %43, %cst_3 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> | |
| %45 = tt.expand_dims %37 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xf32, #blocked1> | |
| %46 = tt.expand_dims %44 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xf32, #blocked1> | |
| %47 = tt.broadcast %45 : tensor<32x1xf32, #blocked1> -> tensor<32x256xf32, #blocked1> | |
| %48 = arith.subf %35, %47 : tensor<32x256xf32, #blocked1> | |
| %49 = tt.broadcast %46 : tensor<32x1xf32, #blocked1> -> tensor<32x256xf32, #blocked1> | |
| %50 = arith.divf %48, %49 : tensor<32x256xf32, #blocked1> | |
| %51 = arith.truncf %50 : tensor<32x256xf32, #blocked1> to tensor<32x256xf16, #blocked1> | |
| %52 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> | |
| %53 = tt.expand_dims %52 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> | |
| %54 = arith.muli %53, %cst_1 : tensor<32x1xi32, #blocked1> | |
| %55 = tt.splat %1 : i32 -> tensor<32x1xi32, #blocked1> | |
| %56 = arith.addi %55, %54 : tensor<32x1xi32, #blocked1> | |
| %57 = tt.broadcast %56 : tensor<32x1xi32, #blocked1> -> tensor<32x256xi32, #blocked1> | |
| %58 = tt.broadcast %23 : tensor<1x256xi32, #blocked1> -> tensor<32x256xi32, #blocked1> | |
| %59 = arith.addi %57, %58 : tensor<32x256xi32, #blocked1> | |
| %60 = tt.splat %arg5 : !tt.ptr<f16> -> tensor<32x256x!tt.ptr<f16>, #blocked1> | |
| %61 = tt.addptr %60, %59 : tensor<32x256x!tt.ptr<f16>, #blocked1>, tensor<32x256xi32, #blocked1> | |
| tt.store %61, %51 : tensor<32x256x!tt.ptr<f16>, #blocked1> | |
| tt.return | |
| } | |
| } | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment