Skip to content

Instantly share code, notes, and snippets.

@galv
Created August 22, 2025 19:33
Show Gist options
  • Select an option

  • Save galv/866d93a6ebb73b9c29b871e6a3584e80 to your computer and use it in GitHub Desktop.

Select an option

Save galv/866d93a6ebb73b9c29b871e6a3584e80 to your computer and use it in GitHub Desktop.
cudaGraphExecNodeSetParams timing
// graph_1000_chain.cu
// nvcc -std=c++17 -O2 graph_1000_chain.cu -o graph_1000_chain
#include <cuda_runtime.h>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <array>
#include <cstring>
#include <chrono>
#include <iostream>
#define CHECK_CUDA(call) \
do { \
cudaError_t _status = (call); \
if (_status != cudaSuccess) { \
std::fprintf(stderr, "CUDA error %s:%d: %s\n", __FILE__, __LINE__, \
cudaGetErrorString(_status)); \
std::exit(EXIT_FAILURE); \
} \
} while (0)
constexpr int NUM_NODES = 1000;
// Simple bytewise copy kernel: copies `size` bytes from src -> dst.
__global__ void copy_bytes_kernel(const unsigned char* __restrict__ src,
unsigned char* __restrict__ dst,
size_t size) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
size_t stride = blockDim.x * gridDim.x;
for (size_t i = idx; i < size; i += stride) {
dst[i] = src[i];
}
}
int main() {
// ---- Config ----
const size_t sizeBytes = 1 << 20; // 1 MiB
const int blockSize = 256;
const int gridSize = 1024; // Use striding in kernel; doesn't need to equal sizeBytes/blockSize
// ---- Host buffers (two different sources to prove param update works) ----
std::vector<unsigned char> h_srcA(sizeBytes), h_srcB(sizeBytes), h_out(sizeBytes, 0);
for (size_t i = 0; i < sizeBytes; ++i) {
h_srcA[i] = static_cast<unsigned char>((i * 7u) & 0xFFu);
h_srcB[i] = static_cast<unsigned char>((255u - i) & 0xFFu); // distinctly different pattern
}
// ---- Device buffers ----
unsigned char *d_srcA = nullptr, *d_srcB = nullptr, *d_dst = nullptr;
CHECK_CUDA(cudaMalloc(&d_srcA, sizeBytes));
CHECK_CUDA(cudaMalloc(&d_srcB, sizeBytes));
CHECK_CUDA(cudaMalloc(&d_dst, sizeBytes));
CHECK_CUDA(cudaMemcpy(d_srcA, h_srcA.data(), sizeBytes, cudaMemcpyHostToDevice));
CHECK_CUDA(cudaMemcpy(d_srcB, h_srcB.data(), sizeBytes, cudaMemcpyHostToDevice));
CHECK_CUDA(cudaMemset(d_dst, 0, sizeBytes));
// ---- Create graph and add 1000 kernel nodes in a linear chain ----
cudaGraph_t graph;
CHECK_CUDA(cudaGraphCreate(&graph, 0));
std::vector<cudaGraphNode_t> nodes(NUM_NODES);
// Keep per-node kernel argument storage alive for the duration of graph creation.
std::vector<std::array<void*, 3>> initialArgs(NUM_NODES);
dim3 grid(gridSize, 1, 1), block(blockSize, 1, 1);
for (int i = 0; i < NUM_NODES; ++i) {
initialArgs[i] = { (void*)&d_srcA, (void*)&d_dst, (void*)&sizeBytes };
cudaKernelNodeParams params{};
params.func = (void*)copy_bytes_kernel;
params.gridDim = grid;
params.blockDim = block;
params.sharedMemBytes = 0;
params.kernelParams = initialArgs[i].data();
params.extra = nullptr;
const cudaGraphNode_t* deps = (i == 0) ? nullptr : &nodes[i - 1];
const size_t numDeps = (i == 0) ? 0 : 1;
CHECK_CUDA(cudaGraphAddKernelNode(&nodes[i], graph, deps, numDeps, &params));
}
// ---- Instantiate the graph ----
cudaGraphExec_t graphExec = nullptr;
CHECK_CUDA(cudaGraphInstantiate(&graphExec, graph, nullptr, nullptr, 0));
auto t0 = std::chrono::high_resolution_clock::now();
// ---- Update the source pointer in *all* kernel nodes before launch ----
// Prepare updated per-node kernel arg storage (using d_srcB instead of d_srcA).
std::vector<std::array<void*, 3>> updatedArgs(NUM_NODES);
for (int i = 0; i < NUM_NODES; ++i) {
updatedArgs[i] = { (void*)&d_srcB, (void*)&d_dst, (void*)&sizeBytes };
cudaKernelNodeParamsV2 kparams{};
kparams.func = (void*)copy_bytes_kernel;
kparams.gridDim = grid;
kparams.blockDim = block;
kparams.sharedMemBytes = 0;
kparams.kernelParams = updatedArgs[i].data();
kparams.extra = nullptr;
// Prefer the generic API when available; otherwise, fall back to the kernel-specific setter.
// cudaGraphExecNodeSetParams was introduced in recent CUDA (12.x). Adjust the guard if needed.
#if defined(CUDART_VERSION) && (CUDART_VERSION >= 12040)
cudaGraphNodeParams generic{};
generic.type = cudaGraphNodeTypeKernel;
generic.kernel = kparams; // update the union's kernel payload
CHECK_CUDA(cudaGraphExecNodeSetParams(graphExec, nodes[i], &generic));
#else
CHECK_CUDA(cudaGraphExecKernelNodeSetParams(graphExec, nodes[i], &kparams));
#endif
}
auto t1 = std::chrono::high_resolution_clock::now();
double total_us = std::chrono::duration<double, std::micro>(t1 - t0).count();
double us_per_node = total_us / NUM_NODES;
std::printf("Updated %d nodes in %.3f ms (%.3f us per node)\n",
NUM_NODES, total_us / 1000.0, us_per_node);
// ---- Launch the graph ----
cudaStream_t stream;
CHECK_CUDA(cudaStreamCreate(&stream));
CHECK_CUDA(cudaGraphLaunch(graphExec, stream));
CHECK_CUDA(cudaStreamSynchronize(stream));
// ---- Validate result: d_dst should equal h_srcB (since we updated src to d_srcB) ----
CHECK_CUDA(cudaMemcpy(h_out.data(), d_dst, sizeBytes, cudaMemcpyDeviceToHost));
size_t mismatches = 0;
for (size_t i = 0; i < sizeBytes; ++i) {
if (h_out[i] != h_srcB[i]) {
if (mismatches < 10) {
std::fprintf(stderr, "Mismatch at %zu: got %u expected %u\n",
i, (unsigned)h_out[i], (unsigned)h_srcB[i]);
}
++mismatches;
}
}
if (mismatches == 0) {
std::puts("Success: destination matches updated source buffer (d_srcB).");
} else {
std::fprintf(stderr, "Validation failed: %zu mismatches.\n", mismatches);
}
// ---- Cleanup ----
CHECK_CUDA(cudaStreamDestroy(stream));
CHECK_CUDA(cudaGraphExecDestroy(graphExec));
CHECK_CUDA(cudaGraphDestroy(graph));
CHECK_CUDA(cudaFree(d_srcA));
CHECK_CUDA(cudaFree(d_srcB));
CHECK_CUDA(cudaFree(d_dst));
return (mismatches == 0) ? EXIT_SUCCESS : EXIT_FAILURE;
}

It takes about 700ns to update each node. $ ./exec_node_set_params Updated 1000 nodes in 0.780 ms (0.780 us per node) Success: destination matches updated source buffer (d_srcB). $ ./exec_node_set_params Updated 1000 nodes in 0.698 ms (0.698 us per node) Success: destination matches updated source buffer (d_srcB). $ ./exec_node_set_params Updated 1000 nodes in 0.723 ms (0.723 us per node) Success: destination matches updated source buffer (d_srcB).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment