Skip to content

Instantly share code, notes, and snippets.

@chaosma
Last active December 8, 2025 15:33
Show Gist options
  • Select an option

  • Save chaosma/75af6ae21400928e7fcdb33a73155f75 to your computer and use it in GitHub Desktop.

Select an option

Save chaosma/75af6ae21400928e7fcdb33a73155f75 to your computer and use it in GitHub Desktop.
instruction_read_raf
/* Each GPU corresponds to a chunk of cycles.
if T = total_cycles, then following is for total cost
if T = cycles_in_one_GPU, then following is for cost of one GPU.
In the log_T rounds sumcheck, use split-sumchecker to deal with boundary across GPUs.
| Function | Per-Round Complexity | Max Memory |
|--------------------------|------------------------------------|----------------------------------|
| gen_prover_state | O(T) (once) | O(6*T + L*suffix_per_table*M) |
| prover_msg_read_checking | O(L x suffix_per_table × M/2) | O(L x suffix_per_table × M) |
| prover_msg_raf | O(M) | O(3 × M) |
| compute_log_t_message | O(T) | O(3 × T) |
| init_P_and_Q | O(T) (per phase) | O(L × suffix_per_table x M) |
| init_log_t_rounds | O(T) (once) | O(3 × T) |
*/
// =============================================================================
// CONSTANTS
// =============================================================================
let LOG_K = 128; // Total address bits
let PHASES = 16; // Number of decomposition phases
let LOG_M = 8; // rounds per phase = LOG_K / PHASES
let M = 256; // 2^LOG_M = size of each phase
let NUM_TABLES = 41; // L=NUM_TABLES=LookupTables::COUNT
let NUM_PREFIXES = 46; // Number of prefixes
let NUM_SUFFIXES = 43; // Number of suffixes
// =============================================================================
// MAIN ENTRY: compute_message
// =============================================================================
fn compute_message(round, previous_claim) -> UniPoly {
if round < LOG_K {
// Phase 1: First 128 rounds (prefix-suffix decomposition)
return compute_prefix_suffix_prover_message(round, previous_claim);
} else {
// Phase 2: Last log_T rounds (degree 3 dense sumcheck using Gruen's trick)
return compute_log_t_message(round, previous_claim);
}
}
// =============================================================================
// PHASE 1: PREFIX-SUFFIX MESSAGE (rounds 0..127)
// =============================================================================
fn compute_prefix_suffix_prover_message(round, previous_claim) -> [Fr; 3] {
let read_checking = [0; 2]; // [eval_at_0, eval_at_2]
let raf = [0; 2]; // [eval_at_0, eval_at_2]
read_checking = prover_msg_read_checking(round); // eval read
raf = prover_msg_raf(); // eval raf
let eval_at_0 = read_checking[0] + raf[0];
let eval_at_2 = read_checking[1] + raf[1];
// evals at [0,1,2]
return [eval_at_0, prev_claim - eval_at_0, eval_at_2];
}
// =============================================================================
// eval read: dense sumcheck
// Complexity: O(M * NUM_TABLES/2 * NUM_SUFFIXES_PER_TABLE)
// Memory: O(M * NUM_TABLES * NUM_SUFFIXES_PER_TABLE)
// =============================================================================
fn prover_msg_read_checking(round) -> [Fr; 2] {
let len = suffix_polys[0][0].len(); // =M, M/2, ..., 2 for each phase
let log_len = log2(len);
// Get challenge from previous round (if odd round)
let r_x = if round % 2 == 1 { Some(r[round-1]) } else { None };
let eval_0 = 0;
let eval_2_left = 0;
let eval_2_right = 0;
parallel_for b in 0..(len/2) {
let b_bits = LookupBits(b, log_len - 1);
// Compute prefix on the fly at (c, b), O(NUM_PREFIX*M)
let prefixes_c0 = [0; NUM_PREFIXES];
let prefixes_c2 = [0; NUM_PREFIXES];
for p in 0..NUM_PREFIXES {
prefixes_c0[p] = prefix_mle(prefix_checkpoints, r_x, c=0, b_bits, round);
prefixes_c2[p] = prefix_mle(prefix_checkpoints, r_x, c=2, b_bits, round);
}
// For each lookup table, combine prefixes with suffixes
// O(NUM_TABLES * NUM_SUFFIXES_PER_TABLE)
for table_idx in 0..NUM_TABLES {
let suffixes = suffix_polys[table_idx];
// Rust: HighToLow
// CHANGE to: LowToHigh
for s in 0..suffixes.len() {
suffixes_left[s] = suffixes[s][b];
suffixes_right[s] = suffixes[s][b + len/2];
}
// Combine using table-specific formula
// b = (b_h, b_l);
// poly(b) = Σ_i prefixes_i[b_h] × suffixes_i[b_l]
let combined_0 = table.combine(prefixes_c0, suffixes_left);
let combined_2_left = table.combine(prefixes_c2, suffixes_left); // sum_i P_i(2)*Q_i(0)
let combined_2_right = table.combine(prefixes_c2, suffixes_right); // sum_i P_i(2)*Q_i(1)
// Accumulate
eval_0 += combined_0;
eval_2_left += combined_2_left;
eval_2_right += combined_2_right;
}
}
// Final result using quadratic interpolation trick
return [eval_0, eval_2_right + eval_2_right - eval_2_left];
}
// =============================================================================
// eval raf (Left/Right Operand + Identity)
// Complexity: O(Q_len)
// =============================================================================
fn prover_msg_raf() -> [Field; 2] {
let len = identity_ps.Q_len(); // len = M, M/2,...,2 for each phase
// Accumulators
let left_0 = 0;
let left_2 = 0;
let right_0 = 0; // Actually stores identity + right
let right_2 = 0;
parallel_for b in 0..(len/2) {
// b = (b_h, b_l);
// evaluate poly(b)=sum_i P_i(b_h)*Q_i(b_l) at 0, 2
let (i0, i2) = identity_ps.sumcheck_evals(b);
let (r0, r2) = right_operand_ps.sumcheck_evals(b);
let (l0, l2) = left_operand_ps.sumcheck_evals(b);
// Accumulate: left separate, identity+right combined
left_0 += l0;
left_2 += l2;
right_0 += i0 + r0;
right_2 += i2 + r2;
}
// Apply gamma weights: gamma * left + gamma^2 * (identity + right)
return [
left_0 * gamma + right_0 * gamma_sqr,
left_2 * gamma + right_2 * gamma_sqr,
];
}
// =============================================================================
// PHASE 2: LOG_T MESSAGE (Gruen Split EQ)
// Complexity: O(T/2)
// =============================================================================
fn compute_log_t_message(round, previous_claim) -> [Fr; 3] {
let T_current = ra.len(); // Shrinks each round: T >> (round - LOG_K)
let eval_at_0 = 0;
let eval_at_inf = 0;
// MAIN LOOP: O(T_current/2) iterations, parallelized
// Uses Gruen's par_fold_out_in for efficient EQ evaluation
parallel_for j in 0..(T_current/2) {
// Read polynomial values at even/odd indices
let ra_at_0 = ra[2*j];
let ra_at_inf = ra[2*j + 1] - ra[2*j];
let val_at_0 = combined_val_polynomial[2*j];
let val_at_inf = combined_val_polynomial[2*j + 1] - combined_val_polynomial[2*j];
let raf_val_at_0 = combined_raf_val_polynomial[2*j];
let raf_val_at_inf = combined_raf_val_polynomial[2*j + 1] - combined_raf_val_polynomial[2*j];
// Get EQ contribution from Gruen split EQ
let e_in = eq_r_reduction.eval_in(j);
// Accumulate: ra * (val + raf_val) * eq
eval_at_0 += e_in * ra_at_0 * (val_at_0 + raf_val_at_0);
eval_at_inf += e_in * ra_at_inf * (val_at_inf + raf_val_at_inf);
}
// Gruen polynomial construction (degree 3) s(X)=eq(tau,X)*t(X)
// t[0] = eval_at_0; t[inf] = eval_at_inf; prev_claim = s[0] + s[1];
// return [s[0], s[1], s[2]]
return eq_r_reduction.gruen_poly_deg_3(eval_at_0, eval_at_inf, previous_claim);
}
// =============================================================================
// INIT_PHASE (called at start of each of phase)
// Complexity: O(T) for condensation + O(T) for init_suffix_polys
// =============================================================================
fn init_phase(phase) {
let log_m = LOG_M; // 8
let m = M; // 256
// CONDENSATION: multiply u_evals by previous phase's v table
// Complexity: O(T)
if phase != 0 {
parallel_for j in 0..T {
let k = lookup_indices[j];
let (prefix, _) = k.split((PHASES - phase) * log_m);
let k_bound = prefix % m;
u_evals[j] *= v[phase - 1][k_bound];
}
}
// INIT Q for raf suffix
// Complexity: O(T) for each
parallel {
// Left and right share same indices (interleaved operands)
init_Q_dual(left_operand_ps, right_operand_ps,
u_evals, lookup_indices_uninterleave, lookup_indices);
// Identity uses non-interleaved indices
identity_ps.init_Q(u_evals, lookup_indices_identity, lookup_indices);
}
// INIT Q for read suffix
// loop over T cycles, aggregated into O(L*M) buckets
init_suffix_polys(phase);
// INIT P for raf prefix
// Complexity: O(M) per decomposition
identity_ps.init_P(prefix_registry);
right_operand_ps.init_P(prefix_registry);
left_operand_ps.init_P(prefix_registry);
// Reset expanding table for this phase
v[phase].reset(1);
}
// =============================================================================
// INIT_SUFFIX_POLYS
// Complexity: O(T) reads, O(NUM_TABLES * NUM_SUFFIXES * M) memory
// =============================================================================
fn init_suffix_polys(phase) {
let log_m = LOG_M;
let m = M;
// For each table, initialize suffix polynomials
// Outer loop: O(NUM_TABLES)
parallel_for table_idx in 0..NUM_TABLES {
let lookup_indices_for_table = lookup_indices_by_table[table_idx];
let suffixes = table.suffixes(); // NUM_SUFFIXES_PER_TABLE
// Allocate result: NUM_SUFFIXES_PER_TABLE polynomials of size M each
let result = [[0; M]; suffixes.len()];
// MAIN AGGREGATION: O(|lookup_indices_for_table|)
// This is chunked and reduced in parallel
for j in lookup_indices_for_table {
let k = lookup_indices[j];
let (prefix_bits, suffix_bits) = k.split((PHASES - 1 - phase) * log_m);
// For each suffix polynomial, compute contribution
for s in 0..suffixes.len() {
let t = suffix.suffix_mle(suffix_bits); // Table lookup, O(1)
if t != 0 {
let u = u_evals[j];
result[s][prefix_bits % m] += u * t; // Bucket by prefix
}
}
}
suffix_polys[table_idx] = result;
}
}
// =============================================================================
// call it once at start of rest log_T rounds
// Complexity: O(T * PHASES) for ra, O(T) for val/raf_val
// =============================================================================
fn init_log_t_rounds(gamma, gamma_sqr) {
let log_m = LOG_M;
let m = M;
// MATERIALIZE RA POLYNOMIAL
// Complexity: O(T * PHASES)
let ra = [0; T];
parallel_for j in 0..T {
let k = lookup_indices[j];
let product = 1;
for phase in 0..PHASES {
let (prefix, _) = k.split((PHASES - 1 - phase) * log_m);
let k_bound = prefix % m;
product *= v[phase][k_bound];
}
ra[j] = product;
}
// MATERIALIZE COMBINED_VAL_POLYNOMIAL
// Complexity: O(T * NUM_SUFFIXES_PER_TABLE)
let combined_val = [0; T];
parallel_for j in 0..T {
if lookup_tables[j] is Some(table) {
// Get suffix values (all bound to 0 at this point)
let suffixes = [0; table.suffixes().len()];
for s in 0..table.suffixes().len() {
suffixes[s] = suffix.suffix_mle(0);
}
combined_val[j] = table.combine(prefixes, suffixes);
}
}
// MATERIALIZE COMBINED_RAF_VAL_POLYNOMIAL
// Complexity: O(T)
let combined_raf_val = [0; T];
parallel_for j in 0..T {
if is_interleaved_operands[j] {
combined_raf_val[j] = gamma * prefix_registry[LeftOperand]
+ gamma_sqr * prefix_registry[RightOperand];
} else {
combined_raf_val[j] = gamma_sqr * prefix_registry[Identity];
}
}
}
// bind
fn ingest_challenge(r_j, round) {
r.push(r_j);
if round < LOG_K {
let log_m = LOG_M;
let phase = round / log_m;
// bind read suffix (prefix is computed on the fly)
// Complexity: O(NUM_TABLES * NUM_SUFFIXES * current_len)
parallel {
for table_idx in 0..NUM_TABLES {
for s in 0..suffix_polys[table_idx].len() {
// CHANGE TO: LowToHigh
suffix_polys[table_idx][s].bind_parallel(r_j, HighToLow);
// bind halves the polynomial size
}
}
}
// bind raf
// Complexity: O(current_Q_len) each
parallel {
identity_ps.bind(r_j);
right_operand_ps.bind(r_j);
left_operand_ps.bind(r_j);
}
// v[phase]=eq(r,x) grows from 1 → 2 → 4 → ... → 256 over 8 rounds
v[phase].update(r_j);
// UPDATE CHECKPOINTS (every 2 rounds)
if (round + 1) % 2 == 0 {
Prefixes::update_checkpoints(prefix_checkpoints, r[round-1], r[round], ...);
}
// PHASE TRANSITION (every LOG_M rounds)
if (round + 1) % log_m == 0 {
prefix_registry.update_checkpoints();
if phase != PHASES - 1 {
init_phase(phase + 1); // O(T)
}
}
// TRANSITION TO LOG_T ROUNDS
if round + 1 == LOG_K {
init_log_t_rounds(gamma, gamma_sqr); // O(T * PHASES)
}
} else {
// LOG_T ROUNDS: bind ra, val, raf_val polynomials
// Complexity: O(current_T) for each
eq_r_reduction.bind(r_j);
ra.bind_parallel(r_j, LowToHigh);
combined_val_polynomial.bind_parallel(r_j, LowToHigh);
combined_raf_val_polynomial.bind_parallel(r_j, LowToHigh);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment