diff --git a/Cargo.lock b/Cargo.lock index 2049508d..eb321d88 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,7 +26,7 @@ dependencies = [ [[package]] name = "air" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#bd8a91199f45fcb3c21af056468340e2f9de2d1d" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#576b88463fd09b2b421235c182285756581d8673" dependencies = [ "p3-field", ] @@ -99,7 +99,7 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "backend" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#bd8a91199f45fcb3c21af056468340e2f9de2d1d" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#576b88463fd09b2b421235c182285756581d8673" dependencies = [ "itertools", "p3-field", @@ -134,9 +134,9 @@ checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" [[package]] name = "clap" -version = "4.5.56" +version = "4.5.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75ca66430e33a14957acc24c5077b503e7d374151b2b4b3a10c83b4ceb4be0e" +checksum = "6899ea499e3fb9305a65d5ebf6e3d2248c5fab291f300ad0a704fbe142eae31a" dependencies = [ "clap_builder", "clap_derive", @@ -144,9 +144,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.56" +version = "4.5.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793207c7fa6300a0608d1080b858e5fdbe713cdc1c8db9fb17777d8a13e63df0" +checksum = "7b12c8b680195a62a8364d16b8447b01b6c2c8f9aaf68bee653be34d4245e238" dependencies = [ "anstream", "anstyle", @@ -190,7 +190,7 @@ dependencies = [ [[package]] name = "constraints-folder" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#bd8a91199f45fcb3c21af056468340e2f9de2d1d" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#576b88463fd09b2b421235c182285756581d8673" dependencies = [ "air 0.3.0", "backend", @@ -266,7 +266,7 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "fiat-shamir" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#bd8a91199f45fcb3c21af056468340e2f9de2d1d" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#576b88463fd09b2b421235c182285756581d8673" dependencies = [ "p3-field", "p3-koala-bear", @@ -468,7 +468,7 @@ checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" [[package]] name = "multilinear-toolkit" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#bd8a91199f45fcb3c21af056468340e2f9de2d1d" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#576b88463fd09b2b421235c182285756581d8673" dependencies = [ "air 0.3.0", "backend", @@ -557,7 +557,7 @@ checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" [[package]] name = "p3-challenger" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1f5aca1d6845caec76fa25b00fae3ea4f07ef930" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#2d4318f6c13d808725d848c483db712a65bac33f" dependencies = [ "p3-field", "p3-maybe-rayon", @@ -569,7 +569,7 @@ dependencies = [ [[package]] name = "p3-commit" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1f5aca1d6845caec76fa25b00fae3ea4f07ef930" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#2d4318f6c13d808725d848c483db712a65bac33f" dependencies = [ "itertools", "p3-challenger", @@ -583,7 +583,7 @@ dependencies = [ [[package]] name = "p3-dft" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1f5aca1d6845caec76fa25b00fae3ea4f07ef930" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#2d4318f6c13d808725d848c483db712a65bac33f" dependencies = [ "itertools", "p3-field", @@ -596,7 +596,7 @@ dependencies = [ [[package]] name = "p3-field" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1f5aca1d6845caec76fa25b00fae3ea4f07ef930" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#2d4318f6c13d808725d848c483db712a65bac33f" dependencies = [ "itertools", "num-bigint", @@ -611,7 +611,7 @@ dependencies = [ [[package]] name = "p3-koala-bear" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1f5aca1d6845caec76fa25b00fae3ea4f07ef930" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#2d4318f6c13d808725d848c483db712a65bac33f" dependencies = [ "itertools", "num-bigint", @@ -627,7 +627,7 @@ dependencies = [ [[package]] name = "p3-matrix" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1f5aca1d6845caec76fa25b00fae3ea4f07ef930" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#2d4318f6c13d808725d848c483db712a65bac33f" dependencies = [ "itertools", "p3-field", @@ -642,7 +642,7 @@ dependencies = [ [[package]] name = "p3-maybe-rayon" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1f5aca1d6845caec76fa25b00fae3ea4f07ef930" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#2d4318f6c13d808725d848c483db712a65bac33f" dependencies = [ "rayon", ] @@ -650,7 +650,7 @@ dependencies = [ [[package]] name = "p3-mds" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1f5aca1d6845caec76fa25b00fae3ea4f07ef930" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#2d4318f6c13d808725d848c483db712a65bac33f" dependencies = [ "p3-dft", "p3-field", @@ -662,7 +662,7 @@ dependencies = [ [[package]] name = "p3-merkle-tree" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1f5aca1d6845caec76fa25b00fae3ea4f07ef930" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#2d4318f6c13d808725d848c483db712a65bac33f" dependencies = [ "itertools", "p3-commit", @@ -679,7 +679,7 @@ dependencies = [ [[package]] name = "p3-monty-31" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1f5aca1d6845caec76fa25b00fae3ea4f07ef930" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#2d4318f6c13d808725d848c483db712a65bac33f" dependencies = [ "itertools", "num-bigint", @@ -701,7 +701,7 @@ dependencies = [ [[package]] name = "p3-poseidon2" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1f5aca1d6845caec76fa25b00fae3ea4f07ef930" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#2d4318f6c13d808725d848c483db712a65bac33f" dependencies = [ "p3-field", "p3-mds", @@ -713,7 +713,7 @@ dependencies = [ [[package]] name = "p3-symmetric" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1f5aca1d6845caec76fa25b00fae3ea4f07ef930" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#2d4318f6c13d808725d848c483db712a65bac33f" dependencies = [ "itertools", "p3-field", @@ -723,7 +723,7 @@ dependencies = [ [[package]] name = "p3-util" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1f5aca1d6845caec76fa25b00fae3ea4f07ef930" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#2d4318f6c13d808725d848c483db712a65bac33f" dependencies = [ "rayon", "serde", @@ -900,9 +900,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.13" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" dependencies = [ "aho-corasick", "memchr", @@ -911,9 +911,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" +checksum = "a96887878f22d7bad8a3b6dc5b7440e0ada9a245242924394987b21cf2210a4c" [[package]] name = "rustversion" @@ -1049,7 +1049,7 @@ dependencies = [ [[package]] name = "sumcheck" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#bd8a91199f45fcb3c21af056468340e2f9de2d1d" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#576b88463fd09b2b421235c182285756581d8673" dependencies = [ "air 0.3.0", "backend", @@ -1277,7 +1277,7 @@ dependencies = [ [[package]] name = "whir" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#bd8a91199f45fcb3c21af056468340e2f9de2d1d" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#576b88463fd09b2b421235c182285756581d8673" dependencies = [ "backend", "fiat-shamir", @@ -1287,6 +1287,7 @@ dependencies = [ "p3-field", "p3-koala-bear", "p3-matrix", + "p3-maybe-rayon", "p3-merkle-tree", "p3-symmetric", "p3-util", @@ -1384,18 +1385,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.37" +version = "0.8.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7456cf00f0685ad319c5b1693f291a650eaf345e941d082fc4e03df8a03996ac" +checksum = "57cf3aa6855b23711ee9852dfc97dfaa51c45feaba5b645d0c777414d494a961" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.37" +version = "0.8.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1328722bbf2115db7e19d69ebcc15e795719e2d66b60827c6a69a117365e37a0" +checksum = "8a616990af1a287837c4fe6596ad77ef57948f787e46ce28e166facc0cc1cb75" dependencies = [ "proc-macro2", "quote", @@ -1404,6 +1405,6 @@ dependencies = [ [[package]] name = "zmij" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1966f8ac2c1f76987d69a74d0e0f929241c10e78136434e3be70ff7f58f64214" +checksum = "3ff05f8caa9038894637571ae6b9e29466c1f4f829d26c9b28f869a29cbe3445" diff --git a/crates/lean_compiler/snark_lib.py b/crates/lean_compiler/snark_lib.py index 1d8c9cbc..d637a043 100644 --- a/crates/lean_compiler/snark_lib.py +++ b/crates/lean_compiler/snark_lib.py @@ -62,8 +62,9 @@ def pop(self): # Built-in constants ZERO_VEC_PTR = 0 -ONE_VEC_PTR = 16 -NONRESERVED_PROGRAM_INPUT_START = 58 +SAMPLING_DOMAIN_SEPARATOR_PTR = 16 +ONE_VEC_PTR = 24 +NONRESERVED_PROGRAM_INPUT_START = 66 def poseidon16(left, right, output, mode): @@ -83,6 +84,10 @@ def log2_ceil(x: int) -> int: return math.ceil(math.log2(x)) +def div_ceil(a: int, b: int) -> int: + return (a + b - 1) // b + + def next_multiple_of(x: int, n: int) -> int: return x + (n - x % n) % n diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs index b4433181..bf1213bc 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/b_compile_intermediate.rs @@ -470,7 +470,7 @@ fn compile_lines( SimpleLine::Precompile { table, args, .. } => { match table { Table::DotProduct(_) => assert_eq!(args.len(), 5), - Table::Poseidon16(_) => assert_eq!(args.len(), 4), + Table::Poseidon16(_) => assert_eq!(args.len(), 3), Table::Execution(_) => unreachable!(), } // if arg_c is constant, create a variable (in memory) to hold it diff --git a/crates/lean_compiler/src/grammar.pest b/crates/lean_compiler/src/grammar.pest index cfd93c3d..e5aa8add 100644 --- a/crates/lean_compiler/src/grammar.pest +++ b/crates/lean_compiler/src/grammar.pest @@ -118,6 +118,7 @@ primary = { lambda_expr | log2_ceil_expr | next_multiple_of_expr | + div_ceil_expr | saturating_sub_expr | len_expr | array_access_expr | @@ -134,6 +135,7 @@ vec_element = { vec_literal | expression } function_call_expr = { identifier ~ "(" ~ tuple_expression? ~ ")" } log2_ceil_expr = { "log2_ceil" ~ "(" ~ expression ~ ")" } next_multiple_of_expr = { "next_multiple_of" ~ "(" ~ expression ~ "," ~ expression ~ ")" } +div_ceil_expr = { "div_ceil" ~ "(" ~ expression ~ "," ~ expression ~ ")" } saturating_sub_expr = { "saturating_sub" ~ "(" ~ expression ~ "," ~ expression ~ ")" } len_expr = { "len" ~ "(" ~ len_argument ~ ")" } len_argument = { identifier ~ ("[" ~ expression ~ "]")* } diff --git a/crates/lean_compiler/src/ir/instruction.rs b/crates/lean_compiler/src/ir/instruction.rs index c6a3ecd3..fc9a7332 100644 --- a/crates/lean_compiler/src/ir/instruction.rs +++ b/crates/lean_compiler/src/ir/instruction.rs @@ -99,11 +99,7 @@ impl IntermediateInstruction { arg_b, res: arg_a, }, - MathOperation::Exp - | MathOperation::Mod - | MathOperation::NextMultipleOf - | MathOperation::SaturatingSub - | MathOperation::Log2Ceil => { + _ => { unreachable!() } } diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index 872e39ed..c83cfbfe 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -331,6 +331,8 @@ pub enum MathOperation { NextMultipleOf, /// saturating subtraction SaturatingSub, + /// Integer division with ceiling + DivCeil, } impl TryFrom for Operation { @@ -357,6 +359,7 @@ impl Display for MathOperation { Self::Log2Ceil => write!(f, "log2_ceil"), Self::NextMultipleOf => write!(f, "next_multiple_of"), Self::SaturatingSub => write!(f, "saturating_sub"), + Self::DivCeil => write!(f, "div_ceil"), } } } @@ -375,7 +378,8 @@ impl MathOperation { | Self::Exp | Self::Mod | Self::NextMultipleOf - | Self::SaturatingSub => 2, + | Self::SaturatingSub + | Self::DivCeil => 2, } } pub fn eval(&self, args: &[F]) -> F { @@ -397,6 +401,7 @@ impl MathOperation { F::from_usize(res) } Self::SaturatingSub => F::from_usize(args[0].to_usize().saturating_sub(args[1].to_usize())), + Self::DivCeil => F::from_usize(args[0].to_usize().div_ceil(args[1].to_usize())), } } } diff --git a/crates/lean_compiler/src/parser/parsers/expression.rs b/crates/lean_compiler/src/parser/parsers/expression.rs index 4535c798..a055379a 100644 --- a/crates/lean_compiler/src/parser/parsers/expression.rs +++ b/crates/lean_compiler/src/parser/parsers/expression.rs @@ -29,6 +29,7 @@ impl Parse for ExpressionParser { Rule::exp_expr => MathOperation::Exp.parse(pair, ctx), Rule::log2_ceil_expr => MathOperation::Log2Ceil.parse(pair, ctx), Rule::next_multiple_of_expr => MathOperation::NextMultipleOf.parse(pair, ctx), + Rule::div_ceil_expr => MathOperation::DivCeil.parse(pair, ctx), Rule::saturating_sub_expr => MathOperation::SaturatingSub.parse(pair, ctx), Rule::var_or_constant => Ok(Expression::Value(VarOrConstantParser.parse(pair, ctx)?)), Rule::array_access_expr => ArrayAccessParser.parse(pair, ctx), diff --git a/crates/lean_compiler/src/parser/parsers/literal.rs b/crates/lean_compiler/src/parser/parsers/literal.rs index 5bcdb0f3..d3178b89 100644 --- a/crates/lean_compiler/src/parser/parsers/literal.rs +++ b/crates/lean_compiler/src/parser/parsers/literal.rs @@ -1,4 +1,6 @@ -use lean_vm::{NONRESERVED_PROGRAM_INPUT_START, ONE_VEC_PTR, PRIVATE_INPUT_START_PTR, ZERO_VEC_PTR}; +use lean_vm::{ + NONRESERVED_PROGRAM_INPUT_START, ONE_VEC_PTR, PRIVATE_INPUT_START_PTR, SAMPLING_DOMAIN_SEPARATOR_PTR, ZERO_VEC_PTR, +}; use multilinear_toolkit::prelude::*; use super::expression::ExpressionParser; @@ -134,6 +136,9 @@ impl VarOrConstantParser { "PRIVATE_INPUT_START_PTR" => Ok(SimpleExpr::Constant(ConstExpression::from(PRIVATE_INPUT_START_PTR))), "ZERO_VEC_PTR" => Ok(SimpleExpr::Constant(ConstExpression::from(ZERO_VEC_PTR))), "ONE_VEC_PTR" => Ok(SimpleExpr::Constant(ConstExpression::from(ONE_VEC_PTR))), + "SAMPLING_DOMAIN_SEPARATOR_PTR" => Ok(SimpleExpr::Constant(ConstExpression::from( + SAMPLING_DOMAIN_SEPARATOR_PTR, + ))), _ => { // Check if it's a const array (error case - can't use array as value) if ctx.get_const_array(text).is_some() { diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 60644060..956a40d2 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -10,15 +10,12 @@ fn test_poseidon() { def main(): a = NONRESERVED_PROGRAM_INPUT_START b = a + 8 - c = Array(2*8) - poseidon16(a, b, c, 0) + c = Array(8) + poseidon16(a, b, c) for i in range(0, 8): cc = c[i] print(cc) - for i in range(0, 8): - dd = c[i+8] - print(dd) return "#; let public_input: [F; 16] = (0..16).map(F::new).collect::>().try_into().unwrap(); diff --git a/crates/lean_compiler/tests/test_data/error_34.py b/crates/lean_compiler/tests/test_data/error_34.py index 6f6acded..3575f3fb 100644 --- a/crates/lean_compiler/tests/test_data/error_34.py +++ b/crates/lean_compiler/tests/test_data/error_34.py @@ -1,9 +1,8 @@ from snark_lib import * + # Error: match_range with non-continuous ranges (gap between 2 and 5) def main(): x = 1 - result = match_range(x, - range(0, 2), lambda i: i * 10, - range(5, 8), lambda i: i * 100) + result = match_range(x, range(0, 2), lambda i: i * 10, range(5, 8), lambda i: i * 100) return diff --git a/crates/lean_compiler/tests/test_data/error_35.py b/crates/lean_compiler/tests/test_data/error_35.py index 07066aea..2eba4756 100644 --- a/crates/lean_compiler/tests/test_data/error_35.py +++ b/crates/lean_compiler/tests/test_data/error_35.py @@ -1,5 +1,6 @@ from snark_lib import * + # Error: match_range results are always immutable, cannot use Mut def main(): x = 1 diff --git a/crates/lean_compiler/tests/test_data/program_171.py b/crates/lean_compiler/tests/test_data/program_171.py index 7c88b222..bb3c47c6 100644 --- a/crates/lean_compiler/tests/test_data/program_171.py +++ b/crates/lean_compiler/tests/test_data/program_171.py @@ -7,6 +7,7 @@ # Simple inline functions with mutable variables # ============================================================================ + @inline def count_up(n): """Count from 0 to n-1, return the sum""" @@ -15,6 +16,7 @@ def count_up(n): acc = acc + 1 return acc + @inline def sum_range(start, end): """Sum integers from start to end-1""" @@ -23,6 +25,7 @@ def sum_range(start, end): total = total + i return total + @inline def double_count(n): """Two mutable variables in same function""" @@ -33,10 +36,12 @@ def double_count(n): b = b - 1 return a + b + # ============================================================================ # Nested inline functions (inline calling inline) # ============================================================================ + @inline def inner_loop(k): """Inner inline function""" @@ -45,6 +50,7 @@ def inner_loop(k): x = x + j return x + @inline def outer_with_inner(n): """Outer inline that calls inner inline""" @@ -53,6 +59,7 @@ def outer_with_inner(n): result = result + inner_loop(i) return result + @inline def deep_nested(a): """Deeply nested: calls outer_with_inner which calls inner_loop""" @@ -60,10 +67,12 @@ def deep_nested(a): base = base + outer_with_inner(a) return base + # ============================================================================ # Inline functions with multiple mutable variables and complex flow # ============================================================================ + @inline def complex_muts(n): """Multiple mutable variables with interdependencies""" @@ -77,6 +86,7 @@ def complex_muts(n): z = temp + z return x + y + z + @inline def with_immutable(n): """Mix of mutable and immutable inside inline""" @@ -87,10 +97,12 @@ def with_immutable(n): final_imm = m + 1000 return final_imm + # ============================================================================ # Inline functions with internal branching # ============================================================================ + @inline def inline_with_if(x): """Inline function that itself contains if/else""" @@ -102,6 +114,7 @@ def inline_with_if(x): result = result + x return result + @inline def inline_with_match(selector): """Inline function that itself contains match""" @@ -115,6 +128,7 @@ def inline_with_match(selector): out = 3000 return out + @inline def inline_with_nested_branch(a, b): """Inline with nested if inside match""" @@ -132,10 +146,12 @@ def inline_with_nested_branch(a, b): res = 40 return res + # ============================================================================ # Inline functions returning multiple values # ============================================================================ + @inline def multi_return_inline(n): """Inline returning multiple values""" @@ -146,6 +162,7 @@ def multi_return_inline(n): b = b + 2 return a, b + @inline def triple_return(x): """Inline returning three values with different computations""" @@ -158,10 +175,12 @@ def triple_return(x): m3 = m3 + 3 return m1, m2, m3 + # ============================================================================ # Deeper nesting of inline functions # ============================================================================ + @inline def level_d(x): """Deepest level""" @@ -170,6 +189,7 @@ def level_d(x): acc = acc + 1 return acc + @inline def level_c(x): """Calls level_d""" @@ -179,6 +199,7 @@ def level_c(x): acc = acc + 10 return acc + @inline def level_b(x): """Calls level_c""" @@ -188,6 +209,7 @@ def level_b(x): acc = acc + 100 return acc + @inline def level_a(x): """Calls level_b - 4 levels deep""" @@ -197,10 +219,12 @@ def level_a(x): acc = acc + 1000 return acc + # ============================================================================ # Inline with Array operations # ============================================================================ + @inline def inline_with_array(n): """Inline that allocates and uses an array""" @@ -214,6 +238,7 @@ def inline_with_array(n): total = total + arr[i] return total + @inline def inline_modify_array(base): """Inline that creates array and does complex operations""" @@ -224,10 +249,12 @@ def inline_modify_array(base): acc = acc * 2 return buf[0] + buf[1] + buf[2] + # ============================================================================ # Chained inline calls # ============================================================================ + @inline def chain_a(x): m: Mut = x @@ -235,6 +262,7 @@ def chain_a(x): m = m + 1 return m + @inline def chain_b(x): m: Mut = x @@ -242,6 +270,7 @@ def chain_b(x): m = m * 2 return m + @inline def chain_c(x): m: Mut = x @@ -249,10 +278,12 @@ def chain_c(x): m = m + 10 return m + # ============================================================================ # Stress test inline with many variables # ============================================================================ + @inline def many_vars(seed): """Inline with 10 mutable variables""" @@ -279,10 +310,12 @@ def many_vars(seed): v9 = v9 + 1 return v0 + v1 + v2 + v3 + v4 + v5 + v6 + v7 + v8 + v9 + # ============================================================================ # Main test function # ============================================================================ + def main(): # ------------------------------------------------------------------- # TEST 1: Basic inline in match arms (different inlined vars per arm) diff --git a/crates/lean_compiler/tests/test_data/program_172.py b/crates/lean_compiler/tests/test_data/program_172.py index f4accb9c..da813966 100644 --- a/crates/lean_compiler/tests/test_data/program_172.py +++ b/crates/lean_compiler/tests/test_data/program_172.py @@ -2,9 +2,11 @@ # Test match_range feature + def helper_const(n: Const): return n * 10 + def main(): # Test 1: Basic match_range - no forward declaration needed (auto-generated as Imu) x = 2 @@ -34,29 +36,21 @@ def main(): # Test 6: match_range with multiple continuous ranges d = 0 - r6a = match_range(d, - range(0, 2), lambda i: 100 + i, - range(2, 5), lambda i: 200 + i) + r6a = match_range(d, range(0, 2), lambda i: 100 + i, range(2, 5), lambda i: 200 + i) assert r6a == 100 # d=0 -> 100+0=100 e = 3 - r6b = match_range(e, - range(0, 2), lambda i: 100 + i, - range(2, 5), lambda i: 200 + i) + r6b = match_range(e, range(0, 2), lambda i: 100 + i, range(2, 5), lambda i: 200 + i) assert r6b == 203 # e=3 -> 200+3=203 # Test 7: match_range with different lambdas calling functions f = 1 - r7 = match_range(f, - range(0, 1), lambda i: 999, - range(1, 4), lambda i: helper_const(i)) + r7 = match_range(f, range(0, 1), lambda i: 999, range(1, 4), lambda i: helper_const(i)) assert r7 == 10 # f=1 -> helper_const(1)=10 # Test 8: match_range first range (special case) g = 0 - r8 = match_range(g, - range(0, 1), lambda i: 42, - range(1, 3), lambda i: i * 7) + r8 = match_range(g, range(0, 1), lambda i: 42, range(1, 3), lambda i: i * 7) assert r8 == 42 # g=0 -> 42 # Test 9: Results are always immutable @@ -70,36 +64,32 @@ def main(): # Test 10: Basic multiple return values (2 values) v10 = 1 a10, b10 = match_range(v10, range(0, 3), lambda i: two_values_const(i)) - assert a10 == 10 # 1 * 10 + assert a10 == 10 # 1 * 10 assert b10 == 101 # 1 + 100 # Test 11: Multiple return values with different case v11 = 2 a11, b11 = match_range(v11, range(0, 3), lambda i: two_values_const(i)) - assert a11 == 20 # 2 * 10 + assert a11 == 20 # 2 * 10 assert b11 == 102 # 2 + 100 # Test 12: Three return values v12 = 1 x12, y12, z12 = match_range(v12, range(0, 3), lambda i: three_values_const(i)) - assert x12 == 1 # i - assert y12 == 10 # i * 10 - assert z12 == 1001 # i + 1000 + assert x12 == 1 # i + assert y12 == 10 # i * 10 + assert z12 == 1001 # i + 1000 # Test 13: Multiple return values with multiple ranges v13 = 3 - a13, b13 = match_range(v13, - range(0, 2), lambda i: pair_small(i), - range(2, 5), lambda i: pair_large(i)) - assert a13 == 300 # 3 * 100 (pair_large) + a13, b13 = match_range(v13, range(0, 2), lambda i: pair_small(i), range(2, 5), lambda i: pair_large(i)) + assert a13 == 300 # 3 * 100 (pair_large) assert b13 == 3000 # 3 * 1000 # Test 14: Multiple return values with multiple ranges - different range v14 = 1 - a14, b14 = match_range(v14, - range(0, 2), lambda i: pair_small(i), - range(2, 5), lambda i: pair_large(i)) - assert a14 == 1 # 1 * 1 (pair_small) + a14, b14 = match_range(v14, range(0, 2), lambda i: pair_small(i), range(2, 5), lambda i: pair_large(i)) + assert a14 == 1 # 1 * 1 (pair_small) assert b14 == 10 # 1 * 10 # Test 15: Multiple return values - edge case first element @@ -137,11 +127,11 @@ def main(): # Test 20: Three values with multiple ranges v20 = 4 - x20, y20, z20 = match_range(v20, - range(0, 3), lambda i: three_values_const(i), - range(3, 6), lambda i: three_values_offset(i)) - assert x20 == 104 # 4 + 100 - assert y20 == 1004 # 4 + 1000 + x20, y20, z20 = match_range( + v20, range(0, 3), lambda i: three_values_const(i), range(3, 6), lambda i: three_values_offset(i) + ) + assert x20 == 104 # 4 + 100 + assert y20 == 1004 # 4 + 1000 assert z20 == 10004 # 4 + 10000 # ========== INLINED FUNCTION TESTS ========== @@ -154,29 +144,25 @@ def main(): # Test 22: Inlined function - two return values v22 = 3 a22, b22 = match_range(v22, range(0, 5), lambda i: inlined_pair(i)) - assert a22 == 30 # 3 * 10 + assert a22 == 30 # 3 * 10 assert b22 == 300 # 3 * 100 # Test 23: Inlined function with multiple ranges v23 = 4 - r23 = match_range(v23, - range(0, 3), lambda i: inlined_small(i), - range(3, 6), lambda i: inlined_large(i)) + r23 = match_range(v23, range(0, 3), lambda i: inlined_small(i), range(3, 6), lambda i: inlined_large(i)) assert r23 == 4000 # 4 * 1000 (inlined_large) # Test 24: Inlined function - first range v24 = 1 - r24 = match_range(v24, - range(0, 3), lambda i: inlined_small(i), - range(3, 6), lambda i: inlined_large(i)) + r24 = match_range(v24, range(0, 3), lambda i: inlined_small(i), range(3, 6), lambda i: inlined_large(i)) assert r24 == 10 # 1 * 10 (inlined_small) # Test 25: Inlined function with three return values v25 = 2 x25, y25, z25 = match_range(v25, range(0, 4), lambda i: inlined_triple(i)) - assert x25 == 2 # i - assert y25 == 20 # i * 10 - assert z25 == 200 # i * 100 + assert x25 == 2 # i + assert y25 == 20 # i * 10 + assert z25 == 200 # i * 100 # Test 26: Inlined function with complex body v26 = 3 @@ -185,10 +171,8 @@ def main(): # Test 27: Mix of inlined and const functions in multiple ranges v27 = 2 - a27, b27 = match_range(v27, - range(0, 2), lambda i: inlined_pair(i), - range(2, 5), lambda i: two_values_const(i)) - assert a27 == 20 # 2 * 10 (two_values_const) + a27, b27 = match_range(v27, range(0, 2), lambda i: inlined_pair(i), range(2, 5), lambda i: two_values_const(i)) + assert a27 == 20 # 2 * 10 (two_values_const) assert b27 == 102 # 2 + 100 # Test 28: Inlined with expression as match value diff --git a/crates/lean_compiler/tests/test_data/program_173.py b/crates/lean_compiler/tests/test_data/program_173.py index 906197c7..5cd16de1 100644 --- a/crates/lean_compiler/tests/test_data/program_173.py +++ b/crates/lean_compiler/tests/test_data/program_173.py @@ -2,6 +2,7 @@ # Test match_range with non-zero starting indices + def main(): # Test 1: Range starting at 1 x1 = 2 @@ -30,30 +31,26 @@ def main(): # Test 6: Multiple ranges, first starting at non-zero x6 = 2 - r6 = match_range(x6, - range(1, 3), lambda i: i * 10, - range(3, 6), lambda i: i * 100) + r6 = match_range(x6, range(1, 3), lambda i: i * 10, range(3, 6), lambda i: i * 100) assert r6 == 20 # 2 * 10 # Test 7: Multiple ranges, selecting from second range x7 = 4 - r7 = match_range(x7, - range(1, 3), lambda i: i * 10, - range(3, 6), lambda i: i * 100) + r7 = match_range(x7, range(1, 3), lambda i: i * 10, range(3, 6), lambda i: i * 100) assert r7 == 400 # 4 * 100 # Test 8: Non-zero start with multiple return values x8 = 3 a8, b8 = match_range(x8, range(1, 5), lambda i: two_vals(i)) - assert a8 == 30 # 3 * 10 + assert a8 == 30 # 3 * 10 assert b8 == 300 # 3 * 100 # Test 9: Non-zero start with three return values x9 = 2 p9, q9, r9 = match_range(x9, range(1, 4), lambda i: three_vals(i)) - assert p9 == 2 # i - assert q9 == 20 # i * 10 - assert r9 == 200 # i * 100 + assert p9 == 2 # i + assert q9 == 20 # i * 10 + assert r9 == 200 # i * 100 # Test 10: Non-zero start with expression as match value a10 = 7 @@ -73,19 +70,15 @@ def main(): # Test 13: Multiple return values with multiple non-zero ranges x13 = 5 - a13, b13 = match_range(x13, - range(2, 4), lambda i: pair_small(i), - range(4, 7), lambda i: pair_large(i)) - assert a13 == 500 # 5 * 100 + a13, b13 = match_range(x13, range(2, 4), lambda i: pair_small(i), range(4, 7), lambda i: pair_large(i)) + assert a13 == 500 # 5 * 100 assert b13 == 5000 # 5 * 1000 # Test 14: First range selected in multiple non-zero ranges x14 = 3 - a14, b14 = match_range(x14, - range(2, 4), lambda i: pair_small(i), - range(4, 7), lambda i: pair_large(i)) - assert a14 == 3 # 3 * 1 - assert b14 == 30 # 3 * 10 + a14, b14 = match_range(x14, range(2, 4), lambda i: pair_small(i), range(4, 7), lambda i: pair_large(i)) + assert a14 == 3 # 3 * 1 + assert b14 == 30 # 3 * 10 return diff --git a/crates/lean_compiler/tests/test_data/program_174.py b/crates/lean_compiler/tests/test_data/program_174.py index 7bfb0d1d..19be7961 100644 --- a/crates/lean_compiler/tests/test_data/program_174.py +++ b/crates/lean_compiler/tests/test_data/program_174.py @@ -2,6 +2,7 @@ # Test classical match statement with cases starting after 0 + def main(): # Test 1: Basic match starting at 1 r1 = match_start_at_1(2) diff --git a/crates/lean_compiler/tests/test_data/program_175.py b/crates/lean_compiler/tests/test_data/program_175.py index ba0cb120..b68d9003 100644 --- a/crates/lean_compiler/tests/test_data/program_175.py +++ b/crates/lean_compiler/tests/test_data/program_175.py @@ -27,9 +27,7 @@ def inlined_with_match_range_two_args(a, x): @inline def inlined_with_match_range_multi_range(x): - res = match_range(x, - range(0, 3), lambda i: i * 10, - range(3, 6), lambda i: helper_const(i)) + res = match_range(x, range(0, 3), lambda i: i * 10, range(3, 6), lambda i: helper_const(i)) return res @@ -66,9 +64,9 @@ def main(): a = inlined_with_match_range(1) b = inlined_with_match_range(2) c = inlined_with_match_range(3) - assert a == 1 # 1*1 - assert b == 4 # 2*2 - assert c == 9 # 3*3 + assert a == 1 # 1*1 + assert b == 4 # 2*2 + assert c == 9 # 3*3 # Test 5: Inlined function with two match_ranges r5 = inlined_nested_match_range(2, 3) @@ -77,7 +75,7 @@ def main(): # Test 6: Edge cases - first and last values first = inlined_with_match_range(1) last = inlined_with_match_range(4) - assert first == 1 # 1*1 - assert last == 16 # 4*4 + assert first == 1 # 1*1 + assert last == 16 # 4*4 return diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index 4b03b762..41f3afc5 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -414,18 +414,15 @@ ONE_VEC_PTR # [1, 0, 0, ...] ## Precompiles ### poseidon16 +Always in "compression" mode ``` -COMPRESSION = 1 # (output: 8 elements) (For now this is not a real permutation in the cryptographic sense, see Plonky3 PseudoCompression trait, but it will change in the future) -PERMUTATION = 0 # full permutation (output: 16 elements) - poseidon16(left, right, output, mode) ``` - `left`, `right`: pointers to 8 field elements each -- `output`: pointer to result (8 or 16 elements depending on mode) -- Used for Merkle tree hashing and Fiat-Shamir: +- `output`: pointer to result (8 elements) ``` -poseidon16(leaf_a, leaf_b, parent_hash, COMPRESSION) -poseidon16(state, data, new_state, PERMUTATION) +poseidon16(leaf_a, leaf_b, parent_hash) +poseidon16(state, data, new_state) ``` ### dot_product diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index 82146336..fad62da5 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -88,9 +88,7 @@ pub fn prove_execution( // logup (GKR) let logup_c = prover_state.sample(); - prover_state.duplexing(); let logup_alphas = prover_state.sample_vec(log2_ceil_usize(max_bus_width())); - prover_state.duplexing(); let logup_alphas_eq_poly = eval_eq(&logup_alphas); let logup_statements = prove_generic_logup( @@ -113,7 +111,6 @@ pub fn prove_execution( } let bus_beta = prover_state.sample(); - prover_state.duplexing(); let air_alpha = prover_state.sample(); let air_alpha_powers: Vec = air_alpha.powers().collect_n(max_air_constraints() + 1); @@ -176,7 +173,6 @@ pub fn prove_execution( )); let public_memory_random_point = MultilinearPoint(prover_state.sample_vec(log2_strict_usize(public_memory_size))); - prover_state.duplexing(); let public_memory_eval = (&memory[..public_memory_size]).evaluate(&public_memory_random_point); let memory_acc_statements = vec![ diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs index 7e8971ae..6ec98735 100644 --- a/crates/lean_prover/src/test_zkvm.rs +++ b/crates/lean_prover/src/test_zkvm.rs @@ -19,8 +19,6 @@ fn test_zk_vm_fuzzing() { fn test_zk_vm_all_precompiles_helper(fuzzing: bool) { let program_str = r#" DIM = 5 -COMPRESSION = 1 -PERMUTATION = 0 N = 11 VECTOR_LEN = 8 @@ -30,8 +28,7 @@ EE = 0 # extension-extension def main(): pub_start = NONRESERVED_PROGRAM_INPUT_START - poseidon16(pub_start, pub_start + VECTOR_LEN, pub_start + 2 * VECTOR_LEN, PERMUTATION) - poseidon16(pub_start + 4 * VECTOR_LEN, pub_start + 5 * VECTOR_LEN, pub_start + 6 * VECTOR_LEN, COMPRESSION) + poseidon16(pub_start + 4 * VECTOR_LEN, pub_start + 5 * VECTOR_LEN, pub_start + 6 * VECTOR_LEN) dot_product(pub_start + 88, pub_start + 88 + N, pub_start + 1000, N, BE) dot_product(pub_start + 88 + N, pub_start + 88 + N * (DIM + 1), pub_start + 1000 + DIM, N, EE) c: Mut = 0 @@ -48,10 +45,6 @@ def main(): let mut rng = StdRng::seed_from_u64(0); let mut public_input = F::zero_vec(1 << 13); - let poseidon_16_perm_input: [F; 16] = rng.random(); - public_input[..16].copy_from_slice(&poseidon_16_perm_input); - public_input[16..32].copy_from_slice(&poseidon16_permute(poseidon_16_perm_input)); - let poseidon_16_compress_input: [F; 16] = rng.random(); public_input[32..48].copy_from_slice(&poseidon_16_compress_input); public_input[48..56].copy_from_slice(&poseidon16_permute(poseidon_16_compress_input)[..8]); diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index b165250c..419e851e 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -24,7 +24,6 @@ pub fn verify_execution( params: &SnarkParams, ) -> Result { let mut verifier_state = VerifierState::::new(proof, get_poseidon16().clone()); - verifier_state.duplexing(); let dims = verifier_state .next_base_scalars_vec(1 + N_TABLES)? @@ -62,9 +61,7 @@ pub fn verify_execution( packed_pcs_parse_commitment(¶ms.first_whir, &mut verifier_state, log_memory, &table_n_vars)?; let logup_c = verifier_state.sample(); - verifier_state.duplexing(); let logup_alphas = verifier_state.sample_vec(log2_ceil_usize(max_bus_width())); - verifier_state.duplexing(); let logup_alphas_eq_poly = eval_eq(&logup_alphas); let logup_statements = verify_generic_logup( @@ -86,7 +83,6 @@ pub fn verify_execution( } let bus_beta = verifier_state.sample(); - verifier_state.duplexing(); let air_alpha = verifier_state.sample(); let air_alpha_powers: Vec = air_alpha.powers().collect_n(max_air_constraints() + 1); @@ -145,7 +141,6 @@ pub fn verify_execution( let public_memory_random_point = MultilinearPoint(verifier_state.sample_vec(log2_strict_usize(public_memory.len()))); - verifier_state.duplexing(); let public_memory_eval = public_memory.evaluate(&public_memory_random_point); let memory_acc_statements = vec![ diff --git a/crates/lean_vm/src/core/constants.rs b/crates/lean_vm/src/core/constants.rs index bf118c4a..93a0b7f4 100644 --- a/crates/lean_vm/src/core/constants.rs +++ b/crates/lean_vm/src/core/constants.rs @@ -38,19 +38,22 @@ pub const ENDING_PC: usize = 0; /// reserved_area: reserved for special constants (size = 48 field elements) /// program_input: the input of the program we want to prove /// -/// [reserved_area] = [00000000] [00000000] [10000] [01000] [00100] [00010] [00001] [poseidon_16(0) (16 field elements)] [private input start pointer] +/// [reserved_area] = [00000000] [00000000] [10000000] [10000] [01000] [00100] [00010] [00001] [poseidon_16(0) (8 field elements)] [private input start pointer] /// /// Convention: pointing to 16 zeros pub const ZERO_VEC_PTR: usize = 0; +/// Convention: pointing to [10000000] +pub const SAMPLING_DOMAIN_SEPARATOR_PTR: usize = ZERO_VEC_PTR + 2 * DIGEST_LEN; + /// Convention: pointing to [10000] [01000] [00100] [00010] [00001] -pub const EXTENSION_BASIS_PTR: usize = 2 * DIGEST_LEN; +pub const EXTENSION_BASIS_PTR: usize = SAMPLING_DOMAIN_SEPARATOR_PTR + DIGEST_LEN; -/// Convention: pointing to the 16 elements of poseidon_16(0) +/// Convention: pointing to the 8 elements of poseidon_16(0) pub const POSEIDON_16_NULL_HASH_PTR: usize = EXTENSION_BASIS_PTR + DIMENSION.pow(2); /// Pointer to start of private input -pub const PRIVATE_INPUT_START_PTR: usize = POSEIDON_16_NULL_HASH_PTR + DIGEST_LEN * 2; +pub const PRIVATE_INPUT_START_PTR: usize = POSEIDON_16_NULL_HASH_PTR + DIGEST_LEN; /// Normal pointer to start of program input pub const NONRESERVED_PROGRAM_INPUT_START: usize = PRIVATE_INPUT_START_PTR + 1; diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index d1165707..ff6bc8e2 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -9,7 +9,7 @@ use crate::isa::Bytecode; use crate::isa::instruction::InstructionContext; use crate::{ ALL_TABLES, CodeAddress, ENDING_PC, EXTENSION_BASIS_PTR, HintExecutionContext, N_TABLES, PRIVATE_INPUT_START_PTR, - STARTING_PC, SourceLocation, Table, TableTrace, + SAMPLING_DOMAIN_SEPARATOR_PTR, STARTING_PC, SourceLocation, Table, TableTrace, }; use multilinear_toolkit::prelude::*; use std::collections::{BTreeMap, BTreeSet}; @@ -29,6 +29,9 @@ pub fn build_public_memory(public_input: &[F]) -> Vec { *slot = F::ZERO; } + // sampling domain separator + public_memory[SAMPLING_DOMAIN_SEPARATOR_PTR] = F::ONE; + // extension basis for i in 0..DIMENSION { let mut vec = F::zero_vec(DIMENSION); @@ -36,7 +39,8 @@ pub fn build_public_memory(public_input: &[F]) -> Vec { public_memory[EXTENSION_BASIS_PTR + i * DIMENSION..][..DIMENSION].copy_from_slice(&vec); } - public_memory[POSEIDON_16_NULL_HASH_PTR..][..2 * DIGEST_LEN].copy_from_slice(&poseidon16_permute([F::ZERO; 16])); + public_memory[POSEIDON_16_NULL_HASH_PTR..][..DIGEST_LEN] + .copy_from_slice(&poseidon16_permute([F::ZERO; DIGEST_LEN * 2])[..DIGEST_LEN]); public_memory[PRIVATE_INPUT_START_PTR] = F::from_usize(public_memory_len); public_memory } diff --git a/crates/lean_vm/src/tables/poseidon_16/mod.rs b/crates/lean_vm/src/tables/poseidon_16/mod.rs index 8c6a5580..315d96dc 100644 --- a/crates/lean_vm/src/tables/poseidon_16/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_16/mod.rs @@ -12,8 +12,6 @@ use utils::{ToUsize, poseidon16_permute}; mod trace_gen; pub use trace_gen::fill_trace_poseidon_16; -pub const POSEIDON_16_DEFAULT_COMPRESSION: bool = true; - pub(super) const WIDTH: usize = 16; const HALF_INITIAL_FULL_ROUNDS: usize = KOALABEAR_RC16_EXTERNAL_INITIAL.len() / 2; const PARTIAL_ROUNDS: usize = KOALABEAR_RC16_INTERNAL.len(); @@ -22,11 +20,9 @@ const HALF_FINAL_FULL_ROUNDS: usize = KOALABEAR_RC16_EXTERNAL_FINAL.len() / 2; pub const POSEIDON_16_COL_FLAG: ColIndex = 0; pub const POSEIDON_16_COL_A: ColIndex = 1; pub const POSEIDON_16_COL_B: ColIndex = 2; -pub const POSEIDON_16_COL_COMPRESSION: ColIndex = 3; -pub const POSEIDON_16_COL_RES: ColIndex = 4; -pub const POSEIDON_16_COL_RES_BIS: ColIndex = 5; // = if compressed { 0 } else { POSEIDON_16_COL_RES + 1 } -pub const POSEIDON_16_COL_INPUT_START: ColIndex = 6; -const POSEIDON_16_COL_OUTPUT_START: ColIndex = num_cols_poseidon_16() - 16; +pub const POSEIDON_16_COL_RES: ColIndex = 3; +pub const POSEIDON_16_COL_INPUT_START: ColIndex = 4; +const POSEIDON_16_COL_OUTPUT_START: ColIndex = num_cols_poseidon_16() - 8; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Poseidon16Precompile; @@ -55,11 +51,6 @@ impl TableT for Poseidon16Precompile { index: POSEIDON_16_COL_RES, values: (POSEIDON_16_COL_OUTPUT_START..POSEIDON_16_COL_OUTPUT_START + DIGEST_LEN).collect(), }, - LookupIntoMemory { - index: POSEIDON_16_COL_RES_BIS, - values: (POSEIDON_16_COL_OUTPUT_START + DIGEST_LEN..POSEIDON_16_COL_OUTPUT_START + DIGEST_LEN * 2) - .collect(), - }, ] } @@ -72,12 +63,7 @@ impl TableT for Poseidon16Precompile { table: BusTable::Constant(self.table()), direction: BusDirection::Pull, selector: POSEIDON_16_COL_FLAG, - data: vec![ - POSEIDON_16_COL_A, - POSEIDON_16_COL_B, - POSEIDON_16_COL_RES, - POSEIDON_16_COL_COMPRESSION, - ], + data: vec![POSEIDON_16_COL_A, POSEIDON_16_COL_B, POSEIDON_16_COL_RES], } } @@ -95,12 +81,10 @@ impl TableT for Poseidon16Precompile { arg_a: F, arg_b: F, index_res_a: F, - is_compression: usize, + _: usize, _: usize, ctx: &mut InstructionContext<'_>, ) -> Result<(), RunnerError> { - assert!(is_compression == 0 || is_compression == 1); - let is_compression = is_compression == 1; let trace = ctx.traces.get_mut(&self.table()).unwrap(); let arg0 = ctx.memory.get_slice(arg_a.to_usize(), DIGEST_LEN)?; @@ -119,24 +103,13 @@ impl TableT for Poseidon16Precompile { }; let res_a: [F; DIGEST_LEN] = output[..DIGEST_LEN].try_into().unwrap(); - let (index_res_b, res_b): (F, [F; DIGEST_LEN]) = if is_compression { - (F::from_usize(ZERO_VEC_PTR), [F::ZERO; DIGEST_LEN]) - } else { - ( - index_res_a + F::from_usize(DIGEST_LEN), - output[DIGEST_LEN..].try_into().unwrap(), - ) - }; ctx.memory.set_slice(index_res_a.to_usize(), &res_a)?; - ctx.memory.set_slice(index_res_b.to_usize(), &res_b)?; trace.base[POSEIDON_16_COL_FLAG].push(F::ONE); trace.base[POSEIDON_16_COL_A].push(arg_a); trace.base[POSEIDON_16_COL_B].push(arg_b); trace.base[POSEIDON_16_COL_RES].push(index_res_a); - trace.base[POSEIDON_16_COL_RES_BIS].push(index_res_b); - trace.base[POSEIDON_16_COL_COMPRESSION].push(F::from_bool(is_compression)); for (i, value) in input.iter().enumerate() { trace.base[POSEIDON_16_COL_INPUT_START + i].push(*value); } @@ -156,7 +129,7 @@ impl Air for Poseidon16Precompile { 0 } fn degree_air(&self) -> usize { - if BUS { 10 } else { 9 } + 9 } fn down_column_indexes_f(&self) -> Vec { vec![] @@ -165,7 +138,7 @@ impl Air for Poseidon16Precompile { vec![] } fn n_constraints(&self) -> usize { - BUS as usize + 87 + BUS as usize + 76 } fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { let cols: Poseidon2Cols = { @@ -182,28 +155,13 @@ impl Air for Poseidon16Precompile { extra_data, AB::F::from_usize(self.table().index()), cols.flag.clone(), - &[ - cols.index_a.clone(), - cols.index_b.clone(), - cols.index_res.clone(), - cols.compress.clone(), - ], + &[cols.index_a.clone(), cols.index_b.clone(), cols.index_res.clone()], )); } else { - builder.declare_values(&[ - cols.index_a.clone(), - cols.index_b.clone(), - cols.index_res.clone(), - cols.compress.clone(), - ]); + builder.declare_values(&[cols.index_a.clone(), cols.index_b.clone(), cols.index_res.clone()]); } builder.assert_bool(cols.flag.clone()); - builder.assert_bool(cols.compress.clone()); - builder.assert_eq( - cols.index_res_bis.clone(), - (cols.index_res.clone() + AB::F::from_usize(DIGEST_LEN)) * (AB::F::ONE - cols.compress.clone()), - ); eval(builder, &cols) } @@ -215,18 +173,17 @@ pub(super) struct Poseidon2Cols { pub flag: T, pub index_a: T, pub index_b: T, - pub compress: T, pub index_res: T, - pub index_res_bis: T, pub inputs: [T; WIDTH], pub beginning_full_rounds: [[T; WIDTH]; HALF_INITIAL_FULL_ROUNDS], pub partial_rounds: [T; PARTIAL_ROUNDS], - pub ending_full_rounds: [[T; WIDTH]; HALF_FINAL_FULL_ROUNDS], + pub ending_full_rounds: [[T; WIDTH]; HALF_FINAL_FULL_ROUNDS - 1], + pub outputs: [T; WIDTH / 2], } fn eval(builder: &mut AB, local: &Poseidon2Cols) { - let mut state: [_; WIDTH] = local.inputs.clone().map(|x| x); + let mut state: [_; WIDTH] = local.inputs.clone(); GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(&mut state); @@ -255,11 +212,11 @@ fn eval(builder: &mut AB, local: &Poseidon2Cols) { } eval_last_2_full_rounds( + &local.inputs, &mut state, - &local.ending_full_rounds[HALF_FINAL_FULL_ROUNDS - 1], + &local.outputs, &KOALABEAR_RC16_EXTERNAL_FINAL[2 * (HALF_FINAL_FULL_ROUNDS - 1)], &KOALABEAR_RC16_EXTERNAL_FINAL[2 * (HALF_FINAL_FULL_ROUNDS - 1) + 1], - local.compress.clone(), builder, ); } @@ -294,11 +251,11 @@ fn eval_2_full_rounds( #[inline] fn eval_last_2_full_rounds( + initial_state: &[AB::F; WIDTH], state: &mut [AB::F; WIDTH], - post_full_round: &[AB::F; WIDTH], + outputs: &[AB::F; WIDTH / 2], round_constants_1: &[F; WIDTH], round_constants_2: &[F; WIDTH], - compress: AB::F, builder: &mut AB, ) { for (s, r) in state.iter_mut().zip(round_constants_1.iter()) { @@ -311,13 +268,13 @@ fn eval_last_2_full_rounds( *s = s.cube(); } GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(state); - for (state_i, post_i) in state.iter_mut().zip(post_full_round).take(WIDTH / 2) { - builder.assert_eq(state_i.clone(), post_i.clone()); - *state_i = post_i.clone(); + // add inputs to outputs (for compression) + for (state_i, init_state_i) in state.iter_mut().zip(initial_state) { + *state_i += init_state_i.clone(); } - for (state_i, post_i) in state.iter_mut().zip(post_full_round).skip(WIDTH / 2) { - builder.assert_eq(state_i.clone() * -(compress.clone() - AB::F::ONE), post_i.clone()); - *state_i = post_i.clone(); + for (state_i, output_i) in state.iter_mut().zip(outputs) { + builder.assert_eq(state_i.clone(), output_i.clone()); + *state_i = output_i.clone(); } } diff --git a/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs b/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs index 4330ccd5..a2c5c954 100644 --- a/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs +++ b/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs @@ -2,7 +2,7 @@ use p3_poseidon2::GenericPoseidon2LinearLayers; use tracing::instrument; use crate::{ - DIGEST_LEN, F, POSEIDON_16_DEFAULT_COMPRESSION, POSEIDON_16_NULL_HASH_PTR, ZERO_VEC_PTR, + F, POSEIDON_16_NULL_HASH_PTR, ZERO_VEC_PTR, tables::{Poseidon2Cols, WIDTH, num_cols_poseidon_16}, }; use multilinear_toolkit::prelude::*; @@ -56,18 +56,14 @@ pub fn default_poseidon_row() -> Vec { *perm.index_a = F::from_usize(ZERO_VEC_PTR); *perm.index_b = F::from_usize(ZERO_VEC_PTR); *perm.index_res = F::from_usize(POSEIDON_16_NULL_HASH_PTR); - *perm.index_res_bis = if POSEIDON_16_DEFAULT_COMPRESSION { - F::from_usize(ZERO_VEC_PTR) - } else { - F::from_usize(POSEIDON_16_NULL_HASH_PTR + DIGEST_LEN) - }; - *perm.compress = F::from_bool(POSEIDON_16_DEFAULT_COMPRESSION); generate_trace_rows_for_perm(perm); row } + fn generate_trace_rows_for_perm + Copy>(perm: &mut Poseidon2Cols<&mut F>) { - let mut state: [F; WIDTH] = std::array::from_fn(|i| *perm.inputs[i]); + let inputs: [F; WIDTH] = std::array::from_fn(|i| *perm.inputs[i]); + let mut state = inputs; GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(&mut state); @@ -76,30 +72,34 @@ fn generate_trace_rows_for_perm + Copy>(perm: &mut Poseido .iter_mut() .zip(KOALABEAR_RC16_EXTERNAL_INITIAL.chunks_exact(2)) { - generate_full_round(&mut state, full_round, &constants[0], &constants[1]); + generate_2_full_round(&mut state, full_round, &constants[0], &constants[1]); } for (partial_round, constant) in perm.partial_rounds.iter_mut().zip(&KOALABEAR_RC16_INTERNAL) { generate_partial_round(&mut state, partial_round, *constant); } + let n_ending_full_rounds = perm.ending_full_rounds.len(); for (full_round, constants) in perm .ending_full_rounds .iter_mut() .zip(KOALABEAR_RC16_EXTERNAL_FINAL.chunks_exact(2)) { - generate_full_round(&mut state, full_round, &constants[0], &constants[1]); + generate_2_full_round(&mut state, full_round, &constants[0], &constants[1]); } - perm.ending_full_rounds.last_mut().unwrap()[8..16] - .iter_mut() - .for_each(|x| { - **x = (F::ONE - *perm.compress) * **x; - }); + // Last 2 full rounds with compression (add inputs to outputs) + generate_last_2_full_rounds( + &mut state, + &inputs, + &mut perm.outputs, + &KOALABEAR_RC16_EXTERNAL_FINAL[2 * n_ending_full_rounds], + &KOALABEAR_RC16_EXTERNAL_FINAL[2 * n_ending_full_rounds + 1], + ); } #[inline] -fn generate_full_round + Copy>( +fn generate_2_full_round + Copy>( state: &mut [F; WIDTH], post_full_round: &mut [&mut F; WIDTH], round_constants_1: &[KoalaBear; WIDTH], @@ -123,6 +123,32 @@ fn generate_full_round + Copy>( }); } +#[inline] +fn generate_last_2_full_rounds + Copy>( + state: &mut [F; WIDTH], + inputs: &[F; WIDTH], + outputs: &mut [&mut F; WIDTH / 2], + round_constants_1: &[KoalaBear; WIDTH], + round_constants_2: &[KoalaBear; WIDTH], +) { + for (state_i, const_i) in state.iter_mut().zip(round_constants_1) { + *state_i += *const_i; + *state_i = state_i.cube(); + } + GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(state); + + for (state_i, const_i) in state.iter_mut().zip(round_constants_2.iter()) { + *state_i += *const_i; + *state_i = state_i.cube(); + } + GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(state); + + // Add inputs to outputs (compression) + for ((output, state_i), &input_i) in outputs.iter_mut().zip(state).zip(inputs) { + **output = *state_i + input_i; + } +} + #[inline] fn generate_partial_round + Copy>( state: &mut [F; WIDTH], diff --git a/crates/rec_aggregation/fiat_shamir.py b/crates/rec_aggregation/fiat_shamir.py index fe2614b8..5e49cf70 100644 --- a/crates/rec_aggregation/fiat_shamir.py +++ b/crates/rec_aggregation/fiat_shamir.py @@ -1,105 +1,118 @@ from snark_lib import * # FIAT SHAMIR layout: 17 field elements # 0..8 -> first half of sponge state -# 8..16 -> second half of sponge state -# 16 -> transcript pointer +# 8 -> transcript pointer from utils import * def fs_new(transcript_ptr): - fs_state = Array(17) - set_to_16_zeros(fs_state) - fs_state[16] = transcript_ptr - duplexed = duplexing(fs_state) - return duplexed - - -def duplexing(fs): - new_fs = Array(17) - poseidon16(fs, fs + 8, new_fs, PERMUTATION) - new_fs[16] = fs[16] - return new_fs + fs_state = Array(9) + set_to_8_zeros(fs_state) + fs_state[8] = transcript_ptr + return fs_state def fs_grinding(fs, bits): if bits == 0: return fs # no grinding - left = Array(8) - grinding_witness = read_memory(fs[16]) - left[0] = grinding_witness - set_to_7_zeros(left + 1) + transcript_ptr = fs[8] + set_to_7_zeros(transcript_ptr + 1) - fs_after_poseidon = Array(17) - poseidon16(left, fs + 8, fs_after_poseidon, PERMUTATION) - fs_after_poseidon[16] = fs[16] + 1 # one element read from transcript + new_fs = Array(9) + poseidon16(fs, transcript_ptr, new_fs) + new_fs[8] = transcript_ptr + 8 - sampled = fs_after_poseidon[0] + sampled = new_fs[0] _, sampled_low_bits_value = checked_decompose_bits(sampled, bits) assert sampled_low_bits_value == 0 - fs_duplexed = duplexing(fs_after_poseidon) + return new_fs - return fs_duplexed + +def fs_sample_chunks(fs, n_chunks: Const): + # return the updated fiat-shamir, and a pointer to n_chunks chunks of 8 field elements + + sampled = Array((n_chunks + 1) * 8 + 1) + for i in unroll(0, (n_chunks + 1)): + domain_sep = Array(8) + domain_sep[0] = i + set_to_7_zeros(domain_sep + 1) + poseidon16( + fs, + domain_sep, + sampled + i * 8, + ) + sampled[(n_chunks + 1) * 8] = fs[8] # same transcript pointer + new_fs = sampled + n_chunks * 8 + return new_fs, sampled def fs_sample_ef(fs): - return fs + sampled = Array(8) + poseidon16(fs, ZERO_VEC_PTR, sampled) + new_fs = Array(9) + poseidon16(fs, SAMPLING_DOMAIN_SEPARATOR_PTR, new_fs) + new_fs[8] = fs[8] # same transcript pointer + return new_fs, sampled + + +def fs_sample_many_ef(fs, n): + # return the updated fiat-shamir, and a pointer to n (continuous) extension field elements + n_chunks = div_ceil_dynamic(n * DIM, 8) + debug_assert(n_chunks <= 31) + debug_assert(1 <= n_chunks) + new_fs, sampled = match_range(n_chunks, range(1, 32), lambda nc: fs_sample_chunks(fs, nc)) + return new_fs, sampled def fs_hint(fs, n): # return the updated fiat-shamir, and a pointer to n field elements from the transcript - - transcript_ptr = fs[16] - new_fs = Array(17) - copy_16(fs, new_fs) - new_fs[16] = fs[16] + n # advance transcript pointer + transcript_ptr = fs[8] + new_fs = Array(9) + copy_8(fs, new_fs) + new_fs[8] = fs[8] + n # advance transcript pointer return new_fs, transcript_ptr def fs_receive_chunks(fs, n_chunks: Const): # each chunk = 8 field elements - new_fs = Array(1 + 16 * n_chunks) - transcript_ptr = fs[16] - new_fs[16 * n_chunks] = transcript_ptr + 8 * n_chunks # advance transcript pointer + new_fs = Array(1 + 8 * n_chunks) + transcript_ptr = fs[8] + new_fs[8 * n_chunks] = transcript_ptr + 8 * n_chunks # advance transcript pointer - poseidon16(transcript_ptr, fs + 8, new_fs, PERMUTATION) + poseidon16(fs, transcript_ptr, new_fs) for i in unroll(1, n_chunks): poseidon16( + new_fs + ((i - 1) * 8), transcript_ptr + i * 8, - new_fs + ((i - 1) * 16 + 8), - new_fs + i * 16, - PERMUTATION, + new_fs + i * 8, ) - return new_fs + 16 * (n_chunks - 1), transcript_ptr + return new_fs + 8 * (n_chunks - 1), transcript_ptr def fs_receive_ef(fs, n: Const): - new_fs, ef_ptr = fs_receive_chunks(fs, next_multiple_of(n * DIM, 8) / 8) + new_fs, ef_ptr = fs_receive_chunks(fs, div_ceil(n * DIM, 8)) for i in unroll(n * DIM, next_multiple_of(n * DIM, 8)): assert ef_ptr[i] == 0 return new_fs, ef_ptr def fs_print_state(fs_state): - for i in unroll(0, 17): + for i in unroll(0, 9): print(i, fs_state[i]) return -def sample_bits_const(fs: Mut, n_samples: Const, K): +def sample_bits_const(fs, n_samples: Const, K): # return the updated fiat-shamir, and a pointer to n pointers, each pointing to 31 (boolean) field elements, sampled_bits = Array(n_samples) - for i in unroll(0, (next_multiple_of(n_samples, 8) / 8) - 1): - for j in unroll(0, 8): - bits, _ = checked_decompose_bits(fs[j], K) - sampled_bits[i * 8 + j] = bits - fs = duplexing(fs) - # Last batch (may be partial) - for j in unroll(0, 8 - ((8 - (n_samples % 8)) % 8)): - bits, _ = checked_decompose_bits(fs[j], K) - sampled_bits[((next_multiple_of(n_samples, 8) / 8) - 1) * 8 + j] = bits - return duplexing(fs), sampled_bits + n_chunks = div_ceil(n_samples, 8) + new_fs, sampled = fs_sample_chunks(fs, n_chunks) + for i in unroll(0, n_samples): + bits, _ = checked_decompose_bits(sampled[i], K) + sampled_bits[i] = bits + return new_fs, sampled_bits def sample_bits_dynamic(fs_state, n_samples, K): diff --git a/crates/rec_aggregation/hashing.py b/crates/rec_aggregation/hashing.py index bf67a9ef..09cb357b 100644 --- a/crates/rec_aggregation/hashing.py +++ b/crates/rec_aggregation/hashing.py @@ -1,8 +1,5 @@ from snark_lib import * -COMPRESSION = 1 -PERMUTATION = 0 - DIM = 5 # extension degree VECTOR_LEN = 8 @@ -37,12 +34,12 @@ def batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, def slice_hash(seed, data, len: Const): states = Array(len * VECTOR_LEN) - poseidon16(ZERO_VEC_PTR, data, states, COMPRESSION) + poseidon16(ZERO_VEC_PTR, data, states) state_indexes = Array(len) state_indexes[0] = states for j in unroll(1, len): state_indexes[j] = state_indexes[j - 1] + VECTOR_LEN - poseidon16(state_indexes[j - 1], data + j * VECTOR_LEN, state_indexes[j], COMPRESSION) + poseidon16(state_indexes[j - 1], data + j * VECTOR_LEN, state_indexes[j]) return state_indexes[len - 1] @@ -99,9 +96,9 @@ def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, root, height: Co # First merkle round match leaf_position_bits[0]: case 0: - poseidon16(leaf_digest, merkle_path, states, COMPRESSION) + poseidon16(leaf_digest, merkle_path, states) case 1: - poseidon16(merkle_path, leaf_digest, states, COMPRESSION) + poseidon16(merkle_path, leaf_digest, states) # Remaining merkle rounds state_indexes = Array(height) @@ -115,14 +112,12 @@ def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, root, height: Co state_indexes[j - 1], merkle_path + j * VECTOR_LEN, state_indexes[j], - COMPRESSION, ) case 1: poseidon16( merkle_path + j * VECTOR_LEN, state_indexes[j - 1], state_indexes[j], - COMPRESSION, ) copy_8(state_indexes[height - 1], root) return diff --git a/crates/rec_aggregation/recursion.py b/crates/rec_aggregation/recursion.py index e0791dec..b93a34d5 100644 --- a/crates/rec_aggregation/recursion.py +++ b/crates/rec_aggregation/recursion.py @@ -77,13 +77,10 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip # parse 1st whir commitment fs, whir_base_root, whir_base_ood_points, whir_base_ood_evals = parse_whir_commitment_const(fs, NUM_OOD_COMMIT_BASE) - logup_c = fs_sample_ef(fs) - fs = duplexing(fs) - - logup_alphas = Array(DIM * log2_ceil(MAX_BUS_WIDTH)) - for i in unroll(0, log2_ceil(MAX_BUS_WIDTH)): - copy_5(fs_sample_ef(fs), logup_alphas + i * DIM) # TODO avoid duplication - fs = duplexing(fs) + fs, logup_c = fs_sample_ef(fs) + + fs, logup_alphas = fs_sample_many_ef(fs, log2_ceil(MAX_BUS_WIDTH)) + logup_alphas_eq_poly = poly_eq_extension(logup_alphas, log2_ceil(MAX_BUS_WIDTH)) # GENRIC LOGUP @@ -221,10 +218,8 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip # VERIFY BUS AND AIR - bus_beta = fs_sample_ef(fs) - fs = duplexing(fs) - - air_alpha = fs_sample_ef(fs) + fs, bus_beta = fs_sample_ef(fs) + fs, air_alpha = fs_sample_ef(fs) air_alpha_powers = powers_const(air_alpha, MAX_NUM_AIR_CONSTRAINTS + 1) for table_index in unroll(0, N_TABLES): @@ -263,7 +258,7 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip copy_5(expected_outer_eval, outer_eval) if len(AIR_DOWN_COLUMNS_F[table_index]) != 0: - batching_scalar = fs_sample_ef(fs) + fs, batching_scalar = fs_sample_ef(fs) batching_scalar_powers = powers_const(batching_scalar, n_down_columns) evals_down_f = inner_evals + n_up_columns_f * DIM evals_down_ef = inner_evals + (n_up_columns_f + n_down_columns_f + n_up_columns_ef) * DIM @@ -344,11 +339,7 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip pcs_values[table_index][last_index_2][virtual_col_index].push(transposed + j * DIM) log_num_instrs = log2_ceil(NUM_BYTECODE_INSTRUCTIONS) - bytecode_compression_challenges = Array(DIM * log_num_instrs) - for i in unroll(0, log_num_instrs): - copy_5(fs_sample_ef(fs), bytecode_compression_challenges + i * DIM) # TODO avoid duplication - if i != log_num_instrs - 1: - fs = duplexing(fs) + fs, bytecode_compression_challenges = fs_sample_many_ef(fs, log_num_instrs) bytecode_air_values = Array(DIM * 2**log_num_instrs) for i in unroll(0, NUM_BYTECODE_INSTRUCTIONS): @@ -380,7 +371,7 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip fs, pushforward_eval = fs_receive_ef(fs, 1) mul_extension(table_eval, pushforward_eval, ls_sumcheck_value) - ls_c = fs_sample_ef(fs) + fs, ls_c = fs_sample_ef(fs) fs, quotient_left, claim_point_left, claim_num_left, eval_c_minus_indexes = verify_gkr_quotient(fs, log_n_cycles) fs, quotient_right, claim_point_right, pushforward_final_eval, claim_den_right = verify_gkr_quotient( @@ -422,10 +413,8 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip for i in unroll(0, next_multiple_of(NONRESERVED_PROGRAM_INPUT_START, DIM) / DIM): copy_5(i * DIM, outer_public_memory + i * DIM) - public_memory_random_point = Array(outer_public_memory_log_size * DIM) - for i in range(0, outer_public_memory_log_size): - copy_5(fs_sample_ef(fs), public_memory_random_point + i * DIM) - fs = duplexing(fs) + fs, public_memory_random_point = fs_sample_many_ef(fs, outer_public_memory_log_size) + poly_eq_public_mem = poly_eq_extension_dynamic(public_memory_random_point, outer_public_memory_log_size) public_memory_eval = Array(DIM) dot_product_be_dynamic( @@ -436,7 +425,8 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip ) # WHIR BASE - combination_randomness_gen: Mut = fs_sample_ef(fs) + combination_randomness_gen: Mut + fs, combination_randomness_gen = fs_sample_ef(fs) combination_randomness_powers: Mut = powers_const( combination_randomness_gen, NUM_OOD_COMMIT_BASE + TOTAL_WHIR_STATEMENTS_BASE ) @@ -565,7 +555,7 @@ def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcrip copy_5(mul_extension_ret(s, final_value), end_sum) # WHIR EXT (Pushforward) - combination_randomness_gen = fs_sample_ef(fs) + fs, combination_randomness_gen = fs_sample_ef(fs) combination_randomness_powers = powers_const(combination_randomness_gen, NUM_OOD_COMMIT_EXT + 2) whir_sum = dot_product_ret(whir_ext_ood_evals, combination_randomness_powers, NUM_OOD_COMMIT_EXT, EE) whir_sum = add_extension_ret( @@ -618,9 +608,8 @@ def verify_gkr_quotient(fs: Mut, n_vars): claims_num = Array(n_vars) claims_den = Array(n_vars) - points[0] = fs_sample_ef(fs) - fs = duplexing(fs) - + fs, points[0] = fs_sample_ef(fs) + point_poly_eq = poly_eq_extension(points[0], 1) first_claim_num = dot_product_ret(nums, point_poly_eq, 2, EE) @@ -643,7 +632,7 @@ def verify_gkr_quotient(fs: Mut, n_vars): def verify_gkr_quotient_step(fs: Mut, n_vars, point, claim_num, claim_den): - alpha = fs_sample_ef(fs) + fs, alpha = fs_sample_ef(fs) alpha_mul_claim_den = mul_extension_ret(alpha, claim_den) num_plus_alpha_mul_claim_den = add_extension_ret(claim_num, alpha_mul_claim_den) postponed_point = Array((n_vars + 1) * DIM) @@ -659,8 +648,8 @@ def verify_gkr_quotient_step(fs: Mut, n_vars, point, claim_num, claim_den): eq_factor = eq_mle_extension(point, postponed_point + DIM, n_vars) mul_extension(sum_num_plus_sum_den_mul_alpha, eq_factor, postponed_value) - beta = fs_sample_ef(fs) - fs = duplexing(fs) + fs, beta = fs_sample_ef(fs) + point_poly_eq = poly_eq_extension(beta, 1) new_claim_num = dot_product_ret(inner_evals, point_poly_eq, 2, EE) new_claim_den = dot_product_ret(inner_evals + 2 * DIM, point_poly_eq, 2, EE) diff --git a/crates/rec_aggregation/src/recursion.rs b/crates/rec_aggregation/src/recursion.rs index bd2b90b8..2c9ccc7c 100644 --- a/crates/rec_aggregation/src/recursion.rs +++ b/crates/rec_aggregation/src/recursion.rs @@ -30,8 +30,6 @@ pub fn run_recursion_benchmark(count: usize, tracing: bool) { }; let program_to_prove = r#" DIM = 5 -COMPRESSION = 1 -PERMUTATION = 0 POSEIDON_OF_ZERO = POSEIDON_OF_ZERO_PLACEHOLDER # Dot product precompile: BE = 1 # base-extension @@ -41,8 +39,7 @@ def main(): for i in range(0, 1000): null_ptr = ZERO_VEC_PTR # pointer to zero vector poseidon_of_zero = POSEIDON_OF_ZERO - poseidon16(null_ptr, null_ptr, poseidon_of_zero, PERMUTATION) - poseidon16(null_ptr, null_ptr, poseidon_of_zero, COMPRESSION) + poseidon16(null_ptr, null_ptr, poseidon_of_zero) dot_product(null_ptr, null_ptr, null_ptr, 2, BE) dot_product(null_ptr, null_ptr, null_ptr, 2, EE) x: Mut = 0 diff --git a/crates/rec_aggregation/utils.py b/crates/rec_aggregation/utils.py index a2e8be99..96975a93 100644 --- a/crates/rec_aggregation/utils.py +++ b/crates/rec_aggregation/utils.py @@ -15,6 +15,12 @@ LITTLE_ENDIAN = 1 +def div_ceil_dynamic(a, b: Const): + debug_assert(a <= 150) + res = match_range(a, range(0, 151), lambda i: div_ceil(i, b)) + return res + + def powers(alpha, n): # alpha: EF # n: F @@ -54,9 +60,7 @@ def unit_root_pow_const(domain_size: Const, index_bits): def poly_eq_extension_dynamic(point, n): debug_assert(n < 8) - res = match_range(n, - range(0, 1), lambda i: ONE_VEC_PTR, - range(1, 8), lambda i: poly_eq_extension(point, i)) + res = match_range(n, range(0, 1), lambda i: ONE_VEC_PTR, range(1, 8), lambda i: poly_eq_extension(point, i)) return res @@ -374,12 +378,12 @@ def set_to_7_zeros(a): @inline -def set_to_16_zeros(a): +def set_to_8_zeros(a): zero_ptr = ZERO_VEC_PTR dot_product(a, ONE_VEC_PTR, zero_ptr, 1, EE) - dot_product(a + 5, ONE_VEC_PTR, zero_ptr, 1, EE) - dot_product(a + 10, ONE_VEC_PTR, zero_ptr, 1, EE) - a[15] = 0 + a[5] = 0 + a[6] = 0 + a[7] = 0 return diff --git a/crates/rec_aggregation/whir.py b/crates/rec_aggregation/whir.py index 818299d0..06f00951 100644 --- a/crates/rec_aggregation/whir.py +++ b/crates/rec_aggregation/whir.py @@ -157,8 +157,6 @@ def whir_open_base( final_value = dot_product_ret(poly_eq_final, final_coeffcients, 2**FINAL_VARS_BASE, EE) # copy_5(mul_extension_ret(s, final_value), end_sum); - fs = duplexing(fs) - return fs, folding_randomness_global, s, final_value, end_sum @@ -289,8 +287,6 @@ def whir_open_ext( final_value = dot_product_ret(poly_eq_final, final_coeffcients, 2**FINAL_VARS_EXT, EE) # copy_5(mul_extension_ret(s, final_value), end_sum); - fs = duplexing(fs) - return fs, folding_randomness_global, s, final_value, end_sum @@ -305,7 +301,7 @@ def sumcheck_verify_helper(fs: Mut, n_steps, claimed_sum: Mut, degree: Const, ch fs, poly = fs_receive_ef(fs, degree + 1) sum_over_boolean_hypercube = polynomial_sum_at_0_and_1(poly, degree) copy_5(sum_over_boolean_hypercube, claimed_sum) - rand = fs_sample_ef(fs) + fs, rand = fs_sample_ef(fs) claimed_sum = univariate_polynomial_eval(poly, rand, degree) copy_5(rand, challenges + sc_round * DIM) @@ -407,7 +403,7 @@ def whir_round( grinding_bits, ) - combination_randomness_gen = fs_sample_ef(fs) + fs, combination_randomness_gen = fs_sample_ef(fs) combination_randomness_powers = powers(combination_randomness_gen, num_queries + num_ood) @@ -464,10 +460,6 @@ def parse_commitment(fs: Mut, num_ood): def parse_whir_commitment_const(fs: Mut, num_ood: Const): fs, root = fs_receive_chunks(fs, 1) - ood_points = Array(num_ood * DIM) - for i in unroll(0, num_ood): - ood_point = fs_sample_ef(fs) - copy_5(ood_point, ood_points + i * DIM) - fs = duplexing(fs) + fs, ood_points = fs_sample_many_ef(fs, num_ood) fs, ood_evals = fs_receive_ef(fs, num_ood) return fs, root, ood_points, ood_evals diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index ea6101c5..a776230e 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -1,8 +1,5 @@ from snark_lib import * -COMPRESSION = 1 -PERMUTATION = 0 - V = 66 W = 4 TARGET_SUM = 118 @@ -49,7 +46,7 @@ def xmss_recover_pub_key(message_hash, signature, log_lifetime, merkle_index): # 1) We encode message_hash + randomness into the d-th layer of the hypercube compressed = Array(VECTOR_LEN) - poseidon16(message_hash, randomness, compressed, COMPRESSION) + poseidon16(message_hash, randomness, compressed) compressed_vals = Array(6) dot_product(compressed, ONE_VEC_PTR, compressed_vals, 1, EE) compressed_vals[5] = compressed[5] @@ -99,19 +96,19 @@ def xmss_recover_pub_key(message_hash, signature, log_lifetime, merkle_index): var_2 = public_key + i * VECTOR_LEN var_3 = Array(vector_len) var_4 = Array(vector_len) - poseidon16(var_1, ZERO_VEC_PTR, var_3, COMPRESSION) - poseidon16(var_3, ZERO_VEC_PTR, var_4, COMPRESSION) - poseidon16(var_4, ZERO_VEC_PTR, var_2, COMPRESSION) + poseidon16(var_1, ZERO_VEC_PTR, var_3) + poseidon16(var_3, ZERO_VEC_PTR, var_4) + poseidon16(var_4, ZERO_VEC_PTR, var_2) case 1: var_3 = Array(vector_len) var_1 = chain_tips + i * VECTOR_LEN var_2 = public_key + i * VECTOR_LEN - poseidon16(var_1, ZERO_VEC_PTR, var_3, COMPRESSION) - poseidon16(var_3, ZERO_VEC_PTR, var_2, COMPRESSION) + poseidon16(var_1, ZERO_VEC_PTR, var_3) + poseidon16(var_3, ZERO_VEC_PTR, var_2) case 2: var_1 = chain_tips + i * VECTOR_LEN var_2 = public_key + i * VECTOR_LEN - poseidon16(var_1, ZERO_VEC_PTR, var_2, COMPRESSION) + poseidon16(var_1, ZERO_VEC_PTR, var_2) case 3: var_1 = chain_tips + (i * VECTOR_LEN) var_2 = public_key + (i * VECTOR_LEN) @@ -202,9 +199,9 @@ def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height: Const): # First merkle round match leaf_position_bits[0]: case 0: - poseidon16(leaf_digest, merkle_path, states, COMPRESSION) + poseidon16(leaf_digest, merkle_path, states) case 1: - poseidon16(merkle_path, leaf_digest, states, COMPRESSION) + poseidon16(merkle_path, leaf_digest, states) # Remaining merkle rounds state_indexes = Array(height) @@ -218,26 +215,24 @@ def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height: Const): state_indexes[j - 1], merkle_path + j * VECTOR_LEN, state_indexes[j], - COMPRESSION, ) case 1: poseidon16( merkle_path + j * VECTOR_LEN, state_indexes[j - 1], state_indexes[j], - COMPRESSION, ) return state_indexes[height - 1] def slice_hash(seed, data, half_len: Const): states = Array(half_len * 2 * VECTOR_LEN) - poseidon16(ZERO_VEC_PTR, data, states, COMPRESSION) + poseidon16(ZERO_VEC_PTR, data, states) state_indexes = Array(half_len * 2) state_indexes[0] = states for j in unroll(1, (half_len * 2)): state_indexes[j] = state_indexes[j - 1] + VECTOR_LEN - poseidon16(state_indexes[j - 1], data + j * VECTOR_LEN, state_indexes[j], COMPRESSION) + poseidon16(state_indexes[j - 1], data + j * VECTOR_LEN, state_indexes[j]) return state_indexes[half_len * 2 - 1] diff --git a/crates/sub_protocols/src/quotient_gkr.rs b/crates/sub_protocols/src/quotient_gkr.rs index b7c2b18a..97c8640f 100644 --- a/crates/sub_protocols/src/quotient_gkr.rs +++ b/crates/sub_protocols/src/quotient_gkr.rs @@ -42,7 +42,6 @@ pub fn prove_gkr_quotient>>( let quotient = last_numerators[0] / last_denominators[0] + last_numerators[1] / last_denominators[1]; let mut point = MultilinearPoint(vec![prover_state.sample()]); - prover_state.duplexing(); let mut claims = vec![last_numerators.evaluate(&point), last_denominators.evaluate(&point)]; for (nums, denoms) in layers.iter().rev() { @@ -90,7 +89,6 @@ fn prove_gkr_quotient_step>>( prover_state.add_extension_scalars(&inner_evals); let beta = prover_state.sample(); - prover_state.duplexing(); let next_claims = inner_evals .chunks_exact(2) @@ -112,7 +110,6 @@ pub fn verify_gkr_quotient>>( let quotient = last_nums[0] / last_dens[0] + last_nums[1] / last_dens[1]; let mut point = MultilinearPoint(vec![verifier_state.sample()]); - verifier_state.duplexing(); let mut claims_num = last_nums.evaluate(&point); let mut claims_den = last_dens.evaluate(&point); for i in 1..n_vars { @@ -152,7 +149,6 @@ fn verify_gkr_quotient_step>>( } let beta = verifier_state.sample(); - verifier_state.duplexing(); let next_claims_numerators = (&inner_evals[..2]).evaluate(&MultilinearPoint(vec![beta])); let next_claims_denominators = (&inner_evals[2..]).evaluate(&MultilinearPoint(vec![beta])); diff --git a/crates/utils/src/wrappers.rs b/crates/utils/src/wrappers.rs index 37e0dd9f..ef606cfd 100644 --- a/crates/utils/src/wrappers.rs +++ b/crates/utils/src/wrappers.rs @@ -7,17 +7,13 @@ use crate::get_poseidon16; pub type VarCount = usize; pub fn build_prover_state() -> ProverState { - let mut prover_state = ProverState::new(get_poseidon16().clone()); - prover_state.duplexing(); - prover_state + ProverState::new(get_poseidon16().clone()) } pub fn build_verifier_state( prover_state: ProverState, ) -> VerifierState { - let mut verifier_state = VerifierState::new(prover_state.raw_proof(), get_poseidon16().clone()); - verifier_state.duplexing(); - verifier_state + VerifierState::new(prover_state.raw_proof(), get_poseidon16().clone()) } pub trait ToUsize { diff --git a/src/prove_poseidons.rs b/src/prove_poseidons.rs index 714a3010..ff4e582e 100644 --- a/src/prove_poseidons.rs +++ b/src/prove_poseidons.rs @@ -1,8 +1,8 @@ use air::{check_air_validity, prove_air, verify_air}; use lean_vm::{ - EF, ExtraDataForBuses, F, POSEIDON_16_COL_A, POSEIDON_16_COL_B, POSEIDON_16_COL_COMPRESSION, POSEIDON_16_COL_FLAG, - POSEIDON_16_COL_INPUT_START, POSEIDON_16_COL_RES, POSEIDON_16_COL_RES_BIS, POSEIDON_16_DEFAULT_COMPRESSION, - POSEIDON_16_NULL_HASH_PTR, Poseidon16Precompile, ZERO_VEC_PTR, fill_trace_poseidon_16, num_cols_poseidon_16, + EF, ExtraDataForBuses, F, POSEIDON_16_COL_A, POSEIDON_16_COL_B, POSEIDON_16_COL_FLAG, POSEIDON_16_COL_INPUT_START, + POSEIDON_16_COL_RES, POSEIDON_16_NULL_HASH_PTR, Poseidon16Precompile, ZERO_VEC_PTR, fill_trace_poseidon_16, + num_cols_poseidon_16, }; use multilinear_toolkit::prelude::*; use rand::{Rng, SeedableRng, rngs::StdRng}; @@ -31,10 +31,6 @@ pub fn benchmark_prove_poseidon_16(log_n_rows: usize, tracing: bool) { } trace[POSEIDON_16_COL_FLAG] = (0..n_rows).map(|_| F::ONE).collect(); trace[POSEIDON_16_COL_RES] = (0..n_rows).map(|_| F::from_usize(POSEIDON_16_NULL_HASH_PTR)).collect(); - trace[POSEIDON_16_COL_RES_BIS] = (0..n_rows).map(|_| F::from_usize(ZERO_VEC_PTR)).collect(); - trace[POSEIDON_16_COL_COMPRESSION] = (0..n_rows) - .map(|_| F::from_bool(POSEIDON_16_DEFAULT_COMPRESSION)) - .collect(); trace[POSEIDON_16_COL_A] = (0..n_rows).map(|_| F::from_usize(ZERO_VEC_PTR)).collect(); trace[POSEIDON_16_COL_B] = (0..n_rows).map(|_| F::from_usize(ZERO_VEC_PTR)).collect(); fill_trace_poseidon_16(&mut trace); @@ -77,7 +73,6 @@ pub fn benchmark_prove_poseidon_16(log_n_rows: usize, tracing: bool) { let witness = whir_config.commit(&mut prover_state, &committed_pol); let alpha = prover_state.sample(); - prover_state.duplexing(); let air_alpha_powers: Vec = alpha.powers().collect_n(air.n_constraints() + 1); let extra_data = ExtraDataForBuses { alpha_powers: air_alpha_powers, @@ -98,7 +93,6 @@ pub fn benchmark_prove_poseidon_16(log_n_rows: usize, tracing: bool) { assert_eq!(air_claims.evals_f.len(), air.n_columns_air()); let betas = prover_state.sample_vec(log2_ceil_usize(num_cols_poseidon_16())); - prover_state.duplexing(); let packed_point = MultilinearPoint([betas.clone(), air_claims.point.0].concat()); let packed_eval = padd_with_zero_to_next_power_of_two(&air_claims.evals_f).evaluate(&MultilinearPoint(betas)); @@ -121,7 +115,6 @@ pub fn benchmark_prove_poseidon_16(log_n_rows: usize, tracing: bool) { let parsed_commitment = whir_config.parse_commitment::(&mut verifier_state).unwrap(); let alpha = verifier_state.sample(); - verifier_state.duplexing(); let air_alpha_powers: Vec = alpha.powers().collect_n(air.n_constraints() + 1); let extra_data = ExtraDataForBuses { alpha_powers: air_alpha_powers, @@ -138,7 +131,6 @@ pub fn benchmark_prove_poseidon_16(log_n_rows: usize, tracing: bool) { .unwrap(); let betas = verifier_state.sample_vec(log2_ceil_usize(num_cols_poseidon_16())); - verifier_state.duplexing(); let packed_point = MultilinearPoint([betas.clone(), air_claims.point.0].concat()); let packed_eval = padd_with_zero_to_next_power_of_two(&air_claims.evals_f).evaluate(&MultilinearPoint(betas));