Last active
November 1, 2025 18:52
-
-
Save embg/f88f84fbe42c63e51d7cf146b368a9ef to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include "ATen/ATen.h" | |
| #include "torch/extension.h" | |
| #include <cuda_runtime.h> // cudaMalloc, cudaMemcpy, cudaFree | |
| #include <cuda.h> // Driver API (usually don't need this) | |
| cudaError_t setProp(CUmemAllocationProp *prop, bool UseCompressibleMemory) | |
| { | |
| CUdevice currentDevice; | |
| cudaError_enum e = cuCtxGetDevice(¤tDevice); | |
| if (e != CUDA_SUCCESS) | |
| return static_cast<cudaError_t>(e); | |
| memset(prop, 0, sizeof(CUmemAllocationProp)); | |
| prop->type = CU_MEM_ALLOCATION_TYPE_PINNED; | |
| prop->location.type = CU_MEM_LOCATION_TYPE_DEVICE; | |
| prop->location.id = currentDevice; | |
| if (UseCompressibleMemory) | |
| prop->allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_GENERIC; | |
| return cudaSuccess; | |
| } | |
| cudaError_t allocateCompressible(void **adr, size_t size, bool UseCompressibleMemory) | |
| { | |
| CUmemAllocationProp prop = {}; | |
| cudaError_t err = setProp(&prop, UseCompressibleMemory); | |
| if (err != cudaSuccess) | |
| return err; | |
| size_t granularity = 0; | |
| if (cuMemGetAllocationGranularity(&granularity, &prop, | |
| CU_MEM_ALLOC_GRANULARITY_MINIMUM) != CUDA_SUCCESS) { | |
| return static_cast<cudaError>(101); | |
| } | |
| size = ((size - 1) / granularity + 1) * granularity; | |
| CUdeviceptr dptr; | |
| if (cuMemAddressReserve(&dptr, size, 0, 0, 0) != CUDA_SUCCESS) { | |
| return static_cast<cudaError>(102); | |
| } | |
| CUmemGenericAllocationHandle allocationHandle; | |
| if (cuMemCreate(&allocationHandle, size, &prop, 0) != CUDA_SUCCESS) { | |
| return static_cast<cudaError>(103); | |
| } | |
| // Check if cuMemCreate was able to allocate compressible memory. | |
| if (UseCompressibleMemory) { | |
| CUmemAllocationProp allocationProp = {}; | |
| cuMemGetAllocationPropertiesFromHandle(&allocationProp, allocationHandle); | |
| if (allocationProp.allocFlags.compressionType != CU_MEM_ALLOCATION_COMP_GENERIC) { | |
| fprintf(stderr, "Could not allocate compressible memory... so waiving execution\n"); | |
| exit(1); | |
| } | |
| } | |
| if (cuMemMap(dptr, size, 0, allocationHandle, 0) != CUDA_SUCCESS) { | |
| return static_cast<cudaError>(104); | |
| } | |
| if (cuMemRelease(allocationHandle) != CUDA_SUCCESS) { | |
| return static_cast<cudaError>(105); | |
| } | |
| CUmemAccessDesc accessDescriptor; | |
| accessDescriptor.location.id = prop.location.id; | |
| accessDescriptor.location.type = prop.location.type; | |
| accessDescriptor.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; | |
| if (cuMemSetAccess(dptr, size, &accessDescriptor, 1) != CUDA_SUCCESS) { | |
| return static_cast<cudaError>(106); | |
| } | |
| *adr = (void *)dptr; | |
| return cudaSuccess; | |
| } | |
| // NOTE: LEAKS MEMORY !!! | |
| at::Tensor alloc_bf16_1d(int64_t num_elts) { | |
| auto options = torch::TensorOptions().dtype(torch::kBFloat16).device(torch::kCUDA, 0); | |
| void* data; | |
| const auto ret = allocateCompressible(&data, num_elts * 2, true); | |
| TORCH_CHECK(ret == cudaSuccess, "allocateCompressible failed: ", ret); | |
| return torch::from_blob( | |
| data, | |
| {num_elts}, | |
| options | |
| ); | |
| } | |
| // NOTE: LEAKS MEMORY !!! | |
| at::Tensor alloc_u8_1d(int64_t num_elts) { | |
| auto options = torch::TensorOptions().dtype(torch::kByte).device(torch::kCUDA, 0); | |
| void* data; | |
| const auto ret = allocateCompressible(&data, num_elts * 2, true); | |
| TORCH_CHECK(ret == cudaSuccess, "allocateCompressible failed: ", ret); | |
| return torch::from_blob( | |
| data, | |
| {num_elts}, | |
| options | |
| ); | |
| } | |
| TORCH_LIBRARY(compressed_mem, m) { | |
| m.def("alloc_bf16_1d", alloc_bf16_1d); | |
| m.def("alloc_u8_1d", alloc_u8_1d); | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.cuda | |
| import triton | |
| import triton.language as tl | |
| from triton.testing import do_bench, do_bench_cudagraph | |
| import time | |
| import nvtx | |
| REPS = 20 | |
| torch.cuda.init() | |
| torch.ops.load_library("build/lib.linux-x86_64-cpython-310/compressed_mem.so") | |
| torch.manual_seed(0) | |
| foo = torch.ones(42, dtype=torch.bfloat16, device="cuda") | |
| NUM_ELTS = 2 * 1024 * 1024 * 1024 | |
| def pack(x): | |
| x = x.view(torch.int16) | |
| x_exp = ((x >> 7) & 0xFF).to(torch.uint8) | |
| x_frac = ((x << 1) & 0xFF).to(torch.uint8) | |
| x_sign = ((x >> 15) & 0x1).to(torch.uint8) | |
| x_sign_and_frac = x_frac | x_sign | |
| return x_exp, x_sign_and_frac | |
| @triton.jit | |
| def times2_kernel(x_exp_ptr, # *Pointer* to first input vector. | |
| x_sign_and_frac_ptr, | |
| output_ptr, # *Pointer* to output vector. | |
| n_elements, # Size of the vector. | |
| BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. | |
| # NOTE: `constexpr` so it can be used as a shape value. | |
| ): | |
| pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. | |
| block_start = pid * BLOCK_SIZE | |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) | |
| mask = offsets < n_elements | |
| x_exp = tl.load(x_exp_ptr + offsets, mask=mask) | |
| x_sign_and_frac = tl.load(x_sign_and_frac_ptr + offsets, mask=mask) | |
| x_exp_16 = x_exp.to(tl.uint16) << 7 | |
| x_sign_and_frac_16 = x_sign_and_frac.to(tl.uint16) | |
| x_sign = x_sign_and_frac_16 << 15 | |
| x_frac = x_sign_and_frac_16 >> 1 | |
| x = x_exp_16 | x_sign | x_frac | |
| x = x.to(tl.bfloat16, bitcast=True) | |
| output = 2 * x | |
| tl.store(output_ptr + offsets, output, mask=mask) | |
| def times2(x_exp, x_signfrac): | |
| n_elements = x_exp.numel() | |
| output = torch.empty(n_elements, dtype=torch.bfloat16, device="cuda") | |
| grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) | |
| times2_kernel[grid](x_exp, x_signfrac, output, n_elements, BLOCK_SIZE=4096) | |
| return output | |
| def bench_with_sleep(fn, steps, sleep_ms): | |
| start_events = [torch.cuda.Event(enable_timing=True) for _ in range(steps)] | |
| end_events = [torch.cuda.Event(enable_timing=True) for _ in range(steps)] | |
| for i in range(steps): | |
| start_events[i].record(torch.cuda.default_stream()) | |
| fn() | |
| end_events[i].record(torch.cuda.default_stream()) | |
| time.sleep(sleep_ms / 1000) | |
| torch.cuda.synchronize() | |
| times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] | |
| return min(times) | |
| def input_data(type_, plus=0): | |
| if type_ == "zeros": | |
| return torch.zeros(NUM_ELTS, dtype=torch.bfloat16, device="cuda") | |
| elif type_ == "mean_10_var_1": | |
| return torch.randn(NUM_ELTS, dtype=torch.bfloat16, device="cuda") + 10 | |
| elif type_ == "mean_0_var_1": | |
| return torch.randn(NUM_ELTS, dtype=torch.bfloat16, device="cuda") + plus | |
| else: | |
| raise Exception("unknown type") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.cuda | |
| import triton | |
| import triton.language as tl | |
| from triton.testing import do_bench, do_bench_cudagraph | |
| import time | |
| import nvtx | |
| from compression import bench_with_sleep, input_data, pack, times2 | |
| REPS = 50 | |
| torch.cuda.init() | |
| torch.ops.load_library("build/lib.linux-x86_64-cpython-310/compressed_mem.so") | |
| torch.manual_seed(0) | |
| # necessary for compressed_mem not to crash | |
| foo = torch.ones(42, dtype=torch.bfloat16, device="cuda") | |
| NUM_ELTS = 2 * 1024 * 1024 * 1024 | |
| compressed_exps = torch.ops.compressed_mem.alloc_u8_1d(NUM_ELTS) | |
| signfracs = torch.empty(NUM_ELTS, dtype=torch.uint8, device="cuda") | |
| uncompressed = torch.empty(NUM_ELTS, dtype=torch.bfloat16, device="cuda") | |
| def bench(type_): | |
| print(type_ + ":") | |
| benchmark_input = input_data(type_) | |
| uncompressed.copy_(benchmark_input) | |
| exps_tmp, signfracs_tmp = pack(benchmark_input) | |
| compressed_exps.copy_(exps_tmp) | |
| signfracs.copy_(signfracs_tmp) | |
| del exps_tmp, signfracs_tmp | |
| assert torch.allclose(times2(compressed_exps, signfracs), 2 * uncompressed) | |
| uncompressed_ms = bench_with_sleep(lambda: 2 * uncompressed, REPS, 100) | |
| time.sleep(0.1) | |
| compressed_ms = bench_with_sleep(lambda: times2(compressed_exps, signfracs), REPS, 100) | |
| print(f"Uncompressed ms: {uncompressed_ms}") | |
| print(f"Compressed ms: {compressed_ms}") | |
| print() | |
| bench("zeros") | |
| bench("mean_10_var_1") | |
| bench("mean_0_var_1") |
Author
embg
commented
Oct 25, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment