Skip to content

Instantly share code, notes, and snippets.

@chaosma
Created December 9, 2025 03:38
Show Gist options
  • Select an option

  • Save chaosma/c5424b85a7a1bfa9279121a80dbacbbb to your computer and use it in GitHub Desktop.

Select an option

Save chaosma/c5424b85a7a1bfa9279121a80dbacbbb to your computer and use it in GitHub Desktop.
Instruction read raf
/* Each GPU corresponds to a chunk of cycles.
if T = total_cycles, then following is for total cost
if T = cycles_in_one_GPU, then following is for cost of one GPU.
In the log_T rounds sumcheck, use split-sumchecker to deal with boundary across GPUs.
| Function | Per-Round Complexity | Max Memory |
|--------------------------|------------------------------------|----------------------------------|
| gen_prover_state | O(T) (once) | O(6*T + L*suffix_per_table*M) |
| prover_msg_read_checking | O(L x suffix_per_table × M/2) | O(L x suffix_per_table × M) |
| prover_msg_raf | O(M) | O(3 × M) |
| compute_log_t_message | O(T) | O(3 × T) |
| init_P_and_Q | O(T) (per phase) | O(L × suffix_per_table x M) |
| init_log_t_rounds | O(T) (once) | O(3 × T) |
*/
// =============================================================================
// CONSTANTS
// =============================================================================
constexpr int LOG_K = 128; // Total address bits
constexpr int PHASES = 16; // Number of decomposition phases
constexpr int LOG_M = 8; // rounds per phase = LOG_K / PHASES
constexpr int M = 256; // 2^LOG_M = size of each phase
constexpr int NUM_TABLES = 41; // L=NUM_TABLES=LookupTables::COUNT
constexpr int NUM_PREFIXES = 46; // Number of prefixes
constexpr int NUM_SUFFIXES = 43; // Number of suffixes
// =============================================================================
// MAIN ENTRY: compute_message
// =============================================================================
std::array<Fr, 3> compute_message(int round, Fr previous_claim) {
if (round < LOG_K) {
// Phase 1: First 128 rounds (prefix-suffix decomposition)
return compute_prefix_suffix_prover_message(round, previous_claim);
} else {
// Phase 2: Last log_T rounds (degree 3 dense sumcheck using Gruen's trick)
return compute_log_t_message(round, previous_claim);
}
}
// =============================================================================
// PHASE 1: PREFIX-SUFFIX MESSAGE (rounds 0..127)
// =============================================================================
std::array<Fr, 3> compute_prefix_suffix_prover_message(int round, Fr previous_claim) {
std::array<Fr, 2> read_checking = {0, 0}; // [eval_at_0, eval_at_2]
std::array<Fr, 2> raf = {0, 0}; // [eval_at_0, eval_at_2]
read_checking = prover_msg_read_checking(round); // eval read
raf = prover_msg_raf(); // eval raf
Fr eval_at_0 = read_checking[0] + raf[0];
Fr eval_at_2 = read_checking[1] + raf[1];
// evals at [0,1,2]
return {eval_at_0, previous_claim - eval_at_0, eval_at_2};
}
// =============================================================================
// eval read: dense sumcheck
// Complexity: O(M * NUM_TABLES/2 * NUM_SUFFIXES_PER_TABLE)
// Memory: O(M * NUM_TABLES * NUM_SUFFIXES_PER_TABLE)
// =============================================================================
std::array<Fr, 2> prover_msg_read_checking(int round) {
int len = suffix_polys[0][0].len(); // =M, M/2, ..., 2 for each phase
int log_len = log2(len);
// Get challenge from previous round (if odd round)
std::optional<Fr> r_x = (round % 2 == 1) ? std::optional<Fr>(r[round-1]) : std::nullopt;
Fr eval_0 = 0;
Fr eval_2_left = 0;
Fr eval_2_right = 0;
#pragma omp parallel for reduction(+:eval_0, eval_2_left, eval_2_right)
for (int b = 0; b < len/2; b++) {
LookupBits b_bits(b, log_len - 1);
// Compute prefix on the fly at (c, b), O(NUM_PREFIX*M)
std::array<Fr, NUM_PREFIXES> prefixes_c0 = {0};
std::array<Fr, NUM_PREFIXES> prefixes_c2 = {0};
for (int p = 0; p < NUM_PREFIXES; p++) {
prefixes_c0[p] = prefix_mle(prefix_checkpoints, r_x, 0, b_bits, round);
prefixes_c2[p] = prefix_mle(prefix_checkpoints, r_x, 2, b_bits, round);
}
// For each lookup table, combine prefixes with suffixes
// O(NUM_TABLES * NUM_SUFFIXES_PER_TABLE)
for (int table_idx = 0; table_idx < NUM_TABLES; table_idx++) {
auto& suffixes = suffix_polys[table_idx];
// Rust: HighToLow
// CHANGE to: LowToHigh
std::vector<Fr> suffixes_left(suffixes.size());
std::vector<Fr> suffixes_right(suffixes.size());
for (size_t s = 0; s < suffixes.size(); s++) {
suffixes_left[s] = suffixes[s][b];
suffixes_right[s] = suffixes[s][b + len/2];
}
// Combine using table-specific formula
// b = (b_h, b_l);
// poly(b) = Σ_i prefixes_i[b_h] × suffixes_i[b_l]
Fr combined_0 = table.combine(prefixes_c0, suffixes_left);
Fr combined_2_left = table.combine(prefixes_c2, suffixes_left); // sum_i P_i(2)*Q_i(0)
Fr combined_2_right = table.combine(prefixes_c2, suffixes_right); // sum_i P_i(2)*Q_i(1)
// Accumulate
eval_0 += combined_0;
eval_2_left += combined_2_left;
eval_2_right += combined_2_right;
}
}
// Final result using quadratic interpolation trick
return {eval_0, eval_2_right + eval_2_right - eval_2_left};
}
// =============================================================================
// eval raf (Left/Right Operand + Identity)
// Complexity: O(Q_len)
// =============================================================================
std::array<Fr, 2> prover_msg_raf() {
int len = identity_ps.Q_len(); // len = M, M/2,...,2 for each phase
// Accumulators
Fr left_0 = 0;
Fr left_2 = 0;
Fr right_0 = 0; // Actually stores identity + right
Fr right_2 = 0;
#pragma omp parallel for reduction(+:left_0, left_2, right_0, right_2)
for (int b = 0; b < len/2; b++) {
// b = (b_h, b_l);
// evaluate poly(b)=sum_i P_i(b_h)*Q_i(b_l) at 0, 2
auto [i0, i2] = identity_ps.sumcheck_evals(b);
auto [r0, r2] = right_operand_ps.sumcheck_evals(b);
auto [l0, l2] = left_operand_ps.sumcheck_evals(b);
// Accumulate: left separate, identity+right combined
left_0 += l0;
left_2 += l2;
right_0 += i0 + r0;
right_2 += i2 + r2;
}
// Apply gamma weights: gamma * left + gamma^2 * (identity + right)
return {
left_0 * gamma + right_0 * gamma_sqr,
left_2 * gamma + right_2 * gamma_sqr
};
}
// =============================================================================
// PHASE 2: LOG_T MESSAGE (Gruen Split EQ)
// Complexity: O(T/2)
// =============================================================================
std::array<Fr, 3> compute_log_t_message(int round, Fr previous_claim) {
int T_current = ra.len(); // Shrinks each round: T >> (round - LOG_K)
Fr eval_at_0 = 0;
Fr eval_at_inf = 0;
// MAIN LOOP: O(T_current/2) iterations, parallelized
// Uses Gruen's par_fold_out_in for efficient EQ evaluation
#pragma omp parallel for reduction(+:eval_at_0, eval_at_inf)
for (int j = 0; j < T_current/2; j++) {
// Read polynomial values at even/odd indices
Fr ra_at_0 = ra[2*j];
Fr ra_at_inf = ra[2*j + 1] - ra[2*j];
Fr val_at_0 = combined_val_polynomial[2*j];
Fr val_at_inf = combined_val_polynomial[2*j + 1] - combined_val_polynomial[2*j];
Fr raf_val_at_0 = combined_raf_val_polynomial[2*j];
Fr raf_val_at_inf = combined_raf_val_polynomial[2*j + 1] - combined_raf_val_polynomial[2*j];
// Get EQ contribution from Gruen split EQ
Fr e_in = eq_r_reduction.eval_in(j);
// Accumulate: ra * (val + raf_val) * eq
eval_at_0 += e_in * ra_at_0 * (val_at_0 + raf_val_at_0);
eval_at_inf += e_in * ra_at_inf * (val_at_inf + raf_val_at_inf);
}
// Gruen polynomial construction (degree 3) s(X)=eq(tau,X)*t(X)
// t[0] = eval_at_0; t[inf] = eval_at_inf; prev_claim = s[0] + s[1];
// return [s[0], s[1], s[2]]
return eq_r_reduction.gruen_poly_deg_3(eval_at_0, eval_at_inf, previous_claim);
}
// =============================================================================
// INIT_PHASE (called at start of each of phase)
// Complexity: O(T) for condensation + O(T) for init_suffix_polys
// =============================================================================
void init_phase(int phase) {
int log_m = LOG_M; // 8
int m = M; // 256
// CONDENSATION: multiply u_evals by previous phase's v table
// Complexity: O(T)
if (phase != 0) {
#pragma omp parallel for
for (int j = 0; j < T; j++) {
auto k = lookup_indices[j];
auto [prefix, _] = k.split((PHASES - phase) * log_m);
int k_bound = prefix % m;
u_evals[j] *= v[phase - 1][k_bound];
}
}
// INIT Q for raf suffix
// Complexity: O(T) for each
#pragma omp parallel sections
{
#pragma omp section
{
// Left and right share same indices (interleaved operands)
init_Q_dual(left_operand_ps, right_operand_ps,
u_evals, lookup_indices_uninterleave, lookup_indices);
}
#pragma omp section
{
// Identity uses non-interleaved indices
identity_ps.init_Q(u_evals, lookup_indices_identity, lookup_indices);
}
}
// INIT Q for read suffix
// loop over T cycles, aggregated into O(L*M) buckets
init_suffix_polys(phase);
// INIT P for raf prefix
// Complexity: O(M) per decomposition
identity_ps.init_P(prefix_registry);
right_operand_ps.init_P(prefix_registry);
left_operand_ps.init_P(prefix_registry);
// Reset expanding table for this phase
v[phase].reset(1);
}
// =============================================================================
// INIT_SUFFIX_POLYS
// Complexity: O(T) reads, O(NUM_TABLES * NUM_SUFFIXES * M) memory
// =============================================================================
void init_suffix_polys(int phase) {
int log_m = LOG_M;
int m = M;
// For each table, initialize suffix polynomials
// Outer loop: O(NUM_TABLES)
#pragma omp parallel for
for (int table_idx = 0; table_idx < NUM_TABLES; table_idx++) {
auto& lookup_indices_for_table = lookup_indices_by_table[table_idx];
auto suffixes = table.suffixes(); // NUM_SUFFIXES_PER_TABLE
// Allocate result: NUM_SUFFIXES_PER_TABLE polynomials of size M each
std::vector<std::vector<Fr>> result(suffixes.size(), std::vector<Fr>(m, 0));
// MAIN AGGREGATION: O(|lookup_indices_for_table|)
// This is chunked and reduced in parallel
for (int j : lookup_indices_for_table) {
auto k = lookup_indices[j];
auto [prefix_bits, suffix_bits] = k.split((PHASES - 1 - phase) * log_m);
// For each suffix polynomial, compute contribution
for (size_t s = 0; s < suffixes.size(); s++) {
auto t = suffix.suffix_mle(suffix_bits); // Table lookup, O(1)
if (t != 0) {
Fr u = u_evals[j];
result[s][prefix_bits % m] += u * t; // Bucket by prefix
}
}
}
suffix_polys[table_idx] = result;
}
}
// =============================================================================
// call it once at start of rest log_T rounds
// Complexity: O(T * PHASES) for ra, O(T) for val/raf_val
// =============================================================================
void init_log_t_rounds(Fr gamma, Fr gamma_sqr) {
int log_m = LOG_M;
int m = M;
// MATERIALIZE RA POLYNOMIAL
// Complexity: O(T * PHASES)
std::vector<Fr> ra(T, 0);
#pragma omp parallel for
for (int j = 0; j < T; j++) {
auto k = lookup_indices[j];
Fr product = 1;
for (int phase = 0; phase < PHASES; phase++) {
auto [prefix, _] = k.split((PHASES - 1 - phase) * log_m);
int k_bound = prefix % m;
product *= v[phase][k_bound];
}
ra[j] = product;
}
// MATERIALIZE COMBINED_VAL_POLYNOMIAL
// Complexity: O(T * NUM_SUFFIXES_PER_TABLE)
std::vector<Fr> combined_val(T, 0);
#pragma omp parallel for
for (int j = 0; j < T; j++) {
if (lookup_tables[j].has_value()) {
auto table = lookup_tables[j].value();
// Get suffix values (all bound to 0 at this point)
std::vector<Fr> suffixes(table.suffixes().size());
for (size_t s = 0; s < table.suffixes().size(); s++) {
suffixes[s] = suffix.suffix_mle(0);
}
combined_val[j] = table.combine(prefixes, suffixes);
}
}
// MATERIALIZE COMBINED_RAF_VAL_POLYNOMIAL
// Complexity: O(T)
std::vector<Fr> combined_raf_val(T, 0);
#pragma omp parallel for
for (int j = 0; j < T; j++) {
if (is_interleaved_operands[j]) {
combined_raf_val[j] = gamma * prefix_registry[LeftOperand]
+ gamma_sqr * prefix_registry[RightOperand];
} else {
combined_raf_val[j] = gamma_sqr * prefix_registry[Identity];
}
}
}
// bind
void ingest_challenge(Fr r_j, int round) {
r.push_back(r_j);
if (round < LOG_K) {
int log_m = LOG_M;
int phase = round / log_m;
// bind read suffix (prefix is computed on the fly)
// Complexity: O(NUM_TABLES * NUM_SUFFIXES * current_len)
#pragma omp parallel for collapse(2)
for (int table_idx = 0; table_idx < NUM_TABLES; table_idx++) {
for (size_t s = 0; s < suffix_polys[table_idx].size(); s++) {
// CHANGE TO: LowToHigh
suffix_polys[table_idx][s].bind_parallel(r_j, HighToLow);
// bind halves the polynomial size
}
}
// bind raf
// Complexity: O(current_Q_len) each
#pragma omp parallel sections
{
#pragma omp section
identity_ps.bind(r_j);
#pragma omp section
right_operand_ps.bind(r_j);
#pragma omp section
left_operand_ps.bind(r_j);
}
// v[phase]=eq(r,x) grows from 1 → 2 → 4 → ... → 256 over 8 rounds
v[phase].update(r_j);
// UPDATE CHECKPOINTS (every 2 rounds)
if ((round + 1) % 2 == 0) {
Prefixes::update_checkpoints(prefix_checkpoints, r[round-1], r[round]);
}
// PHASE TRANSITION (every LOG_M rounds)
if ((round + 1) % log_m == 0) {
prefix_registry.update_checkpoints();
if (phase != PHASES - 1) {
init_phase(phase + 1); // O(T)
}
}
// TRANSITION TO LOG_T ROUNDS
if (round + 1 == LOG_K) {
init_log_t_rounds(gamma, gamma_sqr); // O(T * PHASES)
}
} else {
// LOG_T ROUNDS: bind ra, val, raf_val polynomials
// Complexity: O(current_T) for each
eq_r_reduction.bind(r_j);
ra.bind_parallel(r_j, LowToHigh);
combined_val_polynomial.bind_parallel(r_j, LowToHigh);
combined_raf_val_polynomial.bind_parallel(r_j, LowToHigh);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment