Skip to content

Instantly share code, notes, and snippets.

@leonardoalt
Last active March 4, 2026 16:43
Show Gist options
  • Select an option

  • Save leonardoalt/45d259099be340492331a343a86ce1b9 to your computer and use it in GitHub Desktop.

Select an option

Save leonardoalt/45d259099be340492331a343a86ce1b9 to your computer and use it in GitHub Desktop.
CUDA diffs for Call chip GPU tracegen vs LoadStore (closest implementation)
--- extensions/womir_circuit/cuda/include/womir/adapters/loadstore.cuh 2026-03-04 17:32:05.607231710 +0100
+++ extensions/womir_circuit/cuda/include/womir/adapters/call.cuh 2026-03-04 17:36:26.770791861 +0100
@@ -1,62 +1,69 @@
-// Adapted from <openvm>/extensions/rv32im/circuit/cuda/include/rv32im/adapters/loadstore.cuh
-// Main changes: adds frame pointer (fp) field, WomirExecutionState, fp_read_aux, timestamp +1 shift,
-// imm_lo/imm_hi instead of imm/imm_sign
+// Call adapter CUDA implementation for GPU tracegen.
+// Handles the 6 memory operations (reads/writes) and carry-chain arithmetic
+// for the Call/CallIndirect/Ret instructions.
#pragma once
-#include "primitives/execution.h"
#include "primitives/trace_access.h"
#include "system/memory/controller.cuh"
+#include "system/memory/offline_checker.cuh"
#include "womir/execution.cuh"
using namespace riscv;
-template <typename T> struct WomirLoadStoreAdapterCols {
+// Mirror of CallAdapterCols<T> in Rust (adapters/call.rs)
+template <typename T> struct WomirCallAdapterCols {
WomirExecutionState<T> from_state;
- T rs1_ptr;
- T rs1_data[RV32_REGISTER_NUM_LIMBS];
+
+ T to_fp_operand;
+ T save_fp_ptr;
+ T save_pc_ptr;
+ T to_pc_operand;
+
MemoryReadAuxCols<T> fp_read_aux;
- MemoryReadAuxCols<T> rs1_aux_cols;
+ MemoryReadAuxCols<T> to_fp_read_aux;
+ MemoryReadAuxCols<T> to_pc_read_aux;
+ MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS> save_fp_write_aux;
+ MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS> save_pc_write_aux;
+ // FP_AS write: prev_data is 1 field element (native32 cell type)
+ MemoryWriteAuxCols<T, 1> fp_write_aux;
- /// Will write to rd when Load and read from rs2 when Store
- T rd_rs2_ptr;
- MemoryReadAuxCols<T> read_data_aux;
- T imm_lo;
- T imm_hi;
- /// mem_ptr is the intermediate memory pointer limbs, needed to check the correct addition
- T mem_ptr_limbs[2];
- T mem_as;
- /// prev_data will be provided by the core chip to make a complete MemoryWriteAuxCols
- MemoryBaseAuxCols<T> write_base_aux;
- /// Only writes if `needs_write`.
- T needs_write;
+ T offset_limbs[2];
+ T new_fp_limbs[2];
};
-struct WomirLoadStoreAdapterRecord {
+// Mirror of CallAdapterRecord in Rust
+struct WomirCallAdapterRecord {
uint32_t from_pc;
uint32_t fp;
uint32_t from_timestamp;
- uint32_t rs1_ptr;
- uint32_t rs1_val;
- MemoryReadAuxRecord fp_read_aux;
- MemoryReadAuxRecord rs1_aux_record;
-
- uint32_t rd_rs2_ptr;
- MemoryReadAuxRecord read_data_aux;
- uint16_t imm_lo;
- uint16_t imm_hi;
+ uint32_t to_fp_operand;
+ uint32_t save_fp_ptr;
+ uint32_t save_pc_ptr;
+ uint32_t to_pc_operand;
- uint8_t mem_as;
+ bool has_pc_read;
+ bool has_save;
- uint32_t write_prev_timestamp;
+ MemoryReadAuxRecord fp_read_aux;
+ MemoryReadAuxRecord to_fp_read_aux;
+ MemoryReadAuxRecord to_pc_read_aux;
+ MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS> save_fp_write_aux;
+ MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS> save_pc_write_aux;
+ // FpWriteAuxRecord: prev_timestamp + prev_fp (stored as u32)
+ uint32_t fp_write_prev_timestamp;
+ uint32_t fp_write_prev_fp;
};
-struct WomirLoadStoreAdapter {
+struct WomirCallAdapter {
size_t pointer_max_bits;
VariableRangeChecker range_checker;
MemoryAuxColsFactory mem_helper;
- __device__ WomirLoadStoreAdapter(
+ template <typename T>
+ using Cols = WomirCallAdapterCols<T>;
+
+ __device__ WomirCallAdapter(
size_t pointer_max_bits,
VariableRangeChecker range_checker,
uint32_t timestamp_max_bits
@@ -64,67 +71,141 @@
: pointer_max_bits(pointer_max_bits), range_checker(range_checker),
mem_helper(range_checker, timestamp_max_bits) {}
- __device__ void fill_trace_row(RowSlice row, WomirLoadStoreAdapterRecord record) {
- COL_WRITE_VALUE(row, WomirLoadStoreAdapterCols, from_state.pc, record.from_pc);
- COL_WRITE_VALUE(row, WomirLoadStoreAdapterCols, from_state.fp, record.fp);
- COL_WRITE_VALUE(row, WomirLoadStoreAdapterCols, from_state.timestamp, record.from_timestamp);
- COL_WRITE_VALUE(row, WomirLoadStoreAdapterCols, rs1_ptr, record.rs1_ptr);
-
- auto rs1_data = reinterpret_cast<uint8_t *>(&record.rs1_val);
- COL_WRITE_ARRAY(row, WomirLoadStoreAdapterCols, rs1_data, rs1_data);
+ __device__ void fill_trace_row(RowSlice row, WomirCallAdapterRecord record) {
+ bool has_save = record.has_save;
+ bool has_pc_read = record.has_pc_read;
+ bool has_fp_read = !has_save;
+ uint32_t fp = record.fp;
+ uint32_t to_fp_operand = record.to_fp_operand;
+
+ // Scalar fields
+ COL_WRITE_VALUE(row, Cols, from_state.pc, record.from_pc);
+ COL_WRITE_VALUE(row, Cols, from_state.fp, fp);
+ COL_WRITE_VALUE(row, Cols, from_state.timestamp, record.from_timestamp);
+ COL_WRITE_VALUE(row, Cols, to_fp_operand, to_fp_operand);
+ COL_WRITE_VALUE(row, Cols, save_fp_ptr, record.save_fp_ptr);
+ COL_WRITE_VALUE(row, Cols, save_pc_ptr, record.save_pc_ptr);
+ COL_WRITE_VALUE(row, Cols, to_pc_operand, record.to_pc_operand);
+
+ // Fill in reverse timestamp order to match Rust filler
+
+ // 5. FP write (native32 cell type: prev_data is a field element)
+ {
+ uint32_t timestamp = record.from_timestamp + 5;
+ // Write prev_data as Fp field element
+ using FpWriteAux = MemoryWriteAuxCols<uint8_t, 1>;
+ size_t base = COL_INDEX(Cols, fp_write_aux);
+ // prev_data[0] is a single Fp element
+ row.write(base + offsetof(FpWriteAux, prev_data), Fp(record.fp_write_prev_fp));
+ mem_helper.fill(
+ row.slice_from(base),
+ record.fp_write_prev_timestamp,
+ timestamp
+ );
+ }
- bool needs_write = record.rd_rs2_ptr != UINT32_MAX;
+ // 4. save_pc write (conditional on has_save)
+ {
+ uint32_t timestamp = record.from_timestamp + 4;
+ using WriteAux = MemoryWriteAuxCols<uint8_t, RV32_REGISTER_NUM_LIMBS>;
+ size_t base = COL_INDEX(Cols, save_pc_write_aux);
+ if (has_save) {
+ row.write_array(
+ base + offsetof(WriteAux, prev_data),
+ RV32_REGISTER_NUM_LIMBS,
+ record.save_pc_write_aux.prev_data
+ );
+ mem_helper.fill(
+ row.slice_from(base),
+ record.save_pc_write_aux.prev_timestamp,
+ timestamp
+ );
+ } else {
+ mem_helper.fill_zero(row.slice_from(base));
+ // Also zero out prev_data
+ for (size_t i = 0; i < RV32_REGISTER_NUM_LIMBS; i++) {
+ row.write(base + offsetof(WriteAux, prev_data) + i, 0);
+ }
+ }
+ }
- // Read fp (at from_timestamp + 0)
- mem_helper.fill(
- row.slice_from(COL_INDEX(WomirLoadStoreAdapterCols, fp_read_aux)),
- record.fp_read_aux.prev_timestamp,
- record.from_timestamp
- );
+ // 3. save_fp write (conditional on has_save)
+ {
+ uint32_t timestamp = record.from_timestamp + 3;
+ using WriteAux = MemoryWriteAuxCols<uint8_t, RV32_REGISTER_NUM_LIMBS>;
+ size_t base = COL_INDEX(Cols, save_fp_write_aux);
+ if (has_save) {
+ row.write_array(
+ base + offsetof(WriteAux, prev_data),
+ RV32_REGISTER_NUM_LIMBS,
+ record.save_fp_write_aux.prev_data
+ );
+ mem_helper.fill(
+ row.slice_from(base),
+ record.save_fp_write_aux.prev_timestamp,
+ timestamp
+ );
+ } else {
+ mem_helper.fill_zero(row.slice_from(base));
+ for (size_t i = 0; i < RV32_REGISTER_NUM_LIMBS; i++) {
+ row.write(base + offsetof(WriteAux, prev_data) + i, 0);
+ }
+ }
+ }
- // Read rs1 (at from_timestamp + 1)
- mem_helper.fill(
- row.slice_from(COL_INDEX(WomirLoadStoreAdapterCols, rs1_aux_cols)),
- record.rs1_aux_record.prev_timestamp,
- record.from_timestamp + 1
- );
+ // 2. to_pc_reg read (conditional on has_pc_read)
+ {
+ uint32_t timestamp = record.from_timestamp + 2;
+ if (has_pc_read) {
+ mem_helper.fill(
+ row.slice_from(COL_INDEX(Cols, to_pc_read_aux)),
+ record.to_pc_read_aux.prev_timestamp,
+ timestamp
+ );
+ } else {
+ mem_helper.fill_zero(row.slice_from(COL_INDEX(Cols, to_pc_read_aux)));
+ }
+ }
- if (needs_write) {
- COL_WRITE_VALUE(row, WomirLoadStoreAdapterCols, rd_rs2_ptr, record.rd_rs2_ptr);
- } else {
- COL_WRITE_VALUE(row, WomirLoadStoreAdapterCols, rd_rs2_ptr, 0);
+ // 1. to_fp_reg read (conditional on has_fp_read)
+ {
+ uint32_t timestamp = record.from_timestamp + 1;
+ if (has_fp_read) {
+ mem_helper.fill(
+ row.slice_from(COL_INDEX(Cols, to_fp_read_aux)),
+ record.to_fp_read_aux.prev_timestamp,
+ timestamp
+ );
+ } else {
+ mem_helper.fill_zero(row.slice_from(COL_INDEX(Cols, to_fp_read_aux)));
+ }
}
- // Read data (at from_timestamp + 2)
+ // 0. FP read
mem_helper.fill(
- row.slice_from(COL_INDEX(WomirLoadStoreAdapterCols, read_data_aux)),
- record.read_data_aux.prev_timestamp,
- record.from_timestamp + 2
+ row.slice_from(COL_INDEX(Cols, fp_read_aux)),
+ record.fp_read_aux.prev_timestamp,
+ record.from_timestamp
);
- COL_WRITE_VALUE(row, WomirLoadStoreAdapterCols, imm_lo, record.imm_lo);
- COL_WRITE_VALUE(row, WomirLoadStoreAdapterCols, imm_hi, record.imm_hi);
-
- uint32_t ptr = record.rs1_val + ((uint32_t)record.imm_lo | ((uint32_t)record.imm_hi << 16));
- auto ptr_limbs = reinterpret_cast<uint16_t *>(&ptr);
- COL_WRITE_ARRAY(row, WomirLoadStoreAdapterCols, mem_ptr_limbs, ptr_limbs);
- COL_WRITE_VALUE(row, WomirLoadStoreAdapterCols, mem_as, record.mem_as);
-
- range_checker.add_count((uint32_t)ptr_limbs[0] >> 2, RV32_CELL_BITS * 2 - 2);
- range_checker.add_count((uint32_t)ptr_limbs[1], pointer_max_bits - 16);
-
- COL_WRITE_VALUE(row, WomirLoadStoreAdapterCols, needs_write, needs_write);
- if (needs_write) {
- // Write (at from_timestamp + 3)
- mem_helper.fill(
- row.slice_from(COL_INDEX(WomirLoadStoreAdapterCols, write_base_aux)),
- record.write_prev_timestamp,
- record.from_timestamp + 3
- );
- } else {
- mem_helper.fill_zero(
- row.slice_from(COL_INDEX(WomirLoadStoreAdapterCols, write_base_aux))
- );
+ // Carry-chain limbs for CALL/CALL_INDIRECT (has_save)
+ if (has_save) {
+ uint32_t new_fp = fp + to_fp_operand;
+
+ uint16_t offset_lo = (uint16_t)(to_fp_operand & 0xffff);
+ uint16_t offset_hi = (uint16_t)(to_fp_operand >> 16);
+ uint16_t new_fp_lo = (uint16_t)(new_fp & 0xffff);
+ uint16_t new_fp_hi = (uint16_t)(new_fp >> 16);
+
+ uint16_t limbs[2] = { offset_lo, offset_hi };
+ COL_WRITE_ARRAY(row, Cols, offset_limbs, limbs);
+ uint16_t nfp_limbs[2] = { new_fp_lo, new_fp_hi };
+ COL_WRITE_ARRAY(row, Cols, new_fp_limbs, nfp_limbs);
+
+ range_checker.add_count((uint32_t)offset_lo, 16);
+ range_checker.add_count((uint32_t)offset_hi, pointer_max_bits - 16);
+ range_checker.add_count((uint32_t)new_fp_lo, 16);
+ range_checker.add_count((uint32_t)new_fp_hi, pointer_max_bits - 16);
}
}
};
--- extensions/womir_circuit/cuda/src/loadstore.cu 2026-03-04 17:32:05.608231716 +0100
+++ extensions/womir_circuit/cuda/src/call.cu 2026-03-04 17:36:50.736932727 +0100
@@ -1,177 +1,73 @@
-// Adapted from <openvm>/extensions/rv32im/circuit/cuda/src/loadstore.cu
-// Uses WomirLoadStoreAdapter (with frame pointer) instead of Rv32LoadStoreAdapter
+// GPU tracegen for the Call chip (CALL, CALL_INDIRECT, RET).
+// Uses WomirCallAdapter (with frame pointer) and inlined CallCore logic.
#include "launcher.cuh"
#include "primitives/buffer_view.cuh"
#include "primitives/constants.h"
-#include "primitives/histogram.cuh"
#include "primitives/trace_access.h"
-#include "womir/adapters/loadstore.cuh"
+#include "womir/adapters/call.cuh"
using namespace riscv;
-using namespace program;
-// Core structs inlined from OpenVM (unchanged)
-template <typename T, size_t NUM_CELLS> struct LoadStoreCoreCols {
- T flags[4];
- /// we need to keep the degree of is_valid and is_load to 1
- T is_valid;
- T is_load;
-
- T read_data[NUM_CELLS];
- T prev_data[NUM_CELLS];
- /// write_data will be constrained against read_data and prev_data
- /// depending on the opcode and the shift amount
- T write_data[NUM_CELLS];
+// CallOpcode enum - mirrors Rust CallOpcode
+enum CallOpcode {
+ RET = 0,
+ CALL = 1,
+ CALL_INDIRECT = 2,
+};
+
+// Core columns - mirrors CallCoreCols<T> in Rust
+template <typename T> struct CallCoreCols {
+ T new_fp_data[RV32_REGISTER_NUM_LIMBS];
+ T to_pc_data[RV32_REGISTER_NUM_LIMBS];
+ T old_fp_data[RV32_REGISTER_NUM_LIMBS];
+ T return_pc_data[RV32_REGISTER_NUM_LIMBS];
+ T is_ret;
+ T is_call;
+ T is_call_indirect;
+};
+
+// Core record - mirrors CallCoreRecord in Rust
+struct CallCoreRecord {
+ uint8_t new_fp_data[RV32_REGISTER_NUM_LIMBS];
+ uint8_t to_pc_data[RV32_REGISTER_NUM_LIMBS];
+ uint8_t old_fp_data[RV32_REGISTER_NUM_LIMBS];
+ uint8_t return_pc_data[RV32_REGISTER_NUM_LIMBS];
+ uint8_t local_opcode;
};
-template <size_t NUM_CELLS> struct LoadStoreCoreRecord {
- uint8_t local_opcode;
- uint8_t shift_amount;
- uint8_t read_data[NUM_CELLS];
- // Note: `prev_data` can be from native address space, so we need to use u32
- uint32_t prev_data[NUM_CELLS];
-};
-
-enum Rv32LoadStoreOpcode {
- LOADW,
- /// LOADBU, LOADHU are unsigned extend opcodes, implemented in the same chip with LOADW
- LOADBU,
- LOADHU,
- STOREW,
- STOREH,
- STOREB,
- /// The following are signed extend opcodes
- LOADB,
- LOADH,
-};
-
-template <size_t NUM_CELLS> struct LoadStoreCore {
-
- template <typename T> using Cols = LoadStoreCoreCols<T, NUM_CELLS>;
-
- __device__ void fill_trace_row(RowSlice row, LoadStoreCoreRecord<NUM_CELLS> record) {
- Rv32LoadStoreOpcode opcode = static_cast<Rv32LoadStoreOpcode>(record.local_opcode);
-
- COL_WRITE_VALUE(row, Cols, is_valid, 1);
- COL_WRITE_VALUE(
- row, Cols, is_load, (opcode == LOADW || opcode == LOADBU || opcode == LOADHU)
- );
- COL_WRITE_ARRAY(row, Cols, read_data, record.read_data);
- COL_WRITE_ARRAY(row, Cols, prev_data, record.prev_data);
+struct CallCore {
+ template <typename T> using Cols = CallCoreCols<T>;
- uint8_t flags[4] = {0};
- uint32_t write_data[NUM_CELLS] = {0};
- uint8_t shift = record.shift_amount;
-
- switch (opcode) {
- case LOADW:
-#pragma unroll
- for (size_t i = 0; i < NUM_CELLS; i++) {
- write_data[i] = record.read_data[i];
- }
- flags[0] = 2;
- break;
- case LOADHU:
-#pragma unroll
- for (size_t i = 0; i < NUM_CELLS / 2; i++) {
- write_data[i] = record.read_data[i + shift];
- }
- switch (shift) {
- case 0:
- flags[1] = 2;
- break;
- case 2:
- flags[2] = 2;
- }
- break;
- case LOADBU:
- write_data[0] = record.read_data[shift];
- switch (shift) {
- case 0:
- flags[3] = 2;
- break;
- case 1:
- flags[0] = 1;
- break;
- case 2:
- flags[1] = 1;
- break;
- case 3:
- flags[2] = 1;
- break;
- }
- break;
- case STOREW:
-#pragma unroll
- for (size_t i = 0; i < NUM_CELLS; i++) {
- write_data[i] = record.read_data[i];
- }
- flags[3] = 1;
- break;
- case STOREH:
-#pragma unroll
- for (size_t i = 0; i < NUM_CELLS; i++) {
- if (i >= shift && i < (NUM_CELLS / 2 + shift)) {
- write_data[i] = record.read_data[i - shift];
- } else {
- write_data[i] = record.prev_data[i];
- }
- }
- switch (shift) {
- case 0:
- flags[0] = flags[1] = 1;
- break;
- case 2:
- flags[0] = flags[2] = 1;
- break;
- }
- break;
- case STOREB:
-#pragma unroll
- for (size_t i = 0; i < NUM_CELLS; i++) {
- write_data[i] = record.prev_data[i];
- }
- write_data[shift] = record.read_data[0];
- switch (shift) {
- case 0:
- flags[0] = flags[3] = 1;
- break;
- case 1:
- flags[1] = flags[2] = 1;
- break;
- case 2:
- flags[1] = flags[3] = 1;
- break;
- case 3:
- flags[2] = flags[3] = 1;
- break;
- }
- break;
- default:
- break;
- }
+ __device__ void fill_trace_row(RowSlice row, CallCoreRecord record) {
+ CallOpcode opcode = static_cast<CallOpcode>(record.local_opcode);
- COL_WRITE_ARRAY(row, Cols, flags, flags);
- COL_WRITE_ARRAY(row, Cols, write_data, write_data);
+ COL_WRITE_ARRAY(row, Cols, new_fp_data, record.new_fp_data);
+ COL_WRITE_ARRAY(row, Cols, to_pc_data, record.to_pc_data);
+ COL_WRITE_ARRAY(row, Cols, old_fp_data, record.old_fp_data);
+ COL_WRITE_ARRAY(row, Cols, return_pc_data, record.return_pc_data);
+
+ COL_WRITE_VALUE(row, Cols, is_ret, opcode == RET);
+ COL_WRITE_VALUE(row, Cols, is_call, opcode == CALL);
+ COL_WRITE_VALUE(row, Cols, is_call_indirect, opcode == CALL_INDIRECT);
}
};
-// [Adapter + Core] columns and record
-template <typename T> struct WomirLoadStoreCols {
- WomirLoadStoreAdapterCols<T> adapter;
- LoadStoreCoreCols<T, RV32_REGISTER_NUM_LIMBS> core;
+// Combined columns and record
+template <typename T> struct WomirCallCols {
+ WomirCallAdapterCols<T> adapter;
+ CallCoreCols<T> core;
};
-struct WomirLoadStoreRecord {
- WomirLoadStoreAdapterRecord adapter;
- LoadStoreCoreRecord<RV32_REGISTER_NUM_LIMBS> core;
+struct WomirCallRecord {
+ WomirCallAdapterRecord adapter;
+ CallCoreRecord core;
};
-__global__ void womir_load_store_tracegen(
+__global__ void womir_call_tracegen(
Fp *trace,
size_t height,
size_t width,
- DeviceBufferConstView<WomirLoadStoreRecord> records,
+ DeviceBufferConstView<WomirCallRecord> records,
size_t pointer_max_bits,
uint32_t *range_checker_ptr,
uint32_t range_checker_num_bins,
@@ -182,35 +78,35 @@
if (idx < records.len()) {
auto const &record = records[idx];
- auto adapter = WomirLoadStoreAdapter(
+ auto adapter = WomirCallAdapter(
pointer_max_bits,
VariableRangeChecker(range_checker_ptr, range_checker_num_bins),
timestamp_max_bits
);
adapter.fill_trace_row(row, record.adapter);
- auto core = LoadStoreCore<RV32_REGISTER_NUM_LIMBS>();
- core.fill_trace_row(row.slice_from(COL_INDEX(WomirLoadStoreCols, core)), record.core);
+ auto core = CallCore();
+ core.fill_trace_row(row.slice_from(COL_INDEX(WomirCallCols, core)), record.core);
} else {
- row.fill_zero(0, sizeof(WomirLoadStoreCols<uint8_t>));
+ row.fill_zero(0, sizeof(WomirCallCols<uint8_t>));
}
}
-extern "C" int _womir_load_store_tracegen(
+extern "C" int _womir_call_tracegen(
Fp *d_trace,
size_t height,
size_t width,
- DeviceBufferConstView<WomirLoadStoreRecord> d_records,
+ DeviceBufferConstView<WomirCallRecord> d_records,
size_t pointer_max_bits,
uint32_t *d_range_checker,
uint32_t range_checker_num_bins,
uint32_t timestamp_max_bits
) {
assert((height & (height - 1)) == 0);
- assert(width == sizeof(WomirLoadStoreCols<uint8_t>));
+ assert(width == sizeof(WomirCallCols<uint8_t>));
auto [grid, block] = kernel_launch_params(height);
- womir_load_store_tracegen<<<grid, block>>>(
+ womir_call_tracegen<<<grid, block>>>(
d_trace,
height,
width,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment