diff --git a/Cargo.lock b/Cargo.lock index 5ecb4d7d..6bba03f2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -411,7 +411,6 @@ dependencies = [ "sub_protocols", "tracing", "utils", - "witness_generation", "xmss", ] @@ -1349,28 +1348,6 @@ version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" -[[package]] -name = "witness_generation" -version = "0.1.0" -dependencies = [ - "air 0.1.0", - "lean_compiler", - "lean_vm", - "multilinear-toolkit", - "p3-koala-bear", - "p3-monty-31", - "p3-poseidon2", - "p3-symmetric", - "p3-util", - "pest", - "pest_derive", - "rand", - "sub_protocols", - "tracing", - "utils", - "xmss", -] - [[package]] name = "xmss" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 9e8d3fd1..09382385 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,10 +8,7 @@ version = "0.1.0" edition = "2024" [workspace] -members = [ - "crates/*", - "crates/lean_prover/witness_generation", -] +members = ["crates/*"] [workspace.lints] rust.missing_debug_implementations = "warn" @@ -51,7 +48,6 @@ sub_protocols = { path = "crates/sub_protocols" } lean_compiler = { path = "crates/lean_compiler" } lean_prover = { path = "crates/lean_prover" } rec_aggregation = { path = "crates/rec_aggregation" } -witness_generation = { path = "crates/lean_prover/witness_generation" } # External thiserror = "2.0" diff --git a/README.md b/README.md index c89c64b4..6271a1df 100644 --- a/README.md +++ b/README.md @@ -11,25 +11,25 @@ Documentation: [PDF](minimal_zkVM.pdf) ## Proving System -- multilinear with [WHIR](https://eprint.iacr.org/2024/1586.pdf) +- multilinear with [WHIR](https://eprint.iacr.org/2024/1586.pdf), allowing polynomial stacking (reducing proof size) - [SuperSpartan](https://eprint.iacr.org/2023/552.pdf), with [AIR-specific optimizations](https://solvable.group/posts/super-air/#fnref:1) -- [Logup](https://eprint.iacr.org/2023/1284.pdf) / [Logup*](https://eprint.iacr.org/2025/946.pdf) +- [Logup](https://eprint.iacr.org/2023/1284.pdf), with a system of buses similar to [OpenVM](https://openvm.dev/whitepaper.pdf) The VM design is inspired by the famous [Cairo paper](https://eprint.iacr.org/2021/1063.pdf). ## Security -123 bits of security. Johnson bound + degree 5 extension of koala-bear -> **no proximity gaps conjecture**. (TODO 128 bits, which requires hash digests bigger than 8 koala-bears). +123 bits of security. Johnson bound + degree 5 extension of koala-bear -> **no proximity gaps conjecture**. (TODO 128 bits? this would require hash digests bigger than 8 koala-bears). -## Benchmarks +## Benchmarks (Slightly outdated, new benchmarks incoming) Machine: M4 Max 48GB (CPU only) | Benchmark | Current | Target | | -------------------------- | -------------------- | --------------- | | Poseidon2 (16 koala-bears) | `560K Poseidon2 / s` | n/a | -| 2 -> 1 Recursion | `1.35 s` | `0.25 s ` | +| 2 -> 1 Recursion | `1.15 s` | `0.25 s ` | | XMSS aggregation | `554 XMSS / s` | `1000 XMSS / s` | *Expect incoming perf improvements.* @@ -39,11 +39,11 @@ To reproduce: - `cargo run --release -- recursion --n 2` - `cargo run --release -- xmss --n-signatures 1350` -(Small detail remaining in recursion: final (multilinear) evaluation of the guest program bytecode, there are multiple ways of handling it... TBD soon) - ## Proof size -WHIR intial rate = 1/4. Proof size ≈ 325 KiB. TODO: WHIR batch opening + [2024/108](https://eprint.iacr.org/2024/108.pdf) section 3.1 -> close to 256 KiB. (To go below 256 KiB -> rate 1/8 or 1/16 in the final recursion). +WHIR intial rate = 1/4 -> proof size ≈ 225 KiB. (150 KiB with rate 1/16, and < 100 KiB is possible with poximity gaps conjecture + rate 1/16). + +(TODO: remaining optimization = [2024/108](https://eprint.iacr.org/2024/108.pdf) section 3.1) ## Credits diff --git a/TODO.md b/TODO.md index 83747478..b64b7fc7 100644 --- a/TODO.md +++ b/TODO.md @@ -5,7 +5,7 @@ - 128 bits security - Merkle pruning - the interpreter of leanISA (+ witness generation) can be partially parallelized when there are some independent loops -- Make everything "padding aware" (including WHIR, logup*, AIR, etc) +- Make everything "padding aware" (including WHIR, logup, AIR, etc) - Opti WHIR: in sumcheck we know more than f(0) + f(1), we know f(0) and f(1) - Opti WHIR https://github.com/tcoratger/whir-p3/issues/303 and https://github.com/tcoratger/whir-p3/issues/306 ? - Avoid the embedding overhead in logup, when denominators = "c - index" @@ -37,6 +37,7 @@ But we can get the bost of both worlds (suggested by Lev, TODO implement): - Fiat Shamir: add a claim tracing feature, to ensure all the claims are indeed checked (Lev) - Double Check AIR constraints, logup overflows etc - Formal Verification +- Padd with noop cycles to always ensure memory size >= bytecode size (liveness), and ensure this condition is checked by the verifier (soundness) # Ideas diff --git a/crates/air/src/prove.rs b/crates/air/src/prove.rs index 51cf62ca..48fdb96d 100644 --- a/crates/air/src/prove.rs +++ b/crates/air/src/prove.rs @@ -37,15 +37,7 @@ where "TODO handle the case UNIVARIATE_SKIPS >= log_length" ); - // crate::check_air_validity( - // air, - // &extra_data, - // &columns_f, - // &columns_ef, - // last_row_shifted_f, - // last_row_shifted_ef, - // ) - // .unwrap(); + // crate::check_air_validity(air, &extra_data, &columns_f, &columns_ef).unwrap(); assert!(extra_data.alpha_powers().len() >= air.n_constraints() + virtual_column_statement.is_some() as usize); diff --git a/crates/air/src/validity_check.rs b/crates/air/src/validity_check.rs index 1afd6007..b5c407de 100644 --- a/crates/air/src/validity_check.rs +++ b/crates/air/src/validity_check.rs @@ -8,8 +8,6 @@ pub fn check_air_validity>>( extra_data: &A::ExtraData, columns_f: &[&[PF]], columns_ef: &[&[EF]], - last_row_f: &[PF], - last_row_ef: &[EF], ) -> Result<(), String> { let n_rows = columns_f[0].len(); assert!(columns_f.iter().all(|col| col.len() == n_rows)); @@ -67,13 +65,19 @@ pub fn check_air_validity>>( let up_ef = (0..air.n_columns_ef_air()) .map(|j| columns_ef[j][n_rows - 1]) .collect::>(); - assert_eq!(last_row_f.len(), air.n_down_columns_f()); - assert_eq!(last_row_ef.len(), air.n_down_columns_ef()); let mut constraints_checker = ConstraintChecker { up_f, up_ef, - down_f: last_row_f.to_vec(), - down_ef: last_row_ef.to_vec(), + down_f: air + .down_column_indexes_f() + .iter() + .map(|j| columns_f[*j][n_rows - 1]) + .collect::>(), + down_ef: air + .down_column_indexes_ef() + .iter() + .map(|j| columns_ef[*j][n_rows - 1]) + .collect::>(), constraint_index: 0, errors: Vec::new(), }; diff --git a/crates/lean_compiler/snark_lib.py b/crates/lean_compiler/snark_lib.py index d637a043..b09063ac 100644 --- a/crates/lean_compiler/snark_lib.py +++ b/crates/lean_compiler/snark_lib.py @@ -79,6 +79,10 @@ def hint_decompose_bits(value, bits, n_bits, endian): _ = value, bits, n_bits, endian +def hint_less_than(a, b, result_ptr): + _ = a, b, result_ptr + + def log2_ceil(x: int) -> int: assert x > 0 return math.ceil(math.log2(x)) diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index cc8f09b5..58736c24 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -409,7 +409,10 @@ fn compile_time_transform_in_program( return Err("Inlined functions with mutable arguments are not supported yet".to_string()); } if func.has_const_arguments() { - return Err("Inlined functions with constant arguments are not supported yet".to_string()); + return Err(format!( + "Inlined function should not have \"Const\" arguments (function \"{}\")", + func.name + )); } } diff --git a/crates/lean_compiler/src/c_compile_final.rs b/crates/lean_compiler/src/c_compile_final.rs index 3cf7ec29..1583e40e 100644 --- a/crates/lean_compiler/src/c_compile_final.rs +++ b/crates/lean_compiler/src/c_compile_final.rs @@ -1,4 +1,4 @@ -use crate::{F, ir::*, lang::*}; +use crate::{F, instruction_encoder::field_representation, ir::*, lang::*}; use lean_vm::*; use multilinear_toolkit::prelude::*; use std::collections::BTreeMap; @@ -144,6 +144,15 @@ pub fn compile_to_low_level_bytecode( &mut hints, ); } + let instructions_encoded = instructions.par_iter().map(field_representation).collect::>(); + + let mut instructions_multilinear = vec![]; + for instr in &instructions_encoded { + instructions_multilinear.extend_from_slice(instr); + let padding = N_INSTRUCTION_COLUMNS.next_power_of_two() - N_INSTRUCTION_COLUMNS; + instructions_multilinear.extend(vec![F::ZERO; padding]); + } + instructions_multilinear.resize(instructions_multilinear.len().next_power_of_two(), F::ZERO); // Build pc_to_location mapping from LocationReport hints let mut pc_to_location = Vec::with_capacity(instructions.len()); @@ -164,6 +173,7 @@ pub fn compile_to_low_level_bytecode( Ok(Bytecode { instructions, + instructions_multilinear, hints, starting_frame_memory, function_locations, diff --git a/crates/lean_prover/witness_generation/src/instruction_encoder.rs b/crates/lean_compiler/src/instruction_encoder.rs similarity index 84% rename from crates/lean_prover/witness_generation/src/instruction_encoder.rs rename to crates/lean_compiler/src/instruction_encoder.rs index 55a2bacb..3ec29895 100644 --- a/crates/lean_prover/witness_generation/src/instruction_encoder.rs +++ b/crates/lean_compiler/src/instruction_encoder.rs @@ -1,6 +1,5 @@ use lean_vm::*; use multilinear_toolkit::prelude::*; -use utils::padd_with_zero_to_next_power_of_two; pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { let mut fields = [F::ZERO; N_INSTRUCTION_COLUMNS]; @@ -32,17 +31,17 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { fields[instr_idx(COL_OPERAND_C)] = F::from_usize(*shift_1); match res { MemOrFpOrConstant::Constant(cst) => { - fields[instr_idx(COL_AUX_1)] = F::ONE; + fields[instr_idx(COL_AUX)] = F::ONE; fields[instr_idx(COL_FLAG_B)] = F::ONE; fields[instr_idx(COL_OPERAND_B)] = *cst; } MemOrFpOrConstant::MemoryAfterFp { offset } => { - fields[instr_idx(COL_AUX_1)] = F::ONE; + fields[instr_idx(COL_AUX)] = F::ONE; fields[instr_idx(COL_FLAG_B)] = F::ZERO; fields[instr_idx(COL_OPERAND_B)] = F::from_usize(*offset); } MemOrFpOrConstant::Fp => { - fields[instr_idx(COL_AUX_1)] = F::ZERO; + fields[instr_idx(COL_AUX)] = F::ZERO; fields[instr_idx(COL_FLAG_B)] = F::ONE; } } @@ -70,8 +69,8 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { set_nu_a(&mut fields, arg_a); set_nu_b(&mut fields, arg_b); set_nu_c(&mut fields, arg_c); - fields[instr_idx(COL_AUX_1)] = F::from_usize(*aux_1); - fields[instr_idx(COL_AUX_2)] = F::from_usize(*aux_2); + assert!(*aux_2 == 0 || *aux_2 == 1); + fields[instr_idx(COL_AUX)] = F::from_usize(2 * *aux_1 + *aux_2); } } fields @@ -114,11 +113,3 @@ fn set_nu_c(fields: &mut [F; N_INSTRUCTION_COLUMNS], c: &MemOrFp) { } } } - -pub fn bytecode_to_multilinear_polynomial(instructions: &[Instruction]) -> Vec { - let res = instructions - .par_iter() - .flat_map(|instr| padd_with_zero_to_next_power_of_two(&field_representation(instr))) - .collect::>(); - padd_with_zero_to_next_power_of_two(&res) -} diff --git a/crates/lean_compiler/src/lib.rs b/crates/lean_compiler/src/lib.rs index a59e0d7c..6728e6b4 100644 --- a/crates/lean_compiler/src/lib.rs +++ b/crates/lean_compiler/src/lib.rs @@ -11,6 +11,7 @@ use crate::{ mod a_simplify_lang; mod b_compile_intermediate; mod c_compile_final; +mod instruction_encoder; pub mod ir; mod lang; mod parser; diff --git a/crates/lean_compiler/src/parser/parsers/literal.rs b/crates/lean_compiler/src/parser/parsers/literal.rs index d3178b89..49dd8b48 100644 --- a/crates/lean_compiler/src/parser/parsers/literal.rs +++ b/crates/lean_compiler/src/parser/parsers/literal.rs @@ -1,6 +1,4 @@ -use lean_vm::{ - NONRESERVED_PROGRAM_INPUT_START, ONE_VEC_PTR, PRIVATE_INPUT_START_PTR, SAMPLING_DOMAIN_SEPARATOR_PTR, ZERO_VEC_PTR, -}; +use lean_vm::{NONRESERVED_PROGRAM_INPUT_START, ONE_VEC_PTR, SAMPLING_DOMAIN_SEPARATOR_PTR, ZERO_VEC_PTR}; use multilinear_toolkit::prelude::*; use super::expression::ExpressionParser; @@ -133,7 +131,6 @@ impl VarOrConstantParser { "NONRESERVED_PROGRAM_INPUT_START" => Ok(SimpleExpr::Constant(ConstExpression::from( NONRESERVED_PROGRAM_INPUT_START, ))), - "PRIVATE_INPUT_START_PTR" => Ok(SimpleExpr::Constant(ConstExpression::from(PRIVATE_INPUT_START_PTR))), "ZERO_VEC_PTR" => Ok(SimpleExpr::Constant(ConstExpression::from(ZERO_VEC_PTR))), "ONE_VEC_PTR" => Ok(SimpleExpr::Constant(ConstExpression::from(ONE_VEC_PTR))), "SAMPLING_DOMAIN_SEPARATOR_PTR" => Ok(SimpleExpr::Constant(ConstExpression::from( diff --git a/crates/lean_prover/Cargo.toml b/crates/lean_prover/Cargo.toml index caf9d445..48ec0a47 100644 --- a/crates/lean_prover/Cargo.toml +++ b/crates/lean_prover/Cargo.toml @@ -21,7 +21,6 @@ air.workspace = true sub_protocols.workspace = true lean_vm.workspace = true lean_compiler.workspace = true -witness_generation.workspace = true multilinear-toolkit.workspace = true itertools.workspace = true diff --git a/crates/lean_prover/src/common.rs b/crates/lean_prover/src/common.rs deleted file mode 100644 index 411cf433..00000000 --- a/crates/lean_prover/src/common.rs +++ /dev/null @@ -1,20 +0,0 @@ -use crate::*; -use lean_vm::*; - -pub(crate) fn fold_bytecode(bytecode: &Bytecode, folding_challenges: &MultilinearPoint) -> Vec { - let encoded_bytecode = padd_with_zero_to_next_power_of_two( - &bytecode - .instructions - .par_iter() - .flat_map(|i| padd_with_zero_to_next_power_of_two(&field_representation(i))) - .collect::>(), - ); - fold_multilinear_chunks(&encoded_bytecode, folding_challenges) -} - -fn split_at(stmt: &MultiEvaluation, start: usize, end: usize) -> Vec> { - vec![MultiEvaluation::new( - stmt.point.clone(), - stmt.values[start..end].to_vec(), - )] -} diff --git a/crates/lean_prover/src/lib.rs b/crates/lean_prover/src/lib.rs index 92339a1c..b235a92e 100644 --- a/crates/lean_prover/src/lib.rs +++ b/crates/lean_prover/src/lib.rs @@ -4,56 +4,32 @@ use lean_vm::{EF, F}; use multilinear_toolkit::prelude::*; use utils::*; -use lean_vm::execute_bytecode; -use witness_generation::*; +mod trace_gen; -mod common; pub mod prove_execution; +pub mod verify_execution; + #[cfg(test)] mod test_zkvm; -pub mod verify_execution; -pub use witness_generation::bytecode_to_multilinear_polynomial; +use trace_gen::*; // Right now, hash digests = 8 koala-bear (p = 2^31 - 2^24 + 1, i.e. ≈ 30.98 bits per field element) // so ≈ 123.92 bits of security against collisions -pub const SECURITY_BITS: usize = 123; // TODO 128 bits security (with Poseidon over 20 field elements) +pub const SECURITY_BITS: usize = 123; // TODO 128 bits security? (with Poseidon over 20 field elements or with a more subtle soundness analysis (cf. https://eprint.iacr.org/2021/188.pdf)) // Provable security (no proximity gaps conjectures) pub const SECURITY_REGIME: SecurityAssumption = SecurityAssumption::JohnsonBound; -pub const GRINDING_BITS: usize = 16; - -pub const STARTING_LOG_INV_RATE_BASE: usize = 2; - -pub const STARTING_LOG_INV_RATE_EXTENSION: usize = 3; - -#[derive(Debug)] -pub struct SnarkParams { - pub first_whir: WhirConfigBuilder, - pub second_whir: WhirConfigBuilder, -} - -impl Default for SnarkParams { - fn default() -> Self { - Self { - first_whir: whir_config_builder(STARTING_LOG_INV_RATE_BASE, 7, 5), - second_whir: whir_config_builder(STARTING_LOG_INV_RATE_EXTENSION, 4, 1), - } - } -} +pub const GRINDING_BITS: usize = 18; -pub fn whir_config_builder( - starting_log_inv_rate: usize, - first_folding_factor: usize, - rs_domain_initial_reduction_factor: usize, -) -> WhirConfigBuilder { +pub fn default_whir_config(starting_log_inv_rate: usize) -> WhirConfigBuilder { WhirConfigBuilder { - folding_factor: FoldingFactor::new(first_folding_factor, 4), + folding_factor: FoldingFactor::new(7, 5), soundness_type: SECURITY_REGIME, pow_bits: GRINDING_BITS, - max_num_variables_to_send_coeffs: 6, - rs_domain_initial_reduction_factor, + max_num_variables_to_send_coeffs: 9, + rs_domain_initial_reduction_factor: 5, security_level: SECURITY_BITS, starting_log_inv_rate, } diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index fad62da5..b4e0902f 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -1,14 +1,12 @@ use std::collections::BTreeMap; -use crate::common::*; use crate::*; use air::prove_air; use lean_vm::*; -use p3_util::log2_ceil_usize; use sub_protocols::*; use tracing::info_span; -use utils::{build_prover_state, padd_with_zero_to_next_power_of_two}; +use utils::build_prover_state; use xmss::Poseidon16History; #[derive(Debug)] @@ -23,7 +21,7 @@ pub fn prove_execution( bytecode: &Bytecode, (public_input, private_input): (&[F], &[F]), poseidons_16_precomputed: &Poseidon16History, - params: &SnarkParams, + whir_config: &WhirConfigBuilder, vm_profiler: bool, ) -> ExecutionProof { let mut exec_summary = String::new(); @@ -61,30 +59,56 @@ pub fn prove_execution( .collect::>(), ); + let mut table_log = String::new(); + for (table, trace) in &traces { + table_log.push_str(&format!( + "{}: 2^{:.2} rows |", + table.name(), + f64::log2(trace.non_padded_n_rows as f64) + )); + } + table_log.pop(); // remove last '|' + info_span!("Trace tables sizes: {}", table_log).in_scope(|| {}); + // TODO parrallelize - let mut acc = F::zero_vec(memory.len()); + let mut memory_acc = F::zero_vec(memory.len()); info_span!("Building memory access count").in_scope(|| { for (table, trace) in &traces { for lookup in table.lookups_f() { for i in &trace.base[lookup.index] { for j in 0..lookup.values.len() { - acc[i.to_usize() + j] += F::ONE; + memory_acc[i.to_usize() + j] += F::ONE; } } } for lookup in table.lookups_ef() { for i in &trace.base[lookup.index] { for j in 0..DIMENSION { - acc[i.to_usize() + j] += F::ONE; + memory_acc[i.to_usize() + j] += F::ONE; } } } } }); + // // TODO parrallelize + let mut bytecode_acc = F::zero_vec(bytecode.padded_size()); + info_span!("Building bytecode access count").in_scope(|| { + for pc in traces[&Table::execution()].base[COL_PC].iter() { + bytecode_acc[pc.to_usize()] += F::ONE; + } + }); + // 1st Commitment - let packed_pcs_witness_base = packed_pcs_commit(&mut prover_state, ¶ms.first_whir, &memory, &acc, &traces); - let first_whir_n_vars = packed_pcs_witness_base.packed_polynomial.by_ref().n_vars(); + let packed_pcs_witness = packed_pcs_commit( + &mut prover_state, + whir_config, + &memory, + &memory_acc, + &bytecode_acc, + &traces, + ); + let first_whir_n_vars = packed_pcs_witness.packed_polynomial.by_ref().n_vars(); // logup (GKR) let logup_c = prover_state.sample(); @@ -96,7 +120,9 @@ pub fn prove_execution( logup_c, &logup_alphas_eq_poly, &memory, - &acc, + &memory_acc, + &bytecode.instructions_multilinear, + &bytecode_acc, &traces, ); let mut committed_statements: CommittedStatements = Default::default(); @@ -130,97 +156,50 @@ pub fn prove_execution( committed_statements.get_mut(table).unwrap().extend(this_air_claims); } - let bytecode_compression_challenges = - MultilinearPoint(prover_state.sample_vec(log2_ceil_usize(N_INSTRUCTION_COLUMNS))); - - let folded_bytecode = fold_bytecode(bytecode, &bytecode_compression_challenges); - - let bytecode_air_entry = &mut committed_statements.get_mut(&Table::execution()).unwrap()[2]; - let bytecode_air_point = bytecode_air_entry.0.clone(); - let mut bytecode_air_values = vec![]; - for bytecode_col_index in N_COMMITTED_EXEC_COLUMNS..N_COMMITTED_EXEC_COLUMNS + N_INSTRUCTION_COLUMNS { - bytecode_air_values.push(bytecode_air_entry.1.remove(&bytecode_col_index).unwrap()); - } - - let bytecode_lookup_claim = Evaluation::new( - bytecode_air_point.clone(), - padd_with_zero_to_next_power_of_two(&bytecode_air_values).evaluate(&bytecode_compression_challenges), - ); - let bytecode_poly_eq_point = eval_eq(&bytecode_lookup_claim.point); - let bytecode_pushforward = MleOwned::Extension(compute_pushforward( - &traces[&Table::execution()].base[COL_PC], - folded_bytecode.len(), - &bytecode_poly_eq_point, - )); - - let bytecode_pushforward_commitment = - WhirConfig::new(¶ms.second_whir, log2_ceil_usize(bytecode.instructions.len())) - .commit(&mut prover_state, &bytecode_pushforward); - - let bytecode_logup_star_statements = prove_logup_star( - &mut prover_state, - &MleRef::Extension(&folded_bytecode), - &traces[&Table::execution()].base[COL_PC], - bytecode_lookup_claim.value, - &bytecode_poly_eq_point, - &bytecode_pushforward.by_ref(), - Some(bytecode.instructions.len()), - ); - - committed_statements.get_mut(&Table::execution()).unwrap().push(( - bytecode_logup_star_statements.on_indexes.point.clone(), - BTreeMap::from_iter([(COL_PC, bytecode_logup_star_statements.on_indexes.value)]), - )); - let public_memory_random_point = MultilinearPoint(prover_state.sample_vec(log2_strict_usize(public_memory_size))); let public_memory_eval = (&memory[..public_memory_size]).evaluate(&public_memory_random_point); - let memory_acc_statements = vec![ + let previous_statements = vec![ SparseStatement::new( - packed_pcs_witness_base.packed_n_vars, - logup_statements.memory_acc_point, + packed_pcs_witness.packed_n_vars, + logup_statements.memory_and_acc_point, vec![ SparseValue::new(0, logup_statements.value_memory), - SparseValue::new(1, logup_statements.value_acc), + SparseValue::new(1, logup_statements.value_memory_acc), ], ), SparseStatement::new( - packed_pcs_witness_base.packed_n_vars, + packed_pcs_witness.packed_n_vars, public_memory_random_point, vec![SparseValue::new(0, public_memory_eval)], ), + SparseStatement::new( + packed_pcs_witness.packed_n_vars, + logup_statements.bytecode_and_acc_point, + vec![SparseValue::new( + (2 * memory.len()) >> bytecode.log_size(), + logup_statements.value_bytecode_acc, + )], + ), ]; let table_heights = traces.iter().map(|(table, trace)| (*table, trace.log_n_rows)).collect(); let global_statements_base = packed_pcs_global_statements( - packed_pcs_witness_base.packed_n_vars, + packed_pcs_witness.packed_n_vars, log2_strict_usize(memory.len()), - memory_acc_statements, + bytecode.log_size(), + previous_statements, &table_heights, &committed_statements, ); - WhirConfig::new( - ¶ms.first_whir, - packed_pcs_witness_base.packed_polynomial.by_ref().n_vars(), - ) - .prove( + WhirConfig::new(whir_config, packed_pcs_witness.packed_polynomial.by_ref().n_vars()).prove( &mut prover_state, global_statements_base, - packed_pcs_witness_base.inner_witness, - &packed_pcs_witness_base.packed_polynomial.by_ref(), + packed_pcs_witness.inner_witness, + &packed_pcs_witness.packed_polynomial.by_ref(), ); - WhirConfig::new(¶ms.second_whir, log2_ceil_usize(bytecode.instructions.len())).prove( - &mut prover_state, - bytecode_logup_star_statements - .on_pushforward - .into_iter() - .map(|smt| SparseStatement::dense(smt.point, smt.value)) - .collect::>(), - bytecode_pushforward_commitment, - &bytecode_pushforward.by_ref(), - ); let proof_size_fe = prover_state.pruned_proof().proof_size_fe(); ExecutionProof { proof: prover_state.raw_proof(), diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs index 6ec98735..4f0dedea 100644 --- a/crates/lean_prover/src/test_zkvm.rs +++ b/crates/lean_prover/src/test_zkvm.rs @@ -1,9 +1,9 @@ -use crate::{prove_execution::prove_execution, verify_execution::verify_execution}; +use crate::{default_whir_config, prove_execution::prove_execution, verify_execution::verify_execution}; use lean_compiler::*; use lean_vm::*; use multilinear_toolkit::prelude::*; use rand::{Rng, SeedableRng, rngs::StdRng}; -use utils::poseidon16_permute; +use utils::{init_tracing, poseidon16_permute}; #[test] fn test_zk_vm_all_precompiles() { @@ -37,7 +37,6 @@ def main(): assert c == 100 return - "#; const N: usize = 11; @@ -107,6 +106,9 @@ def main(): #[test] fn test_prove_fibonacci() { + if std::env::var("FIB_TRACING") == Ok("true".to_string()) { + init_tracing(); + } let n = std::env::var("FIB_N") .unwrap_or("10000".to_string()) .parse::() @@ -147,15 +149,22 @@ fn test_zk_vm_helper(program_str: &str, (public_input, private_input): (&[F], &[ } let bytecode = compile_program(&ProgramSource::Raw(program_str.to_string())); let time = std::time::Instant::now(); + let starting_log_inv_rate = 1; let proof = prove_execution( &bytecode, (public_input, private_input), &vec![], - &Default::default(), + &default_whir_config(starting_log_inv_rate), false, ); let proof_time = time.elapsed(); - verify_execution(&bytecode, public_input, proof.proof.clone(), &Default::default()).unwrap(); + verify_execution( + &bytecode, + public_input, + proof.proof.clone(), + &default_whir_config(starting_log_inv_rate), + ) + .unwrap(); println!("{}", proof.exec_summary); println!("Proof time: {:.3} s", proof_time.as_secs_f32()); @@ -170,7 +179,12 @@ fn test_zk_vm_helper(program_str: &str, (public_input, private_input): (&[F], &[ } let mut fuzzed_proof = proof.proof.clone(); fuzzed_proof[i] += F::ONE; - let verify_result = verify_execution(&bytecode, public_input, fuzzed_proof, &Default::default()); + let verify_result = verify_execution( + &bytecode, + public_input, + fuzzed_proof, + &default_whir_config(starting_log_inv_rate), + ); assert!(verify_result.is_err(), "Fuzzing failed at index {}", i); } } diff --git a/crates/lean_prover/witness_generation/src/execution_trace.rs b/crates/lean_prover/src/trace_gen.rs similarity index 89% rename from crates/lean_prover/witness_generation/src/execution_trace.rs rename to crates/lean_prover/src/trace_gen.rs index d945bd32..dcf6a1b2 100644 --- a/crates/lean_prover/witness_generation/src/execution_trace.rs +++ b/crates/lean_prover/src/trace_gen.rs @@ -1,4 +1,3 @@ -use crate::instruction_encoder::field_representation; use lean_vm::*; use multilinear_toolkit::prelude::*; use std::{array, collections::BTreeMap, iter::repeat_n}; @@ -17,7 +16,7 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul let n_cycles = execution_result.pcs.len(); let memory = &execution_result.memory; - let mut main_trace: [Vec; N_EXEC_AIR_COLUMNS + N_TEMPORARY_EXEC_COLUMNS] = + let mut main_trace: [Vec; N_TOTAL_EXECUTION_COLUMNS + N_TEMPORARY_EXEC_COLUMNS] = array::from_fn(|_| F::zero_vec(n_cycles.next_power_of_two())); for col in &mut main_trace { unsafe { @@ -30,20 +29,21 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul .zip(execution_result.fps.par_iter()) .for_each(|((trace_row, &pc), &fp)| { let instruction = &bytecode.instructions[pc]; - let field_repr = field_representation(instruction); + let field_repr = &bytecode.instructions_multilinear[pc * N_INSTRUCTION_COLUMNS.next_power_of_two()..] + [..N_INSTRUCTION_COLUMNS]; let mut addr_a = F::ZERO; if field_repr[instr_idx(COL_FLAG_A)].is_zero() { // flag_a == 0 addr_a = F::from_usize(fp) + field_repr[instr_idx(COL_OPERAND_A)]; // fp + operand_a } - let value_a = memory.0[addr_a.to_usize()].unwrap(); + let value_a = memory.0[addr_a.to_usize()].unwrap_or_default(); let mut addr_b = F::ZERO; if field_repr[instr_idx(COL_FLAG_B)].is_zero() { // flag_b == 0 addr_b = F::from_usize(fp) + field_repr[instr_idx(COL_OPERAND_B)]; // fp + operand_b } - let value_b = memory.0[addr_b.to_usize()].unwrap(); + let value_b = memory.0[addr_b.to_usize()].unwrap_or_default(); let mut addr_c = F::ZERO; if field_repr[instr_idx(COL_FLAG_C)].is_zero() { @@ -54,10 +54,10 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul assert_eq!(field_repr[instr_idx(COL_OPERAND_C)], operand_c); // debug purpose addr_c = value_a + operand_c; } - let value_c = memory.0[addr_c.to_usize()].unwrap(); + let value_c = memory.0[addr_c.to_usize()].unwrap_or_default(); for (j, field) in field_repr.iter().enumerate() { - *trace_row[j + N_COMMITTED_EXEC_COLUMNS] = *field; + *trace_row[j + N_RUNTIME_COLUMNS] = *field; } let nu_a = field_repr[instr_idx(COL_FLAG_A)] * field_repr[instr_idx(COL_OPERAND_A)] @@ -101,6 +101,7 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul TableTrace { base: Vec::from(main_trace), ext: vec![], + non_padded_n_rows: n_cycles, log_n_rows: log2_ceil_usize(n_cycles), }, ); @@ -125,6 +126,7 @@ fn padd_table(table: &Table, traces: &mut BTreeMap) { .enumerate() .for_each(|(i, col)| assert_eq!(col.len(), h, "column {}, table {}", i, table.name())); + trace.non_padded_n_rows = h; trace.log_n_rows = log2_ceil_usize(h + 1).max(MIN_LOG_N_ROWS_PER_TABLE); let padding_len = (1 << trace.log_n_rows) - h; let padding_row_f = table.padding_row_f(); diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index 419e851e..afcf9cf9 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -1,11 +1,9 @@ use std::collections::BTreeMap; use crate::*; -use crate::{SnarkParams, common::*}; use air::verify_air; use lean_vm::*; -use p3_util::{log2_ceil_usize, log2_strict_usize}; -use sub_protocols::verify_logup_star; +use p3_util::log2_strict_usize; use sub_protocols::*; use utils::ToUsize; @@ -14,14 +12,15 @@ pub struct ProofVerificationDetails { pub log_memory: usize, pub table_n_vars: BTreeMap, pub first_quotient_gkr_n_vars: usize, - pub total_whir_statements_base: usize, + pub total_whir_statements: usize, + pub bytecode_evaluation: Evaluation, } pub fn verify_execution( bytecode: &Bytecode, public_input: &[F], proof: Vec, - params: &SnarkParams, + whir_config: &WhirConfigBuilder, ) -> Result { let mut verifier_state = VerifierState::::new(proof, get_poseidon16().clone()); @@ -47,7 +46,7 @@ pub fn verify_execution( } } // check memory is bigger than any other table - if log_memory < *table_n_vars.values().max().unwrap() { + if log_memory < (*table_n_vars.values().max().unwrap()).max(bytecode.log_size()) { return Err(ProofError::InvalidProof); } @@ -57,8 +56,13 @@ pub fn verify_execution( return Err(ProofError::InvalidProof); } - let parsed_commitment_base = - packed_pcs_parse_commitment(¶ms.first_whir, &mut verifier_state, log_memory, &table_n_vars)?; + let parsed_commitment = packed_pcs_parse_commitment( + whir_config, + &mut verifier_state, + log_memory, + bytecode.log_size(), + &table_n_vars, + )?; let logup_c = verifier_state.sample(); let logup_alphas = verifier_state.sample_vec(log2_ceil_usize(max_bus_width())); @@ -67,8 +71,10 @@ pub fn verify_execution( let logup_statements = verify_generic_logup( &mut verifier_state, logup_c, + &logup_alphas, &logup_alphas_eq_poly, log_memory, + &bytecode.instructions_multilinear, &table_n_vars, )?; let mut committed_statements: CommittedStatements = Default::default(); @@ -102,92 +108,55 @@ pub fn verify_execution( committed_statements.get_mut(table).unwrap().extend(this_air_claims); } - let bytecode_compression_challenges = - MultilinearPoint(verifier_state.sample_vec(log2_ceil_usize(N_INSTRUCTION_COLUMNS))); - - let bytecode_air_entry = &mut committed_statements.get_mut(&Table::execution()).unwrap()[2]; - let bytecode_air_point = bytecode_air_entry.0.clone(); - let mut bytecode_air_values = vec![]; - for bytecode_col_index in N_COMMITTED_EXEC_COLUMNS..N_COMMITTED_EXEC_COLUMNS + N_INSTRUCTION_COLUMNS { - bytecode_air_values.push(bytecode_air_entry.1.remove(&bytecode_col_index).unwrap()); - } - - let bytecode_lookup_claim = Evaluation::new( - bytecode_air_point.clone(), - padd_with_zero_to_next_power_of_two(&bytecode_air_values).evaluate(&bytecode_compression_challenges), - ); - - let bytecode_pushforward_parsed_commitment = - WhirConfig::new(¶ms.second_whir, log2_ceil_usize(bytecode.instructions.len())) - .parse_commitment::(&mut verifier_state)?; - - let bytecode_logup_star_statements = verify_logup_star( - &mut verifier_state, - log2_ceil_usize(bytecode.instructions.len()), - table_n_vars[&Table::execution()], - bytecode_lookup_claim, - )?; - let folded_bytecode = fold_bytecode(bytecode, &bytecode_compression_challenges); - if folded_bytecode.evaluate(&bytecode_logup_star_statements.on_table.point) - != bytecode_logup_star_statements.on_table.value - { - return Err(ProofError::InvalidProof); - } - - committed_statements.get_mut(&Table::execution()).unwrap().push(( - bytecode_logup_star_statements.on_indexes.point.clone(), - BTreeMap::from_iter([(COL_PC, bytecode_logup_star_statements.on_indexes.value)]), - )); - let public_memory_random_point = MultilinearPoint(verifier_state.sample_vec(log2_strict_usize(public_memory.len()))); let public_memory_eval = public_memory.evaluate(&public_memory_random_point); - let memory_acc_statements = vec![ + let previous_statements = vec![ SparseStatement::new( - parsed_commitment_base.num_variables, - logup_statements.memory_acc_point, + parsed_commitment.num_variables, + logup_statements.memory_and_acc_point, vec![ SparseValue::new(0, logup_statements.value_memory), - SparseValue::new(1, logup_statements.value_acc), + SparseValue::new(1, logup_statements.value_memory_acc), ], ), SparseStatement::new( - parsed_commitment_base.num_variables, + parsed_commitment.num_variables, public_memory_random_point, vec![SparseValue::new(0, public_memory_eval)], ), + SparseStatement::new( + parsed_commitment.num_variables, + logup_statements.bytecode_and_acc_point, + vec![SparseValue::new( + (2 << log_memory) >> bytecode.log_size(), + logup_statements.value_bytecode_acc, + )], + ), ]; let global_statements_base = packed_pcs_global_statements( - parsed_commitment_base.num_variables, + parsed_commitment.num_variables, log_memory, - memory_acc_statements, + bytecode.log_size(), + previous_statements, &table_n_vars, &committed_statements, ); - let total_whir_statements_base = global_statements_base.iter().map(|s| s.values.len()).sum(); - WhirConfig::new(¶ms.first_whir, parsed_commitment_base.num_variables).verify( + let total_whir_statements = global_statements_base.iter().map(|s| s.values.len()).sum(); + WhirConfig::new(whir_config, parsed_commitment.num_variables).verify( &mut verifier_state, - &parsed_commitment_base, + &parsed_commitment, global_statements_base, )?; - WhirConfig::new(¶ms.second_whir, log2_ceil_usize(bytecode.instructions.len())).verify( - &mut verifier_state, - &bytecode_pushforward_parsed_commitment, - bytecode_logup_star_statements - .on_pushforward - .into_iter() - .map(|smt| SparseStatement::dense(smt.point, smt.value)) - .collect::>(), - )?; - Ok(ProofVerificationDetails { log_memory, table_n_vars, - first_quotient_gkr_n_vars: logup_statements.total_n_vars, - total_whir_statements_base, + first_quotient_gkr_n_vars: logup_statements.total_gkr_n_vars, + total_whir_statements, + bytecode_evaluation: logup_statements.bytecode_evaluation.unwrap(), }) } diff --git a/crates/lean_prover/witness_generation/Cargo.toml b/crates/lean_prover/witness_generation/Cargo.toml deleted file mode 100644 index 714476e9..00000000 --- a/crates/lean_prover/witness_generation/Cargo.toml +++ /dev/null @@ -1,25 +0,0 @@ -[package] -name = "witness_generation" -version.workspace = true -edition.workspace = true - -[lints] -workspace = true - -[dependencies] -pest.workspace = true -pest_derive.workspace = true -utils.workspace = true -xmss.workspace = true -rand.workspace = true -p3-poseidon2.workspace = true -p3-koala-bear.workspace = true -p3-symmetric.workspace = true -p3-util.workspace = true -tracing.workspace = true -air.workspace = true -sub_protocols.workspace = true -lean_vm.workspace = true -lean_compiler.workspace = true -multilinear-toolkit.workspace = true -p3-monty-31.workspace = true \ No newline at end of file diff --git a/crates/lean_prover/witness_generation/src/lib.rs b/crates/lean_prover/witness_generation/src/lib.rs deleted file mode 100644 index e665ab96..00000000 --- a/crates/lean_prover/witness_generation/src/lib.rs +++ /dev/null @@ -1,7 +0,0 @@ -#![cfg_attr(not(test), allow(unused_crate_dependencies))] - -mod execution_trace; -mod instruction_encoder; - -pub use execution_trace::*; -pub use instruction_encoder::*; diff --git a/crates/lean_vm/src/core/constants.rs b/crates/lean_vm/src/core/constants.rs index 93a0b7f4..21dd3836 100644 --- a/crates/lean_vm/src/core/constants.rs +++ b/crates/lean_vm/src/core/constants.rs @@ -15,10 +15,10 @@ pub const MAX_RUNNER_MEMORY_SIZE: usize = 1 << 24; /// Minimum and maximum number of rows per table (as powers of two), both inclusive pub const MIN_LOG_N_ROWS_PER_TABLE: usize = 8; // Zero padding will be added to each at least, if this minimum is not reached, (ensuring AIR / GKR work fine, with SIMD, without too much edge cases). Long term, we should find a more elegant solution. pub const MAX_LOG_N_ROWS_PER_TABLE: [(Table, usize); 3] = [ - (Table::execution(), 29), // 3 lookups - (Table::dot_product(), 25), // 4 lookups - (Table::poseidon16(), 25), // 4 lookups -]; // No overflow in logup: (TODO triple check) 3.2^29 + 4.2^25 + 4.2^25 < p = 2^31 - 2^24 + 1 + (Table::execution(), 29), + (Table::dot_product(), 24), + (Table::poseidon16(), 23), +]; /// Starting program counter pub const STARTING_PC: usize = 1; @@ -52,11 +52,43 @@ pub const EXTENSION_BASIS_PTR: usize = SAMPLING_DOMAIN_SEPARATOR_PTR + DIGEST_LE /// Convention: pointing to the 8 elements of poseidon_16(0) pub const POSEIDON_16_NULL_HASH_PTR: usize = EXTENSION_BASIS_PTR + DIMENSION.pow(2); -/// Pointer to start of private input -pub const PRIVATE_INPUT_START_PTR: usize = POSEIDON_16_NULL_HASH_PTR + DIGEST_LEN; - /// Normal pointer to start of program input -pub const NONRESERVED_PROGRAM_INPUT_START: usize = PRIVATE_INPUT_START_PTR + 1; +pub const NONRESERVED_PROGRAM_INPUT_START: usize = (POSEIDON_16_NULL_HASH_PTR + DIGEST_LEN).next_multiple_of(DIMENSION); /// The first element of basis corresponds to one pub const ONE_VEC_PTR: usize = EXTENSION_BASIS_PTR; + +#[cfg(test)] +mod tests { + use multilinear_toolkit::prelude::PrimeField64; + use p3_util::log2_ceil_usize; + + use crate::{DIMENSION, F, MAX_LOG_N_ROWS_PER_TABLE, Table, TableT}; + + /// CRITICAL FOUR SOUNDNESS: TODO tripple check + #[test] + fn ensure_no_overflow_in_logup() { + fn memory_lookups_count(t: &T) -> usize { + t.lookups_f().iter().map(|l| l.values.len()).sum::() + t.lookups_ef().len() * DIMENSION + } + // memory lookup + let mut max_memory_logup_sum: u64 = 0; + for (table, max_log_n_rows) in MAX_LOG_N_ROWS_PER_TABLE { + let n_rows = 1 << max_log_n_rows; + let num_lookups = memory_lookups_count(&table); + max_memory_logup_sum += (num_lookups * n_rows) as u64; + println!("Table {} has {} memory lookups", table.name(), num_lookups * n_rows); + } + assert!(max_memory_logup_sum < F::ORDER_U64); + + // bytecode lookup + assert!( + MAX_LOG_N_ROWS_PER_TABLE + .iter() + .find(|(table, _)| *table == Table::execution()) + .unwrap() + .1 + < log2_ceil_usize(F::ORDER_U64 as usize) + ); + } +} diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index ff6bc8e2..85c01a0d 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -8,7 +8,7 @@ use crate::execution::{ExecutionHistory, Memory}; use crate::isa::Bytecode; use crate::isa::instruction::InstructionContext; use crate::{ - ALL_TABLES, CodeAddress, ENDING_PC, EXTENSION_BASIS_PTR, HintExecutionContext, N_TABLES, PRIVATE_INPUT_START_PTR, + ALL_TABLES, CodeAddress, ENDING_PC, EXTENSION_BASIS_PTR, HintExecutionContext, N_TABLES, SAMPLING_DOMAIN_SEPARATOR_PTR, STARTING_PC, SourceLocation, Table, TableTrace, }; use multilinear_toolkit::prelude::*; @@ -41,7 +41,6 @@ pub fn build_public_memory(public_input: &[F]) -> Vec { public_memory[POSEIDON_16_NULL_HASH_PTR..][..DIGEST_LEN] .copy_from_slice(&poseidon16_permute([F::ZERO; DIGEST_LEN * 2])[..DIGEST_LEN]); - public_memory[PRIVATE_INPUT_START_PTR] = F::from_usize(public_memory_len); public_memory } diff --git a/crates/lean_vm/src/isa/bytecode.rs b/crates/lean_vm/src/isa/bytecode.rs index b332288e..38994cf7 100644 --- a/crates/lean_vm/src/isa/bytecode.rs +++ b/crates/lean_vm/src/isa/bytecode.rs @@ -1,6 +1,8 @@ //! Bytecode representation and management -use crate::{CodeAddress, FileId, FunctionName, Hint, SourceLocation}; +use p3_util::log2_ceil_usize; + +use crate::{CodeAddress, F, FileId, FunctionName, Hint, SourceLocation}; use super::Instruction; use std::collections::BTreeMap; @@ -10,6 +12,7 @@ use std::fmt::{Display, Formatter}; #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Bytecode { pub instructions: Vec, + pub instructions_multilinear: Vec, pub hints: BTreeMap>, // pc -> hints pub starting_frame_memory: usize, // debug @@ -20,6 +23,20 @@ pub struct Bytecode { pub pc_to_location: Vec, } +impl Bytecode { + pub fn size(&self) -> usize { + self.instructions.len() + } + + pub fn padded_size(&self) -> usize { + self.size().next_power_of_two() + } + + pub fn log_size(&self) -> usize { + log2_ceil_usize(self.size()) + } +} + impl Display for Bytecode { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { for (pc, instruction) in self.instructions.iter().enumerate() { diff --git a/crates/lean_vm/src/isa/hint.rs b/crates/lean_vm/src/isa/hint.rs index 7625aa11..603445ef 100644 --- a/crates/lean_vm/src/isa/hint.rs +++ b/crates/lean_vm/src/isa/hint.rs @@ -76,6 +76,7 @@ pub enum CustomHint { /// The decomposition is unique, and always exists (except for x = -1) DecomposeBitsXMSS, DecomposeBits, + LessThan, } impl CustomHint { @@ -83,6 +84,7 @@ impl CustomHint { match self { Self::DecomposeBitsXMSS => "hint_decompose_bits_xmss", Self::DecomposeBits => "hint_decompose_bits", + Self::LessThan => "hint_less_than", } } @@ -90,6 +92,7 @@ impl CustomHint { match self { Self::DecomposeBitsXMSS => 3..usize::MAX, Self::DecomposeBits => 4..5, + Self::LessThan => 3..4, } } @@ -132,6 +135,13 @@ impl CustomHint { .set_slice(memory_index, &to_little_endian_in_field::(to_decompose, num_bits))? } } + Self::LessThan => { + let a = args[0].read_value(ctx.memory, ctx.fp)?; + let b = args[1].read_value(ctx.memory, ctx.fp)?; + let res_ptr = args[2].memory_address(ctx.fp)?; + let result = if a.to_usize() < b.to_usize() { F::ONE } else { F::ZERO }; + ctx.memory.set(res_ptr, result)?; + } } Ok(()) } diff --git a/crates/lean_vm/src/tables/dot_product/air.rs b/crates/lean_vm/src/tables/dot_product/air.rs index 3d9e9157..1b0c0899 100644 --- a/crates/lean_vm/src/tables/dot_product/air.rs +++ b/crates/lean_vm/src/tables/dot_product/air.rs @@ -24,12 +24,13 @@ pub(super) const DOT_COL_IS_BE: usize = 0; pub(super) const DOT_COL_FLAG: usize = 1; pub(super) const DOT_COL_START: usize = 2; pub(super) const DOT_COL_LEN: usize = 3; -pub const DOT_COL_A: usize = 4; -pub(super) const DOT_COL_B: usize = 5; -pub(super) const DOT_COL_RES: usize = 6; -pub const DOT_COL_VALUE_A_F: usize = 7; +pub(super) const DOT_COL_AUX: usize = 4; // = DOT_COL_IS_BE + DOT_COL_LEN * 2 (used for data flow with execution table. Soundness: if an adversary tries to cheat and switch the value of DOT_COL_IS_BE, then DOT_COL_LEN would be > (p-1)/2 (otherwise no overflow, and thus no cheating), which force a an AIR table of height > 2^29, which would contradict MAX_LOG_N_ROWS_PER_TABLE (in constants.rs)) +pub(super) const DOT_COL_A: usize = 5; +pub(super) const DOT_COL_B: usize = 6; +pub(super) const DOT_COL_RES: usize = 7; +pub(super) const DOT_COL_VALUE_A_F: usize = 8; // EF columns -pub const DOT_COL_VALUE_A_EF: usize = 0; +pub(super) const DOT_COL_VALUE_A_EF: usize = 0; pub(super) const DOT_COL_VALUE_B: usize = 1; pub(super) const DOT_COL_VALUE_RES: usize = 2; pub(super) const DOT_COL_COMPUTATION: usize = 3; @@ -38,7 +39,7 @@ impl Air for DotProductPrecompile { type ExtraData = ExtraDataForBuses; fn n_columns_f_air(&self) -> usize { - 8 + 9 } fn n_columns_ef_air(&self) -> usize { 4 @@ -47,7 +48,7 @@ impl Air for DotProductPrecompile { 3 } fn n_constraints(&self) -> usize { - 15 // TODO: update + 16 // TODO: update } fn down_column_indexes_f(&self) -> Vec { vec![DOT_COL_START, DOT_COL_IS_BE, DOT_COL_LEN, DOT_COL_A, DOT_COL_B] @@ -67,6 +68,7 @@ impl Air for DotProductPrecompile { let flag = up_f[DOT_COL_FLAG].clone(); let start = up_f[DOT_COL_START].clone(); let len = up_f[DOT_COL_LEN].clone(); + let aux = up_f[DOT_COL_AUX].clone(); let index_a = up_f[DOT_COL_A].clone(); let index_b = up_f[DOT_COL_B].clone(); let index_res = up_f[DOT_COL_RES].clone(); @@ -90,23 +92,11 @@ impl Air for DotProductPrecompile { extra_data, AB::F::from_usize(self.table().index()), flag.clone(), - &[ - index_a.clone(), - index_b.clone(), - index_res.clone(), - len.clone(), - is_be.clone(), - ], + &[index_a.clone(), index_b.clone(), index_res.clone(), aux.clone()], )); } else { builder.declare_values(std::slice::from_ref(&flag)); - builder.declare_values(&[ - index_a.clone(), - index_b.clone(), - index_res.clone(), - len.clone(), - is_be.clone(), - ]); + builder.declare_values(&[index_a.clone(), index_b.clone(), index_res.clone(), aux.clone()]); } let is_ee = AB::F::ONE - is_be.clone(); @@ -116,6 +106,8 @@ impl Air for DotProductPrecompile { builder.assert_zero(flag.clone() * (AB::F::ONE - start.clone())); builder.assert_bool(is_be.clone()); + builder.assert_eq(aux, is_be.clone() + len.double()); + let mode_switch = (is_be_down.clone() - is_be.clone()).square(); builder.assert_zero(mode_switch.clone() * (AB::F::ONE - start_down.clone())); diff --git a/crates/lean_vm/src/tables/dot_product/exec.rs b/crates/lean_vm/src/tables/dot_product/exec.rs index 2333c72b..f1d1ad59 100644 --- a/crates/lean_vm/src/tables/dot_product/exec.rs +++ b/crates/lean_vm/src/tables/dot_product/exec.rs @@ -40,6 +40,7 @@ pub(super) fn exec_dot_product_be( trace.base[DOT_COL_START].push(F::ONE); trace.base[DOT_COL_START].extend(F::zero_vec(size - 1)); trace.base[DOT_COL_LEN].extend(((1..=size).rev()).map(F::from_usize)); + trace.base[DOT_COL_AUX].extend(((1..=size).rev()).map(|x| F::from_bool(true) + F::from_usize(2 * x))); trace.base[DOT_COL_A].extend((0..size).map(|i| F::from_usize(ptr_arg_0.to_usize() + i))); trace.base[DOT_COL_B].extend((0..size).map(|i| F::from_usize(ptr_arg_1.to_usize() + i * DIMENSION))); trace.base[DOT_COL_RES].extend(vec![F::from_usize(ptr_res.to_usize()); size]); @@ -126,6 +127,7 @@ pub(super) fn exec_dot_product_ee( trace.base[DOT_COL_START].push(F::ONE); trace.base[DOT_COL_START].extend(F::zero_vec(size - 1)); trace.base[DOT_COL_LEN].extend(((1..=size).rev()).map(F::from_usize)); + trace.base[DOT_COL_AUX].extend(((1..=size).rev()).map(|x| F::from_bool(false) + F::from_usize(2 * x))); trace.base[DOT_COL_A].extend((0..size).map(|i| F::from_usize(ptr_arg_0.to_usize() + i * DIMENSION))); trace.base[DOT_COL_B].extend((0..size).map(|i| F::from_usize(ptr_arg_1.to_usize() + i * DIMENSION))); trace.base[DOT_COL_RES].extend(vec![F::from_usize(ptr_res.to_usize()); size]); diff --git a/crates/lean_vm/src/tables/dot_product/mod.rs b/crates/lean_vm/src/tables/dot_product/mod.rs index 0c3a4d8e..9b6752f6 100644 --- a/crates/lean_vm/src/tables/dot_product/mod.rs +++ b/crates/lean_vm/src/tables/dot_product/mod.rs @@ -51,7 +51,7 @@ impl TableT for DotProductPrecompile { table: BusTable::Constant(self.table()), direction: BusDirection::Pull, selector: DOT_COL_FLAG, - data: vec![DOT_COL_A, DOT_COL_B, DOT_COL_RES, DOT_COL_LEN, DOT_COL_IS_BE], + data: vec![DOT_COL_A, DOT_COL_B, DOT_COL_RES, DOT_COL_AUX], } } @@ -62,8 +62,9 @@ impl TableT for DotProductPrecompile { F::ZERO, // Flag F::ONE, // Start F::ONE, // Len + F::TWO, // Aux (0 + 2*1) ], - vec![F::ZERO; self.n_columns_f_air() - 4], + vec![F::ZERO; self.n_columns_f_air() - 5], ] .concat() } diff --git a/crates/lean_vm/src/tables/execution/air.rs b/crates/lean_vm/src/tables/execution/air.rs index ede04616..4964207e 100644 --- a/crates/lean_vm/src/tables/execution/air.rs +++ b/crates/lean_vm/src/tables/execution/air.rs @@ -1,9 +1,9 @@ use crate::{ALL_TABLES, EF, ExecutionTable, ExtraDataForBuses, F, eval_virtual_bus_column}; use multilinear_toolkit::prelude::*; -pub const N_COMMITTED_EXEC_COLUMNS: usize = 8; -pub const N_INSTRUCTION_COLUMNS: usize = 13; -pub const N_EXEC_AIR_COLUMNS: usize = N_INSTRUCTION_COLUMNS + N_COMMITTED_EXEC_COLUMNS; +pub const N_RUNTIME_COLUMNS: usize = 8; +pub const N_INSTRUCTION_COLUMNS: usize = 12; +pub const N_TOTAL_EXECUTION_COLUMNS: usize = N_INSTRUCTION_COLUMNS + N_RUNTIME_COLUMNS; // Committed columns (IMPORTANT: they must be the first columns) pub const COL_PC: usize = 0; @@ -15,7 +15,7 @@ pub const COL_MEM_VALUE_A: usize = 5; pub const COL_MEM_VALUE_B: usize = 6; pub const COL_MEM_VALUE_C: usize = 7; -// Decoded instruction columns (lookup into bytecode with logup*) +// Decoded instruction columns pub const COL_OPERAND_A: usize = 8; pub const COL_OPERAND_B: usize = 9; pub const COL_OPERAND_C: usize = 10; @@ -26,21 +26,20 @@ pub const COL_ADD: usize = 14; pub const COL_MUL: usize = 15; pub const COL_DEREF: usize = 16; pub const COL_JUMP: usize = 17; -pub const COL_AUX_1: usize = 18; -pub const COL_AUX_2: usize = 19; -pub const COL_PRECOMPILE_INDEX: usize = 20; +pub const COL_AUX: usize = 18; +pub const COL_PRECOMPILE_INDEX: usize = 19; // Temporary columns (stored to avoid duplicate computations) pub const N_TEMPORARY_EXEC_COLUMNS: usize = 4; -pub const COL_IS_PRECOMPILE: usize = 21; -pub const COL_EXEC_NU_A: usize = 22; -pub const COL_EXEC_NU_B: usize = 23; -pub const COL_EXEC_NU_C: usize = 24; +pub const COL_IS_PRECOMPILE: usize = 20; +pub const COL_EXEC_NU_A: usize = 21; +pub const COL_EXEC_NU_B: usize = 22; +pub const COL_EXEC_NU_C: usize = 23; const PRECOMPILE_A_INDEX: F = F::new(ALL_TABLES[1].index() as u32); const PRECOMPILE_B_INDEX: F = F::new(ALL_TABLES[2].index() as u32); -const MINUS_ONE_OVER_AB_PRECOMPILES: usize = 1775588694; -const MINUS_A_MINUS_B_PRECOMPILES: usize = 2130706428; +const MINUS_ONE_OVER_AB_PRECOMPILES: usize = 1065353216; +const MINUS_A_MINUS_B_PRECOMPILES: usize = 2130706430; #[test] fn test_precompile_indices() { @@ -60,7 +59,7 @@ impl Air for ExecutionTable { type ExtraData = ExtraDataForBuses; fn n_columns_f_air(&self) -> usize { - N_EXEC_AIR_COLUMNS + N_TOTAL_EXECUTION_COLUMNS } fn n_columns_ef_air(&self) -> usize { 0 @@ -96,8 +95,7 @@ impl Air for ExecutionTable { let mul = up[COL_MUL].clone(); let deref = up[COL_DEREF].clone(); let jump = up[COL_JUMP].clone(); - let aux_1 = up[COL_AUX_1].clone(); - let aux_2 = up[COL_AUX_2].clone(); + let aux = up[COL_AUX].clone(); let precompile_index = up[COL_PRECOMPILE_INDEX].clone(); let (value_a, value_b, value_c) = ( @@ -142,11 +140,11 @@ impl Air for ExecutionTable { extra_data, precompile_index.clone(), is_precompile.clone(), - &[nu_a.clone(), nu_b.clone(), nu_c.clone(), aux_1.clone(), aux_2.clone()], + &[nu_a.clone(), nu_b.clone(), nu_c.clone(), aux.clone()], )); } else { builder.declare_values(&[is_precompile]); - builder.declare_values(&[nu_a.clone(), nu_b.clone(), nu_c.clone(), aux_1.clone(), aux_2.clone()]); + builder.declare_values(&[nu_a.clone(), nu_b.clone(), nu_c.clone(), aux.clone()]); } builder.assert_zero(flag_a_minus_one * (addr_a.clone() - fp_plus_operand_a)); @@ -157,8 +155,8 @@ impl Air for ExecutionTable { builder.assert_zero(mul * (nu_b.clone() - nu_a.clone() * nu_c.clone())); builder.assert_zero(deref.clone() * (addr_c.clone() - (value_a.clone() + operand_c.clone()))); - builder.assert_zero(deref.clone() * aux_1.clone() * (value_c.clone() - nu_b.clone())); - builder.assert_zero(deref.clone() * (aux_1.clone() - AB::F::ONE) * (value_c.clone() - fp.clone())); + builder.assert_zero(deref.clone() * aux.clone() * (value_c.clone() - nu_b.clone())); + builder.assert_zero(deref.clone() * (aux.clone() - AB::F::ONE) * (value_c.clone() - fp.clone())); builder.assert_zero((jump.clone() - AB::F::ONE) * (next_pc.clone() - pc_plus_one.clone())); builder.assert_zero((jump.clone() - AB::F::ONE) * (next_fp.clone() - fp.clone())); @@ -172,5 +170,5 @@ impl Air for ExecutionTable { } pub const fn instr_idx(col_index_in_air: usize) -> usize { - col_index_in_air - N_COMMITTED_EXEC_COLUMNS + col_index_in_air - N_RUNTIME_COLUMNS } diff --git a/crates/lean_vm/src/tables/execution/mod.rs b/crates/lean_vm/src/tables/execution/mod.rs index d0c799e4..0677ce07 100644 --- a/crates/lean_vm/src/tables/execution/mod.rs +++ b/crates/lean_vm/src/tables/execution/mod.rs @@ -21,7 +21,7 @@ impl TableT for ExecutionTable { } fn n_columns_f_total(&self) -> usize { - N_EXEC_AIR_COLUMNS + N_TEMPORARY_EXEC_COLUMNS + N_TOTAL_EXECUTION_COLUMNS + N_TEMPORARY_EXEC_COLUMNS } fn lookups_f(&self) -> Vec { @@ -50,12 +50,12 @@ impl TableT for ExecutionTable { table: BusTable::Variable(COL_PRECOMPILE_INDEX), direction: BusDirection::Push, selector: COL_IS_PRECOMPILE, - data: vec![COL_EXEC_NU_A, COL_EXEC_NU_B, COL_EXEC_NU_C, COL_AUX_1, COL_AUX_2], + data: vec![COL_EXEC_NU_A, COL_EXEC_NU_B, COL_EXEC_NU_C, COL_AUX], } } fn padding_row_f(&self) -> Vec { - let mut padding_row = vec![F::ZERO; N_EXEC_AIR_COLUMNS + N_TEMPORARY_EXEC_COLUMNS]; + let mut padding_row = vec![F::ZERO; N_TOTAL_EXECUTION_COLUMNS + N_TEMPORARY_EXEC_COLUMNS]; padding_row[COL_PC] = F::from_usize(ENDING_PC); padding_row[COL_JUMP] = F::ONE; padding_row[COL_FLAG_A] = F::ONE; diff --git a/crates/lean_vm/src/tables/table_enum.rs b/crates/lean_vm/src/tables/table_enum.rs index a74acd6b..12887ef4 100644 --- a/crates/lean_vm/src/tables/table_enum.rs +++ b/crates/lean_vm/src/tables/table_enum.rs @@ -1,5 +1,4 @@ use multilinear_toolkit::prelude::*; -use utils::MEMORY_TABLE_INDEX; use crate::*; @@ -48,7 +47,7 @@ impl Table { PF::from_usize(self.index()) } pub const fn index(&self) -> usize { - unsafe { *(self as *const Self as *const usize) + MEMORY_TABLE_INDEX + 1 } + unsafe { *(self as *const Self as *const usize) } } } @@ -122,7 +121,8 @@ impl Air for Table { } pub fn max_bus_width() -> usize { - 1 + ALL_TABLES.iter().map(|table| table.bus().data.len()).max().unwrap() + let max_bus_in_table = ALL_TABLES.iter().map(|table| table.bus().data.len()).max().unwrap(); + 1 + max_bus_in_table.max(N_INSTRUCTION_COLUMNS) } pub fn max_air_constraints() -> usize { @@ -131,12 +131,16 @@ pub fn max_air_constraints() -> usize { #[cfg(test)] mod tests { + use utils::{BYTECODE_TABLE_INDEX, MEMORY_TABLE_INDEX}; + use super::*; #[test] fn test_table_indices() { for (i, table) in ALL_TABLES.iter().enumerate() { - assert_eq!(table.index(), i + MEMORY_TABLE_INDEX + 1); + assert_eq!(table.index(), i); + assert_ne!(table.index(), MEMORY_TABLE_INDEX); + assert_ne!(table.index(), BYTECODE_TABLE_INDEX); } } } diff --git a/crates/lean_vm/src/tables/table_trait.rs b/crates/lean_vm/src/tables/table_trait.rs index 14a48b69..ea4f4272 100644 --- a/crates/lean_vm/src/tables/table_trait.rs +++ b/crates/lean_vm/src/tables/table_trait.rs @@ -1,4 +1,4 @@ -use crate::{DIMENSION, EF, F, InstructionContext, N_COMMITTED_EXEC_COLUMNS, RunnerError, Table}; +use crate::{DIMENSION, EF, F, InstructionContext, RunnerError, Table}; use multilinear_toolkit::prelude::*; use std::{any::TypeId, cmp::Reverse, collections::BTreeMap, mem::transmute}; @@ -54,6 +54,7 @@ pub enum BusTable { pub struct TableTrace { pub base: Vec>, pub ext: Vec>, + pub non_padded_n_rows: usize, pub log_n_rows: VarCount, } @@ -62,7 +63,8 @@ impl TableTrace { Self { base: vec![Vec::new(); air.n_columns_f_total()], ext: vec![Vec::new(); air.n_columns_ef_total()], - log_n_rows: 0, // filled later + non_padded_n_rows: 0, // filled later + log_n_rows: 0, // filled later } } } @@ -144,31 +146,8 @@ pub trait TableT: Air { false } - fn n_commited_columns_f(&self) -> usize { - if self.is_execution_table() { - N_COMMITTED_EXEC_COLUMNS - } else { - self.n_columns_f_air() - } - } - - fn n_commited_columns_ef(&self) -> usize { - self.n_columns_ef_air() - } - - fn n_commited_columns(&self) -> usize { - self.n_commited_columns_ef() * DIMENSION + self.n_commited_columns_f() - } - - fn commited_air_values(&self, air_evals: &[EF]) -> BTreeMap { - // the intermidiate columns are not commited - // (they correspond to decoded instructions, in execution table, obtained via logup* into the bytecode) - air_evals - .iter() - .copied() - .enumerate() - .filter(|(i, _)| *i < self.n_commited_columns_f() || *i >= self.n_columns_f_air()) - .collect::>() + fn n_committed_columns(&self) -> usize { + self.n_columns_ef_air() * DIMENSION + self.n_columns_f_air() } fn lookup_index_columns_f<'a>(&'a self, trace: &'a TableTrace) -> Vec<&'a [F]> { diff --git a/crates/lean_vm/src/tables/utils.rs b/crates/lean_vm/src/tables/utils.rs index e9e3e5f9..4d8f3412 100644 --- a/crates/lean_vm/src/tables/utils.rs +++ b/crates/lean_vm/src/tables/utils.rs @@ -11,12 +11,12 @@ pub(crate) fn eval_virtual_bus_column> let (logup_alphas_eq_poly, bus_beta) = extra_data.transmute_bus_data::(); assert!(data.len() < logup_alphas_eq_poly.len()); - (logup_alphas_eq_poly[1..] + (logup_alphas_eq_poly .iter() .zip(data) .map(|(c, d)| c.clone() * d.clone()) .sum::() - + bus_index) + + logup_alphas_eq_poly.last().unwrap().clone() * bus_index) * bus_beta.clone() + flag } diff --git a/crates/rec_aggregation/fiat_shamir.py b/crates/rec_aggregation/fiat_shamir.py index 5e49cf70..a76e7037 100644 --- a/crates/rec_aggregation/fiat_shamir.py +++ b/crates/rec_aggregation/fiat_shamir.py @@ -47,7 +47,7 @@ def fs_sample_chunks(fs, n_chunks: Const): new_fs = sampled + n_chunks * 8 return new_fs, sampled - +@inline def fs_sample_ef(fs): sampled = Array(8) poseidon16(fs, ZERO_VEC_PTR, sampled) @@ -66,6 +66,7 @@ def fs_sample_many_ef(fs, n): return new_fs, sampled +@inline def fs_hint(fs, n): # return the updated fiat-shamir, and a pointer to n field elements from the transcript transcript_ptr = fs[8] @@ -90,8 +91,8 @@ def fs_receive_chunks(fs, n_chunks: Const): ) return new_fs + 8 * (n_chunks - 1), transcript_ptr - -def fs_receive_ef(fs, n: Const): +@inline +def fs_receive_ef(fs, n): new_fs, ef_ptr = fs_receive_chunks(fs, div_ceil(n * DIM, 8)) for i in unroll(n * DIM, next_multiple_of(n * DIM, 8)): assert ef_ptr[i] == 0 @@ -118,12 +119,8 @@ def sample_bits_const(fs, n_samples: Const, K): def sample_bits_dynamic(fs_state, n_samples, K): new_fs_state: Imu sampled_bits: Imu - for r in unroll(0, N_ROUNDS_BASE + 1): - if n_samples == NUM_QUERIES_BASE[r]: - new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_BASE[r], K) - return new_fs_state, sampled_bits - for r in unroll(0, N_ROUNDS_EXT + 1): - if n_samples == NUM_QUERIES_EXT[r]: - new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_EXT[r], K) + for r in unroll(0, WHIR_N_ROUNDS + 1): + if n_samples == WHIR_NUM_QUERIES[r]: + new_fs_state, sampled_bits = sample_bits_const(fs_state, WHIR_NUM_QUERIES[r], K) return new_fs_state, sampled_bits assert False, "sample_bits_dynamic called with unsupported n_samples" diff --git a/crates/rec_aggregation/hashing.py b/crates/rec_aggregation/hashing.py index 09cb357b..eb3acc56 100644 --- a/crates/rec_aggregation/hashing.py +++ b/crates/rec_aggregation/hashing.py @@ -3,12 +3,9 @@ DIM = 5 # extension degree VECTOR_LEN = 8 -MERKLE_HEIGHTS_BASE = MERKLE_HEIGHTS_BASE_PLACEHOLDER -MERKLE_HEIGHTS_EXT = MERKLE_HEIGHTS_EXT_PLACEHOLDER -NUM_QUERIES_BASE = NUM_QUERIES_BASE_PLACEHOLDER -NUM_QUERIES_EXT = NUM_QUERIES_EXT_PLACEHOLDER -N_ROUNDS_BASE = len(NUM_QUERIES_BASE) - 1 -N_ROUNDS_EXT = len(NUM_QUERIES_EXT) - 1 +WHIR_MERKLE_HEIGHTS = WHIR_MERKLE_HEIGHTS_PLACEHOLDER +WHIR_NUM_QUERIES = WHIR_NUM_QUERIES_PLACEHOLDER +WHIR_N_ROUNDS = len(WHIR_NUM_QUERIES) - 1 def batch_hash_slice(num_queries, all_data_to_hash, all_resulting_hashes, len): @@ -18,21 +15,34 @@ def batch_hash_slice(num_queries, all_data_to_hash, all_resulting_hashes, len): if len == 16: batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, 16) return + if len == 8: + batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, 8) + return + if len == 20: + batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, 20) + return if len == 1: batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, 1) return + if len == 4: + batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, 4) + return + if len == 5: + batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, 5) + return + print(len) assert False, "batch_hash_slice called with unsupported len" def batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, len: Const): for i in range(0, num_queries): data = all_data_to_hash[i] - res = slice_hash(ZERO_VEC_PTR, data, len) + res = slice_hash(data, len) all_resulting_hashes[i] = res return - -def slice_hash(seed, data, len: Const): +@inline +def slice_hash(data, len): states = Array(len * VECTOR_LEN) poseidon16(ZERO_VEC_PTR, data, states) state_indexes = Array(len) @@ -44,26 +54,15 @@ def slice_hash(seed, data, len: Const): def merkle_verif_batch(n_paths, merkle_paths, leaves_digests, leave_positions, root, height, num_queries): - for i in unroll(0, N_ROUNDS_BASE + 1): - if height + num_queries * 1000 == MERKLE_HEIGHTS_BASE[i] + NUM_QUERIES_BASE[i] * 1000: - merkle_verif_batch_const( - NUM_QUERIES_BASE[i], - merkle_paths, - leaves_digests, - leave_positions, - root, - MERKLE_HEIGHTS_BASE[i], - ) - return - for i in unroll(0, N_ROUNDS_EXT + 1): - if height + num_queries * 1000 == MERKLE_HEIGHTS_EXT[i] + NUM_QUERIES_EXT[i] * 1000: + for i in unroll(0, WHIR_N_ROUNDS + 1): + if height + num_queries * 1000 == WHIR_MERKLE_HEIGHTS[i] + WHIR_NUM_QUERIES[i] * 1000: merkle_verif_batch_const( - NUM_QUERIES_EXT[i], + WHIR_NUM_QUERIES[i], merkle_paths, leaves_digests, leave_positions, root, - MERKLE_HEIGHTS_EXT[i], + WHIR_MERKLE_HEIGHTS[i], ) return print(12345555) diff --git a/crates/rec_aggregation/recursion.py b/crates/rec_aggregation/recursion.py index b93a34d5..a47fba50 100644 --- a/crates/rec_aggregation/recursion.py +++ b/crates/rec_aggregation/recursion.py @@ -6,10 +6,13 @@ MAX_LOG_N_ROWS_PER_TABLE = MAX_LOG_N_ROWS_PER_TABLE_PLACEHOLDER MIN_LOG_MEMORY_SIZE = MIN_LOG_MEMORY_SIZE_PLACEHOLDER MAX_LOG_MEMORY_SIZE = MAX_LOG_MEMORY_SIZE_PLACEHOLDER -N_VARS_FIRST_GKR = N_VARS_FIRST_GKR_PLACEHOLDER +N_VARS_LOGUP_GKR = N_VARS_LOGUP_GKR_PLACEHOLDER MAX_BUS_WIDTH = MAX_BUS_WIDTH_PLACEHOLDER MAX_NUM_AIR_CONSTRAINTS = MAX_NUM_AIR_CONSTRAINTS_PLACEHOLDER + MEMORY_TABLE_INDEX = MEMORY_TABLE_INDEX_PLACEHOLDER +BYTECODE_TABLE_INDEX = BYTECODE_TABLE_INDEX_PLACEHOLDER +EXECUTION_TABLE_INDEX = EXECUTION_TABLE_INDEX_PLACEHOLDER LOOKUPS_F_INDEXES = LOOKUPS_F_INDEXES_PLACEHOLDER # [[_; ?]; N_TABLES] LOOKUPS_F_VALUES = LOOKUPS_F_VALUES_PLACEHOLDER # [[[_; ?]; ?]; N_TABLES] @@ -20,46 +23,45 @@ NUM_COLS_F_AIR = NUM_COLS_F_AIR_PLACEHOLDER NUM_COLS_EF_AIR = NUM_COLS_EF_AIR_PLACEHOLDER -NUM_COLS_F_COMMITED = NUM_COLS_F_COMMITED_PLACEHOLDER +NUM_COLS_F_COMMITTED = NUM_COLS_F_COMMITTED_PLACEHOLDER -EXECUTION_TABLE_INDEX = EXECUTION_TABLE_INDEX_PLACEHOLDER AIR_DEGREES = AIR_DEGREES_PLACEHOLDER # [_; N_TABLES] N_AIR_COLUMNS_F = N_AIR_COLUMNS_F_PLACEHOLDER # [_; N_TABLES] N_AIR_COLUMNS_EF = N_AIR_COLUMNS_EF_PLACEHOLDER # [_; N_TABLES] AIR_DOWN_COLUMNS_F = AIR_DOWN_COLUMNS_F_PLACEHOLDER # [[_; ?]; N_TABLES] AIR_DOWN_COLUMNS_EF = AIR_DOWN_COLUMNS_EF_PLACEHOLDER # [[_; _]; N_TABLES] -NUM_BYTECODE_INSTRUCTIONS = NUM_BYTECODE_INSTRUCTIONS_PLACEHOLDER +N_INSTRUCTION_COLUMNS = N_INSTRUCTION_COLUMNS_PLACEHOLDER N_COMMITTED_EXEC_COLUMNS = N_COMMITTED_EXEC_COLUMNS_PLACEHOLDER GUEST_BYTECODE_LEN = GUEST_BYTECODE_LEN_PLACEHOLDER COL_PC = COL_PC_PLACEHOLDER -TOTAL_WHIR_STATEMENTS_BASE = TOTAL_WHIR_STATEMENTS_BASE_PLACEHOLDER +TOTAL_WHIR_STATEMENTS = TOTAL_WHIR_STATEMENTS_PLACEHOLDER STARTING_PC = STARTING_PC_PLACEHOLDER ENDING_PC = ENDING_PC_PLACEHOLDER NONRESERVED_PROGRAM_INPUT_START = NONRESERVED_PROGRAM_INPUT_START_PLACEHOLDER def main(): - mem = 0 - priv_start = mem[PRIVATE_INPUT_START_PTR] + pub_mem = NONRESERVED_PROGRAM_INPUT_START + priv_start = pub_mem[0] proof_size = priv_start[0] - outer_public_memory_log_size = priv_start[1] - outer_public_memory_size = powers_of_two(outer_public_memory_log_size) + inner_public_memory_log_size = priv_start[1] + inner_public_memory_size = powers_of_two(inner_public_memory_log_size) n_recursions = priv_start[2] - outer_public_memory = priv_start + 3 - proofs_start = outer_public_memory + outer_public_memory_size + inner_public_memory = priv_start + 3 + proofs_start = inner_public_memory + inner_public_memory_size for i in range(0, n_recursions): proof_transcript = proofs_start + i * proof_size - recursion(outer_public_memory_log_size, outer_public_memory, proof_transcript) + recursion(inner_public_memory_log_size, inner_public_memory, proof_transcript) return -def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcript): +def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcript): fs: Mut = fs_new(proof_transcript) # table dims - debug_assert(N_TABLES + 1 < VECTOR_LEN) # (because duplex only once bellow) + debug_assert(N_TABLES + 1 < VECTOR_LEN) fs, mem_and_table_dims = fs_receive_chunks(fs, 1) for i in unroll(N_TABLES + 1, 8): assert mem_and_table_dims[i] == 0 @@ -73,9 +75,9 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip assert n_vars_for_table <= MAX_LOG_N_ROWS_PER_TABLE[i] assert MIN_LOG_MEMORY_SIZE <= log_memory assert log_memory <= MAX_LOG_MEMORY_SIZE + assert log_memory <= GUEST_BYTECODE_LEN - # parse 1st whir commitment - fs, whir_base_root, whir_base_ood_points, whir_base_ood_evals = parse_whir_commitment_const(fs, NUM_OOD_COMMIT_BASE) + fs, whir_base_root, whir_base_ood_points, whir_base_ood_evals = parse_whir_commitment_const(fs, WHIR_NUM_OOD_COMMIT) fs, logup_c = fs_sample_ef(fs) @@ -84,23 +86,86 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip logup_alphas_eq_poly = poly_eq_extension(logup_alphas, log2_ceil(MAX_BUS_WIDTH)) # GENRIC LOGUP - fs, quotient_gkr, point_gkr, numerators_value, denominators_value = verify_gkr_quotient(fs, N_VARS_FIRST_GKR) + fs, quotient_gkr, point_gkr, numerators_value, denominators_value = verify_gkr_quotient(fs, N_VARS_LOGUP_GKR) set_to_5_zeros(quotient_gkr) - memory_and_acc_prefix = multilinear_location_prefix(0, N_VARS_FIRST_GKR - log_memory, point_gkr) + memory_and_acc_prefix = multilinear_location_prefix(0, N_VARS_LOGUP_GKR - log_memory, point_gkr) fs, value_acc = fs_receive_ef(fs, 1) fs, value_memory = fs_receive_ef(fs, 1) retrieved_numerators_value: Mut = opposite_extension_ret(mul_extension_ret(memory_and_acc_prefix, value_acc)) - value_index = mle_of_01234567_etc(point_gkr + (N_VARS_FIRST_GKR - log_memory) * DIM, log_memory) - fingerprint_memory = fingerprint_2(MEMORY_TABLE_INDEX, value_index, value_memory, logup_alphas_eq_poly) + value_index = mle_of_01234567_etc(point_gkr + (N_VARS_LOGUP_GKR - log_memory) * DIM, log_memory) + fingerprint_memory = fingerprint_2(MEMORY_TABLE_INDEX, value_memory, value_index, logup_alphas_eq_poly) retrieved_denominators_value: Mut = mul_extension_ret( memory_and_acc_prefix, sub_extension_ret(logup_c, fingerprint_memory) ) offset: Mut = powers_of_two(log_memory) + + log_bytecode = log2_ceil(GUEST_BYTECODE_LEN) + log_n_cycles = table_dims[EXECUTION_TABLE_INDEX] + log_bytecode_padded = maximum(log_bytecode, log_n_cycles) + bytecode_and_acc_point = point_gkr + (N_VARS_LOGUP_GKR - log_bytecode) * DIM + bytecode_multilinear_location_prefix = multilinear_location_prefix( + offset / 2 ** log2_ceil(GUEST_BYTECODE_LEN), N_VARS_LOGUP_GKR - log_bytecode, point_gkr + ) + bytecode_padded_multilinear_location_prefix = multilinear_location_prefix( + offset / powers_of_two(log_bytecode_padded), N_VARS_LOGUP_GKR - log_bytecode_padded, point_gkr + ) + pub_mem = NONRESERVED_PROGRAM_INPUT_START + assert pub_mem[1] == log_bytecode + log2_ceil(N_INSTRUCTION_COLUMNS) + copy_many_ef(bytecode_and_acc_point, pub_mem + 2, log_bytecode) + copy_many_ef( + logup_alphas + (log2_ceil(MAX_BUS_WIDTH) - log2_ceil(N_INSTRUCTION_COLUMNS)) * DIM, + pub_mem + 2 + log_bytecode * DIM, + log2_ceil(N_INSTRUCTION_COLUMNS), + ) + bytecode_value = pub_mem + 2 + (log_bytecode + log2_ceil(N_INSTRUCTION_COLUMNS)) * DIM + bytecode_value_corrected: Mut = bytecode_value + for i in unroll(0, log2_ceil(MAX_BUS_WIDTH) - log2_ceil(N_INSTRUCTION_COLUMNS)): + bytecode_value_corrected = mul_extension_ret( + bytecode_value_corrected, one_minus_self_extension_ret(logup_alphas + i * DIM) + ) + + fs, value_bytecode_acc = fs_receive_ef(fs, 1) + retrieved_numerators_value = sub_extension_ret( + retrieved_numerators_value, mul_extension_ret(bytecode_multilinear_location_prefix, value_bytecode_acc) + ) + + bytecode_index_value = mle_of_01234567_etc(bytecode_and_acc_point, log_bytecode) + retrieved_denominators_value = add_extension_ret( + retrieved_denominators_value, + mul_extension_ret( + bytecode_multilinear_location_prefix, + sub_extension_ret( + logup_c, + add_extension_ret( + bytecode_value_corrected, + add_extension_ret( + mul_extension_ret(bytecode_index_value, logup_alphas_eq_poly + N_INSTRUCTION_COLUMNS * DIM), + mul_base_extension_ret( + BYTECODE_TABLE_INDEX, logup_alphas_eq_poly + (2 ** log2_ceil(MAX_BUS_WIDTH) - 1) * DIM + ), + ), + ), + ), + ), + ) + retrieved_denominators_value = add_extension_ret( + retrieved_denominators_value, + mul_extension_ret( + bytecode_padded_multilinear_location_prefix, + mle_of_zeros_then_ones( + point_gkr + (N_VARS_LOGUP_GKR - log_bytecode_padded) * DIM, + 2 ** log2_ceil(GUEST_BYTECODE_LEN), + log_bytecode_padded, + ), + ), + ) + offset += powers_of_two(log_bytecode_padded) + bus_numerators_values = DynArray([]) bus_denominators_values = DynArray([]) pcs_points = DynArray([]) # [[_; N]; N_TABLES] @@ -109,15 +174,38 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip pcs_values = DynArray([]) # [[[[] or [_]; num cols]; N]; N_TABLES] for i in unroll(0, N_TABLES): pcs_values.push(DynArray([])) + pcs_values[i].push(DynArray([])) + total_num_cols = NUM_COLS_F_AIR[i] + DIM * NUM_COLS_EF_AIR[i] + for col in unroll(0, total_num_cols): + pcs_values[i][0].push(DynArray([])) + for table_index in unroll(0, N_TABLES): # I] Bus (data flow between tables) log_n_rows = table_dims[table_index] n_rows = powers_of_two(log_n_rows) - inner_point = point_gkr + (N_VARS_FIRST_GKR - log_n_rows) * DIM + inner_point = point_gkr + (N_VARS_LOGUP_GKR - log_n_rows) * DIM pcs_points[table_index].push(inner_point) - prefix = multilinear_location_prefix(offset / n_rows, N_VARS_FIRST_GKR - log_n_rows, point_gkr) + if table_index == EXECUTION_TABLE_INDEX: + # 0] Bytecode lookup + bytecode_prefix = multilinear_location_prefix(offset / n_rows, N_VARS_LOGUP_GKR - log_n_rows, point_gkr) + + fs, eval_on_pc = fs_receive_ef(fs, 1) + pcs_values[EXECUTION_TABLE_INDEX][0][COL_PC].push(eval_on_pc) + fs, instr_evals = fs_receive_ef(fs, N_INSTRUCTION_COLUMNS) + for i in unroll(0, N_INSTRUCTION_COLUMNS): + global_index = N_COMMITTED_EXEC_COLUMNS + i + pcs_values[EXECUTION_TABLE_INDEX][0][global_index].push(instr_evals + i * DIM) + retrieved_numerators_value = add_extension_ret(retrieved_numerators_value, bytecode_prefix) + fingerp = fingerprint_bytecode(instr_evals, eval_on_pc, logup_alphas_eq_poly) + retrieved_denominators_value = add_extension_ret( + retrieved_denominators_value, + mul_extension_ret(bytecode_prefix, sub_extension_ret(logup_c, fingerp)), + ) + offset += n_rows + + prefix = multilinear_location_prefix(offset / n_rows, N_VARS_LOGUP_GKR - log_n_rows, point_gkr) fs, eval_on_selector = fs_receive_ef(fs, 1) retrieved_numerators_value = add_extension_ret( @@ -137,11 +225,6 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip # II] Lookup into memory - pcs_values[table_index].push(DynArray([])) - total_num_cols = NUM_COLS_F_AIR[table_index] + DIM * NUM_COLS_EF_AIR[table_index] - for col in unroll(0, total_num_cols): - pcs_values[table_index][0].push(DynArray([])) - for lookup_f_index in unroll(0, len(LOOKUPS_F_INDEXES[table_index])): col_index = LOOKUPS_F_INDEXES[table_index][lookup_f_index] fs, index_eval = fs_receive_ef(fs, 1) @@ -154,13 +237,13 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip pcs_values[table_index][0][col_index].push(value_eval) pref = multilinear_location_prefix( - offset / n_rows, N_VARS_FIRST_GKR - log_n_rows, point_gkr + offset / n_rows, N_VARS_LOGUP_GKR - log_n_rows, point_gkr ) # TODO there is some duplication here retrieved_numerators_value = add_extension_ret(retrieved_numerators_value, pref) fingerp = fingerprint_2( MEMORY_TABLE_INDEX, - add_base_extension_ret(i, index_eval), value_eval, + add_base_extension_ret(i, index_eval), logup_alphas_eq_poly, ) retrieved_denominators_value = add_extension_ret( @@ -182,13 +265,13 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip for i in unroll(0, DIM): fs, value_eval = fs_receive_ef(fs, 1) pref = multilinear_location_prefix( - offset / n_rows, N_VARS_FIRST_GKR - log_n_rows, point_gkr + offset / n_rows, N_VARS_LOGUP_GKR - log_n_rows, point_gkr ) # TODO there is some duplication here retrieved_numerators_value = add_extension_ret(retrieved_numerators_value, pref) fingerp = fingerprint_2( MEMORY_TABLE_INDEX, - add_base_extension_ret(i, index_eval), value_eval, + add_base_extension_ret(i, index_eval), logup_alphas_eq_poly, ) retrieved_denominators_value = add_extension_ret( @@ -197,7 +280,7 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip ) global_index = ( - NUM_COLS_F_COMMITED[table_index] + LOOKUPS_EF_VALUES[table_index][lookup_ef_index] * DIM + i + NUM_COLS_F_COMMITTED[table_index] + LOOKUPS_EF_VALUES[table_index][lookup_ef_index] * DIM + i ) debug_assert(len(pcs_values[table_index][0][global_index]) == 0) pcs_values[table_index][0][global_index].push(value_eval) @@ -206,13 +289,13 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip retrieved_denominators_value = add_extension_ret( retrieved_denominators_value, - mle_of_zeros_then_ones(point_gkr, offset, N_VARS_FIRST_GKR), + mle_of_zeros_then_ones(point_gkr, offset, N_VARS_LOGUP_GKR), ) copy_5(retrieved_numerators_value, numerators_value) copy_5(retrieved_denominators_value, denominators_value) - memory_acc_point = point_gkr + (N_VARS_FIRST_GKR - log_memory) * DIM + memory_and_acc_point = point_gkr + (N_VARS_LOGUP_GKR - log_memory) * DIM # END OF GENERIC LOGUP @@ -229,7 +312,7 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip total_num_cols = NUM_COLS_F_AIR[table_index] + DIM * NUM_COLS_EF_AIR[table_index] bus_final_value: Mut = bus_numerator_value - if table_index != EXECUTION_TABLE_INDEX - 1: # -1 because shift due to memory + if table_index != EXECUTION_TABLE_INDEX: bus_final_value = opposite_extension_ret(bus_final_value) bus_final_value = add_extension_ret( bus_final_value, @@ -338,100 +421,28 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip virtual_col_index = n_up_columns_f + i * DIM + j pcs_values[table_index][last_index_2][virtual_col_index].push(transposed + j * DIM) - log_num_instrs = log2_ceil(NUM_BYTECODE_INSTRUCTIONS) - fs, bytecode_compression_challenges = fs_sample_many_ef(fs, log_num_instrs) - - bytecode_air_values = Array(DIM * 2**log_num_instrs) - for i in unroll(0, NUM_BYTECODE_INSTRUCTIONS): - col = N_COMMITTED_EXEC_COLUMNS + i - copy_5( - pcs_values[EXECUTION_TABLE_INDEX - 1][2][col][0], - bytecode_air_values + i * DIM, - ) - pcs_values[EXECUTION_TABLE_INDEX - 1][2][col].pop() - for i in unroll(NUM_BYTECODE_INSTRUCTIONS, 2**log_num_instrs): - set_to_5_zeros(bytecode_air_values + i * DIM) - - bytecode_air_point = pcs_points[EXECUTION_TABLE_INDEX - 1][2] - bytecode_lookup_claim = dot_product_ret( - bytecode_air_values, - poly_eq_extension(bytecode_compression_challenges, log_num_instrs), - 2**log_num_instrs, - EE, - ) + # verify the inner public memory is well constructed (with the conventions) (NONRESERVED_PROGRAM_INPUT_START is a multiple of DIM) + for i in unroll(0, NONRESERVED_PROGRAM_INPUT_START / DIM): + copy_5(i * DIM, inner_public_memory + i * DIM) - fs, whir_ext_root, whir_ext_ood_points, whir_ext_ood_evals = parse_whir_commitment_const(fs, NUM_OOD_COMMIT_EXT) - - # VERIFY LOGUP* - - log_table_len = log2_ceil(GUEST_BYTECODE_LEN) - log_n_cycles = table_dims[EXECUTION_TABLE_INDEX - 1] - fs, ls_sumcheck_point, ls_sumcheck_value = sumcheck_verify(fs, log_table_len, bytecode_lookup_claim, 2) - fs, table_eval = fs_receive_ef(fs, 1) - fs, pushforward_eval = fs_receive_ef(fs, 1) - mul_extension(table_eval, pushforward_eval, ls_sumcheck_value) - - fs, ls_c = fs_sample_ef(fs) - - fs, quotient_left, claim_point_left, claim_num_left, eval_c_minus_indexes = verify_gkr_quotient(fs, log_n_cycles) - fs, quotient_right, claim_point_right, pushforward_final_eval, claim_den_right = verify_gkr_quotient( - fs, log_table_len - ) - - copy_5(quotient_left, quotient_right) - - copy_5( - eq_mle_extension(claim_point_left, bytecode_air_point, log_n_cycles), - claim_num_left, - ) - copy_5( - sub_extension_ret(ls_c, mle_of_01234567_etc(claim_point_right, log_table_len)), - claim_den_right, - ) - - # logupstar statements: - ls_on_indexes_point = claim_point_left - ls_on_indexes_eval = sub_extension_ret(ls_c, eval_c_minus_indexes) - ls_on_table_point = ls_sumcheck_point - ls_on_table_eval = table_eval - ls_on_pushforward_point_1 = ls_sumcheck_point - ls_on_pushforward_eval_1 = pushforward_eval - ls_on_pushforward_point_2 = claim_point_right - ls_on_pushforward_eval_2 = pushforward_final_eval - - # TODO evaluate the folded bytecode - - pcs_points[EXECUTION_TABLE_INDEX - 1].push(ls_on_indexes_point) - pcs_values[EXECUTION_TABLE_INDEX - 1].push(DynArray([])) - last_len = len(pcs_values[EXECUTION_TABLE_INDEX - 1]) - 1 - total_exec_cols = NUM_COLS_F_AIR[EXECUTION_TABLE_INDEX - 1] + DIM * NUM_COLS_EF_AIR[EXECUTION_TABLE_INDEX - 1] - for _ in unroll(0, total_exec_cols): - pcs_values[EXECUTION_TABLE_INDEX - 1][last_len].push(DynArray([])) - pcs_values[EXECUTION_TABLE_INDEX - 1][last_len][COL_PC].push(ls_on_indexes_eval) - - # verify the outer public memory is well constructed (with the conventions) - for i in unroll(0, next_multiple_of(NONRESERVED_PROGRAM_INPUT_START, DIM) / DIM): - copy_5(i * DIM, outer_public_memory + i * DIM) - - fs, public_memory_random_point = fs_sample_many_ef(fs, outer_public_memory_log_size) - - poly_eq_public_mem = poly_eq_extension_dynamic(public_memory_random_point, outer_public_memory_log_size) + fs, public_memory_random_point = fs_sample_many_ef(fs, inner_public_memory_log_size) + poly_eq_public_mem = poly_eq_extension_dynamic(public_memory_random_point, inner_public_memory_log_size) public_memory_eval = Array(DIM) dot_product_be_dynamic( - outer_public_memory, + inner_public_memory, poly_eq_public_mem, public_memory_eval, - powers_of_two(outer_public_memory_log_size), + powers_of_two(inner_public_memory_log_size), ) # WHIR BASE combination_randomness_gen: Mut fs, combination_randomness_gen = fs_sample_ef(fs) combination_randomness_powers: Mut = powers_const( - combination_randomness_gen, NUM_OOD_COMMIT_BASE + TOTAL_WHIR_STATEMENTS_BASE + combination_randomness_gen, WHIR_NUM_OOD_COMMIT + TOTAL_WHIR_STATEMENTS ) - whir_sum: Mut = dot_product_ret(whir_base_ood_evals, combination_randomness_powers, NUM_OOD_COMMIT_BASE, EE) - curr_randomness: Mut = combination_randomness_powers + NUM_OOD_COMMIT_BASE * DIM + whir_sum: Mut = dot_product_ret(whir_base_ood_evals, combination_randomness_powers, WHIR_NUM_OOD_COMMIT, EE) + curr_randomness: Mut = combination_randomness_powers + WHIR_NUM_OOD_COMMIT * DIM whir_sum = add_extension_ret(mul_extension_ret(value_memory, curr_randomness), whir_sum) curr_randomness += DIM @@ -439,6 +450,8 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip curr_randomness += DIM whir_sum = add_extension_ret(mul_extension_ret(public_memory_eval, curr_randomness), whir_sum) curr_randomness += DIM + whir_sum = add_extension_ret(mul_extension_ret(value_bytecode_acc, curr_randomness), whir_sum) + curr_randomness += DIM whir_sum = add_extension_ret(mul_extension_ret(embed_in_ef(STARTING_PC), curr_randomness), whir_sum) curr_randomness += DIM @@ -461,7 +474,7 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip s: Mut final_value: Mut end_sum: Mut - fs, folding_randomness_global, s, final_value, end_sum = whir_open_base( + fs, folding_randomness_global, s, final_value, end_sum = whir_open( fs, whir_base_root, whir_base_ood_points, @@ -469,34 +482,34 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip whir_sum, ) - curr_randomness = combination_randomness_powers + NUM_OOD_COMMIT_BASE * DIM + curr_randomness = combination_randomness_powers + WHIR_NUM_OOD_COMMIT * DIM - eq_memory_acc_point = eq_mle_extension( - folding_randomness_global + (N_VARS_BASE - log_memory) * DIM, - memory_acc_point, + eq_memory_and_acc_point = eq_mle_extension( + folding_randomness_global + (WHIR_N_VARS - log_memory) * DIM, + memory_and_acc_point, log_memory, ) - prefix_mem = multilinear_location_prefix(0, N_VARS_BASE - log_memory, folding_randomness_global) + prefix_memory = multilinear_location_prefix(0, WHIR_N_VARS - log_memory, folding_randomness_global) s = add_extension_ret( s, - mul_extension_ret(mul_extension_ret(curr_randomness, prefix_mem), eq_memory_acc_point), + mul_extension_ret(mul_extension_ret(curr_randomness, prefix_memory), eq_memory_and_acc_point), ) curr_randomness += DIM - prefix_acc = multilinear_location_prefix(1, N_VARS_BASE - log_memory, folding_randomness_global) + prefix_acc_memory = multilinear_location_prefix(1, WHIR_N_VARS - log_memory, folding_randomness_global) s = add_extension_ret( s, - mul_extension_ret(mul_extension_ret(curr_randomness, prefix_acc), eq_memory_acc_point), + mul_extension_ret(mul_extension_ret(curr_randomness, prefix_acc_memory), eq_memory_and_acc_point), ) curr_randomness += DIM eq_pub_mem = eq_mle_extension( - folding_randomness_global + (N_VARS_BASE - outer_public_memory_log_size) * DIM, + folding_randomness_global + (WHIR_N_VARS - inner_public_memory_log_size) * DIM, public_memory_random_point, - outer_public_memory_log_size, + inner_public_memory_log_size, ) prefix_pub_mem = multilinear_location_prefix( - 0, N_VARS_BASE - outer_public_memory_log_size, folding_randomness_global + 0, WHIR_N_VARS - inner_public_memory_log_size, folding_randomness_global ) s = add_extension_ret( s, @@ -504,11 +517,28 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip ) curr_randomness += DIM - offset = powers_of_two(log_memory) * 2 # memory and acc + offset = powers_of_two(log_memory) * 2 # memory and acc_memory + + eq_bytecode_acc = eq_mle_extension( + folding_randomness_global + (WHIR_N_VARS - log2_ceil(GUEST_BYTECODE_LEN)) * DIM, + bytecode_and_acc_point, + log2_ceil(GUEST_BYTECODE_LEN), + ) + prefix_bytecode_acc = multilinear_location_prefix( + offset / 2 ** log2_ceil(GUEST_BYTECODE_LEN), + WHIR_N_VARS - log2_ceil(GUEST_BYTECODE_LEN), + folding_randomness_global, + ) + s = add_extension_ret( + s, + mul_extension_ret(mul_extension_ret(curr_randomness, prefix_bytecode_acc), eq_bytecode_acc), + ) + curr_randomness += DIM + offset += powers_of_two(log_bytecode_padded) prefix_pc_start = multilinear_location_prefix( offset + COL_PC * powers_of_two(log_n_cycles), - N_VARS_BASE, + WHIR_N_VARS, folding_randomness_global, ) s = add_extension_ret(s, mul_extension_ret(curr_randomness, prefix_pc_start)) @@ -516,7 +546,7 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip prefix_pc_end = multilinear_location_prefix( offset + (COL_PC + 1) * powers_of_two(log_n_cycles) - 1, - N_VARS_BASE, + WHIR_N_VARS, folding_randomness_global, ) s = add_extension_ret(s, mul_extension_ret(curr_randomness, prefix_pc_end)) @@ -530,14 +560,14 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip point = pcs_points[table_index][i] eq_factor = eq_mle_extension( point, - folding_randomness_global + (N_VARS_BASE - log_n_rows) * DIM, + folding_randomness_global + (WHIR_N_VARS - log_n_rows) * DIM, log_n_rows, ) for j in unroll(0, total_num_cols): if len(pcs_values[table_index][i][j]) == 1: prefix = multilinear_location_prefix( offset / n_rows + j, - N_VARS_BASE - log_n_rows, + WHIR_N_VARS - log_n_rows, folding_randomness_global, ) s = add_extension_ret( @@ -545,39 +575,10 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip mul_extension_ret(mul_extension_ret(curr_randomness, prefix), eq_factor), ) curr_randomness += DIM - num_commited_cols: Imu - if table_index == EXECUTION_TABLE_INDEX - 1: - num_commited_cols = N_COMMITTED_EXEC_COLUMNS - else: - num_commited_cols = total_num_cols - offset += n_rows * num_commited_cols + offset += n_rows * total_num_cols copy_5(mul_extension_ret(s, final_value), end_sum) - # WHIR EXT (Pushforward) - fs, combination_randomness_gen = fs_sample_ef(fs) - combination_randomness_powers = powers_const(combination_randomness_gen, NUM_OOD_COMMIT_EXT + 2) - whir_sum = dot_product_ret(whir_ext_ood_evals, combination_randomness_powers, NUM_OOD_COMMIT_EXT, EE) - whir_sum = add_extension_ret( - whir_sum, - mul_extension_ret( - combination_randomness_powers + NUM_OOD_COMMIT_EXT * DIM, - ls_on_pushforward_eval_1, - ), - ) - whir_sum = add_extension_ret( - whir_sum, - mul_extension_ret( - combination_randomness_powers + (NUM_OOD_COMMIT_EXT + 1) * DIM, - ls_on_pushforward_eval_2, - ), - ) - fs, folding_randomness_global, s, final_value, end_sum = whir_open_ext( - fs, whir_ext_root, whir_ext_ood_points, combination_randomness_powers, whir_sum - ) - - # Last TODO = Opening on the guest bytecode, but there are multiple ways to handle this - return @@ -587,12 +588,24 @@ def multilinear_location_prefix(offset, n_vars, point): return res -def fingerprint_2(table_index, data_1, data_2, alpha_powers): +def fingerprint_2(table_index, data_1, data_2, logup_alphas_eq_poly): buff = Array(DIM * 2) copy_5(data_1, buff) copy_5(data_2, buff + DIM) - res: Mut = dot_product_ret(buff, alpha_powers + DIM, 2, EE) - res = add_base_extension_ret(table_index, res) + res: Mut = dot_product_ret(buff, logup_alphas_eq_poly, 2, EE) + res = add_extension_ret( + res, mul_base_extension_ret(table_index, logup_alphas_eq_poly + (2 ** log2_ceil(MAX_BUS_WIDTH) - 1) * DIM) + ) + return res + + +def fingerprint_bytecode(instr_evals, eval_on_pc, logup_alphas_eq_poly): + res: Mut = dot_product_ret(instr_evals, logup_alphas_eq_poly, N_INSTRUCTION_COLUMNS, EE) + res = add_extension_ret(res, mul_extension_ret(eval_on_pc, logup_alphas_eq_poly + N_INSTRUCTION_COLUMNS * DIM)) + res = add_extension_ret( + res, + mul_base_extension_ret(BYTECODE_TABLE_INDEX, logup_alphas_eq_poly + (2 ** log2_ceil(MAX_BUS_WIDTH) - 1) * DIM), + ) return res @@ -659,16 +672,16 @@ def verify_gkr_quotient_step(fs: Mut, n_vars, point, claim_num, claim_den): return fs, postponed_point, new_claim_num, new_claim_den -def evaluate_air_constraints(table_index, inner_evals, air_alpha_powers, bus_beta, bus_alpha_powers): +def evaluate_air_constraints(table_index, inner_evals, air_alpha_powers, bus_beta, logup_alphas_eq_poly): res: Imu debug_assert(table_index < 3) match table_index: case 0: - res = evaluate_air_constraints_table_0(inner_evals, air_alpha_powers, bus_beta, bus_alpha_powers) + res = evaluate_air_constraints_table_0(inner_evals, air_alpha_powers, bus_beta, logup_alphas_eq_poly) case 1: - res = evaluate_air_constraints_table_1(inner_evals, air_alpha_powers, bus_beta, bus_alpha_powers) + res = evaluate_air_constraints_table_1(inner_evals, air_alpha_powers, bus_beta, logup_alphas_eq_poly) case 2: - res = evaluate_air_constraints_table_2(inner_evals, air_alpha_powers, bus_beta, bus_alpha_powers) + res = evaluate_air_constraints_table_2(inner_evals, air_alpha_powers, bus_beta, logup_alphas_eq_poly) return res diff --git a/crates/rec_aggregation/src/recursion.rs b/crates/rec_aggregation/src/recursion.rs index d3da7601..5fe842ae 100644 --- a/crates/rec_aggregation/src/recursion.rs +++ b/crates/rec_aggregation/src/recursion.rs @@ -4,30 +4,26 @@ use std::rc::Rc; use std::time::Instant; use lean_compiler::{CompilationFlags, ProgramSource, compile_program, compile_program_with_flags}; +use lean_prover::default_whir_config; use lean_prover::prove_execution::prove_execution; use lean_prover::verify_execution::verify_execution; -use lean_prover::{STARTING_LOG_INV_RATE_BASE, STARTING_LOG_INV_RATE_EXTENSION, SnarkParams, whir_config_builder}; use lean_vm::*; use multilinear_toolkit::prelude::symbolic::{ SymbolicExpression, SymbolicOperation, get_symbolic_constraints_and_bus_data_values, }; use multilinear_toolkit::prelude::*; -use utils::{Counter, MEMORY_TABLE_INDEX}; +use utils::{BYTECODE_TABLE_INDEX, Counter, MEMORY_TABLE_INDEX}; + +const LOG_INV_RATE: usize = 2; pub fn run_recursion_benchmark(count: usize, tracing: bool) { - if tracing { - utils::init_tracing(); - } let filepath = Path::new(env!("CARGO_MANIFEST_DIR")) .join("recursion.py") .to_str() .unwrap() .to_string(); - let snark_params = SnarkParams { - first_whir: whir_config_builder(STARTING_LOG_INV_RATE_BASE, 3, 1), - second_whir: whir_config_builder(STARTING_LOG_INV_RATE_EXTENSION, 4, 1), - }; + let inner_whir_config = default_whir_config(LOG_INV_RATE); let program_to_prove = r#" DIM = 5 POSEIDON_OF_ZERO = POSEIDON_OF_ZERO_PLACEHOLDER @@ -58,22 +54,24 @@ def main(): .replace("POSEIDON_OF_ZERO_PLACEHOLDER", &POSEIDON_16_NULL_HASH_PTR.to_string()); let bytecode_to_prove = compile_program(&ProgramSource::Raw(program_to_prove.to_string())); precompute_dft_twiddles::(1 << 24); - let outer_public_input = vec![]; - let outer_private_input = vec![]; + let inner_public_input = vec![]; + let inner_private_input = vec![]; let proof_to_prove = prove_execution( &bytecode_to_prove, - (&outer_public_input, &outer_private_input), + (&inner_public_input, &inner_private_input), &vec![], - &snark_params, + &inner_whir_config, false, ); - let verif_details = verify_execution(&bytecode_to_prove, &[], proof_to_prove.proof.clone(), &snark_params).unwrap(); + let verif_details = verify_execution( + &bytecode_to_prove, + &inner_public_input, + proof_to_prove.proof.clone(), + &inner_whir_config, + ) + .unwrap(); - let base_whir = WhirConfig::::new(&snark_params.first_whir, proof_to_prove.first_whir_n_vars); - let ext_whir = WhirConfig::::new( - &snark_params.second_whir, - log2_ceil_usize(bytecode_to_prove.instructions.len()), - ); + let outer_whir_config = WhirConfig::::new(&inner_whir_config, proof_to_prove.first_whir_n_vars); // let guest_program_commitment = { // let mut prover_state = build_prover_state(); @@ -83,8 +81,7 @@ def main(): // assert_eq!(commitment_transcript.len(), ext_whir.committment_ood_samples * DIMENSION + VECTOR_LEN); // }; - let mut replacements = whir_recursion_placeholder_replacements(&base_whir, true); - replacements.extend(whir_recursion_placeholder_replacements(&ext_whir, false)); + let mut replacements = whir_recursion_placeholder_replacements(&outer_whir_config); assert!( verif_details.log_memory >= verif_details.table_n_vars[&Table::execution()] @@ -103,7 +100,7 @@ def main(): // VM recursion parameters (different from WHIR) replacements.insert( - "N_VARS_FIRST_GKR_PLACEHOLDER".to_string(), + "N_VARS_LOGUP_GKR_PLACEHOLDER".to_string(), verif_details.first_quotient_gkr_n_vars.to_string(), ); replacements.insert("N_TABLES_PLACEHOLDER".to_string(), N_TABLES.to_string()); @@ -139,6 +136,10 @@ def main(): "MEMORY_TABLE_INDEX_PLACEHOLDER".to_string(), MEMORY_TABLE_INDEX.to_string(), ); + replacements.insert( + "BYTECODE_TABLE_INDEX_PLACEHOLDER".to_string(), + BYTECODE_TABLE_INDEX.to_string(), + ); replacements.insert( "GUEST_BYTECODE_LEN_PLACEHOLDER".to_string(), bytecode_to_prove.instructions.len().to_string(), @@ -176,7 +177,7 @@ def main(): lookup_ef_indexes_str.push(format!("[{}]", this_look_ef_indexes_str.join(", "))); num_cols_f_air.push(table.n_columns_f_air().to_string()); num_cols_ef_air.push(table.n_columns_ef_air().to_string()); - num_cols_f_committed.push(table.n_commited_columns_f().to_string()); + num_cols_f_committed.push(table.n_columns_f_air().to_string()); let this_lookup_f_values_str = table .lookups_f() .iter() @@ -238,7 +239,7 @@ def main(): format!("[{}]", num_cols_ef_air.join(", ")), ); replacements.insert( - "NUM_COLS_F_COMMITED_PLACEHOLDER".to_string(), + "NUM_COLS_F_COMMITTED_PLACEHOLDER".to_string(), format!("[{}]", num_cols_f_committed.join(", ")), ); replacements.insert( @@ -282,50 +283,65 @@ def main(): all_air_evals_in_zk_dsl(), ); replacements.insert( - "NUM_BYTECODE_INSTRUCTIONS_PLACEHOLDER".to_string(), + "N_INSTRUCTION_COLUMNS_PLACEHOLDER".to_string(), N_INSTRUCTION_COLUMNS.to_string(), ); replacements.insert( "N_COMMITTED_EXEC_COLUMNS_PLACEHOLDER".to_string(), - N_COMMITTED_EXEC_COLUMNS.to_string(), + N_RUNTIME_COLUMNS.to_string(), ); replacements.insert( - "TOTAL_WHIR_STATEMENTS_BASE_PLACEHOLDER".to_string(), - verif_details.total_whir_statements_base.to_string(), + "TOTAL_WHIR_STATEMENTS_PLACEHOLDER".to_string(), + verif_details.total_whir_statements.to_string(), ); replacements.insert("STARTING_PC_PLACEHOLDER".to_string(), STARTING_PC.to_string()); replacements.insert("ENDING_PC_PLACEHOLDER".to_string(), ENDING_PC.to_string()); - let inner_public_input = vec![]; - let outer_public_memory = build_public_memory(&outer_public_input); - let mut inner_private_input = vec![ + let mut outer_public_input = vec![F::from_usize(verif_details.bytecode_evaluation.point.num_variables())]; + outer_public_input.extend( + verif_details + .bytecode_evaluation + .point + .0 + .iter() + .flat_map(|c| c.as_basis_coefficients_slice()), + ); + outer_public_input.extend(verif_details.bytecode_evaluation.value.as_basis_coefficients_slice()); + let outer_private_input_start = + (NONRESERVED_PROGRAM_INPUT_START + 1 + outer_public_input.len()).next_power_of_two(); + outer_public_input.insert(0, F::from_usize(outer_private_input_start)); + let inner_public_memory = build_public_memory(&inner_public_input); + let mut outer_private_input = vec![ F::from_usize(proof_to_prove.proof.len()), - F::from_usize(log2_strict_usize(outer_public_memory.len())), + F::from_usize(log2_strict_usize(inner_public_memory.len())), F::from_usize(count), ]; - inner_private_input.extend(outer_public_memory); + outer_private_input.extend(inner_public_memory); for _ in 0..count { - inner_private_input.extend(proof_to_prove.proof.to_vec()); + outer_private_input.extend(proof_to_prove.proof.to_vec()); } let recursion_bytecode = compile_program_with_flags(&ProgramSource::Filepath(filepath), CompilationFlags { replacements }); + if tracing { + utils::init_tracing(); + } let time = Instant::now(); let recursion_proof = prove_execution( &recursion_bytecode, - (&inner_public_input, &inner_private_input), + (&outer_public_input, &outer_private_input), &vec![], // TODO precompute poseidons - &Default::default(), + &default_whir_config(LOG_INV_RATE), false, ); let proving_time = time.elapsed(); verify_execution( &recursion_bytecode, - &inner_public_input, + &outer_public_input, recursion_proof.proof, - &Default::default(), + &default_whir_config(LOG_INV_RATE), ) .unwrap(); println!( @@ -345,10 +361,7 @@ def main(): ); } -pub(crate) fn whir_recursion_placeholder_replacements( - whir_config: &WhirConfig, - base: bool, -) -> BTreeMap { +pub(crate) fn whir_recursion_placeholder_replacements(whir_config: &WhirConfig) -> BTreeMap { let mut num_queries = vec![]; let mut ood_samples = vec![]; let mut grinding_bits = vec![]; @@ -366,37 +379,40 @@ pub(crate) fn whir_recursion_placeholder_replacements( grinding_bits.push(whir_config.final_pow_bits.to_string()); num_queries.push(whir_config.final_queries.to_string()); - let end = if base { "_BASE_PLACEHOLDER" } else { "_EXT_PLACEHOLDER" }; + let end = "_PLACEHOLDER"; let mut replacements = BTreeMap::new(); replacements.insert( - format!("MERKLE_HEIGHTS{}", end), + format!("WHIR_MERKLE_HEIGHTS{}", end), format!("[{}]", merkle_heights.join(", ")), ); - replacements.insert(format!("NUM_QUERIES{}", end), format!("[{}]", num_queries.join(", "))); replacements.insert( - format!("NUM_OOD_COMMIT{}", end), + format!("WHIR_NUM_QUERIES{}", end), + format!("[{}]", num_queries.join(", ")), + ); + replacements.insert( + format!("WHIR_NUM_OOD_COMMIT{}", end), whir_config.committment_ood_samples.to_string(), ); - replacements.insert(format!("NUM_OODS{}", end), format!("[{}]", ood_samples.join(", "))); + replacements.insert(format!("WHIR_NUM_OODS{}", end), format!("[{}]", ood_samples.join(", "))); replacements.insert( - format!("GRINDING_BITS{}", end), + format!("WHIR_GRINDING_BITS{}", end), format!("[{}]", grinding_bits.join(", ")), ); replacements.insert( - format!("FOLDING_FACTORS{}", end), + format!("WHIR_FOLDING_FACTORS{}", end), format!("[{}]", folding_factors.join(", ")), ); - replacements.insert(format!("N_VARS{}", end), whir_config.num_variables.to_string()); + replacements.insert(format!("WHIR_N_VARS{}", end), whir_config.num_variables.to_string()); replacements.insert( - format!("LOG_INV_RATE{}", end), + format!("WHIR_LOG_INV_RATE{}", end), whir_config.starting_log_inv_rate.to_string(), ); replacements.insert( - format!("FINAL_VARS{}", end), + format!("WHIR_FINAL_VARS{}", end), whir_config.n_vars_of_final_polynomial().to_string(), ); replacements.insert( - format!("FIRST_RS_REDUCTION_FACTOR{}", end), + format!("WHIR_FIRST_RS_REDUCTION_FACTOR{}", end), whir_config.rs_domain_initial_reduction_factor.to_string(), ); replacements @@ -421,8 +437,8 @@ where let mut cache: HashMap<*const (), String> = HashMap::new(); let mut res = format!( - "def evaluate_air_constraints_table_{}({}, air_alpha_powers, bus_beta, bus_alpha_powers):\n", - table.table().index() - 1, + "def evaluate_air_constraints_table_{}({}, air_alpha_powers, bus_beta, logup_alphas_eq_poly):\n", + table.table().index(), AIR_INNER_VALUES_VAR ); @@ -448,10 +464,14 @@ where res += &format!("\n copy_5({}, buff + DIM * {})", data_str, i); } res += &format!( - "\n bus_res: Mut = dot_product_ret(buff, bus_alpha_powers + DIM, {}, EE)", + "\n bus_res: Mut = dot_product_ret(buff, logup_alphas_eq_poly, {}, EE)", bus_data.len() ); - res += &format!("\n bus_res = add_extension_ret({}, bus_res)", table_index); + res += &format!( + "\n bus_res = add_extension_ret(mul_extension_ret({}, logup_alphas_eq_poly + {} * DIM), bus_res)", + table_index, + max_bus_width().next_power_of_two() - 1 + ); res += "\n bus_res = mul_extension_ret(bus_res, bus_beta)"; res += &format!("\n sum: Mut = add_extension_ret(bus_res, {})", flag); diff --git a/crates/rec_aggregation/src/xmss_aggregate.rs b/crates/rec_aggregation/src/xmss_aggregate.rs index c26070d0..c61b2a08 100644 --- a/crates/rec_aggregation/src/xmss_aggregate.rs +++ b/crates/rec_aggregation/src/xmss_aggregate.rs @@ -1,5 +1,5 @@ use lean_compiler::*; -use lean_prover::{SnarkParams, prove_execution::prove_execution, verify_execution::verify_execution}; +use lean_prover::{default_whir_config, prove_execution::prove_execution, verify_execution::verify_execution}; use lean_vm::*; use multilinear_toolkit::prelude::*; use rand::{Rng, SeedableRng, rngs::StdRng}; @@ -14,6 +14,7 @@ use xmss::{ }; static XMSS_AGGREGATION_PROGRAM: OnceLock = OnceLock::new(); +const LOG_INV_RATE: usize = 1; fn get_xmss_aggregation_program() -> &'static Bytecode { XMSS_AGGREGATION_PROGRAM.get_or_init(compile_xmss_aggregation_program) @@ -40,6 +41,8 @@ fn build_public_input(xmss_pub_keys: &[XmssPublicKey], message_hash: [F; 8], slo public_input.push(acc); acc += F::from_usize((1 + V + pk.log_lifetime) * DIGEST_LEN); // size of the signature } + let private_input_start = (NONRESERVED_PROGRAM_INPUT_START + 1 + public_input.len()).next_power_of_two(); + public_input.insert(0, F::from_usize(private_input_start)); public_input } @@ -140,7 +143,7 @@ fn xmss_aggregate_signatures_helper( program, (&public_input, &private_input), &poseidons_16_precomputed, - &SnarkParams::default(), + &default_whir_config(LOG_INV_RATE), false, ); @@ -164,7 +167,7 @@ pub fn xmss_verify_aggregated_signatures( let public_input = build_public_input(xmss_pub_keys, message_hash, slot); - verify_execution(program, &public_input, proof, &SnarkParams::default()).map(|_| ()) + verify_execution(program, &public_input, proof, &default_whir_config(LOG_INV_RATE)).map(|_| ()) } #[instrument(skip_all)] diff --git a/crates/rec_aggregation/utils.py b/crates/rec_aggregation/utils.py index 96975a93..9dc7fea6 100644 --- a/crates/rec_aggregation/utils.py +++ b/crates/rec_aggregation/utils.py @@ -21,14 +21,13 @@ def div_ceil_dynamic(a, b: Const): return res +@inline def powers(alpha, n): # alpha: EF # n: F - - res = Array(n * DIM) - set_to_one(res) - for i in range(0, n - 1): - mul_extension(res + i * DIM, alpha, res + (i + 1) * DIM) + assert n < 128 + assert 0 < n + res = match_range(n, range(1, 128), lambda i: powers_const(alpha, i)) return res @@ -43,6 +42,7 @@ def powers_const(alpha, n: Const): return res +@inline def unit_root_pow_dynamic(domain_size, index_bits): # index_bits is a pointer to domain_size bits debug_assert(domain_size < 26) @@ -82,7 +82,8 @@ def poly_eq_extension(point, n: Const): return res + (2**n - 1) * DIM -def poly_eq_base(point, n: Const): +@inline +def poly_eq_base(point, n): # Example: for n = 2: eq(x, y) = [(1 - x)(1 - y), (1 - x)y, x(1 - y), xy] res = Array((2 ** (n + 1) - 1)) @@ -95,18 +96,16 @@ def poly_eq_base(point, n: Const): return res + (2**n - 1) -def pow(a, b): - if b == 0: - return 1 # a^0 = 1 - else: - p = pow(a, b - 1) - return a * p - - def eq_mle_extension(a, b, n): + debug_assert(n < 30) + debug_assert(0 < n) + res = match_range(n, range(1, 30), lambda i: eq_mle_extension_const(a, b, i)) + return res + +def eq_mle_extension_const(a, b, n: Const): buff = Array(n * DIM) - for i in range(0, n): + for i in unroll(0, n): shift = i * DIM ai = a + shift bi = b + shift @@ -117,7 +116,7 @@ def eq_mle_extension(a, b, n): buffi[j] = 2 * ab[j] - ai[j] - bi[j] current_prod: Mut = buff - for i in range(0, n - 1): + for i in unroll(0, n - 1): next_prod = Array(DIM) mul_extension(current_prod, buff + (i + 1) * DIM, next_prod) current_prod = next_prod @@ -125,6 +124,7 @@ def eq_mle_extension(a, b, n): return current_prod +@inline def eq_mle_base_extension(a, b, n): debug_assert(n < 26) debug_assert(0 < n) @@ -154,6 +154,7 @@ def eq_mle_extension_base_const(a, b, n: Const): return prods + (n - 1) * DIM +@inline def expand_from_univariate_base(alpha, n): debug_assert(n < 23) debug_assert(0 < n) @@ -190,6 +191,9 @@ def dot_product_be_dynamic(a, b, res, n): def dot_product_ee_dynamic(a, b, res, n): + if n == 32: + dot_product(a, b, res, 32, EE) + return if n == 16: dot_product(a, b, res, 16, EE) return @@ -200,21 +204,16 @@ def dot_product_ee_dynamic(a, b, res, n): dot_product(a, b, res, 2, EE) return - for i in unroll(0, N_ROUNDS_BASE + 1): - if n == NUM_QUERIES_BASE[i]: - dot_product(a, b, res, NUM_QUERIES_BASE[i], EE) + for i in unroll(0, WHIR_N_ROUNDS + 1): + if n == WHIR_NUM_QUERIES[i]: + dot_product(a, b, res, WHIR_NUM_QUERIES[i], EE) return - if n == NUM_QUERIES_BASE[i] + 1: - dot_product(a, b, res, NUM_QUERIES_BASE[i] + 1, EE) + if n == WHIR_NUM_QUERIES[i] + 1: + dot_product(a, b, res, WHIR_NUM_QUERIES[i] + 1, EE) return - for i in unroll(0, N_ROUNDS_EXT + 1): - if n == NUM_QUERIES_EXT[i]: - dot_product(a, b, res, NUM_QUERIES_EXT[i], EE) - return - if n == NUM_QUERIES_EXT[i] + 1: - dot_product(a, b, res, NUM_QUERIES_EXT[i] + 1, EE) - return - + if n == 8: + dot_product(a, b, res, 8, EE) + return assert False, "dot_product_ee_dynamic called with unsupported n" @@ -231,7 +230,28 @@ def mle_of_01234567_etc(point, n): res = add_extension_ret(b, d) return res +@inline +def checked_less_than(a, b): + res: Imu + hint_less_than(a, b, res) + assert res * (1 - res) == 0 + if res == 1: + assert a < b + else: + assert b <= a + return res + +@inline +def maximum(a, b): + is_a_less_than_b = checked_less_than(a, b) + res: Imu + if is_a_less_than_b == 1: + res = b + else: + res = a + return res +@inline def powers_of_two(n): debug_assert(n < 32) res = match_range(n, range(0, 32), lambda i: 2**i) @@ -308,6 +328,7 @@ def mul_base_extension_ret(a, b): return res +@inline def div_extension_ret(n, d): quotient = Array(DIM) dot_product(d, quotient, n, 1, EE) @@ -390,9 +411,7 @@ def set_to_8_zeros(a): @inline def copy_8(a, b): dot_product(a, ONE_VEC_PTR, b, 1, EE) - assert a[5] == b[5] - assert a[6] == b[6] - assert a[7] == b[7] + dot_product(a + (8 - DIM), ONE_VEC_PTR, b + (8 - DIM), 1, EE) return @@ -404,18 +423,16 @@ def copy_16(a, b): a[15] = b[15] return - +@inline def copy_many_ef(a, b, n): - for i in range(0, n): + for i in unroll(0, n): dot_product(a + i * DIM, ONE_VEC_PTR, b + i * DIM, 1, EE) return @inline def set_to_one(a): - a[0] = 1 - for i in unroll(1, DIM): - a[i] = 0 + dot_product(ONE_VEC_PTR, ONE_VEC_PTR, a, 1, EE) return @@ -431,29 +448,20 @@ def print_vec(a): return -def print_many(a, n): - for i in range(0, n): - print(a[i]) - return - - -def next_multiple_of_8(a: Const): - return a + (8 - (a % 8)) % 8 - - @inline def read_memory(ptr): mem = 0 return mem[ptr] -def univariate_polynomial_eval(coeffs, point, degree: Const): - powers = powers(point, degree + 1) # TODO use a parameter: Const version +@inline +def univariate_polynomial_eval(coeffs, point, degree): + powers = powers_const(point, degree + 1) res = Array(DIM) dot_product(coeffs, powers, res, degree + 1, EE) return res - +@inline def sum_2_ef_fractions(a_num, a_den, b_num, b_den): common_den = mul_extension_ret(a_den, b_den) a_num_mul_b_den = mul_extension_ret(a_num, b_den) @@ -496,18 +504,31 @@ def checked_decompose_bits(a, k): return bits, partial_sum -def checked_decompose_bits_small_value(to_decompose, n_bits): +def checked_decompose_bits_small_value_const(to_decompose, n_bits: Const): bits = Array(n_bits) hint_decompose_bits(to_decompose, bits, n_bits, BIG_ENDIAN) sum: Mut = bits[n_bits - 1] power_of_2: Mut = 1 - for i in range(1, n_bits): + for i in unroll(1, n_bits): power_of_2 *= 2 sum += bits[n_bits - 1 - i] * power_of_2 assert to_decompose == sum return bits +@inline +def checked_decompose_bits_small_value(to_decompose, n_bits): + debug_assert(n_bits < 30) + debug_assert(0 < n_bits) + return match_range( + n_bits, + range(0, 1), + lambda _: 0, + range(1, 30), + lambda i: checked_decompose_bits_small_value_const(to_decompose, i), + ) + + @inline def dot_product_ret(a, b, n, mode): res = Array(DIM) diff --git a/crates/rec_aggregation/whir.py b/crates/rec_aggregation/whir.py index 06f00951..098ca832 100644 --- a/crates/rec_aggregation/whir.py +++ b/crates/rec_aggregation/whir.py @@ -1,39 +1,30 @@ from snark_lib import * from fiat_shamir import * -N_VARS_BASE = N_VARS_BASE_PLACEHOLDER -LOG_INV_RATE_BASE = LOG_INV_RATE_BASE_PLACEHOLDER -FOLDING_FACTORS_BASE = FOLDING_FACTORS_BASE_PLACEHOLDER -FINAL_VARS_BASE = FINAL_VARS_BASE_PLACEHOLDER -FIRST_RS_REDUCTION_FACTOR_BASE = FIRST_RS_REDUCTION_FACTOR_BASE_PLACEHOLDER -NUM_OOD_COMMIT_BASE = NUM_OOD_COMMIT_BASE_PLACEHOLDER -NUM_OODS_BASE = NUM_OODS_BASE_PLACEHOLDER -GRINDING_BITS_BASE = GRINDING_BITS_BASE_PLACEHOLDER - -N_VARS_EXT = N_VARS_EXT_PLACEHOLDER -LOG_INV_RATE_EXT = LOG_INV_RATE_EXT_PLACEHOLDER -FOLDING_FACTORS_EXT = FOLDING_FACTORS_EXT_PLACEHOLDER -FINAL_VARS_EXT = FINAL_VARS_EXT_PLACEHOLDER -FIRST_RS_REDUCTION_FACTOR_EXT = FIRST_RS_REDUCTION_FACTOR_EXT_PLACEHOLDER -NUM_OOD_COMMIT_EXT = NUM_OOD_COMMIT_EXT_PLACEHOLDER -NUM_OODS_EXT = NUM_OODS_EXT_PLACEHOLDER -GRINDING_BITS_EXT = GRINDING_BITS_EXT_PLACEHOLDER - - -def whir_open_base( +WHIR_N_VARS = WHIR_N_VARS_PLACEHOLDER +WHIR_LOG_INV_RATE = WHIR_LOG_INV_RATE_PLACEHOLDER +WHIR_FOLDING_FACTORS = WHIR_FOLDING_FACTORS_PLACEHOLDER +WHIR_FINAL_VARS = WHIR_FINAL_VARS_PLACEHOLDER +WHIR_FIRST_RS_REDUCTION_FACTOR = WHIR_FIRST_RS_REDUCTION_FACTOR_PLACEHOLDER +WHIR_NUM_OOD_COMMIT = WHIR_NUM_OOD_COMMIT_PLACEHOLDER +WHIR_NUM_OODS = WHIR_NUM_OODS_PLACEHOLDER +WHIR_GRINDING_BITS = WHIR_GRINDING_BITS_PLACEHOLDER + + +def whir_open( fs: Mut, root: Mut, ood_points_commit, combination_randomness_powers_0, claimed_sum: Mut, ): - all_folding_randomness = Array(N_ROUNDS_BASE + 2) - all_ood_points = Array(N_ROUNDS_BASE) - all_circle_values = Array(N_ROUNDS_BASE + 1) - all_combination_randomness_powers = Array(N_ROUNDS_BASE) + all_folding_randomness = Array(WHIR_N_ROUNDS + 2) + all_ood_points = Array(WHIR_N_ROUNDS) + all_circle_values = Array(WHIR_N_ROUNDS + 1) + all_combination_randomness_powers = Array(WHIR_N_ROUNDS) - domain_sz: Mut = N_VARS_BASE + LOG_INV_RATE_BASE - for r in unroll(0, N_ROUNDS_BASE): + domain_sz: Mut = WHIR_N_VARS + WHIR_LOG_INV_RATE + for r in unroll(0, WHIR_N_ROUNDS): is_first_round: Imu if r == 0: is_first_round = 1 @@ -50,214 +41,84 @@ def whir_open_base( ) = whir_round( fs, root, - FOLDING_FACTORS_BASE[r], - 2 ** FOLDING_FACTORS_BASE[r], + WHIR_FOLDING_FACTORS[r], + 2 ** WHIR_FOLDING_FACTORS[r], is_first_round, - NUM_QUERIES_BASE[r], + WHIR_NUM_QUERIES[r], domain_sz, claimed_sum, - GRINDING_BITS_BASE[r], - NUM_OODS_BASE[r], + WHIR_GRINDING_BITS[r], + WHIR_NUM_OODS[r], ) if r == 0: - domain_sz -= FIRST_RS_REDUCTION_FACTOR_BASE + domain_sz -= WHIR_FIRST_RS_REDUCTION_FACTOR else: domain_sz -= 1 - fs, all_folding_randomness[N_ROUNDS_BASE], claimed_sum = sumcheck_verify( - fs, FOLDING_FACTORS_BASE[N_ROUNDS_BASE], claimed_sum, 2 + fs, all_folding_randomness[WHIR_N_ROUNDS], claimed_sum = sumcheck_verify( + fs, WHIR_FOLDING_FACTORS[WHIR_N_ROUNDS], claimed_sum, 2 ) - fs, final_coeffcients = fs_receive_ef(fs, 2**FINAL_VARS_BASE) + fs, final_coeffcients = fs_receive_ef(fs, 2**WHIR_FINAL_VARS) - fs, all_circle_values[N_ROUNDS_BASE], final_folds = sample_stir_indexes_and_fold( + fs, all_circle_values[WHIR_N_ROUNDS], final_folds = sample_stir_indexes_and_fold( fs, - NUM_QUERIES_BASE[N_ROUNDS_BASE], + WHIR_NUM_QUERIES[WHIR_N_ROUNDS], 0, - FOLDING_FACTORS_BASE[N_ROUNDS_BASE], - 2 ** FOLDING_FACTORS_BASE[N_ROUNDS_BASE], + WHIR_FOLDING_FACTORS[WHIR_N_ROUNDS], + 2 ** WHIR_FOLDING_FACTORS[WHIR_N_ROUNDS], domain_sz, root, - all_folding_randomness[N_ROUNDS_BASE], - GRINDING_BITS_BASE[N_ROUNDS_BASE], + all_folding_randomness[WHIR_N_ROUNDS], + WHIR_GRINDING_BITS[WHIR_N_ROUNDS], ) - final_circle_values = all_circle_values[N_ROUNDS_BASE] - for i in range(0, NUM_QUERIES_BASE[N_ROUNDS_BASE]): - powers_of_2_rev = expand_from_univariate_base_const(final_circle_values[i], FINAL_VARS_BASE) - poly_eq = poly_eq_base(powers_of_2_rev, FINAL_VARS_BASE) + final_circle_values = all_circle_values[WHIR_N_ROUNDS] + for i in range(0, WHIR_NUM_QUERIES[WHIR_N_ROUNDS]): + powers_of_2_rev = expand_from_univariate_base_const(final_circle_values[i], WHIR_FINAL_VARS) + poly_eq = poly_eq_base(powers_of_2_rev, WHIR_FINAL_VARS) final_pol_evaluated_on_circle = Array(DIM) dot_product( poly_eq, final_coeffcients, final_pol_evaluated_on_circle, - 2**FINAL_VARS_BASE, + 2**WHIR_FINAL_VARS, BE, ) copy_5(final_pol_evaluated_on_circle, final_folds + i * DIM) - fs, all_folding_randomness[N_ROUNDS_BASE + 1], end_sum = sumcheck_verify(fs, FINAL_VARS_BASE, claimed_sum, 2) + fs, all_folding_randomness[WHIR_N_ROUNDS + 1], end_sum = sumcheck_verify(fs, WHIR_FINAL_VARS, claimed_sum, 2) - folding_randomness_global = Array(N_VARS_BASE * DIM) + folding_randomness_global = Array(WHIR_N_VARS * DIM) start: Mut = folding_randomness_global - for i in unroll(0, N_ROUNDS_BASE + 1): - for j in unroll(0, FOLDING_FACTORS_BASE[i]): + for i in unroll(0, WHIR_N_ROUNDS + 1): + for j in unroll(0, WHIR_FOLDING_FACTORS[i]): copy_5(all_folding_randomness[i] + j * DIM, start + j * DIM) - start += FOLDING_FACTORS_BASE[i] * DIM - for j in unroll(0, FINAL_VARS_BASE): - copy_5(all_folding_randomness[N_ROUNDS_BASE + 1] + j * DIM, start + j * DIM) - - all_ood_recovered_evals = Array(NUM_OOD_COMMIT_BASE * DIM) - for i in unroll(0, NUM_OOD_COMMIT_BASE): - expanded_from_univariate = expand_from_univariate_ext(ood_points_commit + i * DIM, N_VARS_BASE) - ood_rec = eq_mle_extension(expanded_from_univariate, folding_randomness_global, N_VARS_BASE) + start += WHIR_FOLDING_FACTORS[i] * DIM + for j in unroll(0, WHIR_FINAL_VARS): + copy_5(all_folding_randomness[WHIR_N_ROUNDS + 1] + j * DIM, start + j * DIM) + + all_ood_recovered_evals = Array(WHIR_NUM_OOD_COMMIT * DIM) + for i in unroll(0, WHIR_NUM_OOD_COMMIT): + expanded_from_univariate = expand_from_univariate_ext(ood_points_commit + i * DIM, WHIR_N_VARS) + ood_rec = eq_mle_extension(expanded_from_univariate, folding_randomness_global, WHIR_N_VARS) copy_5(ood_rec, all_ood_recovered_evals + i * DIM) s: Mut = dot_product_ret( all_ood_recovered_evals, combination_randomness_powers_0, - NUM_OOD_COMMIT_BASE, + WHIR_NUM_OOD_COMMIT, EE, ) - n_vars: Mut = N_VARS_BASE - my_folding_randomness: Mut = folding_randomness_global - for i in unroll(0, N_ROUNDS_BASE): - n_vars -= FOLDING_FACTORS_BASE[i] - my_ood_recovered_evals = Array(NUM_OODS_BASE[i] * DIM) - combination_randomness_powers = all_combination_randomness_powers[i] - my_folding_randomness += FOLDING_FACTORS_BASE[i] * DIM - for j in unroll(0, NUM_OODS_BASE[i]): - expanded_from_univariate = expand_from_univariate_ext(all_ood_points[i] + j * DIM, n_vars) - ood_rec = eq_mle_extension(expanded_from_univariate, my_folding_randomness, n_vars) - copy_5(ood_rec, my_ood_recovered_evals + j * DIM) - summed_ood = Array(DIM) - dot_product_ee_dynamic( - my_ood_recovered_evals, - combination_randomness_powers, - summed_ood, - NUM_OODS_BASE[i], - ) - - s6s = Array((NUM_QUERIES_BASE[i]) * DIM) - circle_value_i = all_circle_values[i] - for j in range(0, NUM_QUERIES_BASE[i]): # unroll ? - expanded_from_univariate = expand_from_univariate_base(circle_value_i[j], n_vars) - temp = eq_mle_base_extension(expanded_from_univariate, my_folding_randomness, n_vars) - copy_5(temp, s6s + j * DIM) - s7 = dot_product_ret( - s6s, - combination_randomness_powers + NUM_OODS_BASE[i] * DIM, - NUM_QUERIES_BASE[i], - EE, - ) - s = add_extension_ret(s, s7) - s = add_extension_ret(summed_ood, s) - poly_eq_final = poly_eq_extension(all_folding_randomness[N_ROUNDS_BASE + 1], FINAL_VARS_BASE) - final_value = dot_product_ret(poly_eq_final, final_coeffcients, 2**FINAL_VARS_BASE, EE) - # copy_5(mul_extension_ret(s, final_value), end_sum); - - return fs, folding_randomness_global, s, final_value, end_sum - - -def whir_open_ext( - fs: Mut, - root: Mut, - ood_points_commit, - combination_randomness_powers_0, - claimed_sum: Mut, -): - all_folding_randomness = Array(N_ROUNDS_EXT + 2) - all_ood_points = Array(N_ROUNDS_EXT) - all_circle_values = Array(N_ROUNDS_EXT + 1) - all_combination_randomness_powers = Array(N_ROUNDS_EXT) - - domain_sz: Mut = N_VARS_EXT + LOG_INV_RATE_EXT - for r in unroll(0, N_ROUNDS_EXT): - ( - fs, - all_folding_randomness[r], - all_ood_points[r], - root, - all_circle_values[r], - all_combination_randomness_powers[r], - claimed_sum, - ) = whir_round( - fs, - root, - FOLDING_FACTORS_EXT[r], - 2 ** FOLDING_FACTORS_EXT[r], - 0, - NUM_QUERIES_EXT[r], - domain_sz, - claimed_sum, - GRINDING_BITS_EXT[r], - NUM_OODS_EXT[r], - ) - if r == 0: - domain_sz -= FIRST_RS_REDUCTION_FACTOR_EXT - else: - domain_sz -= 1 - - fs, all_folding_randomness[N_ROUNDS_EXT], claimed_sum = sumcheck_verify( - fs, FOLDING_FACTORS_EXT[N_ROUNDS_EXT], claimed_sum, 2 - ) - - fs, final_coeffcients = fs_receive_ef(fs, 2**FINAL_VARS_EXT) - - fs, all_circle_values[N_ROUNDS_EXT], final_folds = sample_stir_indexes_and_fold( - fs, - NUM_QUERIES_EXT[N_ROUNDS_EXT], - 0, - FOLDING_FACTORS_EXT[N_ROUNDS_EXT], - 2 ** FOLDING_FACTORS_EXT[N_ROUNDS_EXT], - domain_sz, - root, - all_folding_randomness[N_ROUNDS_EXT], - GRINDING_BITS_EXT[N_ROUNDS_EXT], - ) - - final_circle_values = all_circle_values[N_ROUNDS_EXT] - for i in range(0, NUM_QUERIES_EXT[N_ROUNDS_EXT]): - powers_of_2_rev = expand_from_univariate_base_const(final_circle_values[i], FINAL_VARS_EXT) - poly_eq = poly_eq_base(powers_of_2_rev, FINAL_VARS_EXT) - final_pol_evaluated_on_circle = Array(DIM) - dot_product( - poly_eq, - final_coeffcients, - final_pol_evaluated_on_circle, - 2**FINAL_VARS_EXT, - BE, - ) - copy_5(final_pol_evaluated_on_circle, final_folds + i * DIM) - - fs, all_folding_randomness[N_ROUNDS_EXT + 1], end_sum = sumcheck_verify(fs, FINAL_VARS_EXT, claimed_sum, 2) - - folding_randomness_global = Array(N_VARS_EXT * DIM) - - start: Mut = folding_randomness_global - for i in unroll(0, N_ROUNDS_EXT + 1): - for j in unroll(0, FOLDING_FACTORS_EXT[i]): - copy_5(all_folding_randomness[i] + j * DIM, start + j * DIM) - start += FOLDING_FACTORS_EXT[i] * DIM - for j in unroll(0, FINAL_VARS_EXT): - copy_5(all_folding_randomness[N_ROUNDS_EXT + 1] + j * DIM, start + j * DIM) - - all_ood_recovered_evals = Array(NUM_OOD_COMMIT_EXT * DIM) - for i in unroll(0, NUM_OOD_COMMIT_EXT): - expanded_from_univariate = expand_from_univariate_ext(ood_points_commit + i * DIM, N_VARS_EXT) - ood_rec = eq_mle_extension(expanded_from_univariate, folding_randomness_global, N_VARS_EXT) - copy_5(ood_rec, all_ood_recovered_evals + i * DIM) - s: Mut = dot_product_ret(all_ood_recovered_evals, combination_randomness_powers_0, NUM_OOD_COMMIT_EXT, EE) - - n_vars: Mut = N_VARS_EXT + n_vars: Mut = WHIR_N_VARS my_folding_randomness: Mut = folding_randomness_global - for i in unroll(0, N_ROUNDS_EXT): - n_vars -= FOLDING_FACTORS_EXT[i] - my_ood_recovered_evals = Array(NUM_OODS_EXT[i] * DIM) + for i in unroll(0, WHIR_N_ROUNDS): + n_vars -= WHIR_FOLDING_FACTORS[i] + my_ood_recovered_evals = Array(WHIR_NUM_OODS[i] * DIM) combination_randomness_powers = all_combination_randomness_powers[i] - my_folding_randomness += FOLDING_FACTORS_EXT[i] * DIM - for j in unroll(0, NUM_OODS_EXT[i]): + my_folding_randomness += WHIR_FOLDING_FACTORS[i] * DIM + for j in unroll(0, WHIR_NUM_OODS[i]): expanded_from_univariate = expand_from_univariate_ext(all_ood_points[i] + j * DIM, n_vars) ood_rec = eq_mle_extension(expanded_from_univariate, my_folding_randomness, n_vars) copy_5(ood_rec, my_ood_recovered_evals + j * DIM) @@ -266,25 +127,25 @@ def whir_open_ext( my_ood_recovered_evals, combination_randomness_powers, summed_ood, - NUM_OODS_EXT[i], + WHIR_NUM_OODS[i], ) - s6s = Array((NUM_QUERIES_EXT[i]) * DIM) + s6s = Array((WHIR_NUM_QUERIES[i]) * DIM) circle_value_i = all_circle_values[i] - for j in range(0, NUM_QUERIES_EXT[i]): # unroll ? + for j in range(0, WHIR_NUM_QUERIES[i]): # unroll ? expanded_from_univariate = expand_from_univariate_base(circle_value_i[j], n_vars) temp = eq_mle_base_extension(expanded_from_univariate, my_folding_randomness, n_vars) copy_5(temp, s6s + j * DIM) s7 = dot_product_ret( s6s, - combination_randomness_powers + NUM_OODS_EXT[i] * DIM, - NUM_QUERIES_EXT[i], + combination_randomness_powers + WHIR_NUM_OODS[i] * DIM, + WHIR_NUM_QUERIES[i], EE, ) s = add_extension_ret(s, s7) s = add_extension_ret(summed_ood, s) - poly_eq_final = poly_eq_extension(all_folding_randomness[N_ROUNDS_EXT + 1], FINAL_VARS_EXT) - final_value = dot_product_ret(poly_eq_final, final_coeffcients, 2**FINAL_VARS_EXT, EE) + poly_eq_final = poly_eq_extension(all_folding_randomness[WHIR_N_ROUNDS + 1], WHIR_FINAL_VARS) + final_value = dot_product_ret(poly_eq_final, final_coeffcients, 2**WHIR_FINAL_VARS, EE) # copy_5(mul_extension_ret(s, final_value), end_sum); return fs, folding_randomness_global, s, final_value, end_sum @@ -361,7 +222,7 @@ def sample_stir_indexes_and_fold( if merkle_leaves_in_basefield == 1: for i in range(0, num_queries): - dot_product(answers[i], poly_eq, folds + i * DIM, 2 ** FOLDING_FACTORS_BASE[0], BE) + dot_product(answers[i], poly_eq, folds + i * DIM, 2 ** WHIR_FOLDING_FACTORS[0], BE) else: for i in range(0, num_queries): dot_product_ee_dynamic(answers[i], poly_eq, folds + i * DIM, two_pow_folding_factor) diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index a776230e..86c720c8 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -15,17 +15,15 @@ def main(): - NONRESERVED_PROGRAM_INPUT_START_ = NONRESERVED_PROGRAM_INPUT_START - n_signatures = NONRESERVED_PROGRAM_INPUT_START_[0] - message_hash = NONRESERVED_PROGRAM_INPUT_START + 1 + pub_mem = NONRESERVED_PROGRAM_INPUT_START + signatures_start = pub_mem[0] + n_signatures = pub_mem[1] + message_hash = pub_mem + 2 all_public_keys = message_hash + VECTOR_LEN all_log_lifetimes = all_public_keys + n_signatures * VECTOR_LEN all_merkle_indexes = all_log_lifetimes + n_signatures sig_sizes = all_merkle_indexes + n_signatures * MAX_LOG_LIFETIME - mem = 0 - signatures_start = mem[PRIVATE_INPUT_START_PTR] - for i in range(0, n_signatures): xmss_public_key = all_public_keys + i * VECTOR_LEN signature = signatures_start + sig_sizes[i] diff --git a/crates/sub_protocols/src/generic_logup.rs b/crates/sub_protocols/src/generic_logup.rs index 8f73e22f..84eb3ce0 100644 --- a/crates/sub_protocols/src/generic_logup.rs +++ b/crates/sub_protocols/src/generic_logup.rs @@ -1,86 +1,132 @@ use crate::{prove_gkr_quotient, verify_gkr_quotient}; -use lean_vm::BusDirection; -use lean_vm::BusTable; -use lean_vm::ColIndex; -use lean_vm::DIMENSION; -use lean_vm::EF; -use lean_vm::F; -use lean_vm::Table; -use lean_vm::TableT; -use lean_vm::TableTrace; -use lean_vm::sort_tables_by_height; +use lean_vm::*; use multilinear_toolkit::prelude::*; use std::collections::BTreeMap; -use utils::MEMORY_TABLE_INDEX; -use utils::VarCount; -use utils::VecOrSlice; -use utils::finger_print; -use utils::from_end; -use utils::mle_of_01234567_etc; -use utils::to_big_endian_in_field; -use utils::transpose_slice_to_basis_coefficients; +use tracing::instrument; +use utils::*; #[derive(Debug, PartialEq, Hash, Clone)] pub struct GenericLogupStatements { - pub memory_acc_point: MultilinearPoint, + pub memory_and_acc_point: MultilinearPoint, pub value_memory: EF, - pub value_acc: EF, + pub value_memory_acc: EF, + pub bytecode_and_acc_point: MultilinearPoint, + pub value_bytecode_acc: EF, pub bus_numerators_values: BTreeMap, pub bus_denominators_values: BTreeMap, pub points: BTreeMap>, pub columns_values: BTreeMap>, // Used in recursion - pub total_n_vars: usize, + pub total_gkr_n_vars: usize, + pub bytecode_evaluation: Option>, } #[allow(clippy::too_many_arguments)] +#[instrument(skip_all)] pub fn prove_generic_logup( prover_state: &mut impl FSProver, c: EF, alphas_eq_poly: &[EF], memory: &[F], - acc: &[F], + memory_acc: &[F], + bytecode_multilinear: &[F], + bytecode_acc: &[F], traces: &BTreeMap, ) -> GenericLogupStatements { assert!(memory[0].is_zero()); assert!(memory.len().is_power_of_two()); - assert_eq!(memory.len(), acc.len()); + assert_eq!(memory.len(), memory_acc.len()); assert!(memory.len() >= traces.values().map(|t| 1 << t.log_n_rows).max().unwrap()); + let log_bytecode = log2_strict_usize(bytecode_multilinear.len() / N_INSTRUCTION_COLUMNS.next_power_of_two()); let tables_heights = traces.iter().map(|(table, trace)| (*table, trace.log_n_rows)).collect(); let tables_heights_sorted = sort_tables_by_height(&tables_heights); - let total_n_vars = compute_total_n_vars( + let total_gkr_n_vars = compute_total_gkr_n_vars( log2_strict_usize(memory.len()), + log_bytecode, &tables_heights_sorted.iter().cloned().collect(), ); - let mut numerators = EF::zero_vec(1 << total_n_vars); - let mut denominators = EF::zero_vec(1 << total_n_vars); + let mut numerators = EF::zero_vec(1 << total_gkr_n_vars); + let mut denominators = EF::zero_vec(1 << total_gkr_n_vars); + + let mut offset = 0; // Memory: ... - numerators[..memory.len()] + assert_eq!(memory.len(), memory_acc.len()); + numerators[offset..][..memory.len()] .par_iter_mut() - .zip(acc) // TODO embedding overhead + .zip(memory_acc) // TODO embedding overhead .for_each(|(num, a)| *num = EF::from(-*a)); // Note the negative sign here - denominators[..memory.len()] + denominators[offset..][..memory.len()] .par_iter_mut() .zip(memory.par_iter().enumerate()) .for_each(|(denom, (i, &mem_value))| { *denom = c - finger_print( F::from_usize(MEMORY_TABLE_INDEX), - &[F::from_usize(i), mem_value], + &[mem_value, F::from_usize(i)], alphas_eq_poly, ) }); + offset += memory.len(); + // Bytecode + assert_eq!(1 << log_bytecode, bytecode_acc.len()); + numerators[offset..][..bytecode_acc.len()] + .par_iter_mut() + .zip(bytecode_acc) // TODO embedding overhead + .for_each(|(num, a)| *num = EF::from(-*a)); // Note the negative sign here + denominators[offset..][..1 << log_bytecode] + .par_iter_mut() + .zip( + bytecode_multilinear + .par_chunks_exact(N_INSTRUCTION_COLUMNS.next_power_of_two()) + .enumerate(), + ) + .for_each(|(denom, (i, instr))| { + *denom = c - finger_print( + F::from_usize(BYTECODE_TABLE_INDEX), + &[instr[..N_INSTRUCTION_COLUMNS].to_vec(), vec![F::from_usize(i)]].concat(), + alphas_eq_poly, + ) + }); + let max_table_height = 1 << tables_heights_sorted[0].1; + if 1 << log_bytecode < max_table_height { + // padding + denominators[offset + (1 << log_bytecode)..offset + max_table_height] + .par_iter_mut() + .for_each(|d| *d = EF::ONE); + } + offset += max_table_height.max(1 << log_bytecode); // ... Rest of the tables: - let mut offset = memory.len(); for (table, _) in &tables_heights_sorted { let trace = &traces[table]; let log_n_rows = trace.log_n_rows; - // I] Bus (data flow between tables) + if *table == Table::execution() { + // 0] bytecode lookup + let pc_column = &trace.base[COL_PC]; + let bytecode_columns = trace.base[N_RUNTIME_COLUMNS..][..N_INSTRUCTION_COLUMNS] + .iter() + .collect::>(); + numerators[offset..][..1 << log_n_rows].par_iter_mut().for_each(|num| { + *num = EF::ONE; + }); // TODO embedding overhead + denominators[offset..][..1 << log_n_rows] + .par_iter_mut() + .enumerate() + .for_each(|(i, denom)| { + let mut data = vec![]; + for col in &bytecode_columns { + data.push(col[i]); + } + data.push(pc_column[i]); + *denom = c - finger_print(F::from_usize(BYTECODE_TABLE_INDEX), &data, alphas_eq_poly) + }); + offset += 1 << log_n_rows; + } + // I] Bus (data flow between tables) let bus = table.bus(); numerators[offset..][..1 << log_n_rows] .par_iter_mut() @@ -114,7 +160,6 @@ pub fn prove_generic_logup( offset += 1 << log_n_rows; // II] Lookup into memory - let mut value_columns_f = Vec::>::new(); for cols_f in table.lookup_f_value_columns(trace) { value_columns_f.push(cols_f.iter().map(|s| VecOrSlice::Slice(s)).collect()); @@ -147,7 +192,7 @@ pub fn prove_generic_logup( let index = col_index[j] + i_field; let mem_value = col_values[i].as_slice()[j]; *denom = - c - finger_print(F::from_usize(MEMORY_TABLE_INDEX), &[index, mem_value], alphas_eq_poly) + c - finger_print(F::from_usize(MEMORY_TABLE_INDEX), &[mem_value, index], alphas_eq_poly) }); }); offset += col_values.len() << log_n_rows; @@ -155,7 +200,7 @@ pub fn prove_generic_logup( } } - assert_eq!(log2_ceil_usize(offset), total_n_vars); + assert_eq!(log2_ceil_usize(offset), total_gkr_n_vars); tracing::info!("Logup data: {} = 2^{:.2}", offset, (offset as f64).log2()); denominators[offset..].par_iter_mut().for_each(|d| *d = EF::ONE); // padding @@ -173,28 +218,60 @@ pub fn prove_generic_logup( assert_eq!(sum, EF::ZERO); // Memory: ... - let memory_acc_point = MultilinearPoint(from_end(&claim_point_gkr, log2_strict_usize(memory.len())).to_vec()); - let value_acc = acc.evaluate(&memory_acc_point); - prover_state.add_extension_scalar(value_acc); + let memory_and_acc_point = MultilinearPoint(from_end(&claim_point_gkr, log2_strict_usize(memory.len())).to_vec()); + let value_memory_acc = memory_acc.evaluate(&memory_and_acc_point); + prover_state.add_extension_scalar(value_memory_acc); - let value_memory = memory.evaluate(&memory_acc_point); + let value_memory = memory.evaluate(&memory_and_acc_point); prover_state.add_extension_scalar(value_memory); + let bytecode_and_acc_point = MultilinearPoint(from_end(&claim_point_gkr, log_bytecode).to_vec()); + let value_bytecode_acc = bytecode_acc.evaluate(&bytecode_and_acc_point); + prover_state.add_extension_scalar(value_bytecode_acc); + + // evaluation on bytecode itself can be done directly by the verifier + // ... Rest of the tables: let mut points = BTreeMap::new(); let mut bus_numerators_values = BTreeMap::new(); let mut bus_denominators_values = BTreeMap::new(); let mut columns_values = BTreeMap::new(); - let mut offset = memory.len(); + let mut offset = memory.len() + max_table_height.max(1 << log_bytecode); for (table, _) in &tables_heights_sorted { let trace = &traces[table]; let log_n_rows = trace.log_n_rows; let inner_point = MultilinearPoint(from_end(&claim_point_gkr, log_n_rows).to_vec()); points.insert(*table, inner_point.clone()); + let mut table_values = BTreeMap::::new(); - // I] Bus (data flow between tables) + if table == &Table::execution() { + // 0] bytecode lookup + let pc_column = &trace.base[COL_PC]; + let bytecode_columns = trace.base[N_RUNTIME_COLUMNS..][..N_INSTRUCTION_COLUMNS] + .iter() + .collect::>(); + + let eval_on_pc = pc_column.evaluate(&inner_point); + prover_state.add_extension_scalar(eval_on_pc); + assert!(!table_values.contains_key(&COL_PC)); + table_values.insert(COL_PC, eval_on_pc); + + let instr_evals = bytecode_columns + .iter() + .map(|col| col.evaluate(&inner_point)) + .collect::>(); + prover_state.add_extension_scalars(&instr_evals); + for (i, eval_on_instr_col) in instr_evals.iter().enumerate() { + let global_index = N_RUNTIME_COLUMNS + i; + assert!(!table_values.contains_key(&global_index)); + table_values.insert(global_index, *eval_on_instr_col); + } + offset += 1 << log_n_rows; + } + + // I] Bus (data flow between tables) let eval_on_selector = trace.base[table.bus().selector].evaluate(&inner_point) * table.bus().direction.to_field_flag(); prover_state.add_extension_scalar(eval_on_selector); @@ -206,8 +283,6 @@ pub fn prove_generic_logup( bus_denominators_values.insert(*table, eval_on_data); // II] Lookup into memory - - let mut table_values = BTreeMap::::new(); for lookup_f in table.lookups_f() { let index_eval = trace.base[lookup_f.index].evaluate(&inner_point); prover_state.add_extension_scalar(index_eval); @@ -236,7 +311,7 @@ pub fn prove_generic_logup( { let value_eval = col.evaluate(&inner_point); prover_state.add_extension_scalar(value_eval); - let global_index = table.n_commited_columns_f() + lookup_ef.values * DIMENSION + i; + let global_index = table.n_columns_f_air() + lookup_ef.values * DIMENSION + i; assert!(!table_values.contains_key(&global_index)); table_values.insert(global_index, value_eval); } @@ -248,14 +323,17 @@ pub fn prove_generic_logup( } GenericLogupStatements { - memory_acc_point, + memory_and_acc_point, value_memory, - value_acc, + value_memory_acc, + bytecode_and_acc_point, + value_bytecode_acc, bus_numerators_values, bus_denominators_values, points, columns_values, - total_n_vars, + total_gkr_n_vars, + bytecode_evaluation: None, } } @@ -263,15 +341,21 @@ pub fn prove_generic_logup( pub fn verify_generic_logup( verifier_state: &mut impl FSVerifier, c: EF, + alphas: &[EF], alphas_eq_poly: &[EF], log_memory: usize, + bytecode_multilinear: &[F], table_log_n_rows: &BTreeMap, ) -> ProofResult { let tables_heights_sorted = sort_tables_by_height(table_log_n_rows); + let log_bytecode = log2_strict_usize(bytecode_multilinear.len() / N_INSTRUCTION_COLUMNS.next_power_of_two()); + let total_gkr_n_vars = compute_total_gkr_n_vars( + log_memory, + log_bytecode, + &tables_heights_sorted.iter().cloned().collect(), + ); - let total_n_vars = compute_total_n_vars(log_memory, &tables_heights_sorted.iter().cloned().collect()); - - let (sum, point_gkr, numerators_value, denominators_value) = verify_gkr_quotient(verifier_state, total_n_vars)?; + let (sum, point_gkr, numerators_value, denominators_value) = verify_gkr_quotient(verifier_state, total_gkr_n_vars)?; if sum != EF::ZERO { return Err(ProofError::InvalidProof); @@ -281,38 +365,98 @@ pub fn verify_generic_logup( let mut retrieved_denominators_value = EF::ZERO; // Memory ... - let memory_acc_point = MultilinearPoint(from_end(&point_gkr, log_memory).to_vec()); - let bits = to_big_endian_in_field::(0, total_n_vars - log_memory); + let memory_and_acc_point = MultilinearPoint(from_end(&point_gkr, log_memory).to_vec()); + let bits = to_big_endian_in_field::(0, total_gkr_n_vars - log_memory); let pref = - MultilinearPoint(bits).eq_poly_outside(&MultilinearPoint(point_gkr[..total_n_vars - log_memory].to_vec())); + MultilinearPoint(bits).eq_poly_outside(&MultilinearPoint(point_gkr[..total_gkr_n_vars - log_memory].to_vec())); - let value_acc = verifier_state.next_extension_scalar()?; - retrieved_numerators_value -= pref * value_acc; + let value_memory_acc = verifier_state.next_extension_scalar()?; + retrieved_numerators_value -= pref * value_memory_acc; let value_memory = verifier_state.next_extension_scalar()?; - let value_index = mle_of_01234567_etc(&memory_acc_point); + let value_index = mle_of_01234567_etc(&memory_and_acc_point); retrieved_denominators_value += pref * (c - finger_print( F::from_usize(MEMORY_TABLE_INDEX), - &[value_index, value_memory], + &[value_memory, value_index], alphas_eq_poly, )); + let mut offset = 1 << log_memory; + + // Bytecode + let log_bytecode_padded = log_bytecode.max(tables_heights_sorted[0].1); + let bytecode_and_acc_point = MultilinearPoint(from_end(&point_gkr, log_bytecode).to_vec()); + let bits = to_big_endian_in_field::(offset >> log_bytecode, total_gkr_n_vars - log_bytecode); + let pref = MultilinearPoint(bits) + .eq_poly_outside(&MultilinearPoint(point_gkr[..total_gkr_n_vars - log_bytecode].to_vec())); + let bits_padded = + to_big_endian_in_field::(offset >> log_bytecode_padded, total_gkr_n_vars - log_bytecode_padded); + let pref_padded = MultilinearPoint(bits_padded).eq_poly_outside(&MultilinearPoint( + point_gkr[..total_gkr_n_vars - log_bytecode_padded].to_vec(), + )); + + let value_bytecode_acc = verifier_state.next_extension_scalar()?; + retrieved_numerators_value -= pref * value_bytecode_acc; + + // Bytecode denominator - computed directly by verifier + let bytecode_index_value = mle_of_01234567_etc(&bytecode_and_acc_point); + + let mut bytecode_point = bytecode_and_acc_point.0.clone(); + bytecode_point.extend(from_end(alphas, log2_ceil_usize(N_INSTRUCTION_COLUMNS))); + let bytecode_point = MultilinearPoint(bytecode_point); + let bytecode_value = bytecode_multilinear.evaluate(&bytecode_point); + let bytecode_value_corrected = bytecode_value + * alphas[..alphas.len() - log2_ceil_usize(N_INSTRUCTION_COLUMNS)] + .iter() + .map(|x| EF::ONE - *x) + .product::(); + retrieved_denominators_value += pref + * (c - (bytecode_value_corrected + + bytecode_index_value * alphas_eq_poly[N_INSTRUCTION_COLUMNS] + + *alphas_eq_poly.last().unwrap() * F::from_usize(BYTECODE_TABLE_INDEX))); + // Padding for bytecode + retrieved_denominators_value += + pref_padded * mle_of_zeros_then_ones(1 << log_bytecode, from_end(&point_gkr, log_bytecode_padded)); + offset += 1 << log_bytecode_padded; // ... Rest of the tables: let mut points = BTreeMap::new(); let mut bus_numerators_values = BTreeMap::new(); let mut bus_denominators_values = BTreeMap::new(); let mut columns_values = BTreeMap::new(); - let mut offset = 1 << log_memory; for &(table, log_n_rows) in &tables_heights_sorted { - let n_missing_vars = total_n_vars - log_n_rows; + let n_missing_vars = total_gkr_n_vars - log_n_rows; let inner_point = MultilinearPoint(from_end(&point_gkr, log_n_rows).to_vec()); let missing_point = MultilinearPoint(point_gkr[..n_missing_vars].to_vec()); points.insert(table, inner_point.clone()); + let mut table_values = BTreeMap::::new(); - // I] Bus (data flow between tables) + if table == Table::execution() { + // 0] bytecode lookup + let eval_on_pc = verifier_state.next_extension_scalar()?; + table_values.insert(COL_PC, eval_on_pc); + + let instr_evals = verifier_state.next_extension_scalars_vec(N_INSTRUCTION_COLUMNS)?; + for (i, eval_on_instr_col) in instr_evals.iter().enumerate() { + let global_index = N_RUNTIME_COLUMNS + i; + table_values.insert(global_index, *eval_on_instr_col); + } + let bits = to_big_endian_in_field::(offset >> log_n_rows, n_missing_vars); + let pref = MultilinearPoint(bits).eq_poly_outside(&missing_point); + retrieved_numerators_value += pref; // numerator is 1 + retrieved_denominators_value += pref + * (c - finger_print( + F::from_usize(BYTECODE_TABLE_INDEX), + &[instr_evals, vec![eval_on_pc]].concat(), + alphas_eq_poly, + )); + + offset += 1 << log_n_rows; + } + + // I] Bus (data flow between tables) let eval_on_selector = verifier_state.next_extension_scalar()?; let bits = to_big_endian_in_field::(offset >> log_n_rows, n_missing_vars); @@ -328,8 +472,6 @@ pub fn verify_generic_logup( offset += 1 << log_n_rows; // II] Lookup into memory - - let mut table_values = BTreeMap::::new(); for lookup_f in table.lookups_f() { let index_eval = verifier_state.next_extension_scalar()?; assert!(!table_values.contains_key(&lookup_f.index)); @@ -342,11 +484,11 @@ pub fn verify_generic_logup( let bits = to_big_endian_in_field::(offset >> log_n_rows, n_missing_vars); let pref = MultilinearPoint(bits).eq_poly_outside(&missing_point); - retrieved_numerators_value += pref; + retrieved_numerators_value += pref; // numerator is 1 retrieved_denominators_value += pref * (c - finger_print( F::from_usize(MEMORY_TABLE_INDEX), - &[index_eval + F::from_usize(i), value_eval], + &[value_eval, index_eval + F::from_usize(i)], alphas_eq_poly, )); offset += 1 << log_n_rows; @@ -367,10 +509,10 @@ pub fn verify_generic_logup( retrieved_denominators_value += pref * (c - finger_print( F::from_usize(MEMORY_TABLE_INDEX), - &[index_eval + F::from_usize(i), value_eval], + &[value_eval, index_eval + F::from_usize(i)], alphas_eq_poly, )); - let global_index = table.n_commited_columns_f() + lookup_ef.values * DIMENSION + i; + let global_index = table.n_columns_f_air() + lookup_ef.values * DIMENSION + i; assert!(!table_values.contains_key(&global_index)); table_values.insert(global_index, value_eval); offset += 1 << log_n_rows; @@ -381,21 +523,24 @@ pub fn verify_generic_logup( retrieved_denominators_value += mle_of_zeros_then_ones(offset, &point_gkr); // to compensate for the final padding: XYZ111111...1 if retrieved_numerators_value != numerators_value { - return Err(ProofError::InvalidProof); + panic!() } if retrieved_denominators_value != denominators_value { - return Err(ProofError::InvalidProof); + panic!() } Ok(GenericLogupStatements { - memory_acc_point, + memory_and_acc_point, value_memory, - value_acc, + value_memory_acc, + bytecode_and_acc_point, + value_bytecode_acc, bus_numerators_values, bus_denominators_values, points, columns_values, - total_n_vars, + total_gkr_n_vars, + bytecode_evaluation: Some(Evaluation::new(bytecode_point, bytecode_value)), }) } @@ -405,8 +550,14 @@ fn offset_for_table(table: &Table, log_n_rows: usize) -> usize { num_cols << log_n_rows } -fn compute_total_n_vars(log_memory: usize, tables_heights: &BTreeMap) -> usize { +fn compute_total_gkr_n_vars( + log_memory: usize, + log_bytecode: usize, + tables_heights: &BTreeMap, +) -> usize { + let max_table_height = 1 << tables_heights.values().copied().max().unwrap(); let total_len = (1 << log_memory) + + (1 << log_bytecode).max(max_table_height) + (1 << tables_heights[&Table::execution()]) // bytecode + tables_heights .iter() .map(|(table, log_n_rows)| offset_for_table(table, *log_n_rows)) diff --git a/crates/sub_protocols/src/lib.rs b/crates/sub_protocols/src/lib.rs index 32ffb821..3f43e4b7 100644 --- a/crates/sub_protocols/src/lib.rs +++ b/crates/sub_protocols/src/lib.rs @@ -7,7 +7,4 @@ pub use packed_pcs::*; mod quotient_gkr; pub use quotient_gkr::*; -mod logup_star; -pub use logup_star::*; - pub(crate) const MIN_VARS_FOR_PACKING: usize = 8; diff --git a/crates/sub_protocols/src/logup_star.rs b/crates/sub_protocols/src/logup_star.rs deleted file mode 100644 index cbf34e19..00000000 --- a/crates/sub_protocols/src/logup_star.rs +++ /dev/null @@ -1,285 +0,0 @@ -/* -Logup* (Lev Soukhanov) - -https://eprint.iacr.org/2025/946.pdf - -*/ - -use multilinear_toolkit::prelude::*; -use utils::{ToUsize, mle_of_01234567_etc}; - -use tracing::{info_span, instrument}; - -use crate::{ - MIN_VARS_FOR_PACKING, - quotient_gkr::{prove_gkr_quotient, verify_gkr_quotient}, -}; - -#[derive(Debug, PartialEq)] -pub struct LogupStarStatements { - pub on_indexes: Evaluation, - pub on_table: Evaluation, - pub on_pushforward: Vec>, -} - -#[instrument(skip_all)] -pub fn prove_logup_star( - prover_state: &mut impl FSProver, - table: &MleRef<'_, EF>, - indexes: &[PF], - claimed_value: EF, - poly_eq_point: &[EF], - pushforward: &MleRef<'_, EF>, // already commited - max_index: Option, -) -> LogupStarStatements -where - EF: ExtensionField>, - PF: PrimeField64, -{ - let table_length = table.unpacked_len(); - let indexes_length = indexes.len(); - let packing = log2_strict_usize(table_length) >= MIN_VARS_FOR_PACKING - && log2_strict_usize(indexes_length) >= MIN_VARS_FOR_PACKING; - let mut max_index = max_index.unwrap_or(table_length); - if packing { - max_index = max_index.div_ceil(packing_width::()); - } - // TODO use max_index - let _ = max_index; - - let (poly_eq_point_packed, pushforward_packed, table_packed) = info_span!("packing").in_scope(|| { - ( - MleRef::Extension(poly_eq_point).pack_if(packing), - pushforward.pack_if(packing), - table.pack_if(packing), - ) - }); - - let (sc_point, inner_evals, prod) = - info_span!("logup_star sumcheck", table_length, indexes_length).in_scope(|| { - let (sc_point, prod, table_folded, pushforward_folded) = run_product_sumcheck( - &table_packed.by_ref(), - &pushforward_packed.by_ref(), - prover_state, - claimed_value, - table.n_vars(), - ); - let inner_evals = vec![ - table_folded.as_extension().unwrap()[0], - pushforward_folded.as_extension().unwrap()[0], - ]; - (sc_point, inner_evals, prod) - }); - - let table_eval = inner_evals[0]; - prover_state.add_extension_scalar(table_eval); - // delayed opening - let on_table = Evaluation::new(sc_point.clone(), table_eval); - - let pushforwardt_eval = inner_evals[1]; - prover_state.add_extension_scalar(pushforwardt_eval); - // delayed opening - let mut on_pushforward = vec![Evaluation::new(sc_point, pushforwardt_eval)]; - - // sanity check - assert_eq!(prod, table_eval * pushforwardt_eval); - - let c = prover_state.sample(); - - let c_minus_indexes = indexes - .par_iter() - .map(|i| c - PF::::from_usize(i.to_usize())) - .collect::>(); - let c_minus_indexes_packed = MleRef::Extension(&c_minus_indexes).pack_if(packing); - - let (_, claim_point_left, _, eval_c_minus_indexes) = prove_gkr_quotient( - prover_state, - &poly_eq_point_packed.by_ref(), - &c_minus_indexes_packed.by_ref(), - ); - - let c_minus_increments = MleRef::Extension( - &(0..table.unpacked_len()) - .into_par_iter() - .map(|i| c - PF::::from_usize(i)) - .collect::>(), - ); - let c_minus_increments_packed = c_minus_increments.pack_if(packing); - let (_, claim_point_right, pushforward_final_eval, _) = prove_gkr_quotient( - prover_state, - &pushforward_packed.by_ref(), - &c_minus_increments_packed.by_ref(), - ); - - let on_indexes = Evaluation::new(claim_point_left, c - eval_c_minus_indexes); - - on_pushforward.push(Evaluation::new(claim_point_right, pushforward_final_eval)); - - // These statements remained to be proven - LogupStarStatements { - on_indexes, - on_table, - on_pushforward, - } -} - -pub fn verify_logup_star( - verifier_state: &mut impl FSVerifier, - log_table_len: usize, - log_indexes_len: usize, - claim: Evaluation, -) -> Result, ProofError> -where - EF: ExtensionField>, - PF: PrimeField64, -{ - let (sum, postponed) = sumcheck_verify(verifier_state, log_table_len, 2).map_err(|_| ProofError::InvalidProof)?; - - if sum != claim.value { - return Err(ProofError::InvalidProof); - } - - let table_eval = verifier_state.next_extension_scalar()?; - let pushforward_eval = verifier_state.next_extension_scalar()?; - - let on_table = Evaluation::new(postponed.point.clone(), table_eval); - let mut on_pushforward = vec![Evaluation::new(postponed.point, pushforward_eval)]; - - if table_eval * pushforward_eval != postponed.value { - return Err(ProofError::InvalidProof); - } - - let c = verifier_state.sample(); - - let (quotient_left, claim_point_left, claim_num_left, eval_c_minus_indexes) = - verify_gkr_quotient(verifier_state, log_indexes_len)?; - let (quotient_right, claim_point_right, pushforward_final_eval, claim_den_right) = - verify_gkr_quotient(verifier_state, log_table_len)?; - - if quotient_left != quotient_right { - return Err(ProofError::InvalidProof); - } - - let on_indexes = Evaluation::new(claim_point_left.clone(), c - eval_c_minus_indexes); - if claim_num_left != claim_point_left.eq_poly_outside(&claim.point) { - return Err(ProofError::InvalidProof); - } - - on_pushforward.push(Evaluation::new(claim_point_right.clone(), pushforward_final_eval)); - - if claim_den_right != c - mle_of_01234567_etc(&claim_point_right) { - return Err(ProofError::InvalidProof); - } - - // these statements remained to be verified - Ok(LogupStarStatements { - on_indexes, - on_table, - on_pushforward, - }) -} - -#[instrument(skip_all)] -pub fn compute_pushforward>( - indexes: &[F], - table_length: usize, - poly_eq_point: &[EF], -) -> Vec { - assert_eq!(indexes.len(), poly_eq_point.len()); - // TODO there are a lot of fun optimizations here - let mut pushforward = EF::zero_vec(table_length); - for (index, value) in indexes.iter().zip(poly_eq_point) { - let index_usize = index.to_usize(); - pushforward[index_usize] += *value; - } - pushforward -} - -#[cfg(test)] -mod tests { - use super::*; - use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; - use rand::{Rng, SeedableRng, rngs::StdRng}; - use utils::{build_prover_state, build_verifier_state, init_tracing}; - - type F = KoalaBear; - type EF = QuinticExtensionFieldKB; - - #[test] - fn test_logup_star() { - for log_table_len in [3, 10] { - for log_indexes_len in 3..10 { - test_logup_star_helper(log_table_len, log_indexes_len); - } - } - - test_logup_star_helper(12, 14); - } - - fn test_logup_star_helper(log_table_len: usize, log_indexes_len: usize) { - init_tracing(); - - let table_length = 1 << log_table_len; - - let indexes_len = 1 << log_indexes_len; - - let mut rng = StdRng::seed_from_u64(0); - - let table = (0..table_length).map(|_| rng.random()).collect::>(); - - let mut indexes = vec![]; - let mut values = vec![]; - let max_index = table_length * 3 / 4; - for _ in 0..indexes_len { - let index = rng.random_range(0..max_index); - indexes.push(F::from_usize(index)); - values.push(table[index]); - } - - // Commit to the table - let commited_table = table.clone(); // Phony commitment for the example - // commit to the indexes - let commited_indexes = indexes.clone(); // Phony commitment for the example - - let point = MultilinearPoint((0..log_indexes_len).map(|_| rng.random()).collect::>()); - - let mut prover_state = build_prover_state(); - let eval = values.evaluate(&point); - - let time = std::time::Instant::now(); - let poly_eq_point = info_span!("eval_eq").in_scope(|| eval_eq(&point)); - let pushforward = compute_pushforward(&indexes, table_length, &poly_eq_point); - let claim = Evaluation::new(point, eval); - - let prover_statements = prove_logup_star( - &mut prover_state, - &MleRef::Base(&commited_table), - &commited_indexes, - claim.value, - &poly_eq_point, - &MleRef::Extension(&pushforward), - Some(max_index), - ); - println!("Proving logup_star took {} ms", time.elapsed().as_millis()); - - let last_prover_state = prover_state.state(); - let mut verifier_state = build_verifier_state(prover_state); - let verifier_statements = - verify_logup_star(&mut verifier_state, log_table_len, log_indexes_len, claim).unwrap(); - - assert_eq!(&verifier_statements, &prover_statements); - assert_eq!(last_prover_state, verifier_state.state()); - - assert_eq!( - indexes.evaluate(&verifier_statements.on_indexes.point), - verifier_statements.on_indexes.value - ); - assert_eq!( - table.evaluate(&verifier_statements.on_table.point), - verifier_statements.on_table.value - ); - for eval in &verifier_statements.on_pushforward { - assert_eq!(pushforward.evaluate(&eval.point), eval.value); - } - } -} diff --git a/crates/sub_protocols/src/packed_pcs.rs b/crates/sub_protocols/src/packed_pcs.rs index 68c7c47b..80e6af87 100644 --- a/crates/sub_protocols/src/packed_pcs.rs +++ b/crates/sub_protocols/src/packed_pcs.rs @@ -17,7 +17,8 @@ pub struct MultiCommitmentWitness { pub fn packed_pcs_global_statements( packed_n_vars: usize, memory_n_vars: usize, - memory_acc_statements: Vec>, + bytecode_n_vars: usize, + previous_statements: Vec>, tables_heights: &BTreeMap, committed_statements: &CommittedStatements, ) -> Vec> { @@ -25,8 +26,11 @@ pub fn packed_pcs_global_statements( let tables_heights_sorted = sort_tables_by_height(tables_heights); - let mut global_statements = memory_acc_statements; - let mut offset = 2 << memory_n_vars; + let mut global_statements = previous_statements; + let mut offset = 2 << memory_n_vars; // memory + memory_acc + + let max_table_n_vars = tables_heights_sorted[0].1; + offset += 1 << bytecode_n_vars.max(max_table_n_vars); // bytecode acc for (table, n_vars) in tables_heights_sorted { if table.is_execution_table() { @@ -52,7 +56,7 @@ pub fn packed_pcs_global_statements( .collect(), )); } - offset += table.n_commited_columns() << n_vars; + offset += table.n_committed_columns() << n_vars; } global_statements } @@ -62,34 +66,41 @@ pub fn packed_pcs_commit( prover_state: &mut impl FSProver, whir_config_builder: &WhirConfigBuilder, memory: &[F], - acc: &[F], + memory_acc: &[F], + bytecode_acc: &[F], traces: &BTreeMap, ) -> MultiCommitmentWitness { - assert_eq!(memory.len(), acc.len()); + assert_eq!(memory.len(), memory_acc.len()); let tables_heights = traces.iter().map(|(table, trace)| (*table, trace.log_n_rows)).collect(); let tables_heights_sorted = sort_tables_by_height(&tables_heights); - + assert!(memory.len() >= 1 << tables_heights_sorted.last().unwrap().1); // memory must be at least as large as the largest table (TODO add some padding at execution when this is not the case) let packed_n_vars = compute_total_n_vars( log2_strict_usize(memory.len()), + log2_strict_usize(bytecode_acc.len()), &tables_heights_sorted.iter().cloned().collect(), ); let mut packed_polynomial = F::zero_vec(1 << packed_n_vars); // TODO avoid cloning all witness data packed_polynomial[..memory.len()].copy_from_slice(memory); let mut offset = memory.len(); - packed_polynomial[offset..offset + acc.len()].copy_from_slice(acc); - offset += acc.len(); + packed_polynomial[offset..][..memory_acc.len()].copy_from_slice(memory_acc); + offset += memory_acc.len(); + + packed_polynomial[offset..][..bytecode_acc.len()].copy_from_slice(bytecode_acc); + let largest_table_height = 1 << tables_heights_sorted[0].1; + offset += largest_table_height.max(bytecode_acc.len()); // we may pad bytecode_acc to match largest table height + for (table, log_n_rows) in &tables_heights_sorted { let n_rows = 1 << *log_n_rows; - for col_index_f in 0..table.n_commited_columns_f() { + for col_index_f in 0..table.n_columns_f_air() { let col = &traces[table].base[col_index_f]; - packed_polynomial[offset..offset + n_rows].copy_from_slice(&col[..n_rows]); + packed_polynomial[offset..][..n_rows].copy_from_slice(&col[..n_rows]); offset += n_rows; } - for col_index_ef in 0..table.n_commited_columns_ef() { + for col_index_ef in 0..table.n_columns_ef_air() { let col = &traces[table].ext[col_index_ef]; let transposed = transpose_slice_to_basis_coefficients(col); for basis_col in transposed { - packed_polynomial[offset..offset + n_rows].copy_from_slice(&basis_col); + packed_polynomial[offset..][..n_rows].copy_from_slice(&basis_col); offset += n_rows; } } @@ -111,17 +122,20 @@ pub fn packed_pcs_parse_commitment( whir_config_builder: &WhirConfigBuilder, verifier_state: &mut impl FSVerifier, log_memory: usize, + log_bytecode: usize, tables_heights: &BTreeMap, ) -> Result, ProofError> { - let packed_n_vars = compute_total_n_vars(log_memory, tables_heights); + let packed_n_vars = compute_total_n_vars(log_memory, log_bytecode, tables_heights); WhirConfig::new(whir_config_builder, packed_n_vars).parse_commitment(verifier_state) } -fn compute_total_n_vars(log_memory: usize, tables_heights: &BTreeMap) -> usize { +fn compute_total_n_vars(log_memory: usize, log_bytecode: usize, tables_heights: &BTreeMap) -> usize { + let max_table_log_n_rows = tables_heights.values().copied().max().unwrap(); let total_len = (2 << log_memory) + + (1 << log_bytecode.max(max_table_log_n_rows)) + tables_heights .iter() - .map(|(table, log_n_rows)| table.n_commited_columns() << log_n_rows) + .map(|(table, log_n_rows)| table.n_committed_columns() << log_n_rows) .sum::(); log2_ceil_usize(total_len) } diff --git a/crates/utils/src/multilinear.rs b/crates/utils/src/multilinear.rs index fd516733..b0959c5b 100644 --- a/crates/utils/src/multilinear.rs +++ b/crates/utils/src/multilinear.rs @@ -116,16 +116,19 @@ pub fn mle_of_01234567_etc(point: &[F]) -> F { } } -/// table = 0 is reversed for memory -pub const MEMORY_TABLE_INDEX: usize = 0; +/// table = 3 is reversed for memory lookup +pub const MEMORY_TABLE_INDEX: usize = 3; +/// table = 4 is reversed for bytecode lookup +pub const BYTECODE_TABLE_INDEX: usize = 4; pub fn finger_print>, EF: ExtensionField + ExtensionField>( table: F, data: &[IF], - alpha_powers: &[EF], + alphas_eq_poly: &[EF], ) -> EF { - assert!(alpha_powers.len() > data.len()); - dot_product::(alpha_powers[1..].iter().copied(), data.iter().copied()) + table + assert!(alphas_eq_poly.len() > data.len()); + dot_product::(alphas_eq_poly.iter().copied(), data.iter().copied()) + + *alphas_eq_poly.last().unwrap() * table } #[cfg(test)] diff --git a/minimal_zkVM.pdf b/minimal_zkVM.pdf index 60d3102f..0756e1b7 100644 Binary files a/minimal_zkVM.pdf and b/minimal_zkVM.pdf differ diff --git a/misc/minimal_zkVM.tex b/misc/minimal_zkVM.tex index ac80d6d9..44b8a064 100644 --- a/misc/minimal_zkVM.tex +++ b/misc/minimal_zkVM.tex @@ -45,7 +45,7 @@ \newtheorem{lemma}{Lemma} -\title{Minimal zkVM for Lean Ethereum (draft 0.5.0)} +\title{Minimal zkVM for Lean Ethereum (draft 0.6.0)} \date{} \begin{document} @@ -335,15 +335,10 @@ \section{Proving system} \subsection{Execution table} -\subsubsection{Reduced commitment via logup*} - -In Cairo each instruction is encoded with 15 boolean flags, and 3 offsets. In the execution trace, this leads to committing to 18 field elements at each instruction. - -We can significantly reduce the commitments cost using logup*\cite{logup_star}. In the the execution table, we only need to commit to the pc column, and all the flags / offsets describing the current instruction can be fetched by an indexed lookup argument (for which logup* drastically reduces commitment costs). \subsubsection{Commitment} -\fbox{At each cycle, we commit to 8 (base) field elements:} +\fbox{At each cycle, we commit to 20 (base) field elements:} \begin{itemize} \item pc (program counter) @@ -351,18 +346,19 @@ \subsubsection{Commitment} % \item jump (non zero when a jump occurs) \item $\text{addr}_A$, $\text{addr}_B$, $\text{addr}_C$ \item $\text{value}_A = \textbf{m}[\text{addr}_A]$, $\text{value}_B = \textbf{m}[\text{addr}_B]$, $\text{value}_C = \textbf{m}[\text{addr}_C]$ + \item 12 field elements describing the instruction being executed (see below) \end{itemize} \subsubsection{Instruction Encoding} -Each instruction is described by 14 field elements: +Each instruction is described by 12 field elements: \begin{itemize} \item 3 operands ($\in \Fp$): $\text{operand}_A$, $\text{operand}_B$, $\text{operand}_C$ \item 3 associated flags ($\in \{0, 1\}$): $\text{flag}_A$, $\text{flag}_B$, $\text{flag}_C$ - \item 6 opcode flags ($\in \{0, 1\}$): ADD, MUL, DEREF, JUMP, IS\_PRECOMPILE, PRECOMPILE\_INDEX - \item 2 multi-purpose operands: AUX\_1, AUX\_2 + \item 5 opcode flags ($\in \{0, 1\}$): ADD, MUL, DEREF, JUMP, PRECOMPILE\_INDEX + \item 1 multi-purpose operand: AUX \end{itemize} @@ -525,7 +521,7 @@ \subsubsection{Buses: Data flow between tables} A detailled soundness analysis can be found in \href{https://github.com/openvm-org/stark-backend/blob/main/docs/Soundness_of_Interactions_via_LogUp.pdf}{Soundness of Interactions via LogUp}. -\section{Annex: simple packing of multilinear polynomials} +\section{Annex: simple stacking of multilinear polynomials} \textit{Note 1}: It's always possible to reduce $n$ claims about a multilinear polynomial to a single one, using sumcheck. But this trick is not necessary with WHIR, which natively supports an arbitrary number of claims about the committed polynomial. \vspace{3mm} @@ -535,7 +531,7 @@ \section{Annex: simple packing of multilinear polynomials} \vspace{3mm} % One of the advantage of multilinear polynomials versus univariate polynomials is the ability to efficienty commit to multiple polynomials at once. -In order to commit to multiple univariate polynomials with FRI, each polynomial must be FFT-ed + Merkle-commited. +In order to commit to multiple univariate polynomials with FRI, each polynomial must be FFT-ed + Merkle-committed. Even if it's possible to have some batching at the Merkle tree level (see 'MMCS' in \href{https://github.com/Plonky3/Plonky3}{Plonky3}), the proof size for multiple, complex AIR tables quickly reach the megabyte scale. With a multilinear PCS (such as WHIR), we can "concatenate" multiple multilinear polynomials into a single one, and commit to it once (offering significant proof size savings). @@ -590,8 +586,8 @@ \section{Annex: simple packing of multilinear polynomials} \node[above] at (20,1.6) {$P({\mathbf{1}}, {\mathbf{0}}, x_1, x_2, x_3)$}; \node[above] at (26,1.6) {$P({\mathbf{1}}, {\mathbf{1}}, {\mathbf{0}}, x_1, x_2)$}; \end{tikzpicture} -\caption{Simple packing of $P_1, P_2, P_3$ into a single polynomial $P$} -\label{fig:packing} +\caption{Simple stacking of $P_1, P_2, P_3$ into a single polynomial $P$} +\label{fig:stacking} \end{figure} Advantage of this approach: simplicity. @@ -600,7 +596,7 @@ \section{Annex: simple packing of multilinear polynomials} \vspace{5mm} -There are alterntive ways to handle the packing of multiple multilinear polynomials: +There are alterntive ways to handle the stacking of multiple multilinear polynomials: \begin{itemize} \item \textbf{Jagged PCS} \cite{jagged_pcs}: No padding overhead, at the cost of an additional sumcheck. diff --git a/src/prove_poseidons.rs b/src/prove_poseidons.rs index ff4e582e..ec863b76 100644 --- a/src/prove_poseidons.rs +++ b/src/prove_poseidons.rs @@ -52,8 +52,6 @@ pub fn benchmark_prove_poseidon_16(log_n_rows: usize, tracing: bool) { &ExtraDataForBuses::default(), &collect_refs(&trace), &[] as &[&[EF]], - &[], - &[], ) .unwrap();