Skip to content

Instantly share code, notes, and snippets.

@staghado
Created October 6, 2024 21:44
Show Gist options
  • Select an option

  • Save staghado/bea306a5614f52ace4a814b321602b0b to your computer and use it in GitHub Desktop.

Select an option

Save staghado/bea306a5614f52ace4a814b321602b0b to your computer and use it in GitHub Desktop.
Pi estimation in Triton because why not
import triton
import triton.language as tl
import torch
import time
import numpy as np
@triton.jit
def pi_kernel(
total_ptr,
BLOCK_SIZE: tl.constexpr,
ITERATIONS_PER_THREAD: tl.constexpr,
):
pid = tl.program_id(0)
tid = tl.arange(0, BLOCK_SIZE)
counter = tl.zeros((BLOCK_SIZE,), dtype=tl.int64)
for i in range(ITERATIONS_PER_THREAD):
seed = (pid * BLOCK_SIZE + tid) * ITERATIONS_PER_THREAD + i
x = tl.rand(seed, tid)
y = tl.rand(seed + 1, tid)
inside_circle = (x * x + y * y <= 1.0)
counter += tl.where(inside_circle, 1, 0)
block_total = tl.sum(counter)
tl.atomic_add(total_ptr, block_total)
def estimate_pi(NBLOCKS, BLOCK_SIZE, ITERATIONS_PER_THREAD):
total = torch.zeros(1, dtype=torch.int64, device='cuda')
pi_kernel[(NBLOCKS,)](
total,
BLOCK_SIZE=BLOCK_SIZE,
ITERATIONS_PER_THREAD=ITERATIONS_PER_THREAD,
)
total_hits = total.item()
total_points = NBLOCKS * BLOCK_SIZE * ITERATIONS_PER_THREAD
pi_estimate = 4.0 * total_hits / total_points
return pi_estimate, total_points
def estimate_pi_python(total_points):
x = np.random.rand(total_points)
y = np.random.rand(total_points)
inside_circle = np.sum((x * x + y * y) <= 1.0)
pi_estimate = 4.0 * inside_circle / total_points
return pi_estimate
if __name__ == "__main__":
NBLOCKS = 1024
BLOCK_SIZE = 1024
ITERATIONS_PER_THREAD = 1000
PI_REF = 3.141592653589793
total_points = NBLOCKS * BLOCK_SIZE * ITERATIONS_PER_THREAD
start_time_gpu = time.time()
pi_estimate_gpu, _ = estimate_pi(NBLOCKS, BLOCK_SIZE, ITERATIONS_PER_THREAD)
torch.cuda.synchronize()
end_time_gpu = time.time()
start_time_python = time.time()
pi_estimate_python = estimate_pi_python(total_points)
end_time_python = time.time()
print(f"Total of {total_points/10**9:.2f}G random tests")
print(f"Triton Pi ~= {pi_estimate_gpu:.15f}")
print(f"Numpy Pi ~= {pi_estimate_python:.15f}")
print(f"Numpy error : {abs(PI_REF - pi_estimate_python):.15f}")
print(f"Triton error : {abs(PI_REF - pi_estimate_gpu):.15f}")
print(f"Triton time : {end_time_gpu - start_time_gpu:.6f} seconds")
print(f"Numpy time: {end_time_python - start_time_python:.6f} seconds")
@staghado
Copy link
Author

staghado commented Oct 6, 2024

Total of 10.49G random tests
Triton Pi ~= 3.141591793823242
Numpy Pi ~= 3.141590239334107
Numpy error : 0.000002414255686
Triton error : 0.000000859766551
Triton time : 3.579296 seconds
Numpy time: 195.734966 seconds

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment