diff --git a/Cargo.lock b/Cargo.lock index e8e621e..9a7a500 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -808,6 +808,7 @@ dependencies = [ name = "module-lattice" version = "0.2.1" dependencies = [ + "criterion", "ctutils", "getrandom", "hybrid-array", diff --git a/module-lattice/Cargo.toml b/module-lattice/Cargo.toml index 41b3d27..2640f79 100644 --- a/module-lattice/Cargo.toml +++ b/module-lattice/Cargo.toml @@ -24,8 +24,13 @@ ctutils = { version = "0.4", optional = true } zeroize = { version = "1.8.1", optional = true, default-features = false } [dev-dependencies] +criterion = "0.7" getrandom = { version = "0.4", features = ["sys_rng"] } +[[bench]] +name = "algebra" +harness = false + [features] ctutils = ["dep:ctutils", "array/ctutils"] zeroize = ["array/zeroize", "dep:zeroize"] diff --git a/module-lattice/benches/algebra.rs b/module-lattice/benches/algebra.rs new file mode 100644 index 0000000..05752d0 --- /dev/null +++ b/module-lattice/benches/algebra.rs @@ -0,0 +1,80 @@ +//! Benchmarks for the NTT vector inner product. + +#![allow( + missing_docs, + clippy::integer_division_remainder_used, + clippy::cast_possible_truncation +)] + +use array::typenum::{U2, U3, U4}; +use core::hint::black_box; +use criterion::{Criterion, criterion_group, criterion_main}; +use module_lattice::{Elem, MultiplyNtt, NttPolynomial, NttVector}; + +module_lattice::define_field!(KyberField, u16, u32, u64, 3329); + +// MultiplyNtt is required by the `&NttVector * &NttVector` where clause even though +// the inner product body accumulates directly in Long integers and never calls this. +impl MultiplyNtt for KyberField { + fn multiply_ntt(lhs: &NttPolynomial, rhs: &NttPolynomial) -> NttPolynomial { + NttPolynomial::new( + lhs.0 + .iter() + .zip(rhs.0.iter()) + .map(|(&a, &b)| a * b) + .collect(), + ) + } +} + +fn make_ntt_poly(base: u16) -> NttPolynomial { + let coeffs: [Elem; 256] = + core::array::from_fn(|i| Elem::new((base + i as u16) % 3329)); + NttPolynomial::new(coeffs.into()) +} + +fn bench_ntt_vector_inner_product(c: &mut Criterion) { + // K=2 (ML-KEM-512) + let a2: NttVector = + NttVector::new([make_ntt_poly(100), make_ntt_poly(200)].into()); + let b2: NttVector = + NttVector::new([make_ntt_poly(300), make_ntt_poly(400)].into()); + c.bench_function("ntt_vector_dot_k2", |bench| { + bench.iter(|| black_box(black_box(&a2) * black_box(&b2))); + }); + + // K=3 (ML-KEM-768) + let a3: NttVector = + NttVector::new([make_ntt_poly(100), make_ntt_poly(200), make_ntt_poly(300)].into()); + let b3: NttVector = + NttVector::new([make_ntt_poly(400), make_ntt_poly(500), make_ntt_poly(600)].into()); + c.bench_function("ntt_vector_dot_k3", |bench| { + bench.iter(|| black_box(black_box(&a3) * black_box(&b3))); + }); + + // K=4 (ML-KEM-1024) + let a4: NttVector = NttVector::new( + [ + make_ntt_poly(100), + make_ntt_poly(200), + make_ntt_poly(300), + make_ntt_poly(400), + ] + .into(), + ); + let b4: NttVector = NttVector::new( + [ + make_ntt_poly(500), + make_ntt_poly(600), + make_ntt_poly(700), + make_ntt_poly(800), + ] + .into(), + ); + c.bench_function("ntt_vector_dot_k4", |bench| { + bench.iter(|| black_box(black_box(&a4) * black_box(&b4))); + }); +} + +criterion_group!(benches, bench_ntt_vector_inner_product); +criterion_main!(benches); diff --git a/module-lattice/src/algebra.rs b/module-lattice/src/algebra.rs index f0350b4..baacb79 100644 --- a/module-lattice/src/algebra.rs +++ b/module-lattice/src/algebra.rs @@ -1,6 +1,9 @@ use super::truncate::Truncate; -use array::{Array, ArraySize, typenum::U256}; +use array::{ + Array, ArraySize, + typenum::{U256, Unsigned}, +}; use core::fmt::Debug; use core::ops::{Add, Mul, Neg, Sub}; use num_traits::PrimInt; @@ -528,11 +531,21 @@ where type Output = NttPolynomial; fn mul(self, rhs: &NttVector) -> NttPolynomial { - self.0 - .iter() - .zip(rhs.0.iter()) - .map(|(x, y)| x * y) - .fold(NttPolynomial::default(), |x, y| &x + &y) + let zero = ::zero(); + let mut acc = [zero; ::USIZE]; + for (pa, pb) in self.0.iter().zip(rhs.0.iter()) { + for (i, (&x, &y)) in pa.0.iter().zip(pb.0.iter()).enumerate() { + let lx: F::Long = x.0.into(); + let ly: F::Long = y.0.into(); + acc[i] = acc[i] + lx * ly; + } + } + // Barrett-reduce once per coefficient instead of after every multiply. + NttPolynomial::new( + acc.iter() + .map(|&v| Elem::new(F::barrett_reduce(v))) + .collect(), + ) } }