diff --git a/ml-kem/src/algebra.rs b/ml-kem/src/algebra.rs index ebd4f0c..2a44d5b 100644 --- a/ml-kem/src/algebra.rs +++ b/ml-kem/src/algebra.rs @@ -134,26 +134,42 @@ pub(crate) trait Ntt { fn ntt(&self) -> Self::Output; } +/// One layer of the forward NTT butterfly. +/// +/// `LEN` is the butterfly half-length and `ITERATIONS = 128 / LEN` is the number of +/// butterfly groups in the layer. Making both compile-time constants lets the compiler +/// eliminate the iterator length calculation (`256 / (2 * LEN)`) that `step_by` would +/// otherwise compute with a `UDIV` instruction. +#[inline(always)] +fn ntt_layer(f: &mut Array, k: &mut usize) { + for i in 0..ITERATIONS { + let start = i * 2 * LEN; + let zeta = ZETA_POW_BITREV[*k]; + *k += 1; + + for j in start..(start + LEN) { + let t = zeta * f[j + LEN]; + f[j + LEN] = f[j] - t; + f[j] = f[j] + t; + } + } +} + /// Algorithm 9: `NTT` impl Ntt for Polynomial { type Output = NttPolynomial; fn ntt(&self) -> NttPolynomial { let mut k = 1; - let mut f = self.0; - for len in [128, 64, 32, 16, 8, 4, 2] { - for start in (0..256).step_by(2 * len) { - let zeta = ZETA_POW_BITREV[k]; - k += 1; - - for j in start..(start + len) { - let t = zeta * f[j + len]; - f[j + len] = f[j] - t; - f[j] = f[j] + t; - } - } - } + + ntt_layer::<128, 1>(&mut f, &mut k); + ntt_layer::<64, 2>(&mut f, &mut k); + ntt_layer::<32, 4>(&mut f, &mut k); + ntt_layer::<16, 8>(&mut f, &mut k); + ntt_layer::<8, 16>(&mut f, &mut k); + ntt_layer::<4, 32>(&mut f, &mut k); + ntt_layer::<2, 64>(&mut f, &mut k); f.into() } @@ -175,26 +191,42 @@ pub(crate) trait NttInverse { fn ntt_inverse(&self) -> Self::Output; } +/// One layer of the inverse NTT butterfly. +/// +/// See [`ntt_layer`] for the rationale behind the const generics. +#[inline(always)] +fn ntt_inverse_layer( + f: &mut Array, + k: &mut usize, +) { + for i in 0..ITERATIONS { + let start = i * 2 * LEN; + let zeta = ZETA_POW_BITREV[*k]; + *k -= 1; + + for j in start..(start + LEN) { + let t = f[j]; + f[j] = t + f[j + LEN]; + f[j + LEN] = zeta * (f[j + LEN] - t); + } + } +} + /// Algorithm 10: `NTT^{-1}` impl NttInverse for NttPolynomial { type Output = Polynomial; fn ntt_inverse(&self) -> Polynomial { let mut f: Array = self.0.clone(); - let mut k = 127; - for len in [2, 4, 8, 16, 32, 64, 128] { - for start in (0..256).step_by(2 * len) { - let zeta = ZETA_POW_BITREV[k]; - k -= 1; - - for j in start..(start + len) { - let t = f[j]; - f[j] = t + f[j + len]; - f[j + len] = zeta * (f[j + len] - t); - } - } - } + + ntt_inverse_layer::<2, 64>(&mut f, &mut k); + ntt_inverse_layer::<4, 32>(&mut f, &mut k); + ntt_inverse_layer::<8, 16>(&mut f, &mut k); + ntt_inverse_layer::<16, 8>(&mut f, &mut k); + ntt_inverse_layer::<32, 4>(&mut f, &mut k); + ntt_inverse_layer::<64, 2>(&mut f, &mut k); + ntt_inverse_layer::<128, 1>(&mut f, &mut k); Elem::new(3303) * &Polynomial::new(f) } diff --git a/module-lattice/src/algebra.rs b/module-lattice/src/algebra.rs index f0350b4..113eb72 100644 --- a/module-lattice/src/algebra.rs +++ b/module-lattice/src/algebra.rs @@ -72,7 +72,13 @@ macro_rules! define_field { const BARRETT_MULTIPLIER: Self::LongLong = (1 << Self::BARRETT_SHIFT) / Self::QLL; fn small_reduce(x: Self::Int) -> Self::Int { - if x < Self::Q { x } else { x - Self::Q } + // Branchless conditional subtraction: if x >= Q, subtract Q; else + // leave x alone. Compilers already emit `csel` here at O2, but the + // explicit mask form removes the dependency on optimizer choices + // and keeps the generated assembly free of secret-dependent control + // flow at every optimization level. + let mask = ((x >= Self::Q) as $int).wrapping_neg(); + x - (Self::Q & mask) } fn barrett_reduce(x: Self::Long) -> Self::Int { diff --git a/module-lattice/src/encoding.rs b/module-lattice/src/encoding.rs index 90f2f17..7bdc4c1 100644 --- a/module-lattice/src/encoding.rs +++ b/module-lattice/src/encoding.rs @@ -130,9 +130,11 @@ pub fn byte_decode(bytes: &EncodedPolynomial) -> D let val = F::Int::truncate(x >> (D::USIZE * j)); vj.0 = val & mask; - // Special case for FIPS 203 + // Special case for FIPS 203. For 12-bit values (max 4095) with Q = 3329, + // the masked value is always in [0, 2Q), so `small_reduce` is exact and + // avoids the hardware UDIV that `% F::Q` would emit. if D::USIZE == 12 { - vj.0 = vj.0 % F::Q; + vj.0 = F::small_reduce(vj.0); } } }