Created
December 9, 2025 03:38
-
-
Save chaosma/c5424b85a7a1bfa9279121a80dbacbbb to your computer and use it in GitHub Desktop.
Instruction read raf
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /* Each GPU corresponds to a chunk of cycles. | |
| if T = total_cycles, then following is for total cost | |
| if T = cycles_in_one_GPU, then following is for cost of one GPU. | |
| In the log_T rounds sumcheck, use split-sumchecker to deal with boundary across GPUs. | |
| | Function | Per-Round Complexity | Max Memory | | |
| |--------------------------|------------------------------------|----------------------------------| | |
| | gen_prover_state | O(T) (once) | O(6*T + L*suffix_per_table*M) | | |
| | prover_msg_read_checking | O(L x suffix_per_table × M/2) | O(L x suffix_per_table × M) | | |
| | prover_msg_raf | O(M) | O(3 × M) | | |
| | compute_log_t_message | O(T) | O(3 × T) | | |
| | init_P_and_Q | O(T) (per phase) | O(L × suffix_per_table x M) | | |
| | init_log_t_rounds | O(T) (once) | O(3 × T) | | |
| */ | |
| // ============================================================================= | |
| // CONSTANTS | |
| // ============================================================================= | |
| constexpr int LOG_K = 128; // Total address bits | |
| constexpr int PHASES = 16; // Number of decomposition phases | |
| constexpr int LOG_M = 8; // rounds per phase = LOG_K / PHASES | |
| constexpr int M = 256; // 2^LOG_M = size of each phase | |
| constexpr int NUM_TABLES = 41; // L=NUM_TABLES=LookupTables::COUNT | |
| constexpr int NUM_PREFIXES = 46; // Number of prefixes | |
| constexpr int NUM_SUFFIXES = 43; // Number of suffixes | |
| // ============================================================================= | |
| // MAIN ENTRY: compute_message | |
| // ============================================================================= | |
| std::array<Fr, 3> compute_message(int round, Fr previous_claim) { | |
| if (round < LOG_K) { | |
| // Phase 1: First 128 rounds (prefix-suffix decomposition) | |
| return compute_prefix_suffix_prover_message(round, previous_claim); | |
| } else { | |
| // Phase 2: Last log_T rounds (degree 3 dense sumcheck using Gruen's trick) | |
| return compute_log_t_message(round, previous_claim); | |
| } | |
| } | |
| // ============================================================================= | |
| // PHASE 1: PREFIX-SUFFIX MESSAGE (rounds 0..127) | |
| // ============================================================================= | |
| std::array<Fr, 3> compute_prefix_suffix_prover_message(int round, Fr previous_claim) { | |
| std::array<Fr, 2> read_checking = {0, 0}; // [eval_at_0, eval_at_2] | |
| std::array<Fr, 2> raf = {0, 0}; // [eval_at_0, eval_at_2] | |
| read_checking = prover_msg_read_checking(round); // eval read | |
| raf = prover_msg_raf(); // eval raf | |
| Fr eval_at_0 = read_checking[0] + raf[0]; | |
| Fr eval_at_2 = read_checking[1] + raf[1]; | |
| // evals at [0,1,2] | |
| return {eval_at_0, previous_claim - eval_at_0, eval_at_2}; | |
| } | |
| // ============================================================================= | |
| // eval read: dense sumcheck | |
| // Complexity: O(M * NUM_TABLES/2 * NUM_SUFFIXES_PER_TABLE) | |
| // Memory: O(M * NUM_TABLES * NUM_SUFFIXES_PER_TABLE) | |
| // ============================================================================= | |
| std::array<Fr, 2> prover_msg_read_checking(int round) { | |
| int len = suffix_polys[0][0].len(); // =M, M/2, ..., 2 for each phase | |
| int log_len = log2(len); | |
| // Get challenge from previous round (if odd round) | |
| std::optional<Fr> r_x = (round % 2 == 1) ? std::optional<Fr>(r[round-1]) : std::nullopt; | |
| Fr eval_0 = 0; | |
| Fr eval_2_left = 0; | |
| Fr eval_2_right = 0; | |
| #pragma omp parallel for reduction(+:eval_0, eval_2_left, eval_2_right) | |
| for (int b = 0; b < len/2; b++) { | |
| LookupBits b_bits(b, log_len - 1); | |
| // Compute prefix on the fly at (c, b), O(NUM_PREFIX*M) | |
| std::array<Fr, NUM_PREFIXES> prefixes_c0 = {0}; | |
| std::array<Fr, NUM_PREFIXES> prefixes_c2 = {0}; | |
| for (int p = 0; p < NUM_PREFIXES; p++) { | |
| prefixes_c0[p] = prefix_mle(prefix_checkpoints, r_x, 0, b_bits, round); | |
| prefixes_c2[p] = prefix_mle(prefix_checkpoints, r_x, 2, b_bits, round); | |
| } | |
| // For each lookup table, combine prefixes with suffixes | |
| // O(NUM_TABLES * NUM_SUFFIXES_PER_TABLE) | |
| for (int table_idx = 0; table_idx < NUM_TABLES; table_idx++) { | |
| auto& suffixes = suffix_polys[table_idx]; | |
| // Rust: HighToLow | |
| // CHANGE to: LowToHigh | |
| std::vector<Fr> suffixes_left(suffixes.size()); | |
| std::vector<Fr> suffixes_right(suffixes.size()); | |
| for (size_t s = 0; s < suffixes.size(); s++) { | |
| suffixes_left[s] = suffixes[s][b]; | |
| suffixes_right[s] = suffixes[s][b + len/2]; | |
| } | |
| // Combine using table-specific formula | |
| // b = (b_h, b_l); | |
| // poly(b) = Σ_i prefixes_i[b_h] × suffixes_i[b_l] | |
| Fr combined_0 = table.combine(prefixes_c0, suffixes_left); | |
| Fr combined_2_left = table.combine(prefixes_c2, suffixes_left); // sum_i P_i(2)*Q_i(0) | |
| Fr combined_2_right = table.combine(prefixes_c2, suffixes_right); // sum_i P_i(2)*Q_i(1) | |
| // Accumulate | |
| eval_0 += combined_0; | |
| eval_2_left += combined_2_left; | |
| eval_2_right += combined_2_right; | |
| } | |
| } | |
| // Final result using quadratic interpolation trick | |
| return {eval_0, eval_2_right + eval_2_right - eval_2_left}; | |
| } | |
| // ============================================================================= | |
| // eval raf (Left/Right Operand + Identity) | |
| // Complexity: O(Q_len) | |
| // ============================================================================= | |
| std::array<Fr, 2> prover_msg_raf() { | |
| int len = identity_ps.Q_len(); // len = M, M/2,...,2 for each phase | |
| // Accumulators | |
| Fr left_0 = 0; | |
| Fr left_2 = 0; | |
| Fr right_0 = 0; // Actually stores identity + right | |
| Fr right_2 = 0; | |
| #pragma omp parallel for reduction(+:left_0, left_2, right_0, right_2) | |
| for (int b = 0; b < len/2; b++) { | |
| // b = (b_h, b_l); | |
| // evaluate poly(b)=sum_i P_i(b_h)*Q_i(b_l) at 0, 2 | |
| auto [i0, i2] = identity_ps.sumcheck_evals(b); | |
| auto [r0, r2] = right_operand_ps.sumcheck_evals(b); | |
| auto [l0, l2] = left_operand_ps.sumcheck_evals(b); | |
| // Accumulate: left separate, identity+right combined | |
| left_0 += l0; | |
| left_2 += l2; | |
| right_0 += i0 + r0; | |
| right_2 += i2 + r2; | |
| } | |
| // Apply gamma weights: gamma * left + gamma^2 * (identity + right) | |
| return { | |
| left_0 * gamma + right_0 * gamma_sqr, | |
| left_2 * gamma + right_2 * gamma_sqr | |
| }; | |
| } | |
| // ============================================================================= | |
| // PHASE 2: LOG_T MESSAGE (Gruen Split EQ) | |
| // Complexity: O(T/2) | |
| // ============================================================================= | |
| std::array<Fr, 3> compute_log_t_message(int round, Fr previous_claim) { | |
| int T_current = ra.len(); // Shrinks each round: T >> (round - LOG_K) | |
| Fr eval_at_0 = 0; | |
| Fr eval_at_inf = 0; | |
| // MAIN LOOP: O(T_current/2) iterations, parallelized | |
| // Uses Gruen's par_fold_out_in for efficient EQ evaluation | |
| #pragma omp parallel for reduction(+:eval_at_0, eval_at_inf) | |
| for (int j = 0; j < T_current/2; j++) { | |
| // Read polynomial values at even/odd indices | |
| Fr ra_at_0 = ra[2*j]; | |
| Fr ra_at_inf = ra[2*j + 1] - ra[2*j]; | |
| Fr val_at_0 = combined_val_polynomial[2*j]; | |
| Fr val_at_inf = combined_val_polynomial[2*j + 1] - combined_val_polynomial[2*j]; | |
| Fr raf_val_at_0 = combined_raf_val_polynomial[2*j]; | |
| Fr raf_val_at_inf = combined_raf_val_polynomial[2*j + 1] - combined_raf_val_polynomial[2*j]; | |
| // Get EQ contribution from Gruen split EQ | |
| Fr e_in = eq_r_reduction.eval_in(j); | |
| // Accumulate: ra * (val + raf_val) * eq | |
| eval_at_0 += e_in * ra_at_0 * (val_at_0 + raf_val_at_0); | |
| eval_at_inf += e_in * ra_at_inf * (val_at_inf + raf_val_at_inf); | |
| } | |
| // Gruen polynomial construction (degree 3) s(X)=eq(tau,X)*t(X) | |
| // t[0] = eval_at_0; t[inf] = eval_at_inf; prev_claim = s[0] + s[1]; | |
| // return [s[0], s[1], s[2]] | |
| return eq_r_reduction.gruen_poly_deg_3(eval_at_0, eval_at_inf, previous_claim); | |
| } | |
| // ============================================================================= | |
| // INIT_PHASE (called at start of each of phase) | |
| // Complexity: O(T) for condensation + O(T) for init_suffix_polys | |
| // ============================================================================= | |
| void init_phase(int phase) { | |
| int log_m = LOG_M; // 8 | |
| int m = M; // 256 | |
| // CONDENSATION: multiply u_evals by previous phase's v table | |
| // Complexity: O(T) | |
| if (phase != 0) { | |
| #pragma omp parallel for | |
| for (int j = 0; j < T; j++) { | |
| auto k = lookup_indices[j]; | |
| auto [prefix, _] = k.split((PHASES - phase) * log_m); | |
| int k_bound = prefix % m; | |
| u_evals[j] *= v[phase - 1][k_bound]; | |
| } | |
| } | |
| // INIT Q for raf suffix | |
| // Complexity: O(T) for each | |
| #pragma omp parallel sections | |
| { | |
| #pragma omp section | |
| { | |
| // Left and right share same indices (interleaved operands) | |
| init_Q_dual(left_operand_ps, right_operand_ps, | |
| u_evals, lookup_indices_uninterleave, lookup_indices); | |
| } | |
| #pragma omp section | |
| { | |
| // Identity uses non-interleaved indices | |
| identity_ps.init_Q(u_evals, lookup_indices_identity, lookup_indices); | |
| } | |
| } | |
| // INIT Q for read suffix | |
| // loop over T cycles, aggregated into O(L*M) buckets | |
| init_suffix_polys(phase); | |
| // INIT P for raf prefix | |
| // Complexity: O(M) per decomposition | |
| identity_ps.init_P(prefix_registry); | |
| right_operand_ps.init_P(prefix_registry); | |
| left_operand_ps.init_P(prefix_registry); | |
| // Reset expanding table for this phase | |
| v[phase].reset(1); | |
| } | |
| // ============================================================================= | |
| // INIT_SUFFIX_POLYS | |
| // Complexity: O(T) reads, O(NUM_TABLES * NUM_SUFFIXES * M) memory | |
| // ============================================================================= | |
| void init_suffix_polys(int phase) { | |
| int log_m = LOG_M; | |
| int m = M; | |
| // For each table, initialize suffix polynomials | |
| // Outer loop: O(NUM_TABLES) | |
| #pragma omp parallel for | |
| for (int table_idx = 0; table_idx < NUM_TABLES; table_idx++) { | |
| auto& lookup_indices_for_table = lookup_indices_by_table[table_idx]; | |
| auto suffixes = table.suffixes(); // NUM_SUFFIXES_PER_TABLE | |
| // Allocate result: NUM_SUFFIXES_PER_TABLE polynomials of size M each | |
| std::vector<std::vector<Fr>> result(suffixes.size(), std::vector<Fr>(m, 0)); | |
| // MAIN AGGREGATION: O(|lookup_indices_for_table|) | |
| // This is chunked and reduced in parallel | |
| for (int j : lookup_indices_for_table) { | |
| auto k = lookup_indices[j]; | |
| auto [prefix_bits, suffix_bits] = k.split((PHASES - 1 - phase) * log_m); | |
| // For each suffix polynomial, compute contribution | |
| for (size_t s = 0; s < suffixes.size(); s++) { | |
| auto t = suffix.suffix_mle(suffix_bits); // Table lookup, O(1) | |
| if (t != 0) { | |
| Fr u = u_evals[j]; | |
| result[s][prefix_bits % m] += u * t; // Bucket by prefix | |
| } | |
| } | |
| } | |
| suffix_polys[table_idx] = result; | |
| } | |
| } | |
| // ============================================================================= | |
| // call it once at start of rest log_T rounds | |
| // Complexity: O(T * PHASES) for ra, O(T) for val/raf_val | |
| // ============================================================================= | |
| void init_log_t_rounds(Fr gamma, Fr gamma_sqr) { | |
| int log_m = LOG_M; | |
| int m = M; | |
| // MATERIALIZE RA POLYNOMIAL | |
| // Complexity: O(T * PHASES) | |
| std::vector<Fr> ra(T, 0); | |
| #pragma omp parallel for | |
| for (int j = 0; j < T; j++) { | |
| auto k = lookup_indices[j]; | |
| Fr product = 1; | |
| for (int phase = 0; phase < PHASES; phase++) { | |
| auto [prefix, _] = k.split((PHASES - 1 - phase) * log_m); | |
| int k_bound = prefix % m; | |
| product *= v[phase][k_bound]; | |
| } | |
| ra[j] = product; | |
| } | |
| // MATERIALIZE COMBINED_VAL_POLYNOMIAL | |
| // Complexity: O(T * NUM_SUFFIXES_PER_TABLE) | |
| std::vector<Fr> combined_val(T, 0); | |
| #pragma omp parallel for | |
| for (int j = 0; j < T; j++) { | |
| if (lookup_tables[j].has_value()) { | |
| auto table = lookup_tables[j].value(); | |
| // Get suffix values (all bound to 0 at this point) | |
| std::vector<Fr> suffixes(table.suffixes().size()); | |
| for (size_t s = 0; s < table.suffixes().size(); s++) { | |
| suffixes[s] = suffix.suffix_mle(0); | |
| } | |
| combined_val[j] = table.combine(prefixes, suffixes); | |
| } | |
| } | |
| // MATERIALIZE COMBINED_RAF_VAL_POLYNOMIAL | |
| // Complexity: O(T) | |
| std::vector<Fr> combined_raf_val(T, 0); | |
| #pragma omp parallel for | |
| for (int j = 0; j < T; j++) { | |
| if (is_interleaved_operands[j]) { | |
| combined_raf_val[j] = gamma * prefix_registry[LeftOperand] | |
| + gamma_sqr * prefix_registry[RightOperand]; | |
| } else { | |
| combined_raf_val[j] = gamma_sqr * prefix_registry[Identity]; | |
| } | |
| } | |
| } | |
| // bind | |
| void ingest_challenge(Fr r_j, int round) { | |
| r.push_back(r_j); | |
| if (round < LOG_K) { | |
| int log_m = LOG_M; | |
| int phase = round / log_m; | |
| // bind read suffix (prefix is computed on the fly) | |
| // Complexity: O(NUM_TABLES * NUM_SUFFIXES * current_len) | |
| #pragma omp parallel for collapse(2) | |
| for (int table_idx = 0; table_idx < NUM_TABLES; table_idx++) { | |
| for (size_t s = 0; s < suffix_polys[table_idx].size(); s++) { | |
| // CHANGE TO: LowToHigh | |
| suffix_polys[table_idx][s].bind_parallel(r_j, HighToLow); | |
| // bind halves the polynomial size | |
| } | |
| } | |
| // bind raf | |
| // Complexity: O(current_Q_len) each | |
| #pragma omp parallel sections | |
| { | |
| #pragma omp section | |
| identity_ps.bind(r_j); | |
| #pragma omp section | |
| right_operand_ps.bind(r_j); | |
| #pragma omp section | |
| left_operand_ps.bind(r_j); | |
| } | |
| // v[phase]=eq(r,x) grows from 1 → 2 → 4 → ... → 256 over 8 rounds | |
| v[phase].update(r_j); | |
| // UPDATE CHECKPOINTS (every 2 rounds) | |
| if ((round + 1) % 2 == 0) { | |
| Prefixes::update_checkpoints(prefix_checkpoints, r[round-1], r[round]); | |
| } | |
| // PHASE TRANSITION (every LOG_M rounds) | |
| if ((round + 1) % log_m == 0) { | |
| prefix_registry.update_checkpoints(); | |
| if (phase != PHASES - 1) { | |
| init_phase(phase + 1); // O(T) | |
| } | |
| } | |
| // TRANSITION TO LOG_T ROUNDS | |
| if (round + 1 == LOG_K) { | |
| init_log_t_rounds(gamma, gamma_sqr); // O(T * PHASES) | |
| } | |
| } else { | |
| // LOG_T ROUNDS: bind ra, val, raf_val polynomials | |
| // Complexity: O(current_T) for each | |
| eq_r_reduction.bind(r_j); | |
| ra.bind_parallel(r_j, LowToHigh); | |
| combined_val_polynomial.bind_parallel(r_j, LowToHigh); | |
| combined_raf_val_polynomial.bind_parallel(r_j, LowToHigh); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment