Skip to content

Instantly share code, notes, and snippets.

@embg
Last active November 1, 2025 18:52
Show Gist options
  • Select an option

  • Save embg/f88f84fbe42c63e51d7cf146b368a9ef to your computer and use it in GitHub Desktop.

Select an option

Save embg/f88f84fbe42c63e51d7cf146b368a9ef to your computer and use it in GitHub Desktop.
#include "ATen/ATen.h"
#include "torch/extension.h"
#include <cuda_runtime.h> // cudaMalloc, cudaMemcpy, cudaFree
#include <cuda.h> // Driver API (usually don't need this)
cudaError_t setProp(CUmemAllocationProp *prop, bool UseCompressibleMemory)
{
CUdevice currentDevice;
cudaError_enum e = cuCtxGetDevice(&currentDevice);
if (e != CUDA_SUCCESS)
return static_cast<cudaError_t>(e);
memset(prop, 0, sizeof(CUmemAllocationProp));
prop->type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop->location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop->location.id = currentDevice;
if (UseCompressibleMemory)
prop->allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_GENERIC;
return cudaSuccess;
}
cudaError_t allocateCompressible(void **adr, size_t size, bool UseCompressibleMemory)
{
CUmemAllocationProp prop = {};
cudaError_t err = setProp(&prop, UseCompressibleMemory);
if (err != cudaSuccess)
return err;
size_t granularity = 0;
if (cuMemGetAllocationGranularity(&granularity, &prop,
CU_MEM_ALLOC_GRANULARITY_MINIMUM) != CUDA_SUCCESS) {
return static_cast<cudaError>(101);
}
size = ((size - 1) / granularity + 1) * granularity;
CUdeviceptr dptr;
if (cuMemAddressReserve(&dptr, size, 0, 0, 0) != CUDA_SUCCESS) {
return static_cast<cudaError>(102);
}
CUmemGenericAllocationHandle allocationHandle;
if (cuMemCreate(&allocationHandle, size, &prop, 0) != CUDA_SUCCESS) {
return static_cast<cudaError>(103);
}
// Check if cuMemCreate was able to allocate compressible memory.
if (UseCompressibleMemory) {
CUmemAllocationProp allocationProp = {};
cuMemGetAllocationPropertiesFromHandle(&allocationProp, allocationHandle);
if (allocationProp.allocFlags.compressionType != CU_MEM_ALLOCATION_COMP_GENERIC) {
fprintf(stderr, "Could not allocate compressible memory... so waiving execution\n");
exit(1);
}
}
if (cuMemMap(dptr, size, 0, allocationHandle, 0) != CUDA_SUCCESS) {
return static_cast<cudaError>(104);
}
if (cuMemRelease(allocationHandle) != CUDA_SUCCESS) {
return static_cast<cudaError>(105);
}
CUmemAccessDesc accessDescriptor;
accessDescriptor.location.id = prop.location.id;
accessDescriptor.location.type = prop.location.type;
accessDescriptor.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
if (cuMemSetAccess(dptr, size, &accessDescriptor, 1) != CUDA_SUCCESS) {
return static_cast<cudaError>(106);
}
*adr = (void *)dptr;
return cudaSuccess;
}
// NOTE: LEAKS MEMORY !!!
at::Tensor alloc_bf16_1d(int64_t num_elts) {
auto options = torch::TensorOptions().dtype(torch::kBFloat16).device(torch::kCUDA, 0);
void* data;
const auto ret = allocateCompressible(&data, num_elts * 2, true);
TORCH_CHECK(ret == cudaSuccess, "allocateCompressible failed: ", ret);
return torch::from_blob(
data,
{num_elts},
options
);
}
// NOTE: LEAKS MEMORY !!!
at::Tensor alloc_u8_1d(int64_t num_elts) {
auto options = torch::TensorOptions().dtype(torch::kByte).device(torch::kCUDA, 0);
void* data;
const auto ret = allocateCompressible(&data, num_elts * 2, true);
TORCH_CHECK(ret == cudaSuccess, "allocateCompressible failed: ", ret);
return torch::from_blob(
data,
{num_elts},
options
);
}
TORCH_LIBRARY(compressed_mem, m) {
m.def("alloc_bf16_1d", alloc_bf16_1d);
m.def("alloc_u8_1d", alloc_u8_1d);
}
import torch
import torch.cuda
import triton
import triton.language as tl
from triton.testing import do_bench, do_bench_cudagraph
import time
import nvtx
REPS = 20
torch.cuda.init()
torch.ops.load_library("build/lib.linux-x86_64-cpython-310/compressed_mem.so")
torch.manual_seed(0)
foo = torch.ones(42, dtype=torch.bfloat16, device="cuda")
NUM_ELTS = 2 * 1024 * 1024 * 1024
def pack(x):
x = x.view(torch.int16)
x_exp = ((x >> 7) & 0xFF).to(torch.uint8)
x_frac = ((x << 1) & 0xFF).to(torch.uint8)
x_sign = ((x >> 15) & 0x1).to(torch.uint8)
x_sign_and_frac = x_frac | x_sign
return x_exp, x_sign_and_frac
@triton.jit
def times2_kernel(x_exp_ptr, # *Pointer* to first input vector.
x_sign_and_frac_ptr,
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x_exp = tl.load(x_exp_ptr + offsets, mask=mask)
x_sign_and_frac = tl.load(x_sign_and_frac_ptr + offsets, mask=mask)
x_exp_16 = x_exp.to(tl.uint16) << 7
x_sign_and_frac_16 = x_sign_and_frac.to(tl.uint16)
x_sign = x_sign_and_frac_16 << 15
x_frac = x_sign_and_frac_16 >> 1
x = x_exp_16 | x_sign | x_frac
x = x.to(tl.bfloat16, bitcast=True)
output = 2 * x
tl.store(output_ptr + offsets, output, mask=mask)
def times2(x_exp, x_signfrac):
n_elements = x_exp.numel()
output = torch.empty(n_elements, dtype=torch.bfloat16, device="cuda")
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
times2_kernel[grid](x_exp, x_signfrac, output, n_elements, BLOCK_SIZE=4096)
return output
def bench_with_sleep(fn, steps, sleep_ms):
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(steps)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(steps)]
for i in range(steps):
start_events[i].record(torch.cuda.default_stream())
fn()
end_events[i].record(torch.cuda.default_stream())
time.sleep(sleep_ms / 1000)
torch.cuda.synchronize()
times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)]
return min(times)
def input_data(type_, plus=0):
if type_ == "zeros":
return torch.zeros(NUM_ELTS, dtype=torch.bfloat16, device="cuda")
elif type_ == "mean_10_var_1":
return torch.randn(NUM_ELTS, dtype=torch.bfloat16, device="cuda") + 10
elif type_ == "mean_0_var_1":
return torch.randn(NUM_ELTS, dtype=torch.bfloat16, device="cuda") + plus
else:
raise Exception("unknown type")
import torch
import torch.cuda
import triton
import triton.language as tl
from triton.testing import do_bench, do_bench_cudagraph
import time
import nvtx
from compression import bench_with_sleep, input_data, pack, times2
REPS = 50
torch.cuda.init()
torch.ops.load_library("build/lib.linux-x86_64-cpython-310/compressed_mem.so")
torch.manual_seed(0)
# necessary for compressed_mem not to crash
foo = torch.ones(42, dtype=torch.bfloat16, device="cuda")
NUM_ELTS = 2 * 1024 * 1024 * 1024
compressed_exps = torch.ops.compressed_mem.alloc_u8_1d(NUM_ELTS)
signfracs = torch.empty(NUM_ELTS, dtype=torch.uint8, device="cuda")
uncompressed = torch.empty(NUM_ELTS, dtype=torch.bfloat16, device="cuda")
def bench(type_):
print(type_ + ":")
benchmark_input = input_data(type_)
uncompressed.copy_(benchmark_input)
exps_tmp, signfracs_tmp = pack(benchmark_input)
compressed_exps.copy_(exps_tmp)
signfracs.copy_(signfracs_tmp)
del exps_tmp, signfracs_tmp
assert torch.allclose(times2(compressed_exps, signfracs), 2 * uncompressed)
uncompressed_ms = bench_with_sleep(lambda: 2 * uncompressed, REPS, 100)
time.sleep(0.1)
compressed_ms = bench_with_sleep(lambda: times2(compressed_exps, signfracs), REPS, 100)
print(f"Uncompressed ms: {uncompressed_ms}")
print(f"Compressed ms: {compressed_ms}")
print()
bench("zeros")
bench("mean_10_var_1")
bench("mean_0_var_1")
@embg
Copy link
Author

embg commented Oct 25, 2025

zeros:
Uncompressed ms: 2.434783935546875
Compressed ms: 2.0980160236358643

mean_10_var_1:
Uncompressed ms: 2.4096319675445557
Compressed ms: 2.271807909011841

mean_0_var_1:
Uncompressed ms: 2.434880018234253
Compressed ms: 2.650912046432495

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment