Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 101 additions & 61 deletions ceno_zkvm/src/precompiles/lookup_keccakf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ pub const KECCAK_OUT_EVAL_SIZE: usize = size_of::<KeccakOutEvals<u8>>();

const AND_LOOKUPS_PER_ROUND: usize = 200;
const XOR_LOOKUPS_PER_ROUND: usize = 608;
const RANGE_LOOKUPS_PER_ROUND: usize = 290;
const RANGE_LOOKUPS_PER_ROUND: usize = 286;
const LOOKUP_FELTS_PER_ROUND: usize =
AND_LOOKUPS_PER_ROUND + XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND;

Expand Down Expand Up @@ -125,12 +125,13 @@ pub struct KeccakFixedCols<T> {
#[repr(C)]
pub struct KeccakWitCols<T> {
pub input8: [T; 200],
pub c_aux: [T; 200],
// Prefix-XOR witnesses for j = 1..4 (j = 0 is directly input state column 0).
pub c_aux: [T; 160],
pub c_temp: [T; 40],
pub c_rot: [T; 40],
pub d: [T; 40],
pub theta_output: [T; 200],
pub rotation_witness: [T; 196],
pub rotation_witness: [T; 192],
pub rhopi_output: [T; 200],
pub nonlinear: [T; 200],
pub chi_output: [T; 8],
Expand Down Expand Up @@ -161,10 +162,33 @@ pub struct KeccakLayout<E: ExtensionField> {
pub n_structural_witin: usize,
}

const ROTATION_WITNESS_LEN: usize = 196;
const ROTATION_WITNESS_LEN: usize = 192;
const C_TEMP_SPLIT_SIZES: [usize; 8] = [15, 1, 15, 1, 15, 1, 15, 1];
const BYTE_SPLIT_SIZES: [usize; 8] = [8; 8];

// Wiring point for a future fused `ANDN_XOR` lookup table.
// Today Chi uses 2 lookups/byte (`AND` + `XOR`) and a materialized nonlinear byte witness.
// If a fused table is added in gkr_iop, this helper can be switched with localized changes.
#[inline(always)]
fn constrain_chi_byte<E: ExtensionField>(
system: &mut CircuitBuilder<E>,
b0: Expression<E>,
b1: Expression<E>,
b2: Expression<E>,
nonlinear: Expression<E>,
out: Expression<E>,
) -> Result<(), CircuitBuilderError> {
system.lookup_and_byte(not8_expr(b1), b2, nonlinear.clone())?;
system.lookup_xor_byte(b0, nonlinear, out)
}

#[inline(always)]
fn record_chi_byte_lookups(lk_multiplicity: &mut LkMultiplicity, b0: u64, b1: u64, b2: u64) {
let andn = (0xFF - b1) & b2;
lk_multiplicity.lookup_and_byte(0xFF - b1, b2);
lk_multiplicity.lookup_xor_byte(b0, andn);
}

#[inline(always)]
fn split_mask_to_bytes(value: u64) -> [u64; 8] {
value.to_le_bytes().map(|b| b as u64)
Expand Down Expand Up @@ -290,26 +314,24 @@ impl<E: ExtensionField> ProtocolBuilder<E> for KeccakLayout<E> {
// c[i] = XOR (state[j][i]) for j in 0..5
// We unroll it into
// c_aux[i][j] = XOR (state[k][i]) for k in 0..j
// We use c_aux[i][4] instead of c[i]
// We store only j = 1..=4, i.e. c_aux[i][j - 1] for j in 1..=4.
// We use c_aux[i][3] instead of c[i].
// c_aux is also stored in 8-bit chunks
let c_aux: ArrayView<WitIn, Ix3> = ArrayView::from_shape((5, 5, 8), c_aux).unwrap();
let c_aux: ArrayView<WitIn, Ix3> = ArrayView::from_shape((5, 4, 8), c_aux).unwrap();

for i in 0..5 {
for k in 0..8 {
// Initialize first element
system.require_equal(
|| "init c_aux".to_string(),
state8[[0, i, k]].into(),
c_aux[[i, 0, k]].into(),
)?;
}
for j in 1..5 {
// Check xor using lookups over all chunks
Comment thread
hero78119 marked this conversation as resolved.
for k in 0..8 {
let prev = if j == 1 {
state8[[0, i, k]].into()
} else {
c_aux[[i, j - 2, k]].into()
};
system.lookup_xor_byte(
c_aux[[i, j - 1, k]].into(),
prev,
state8[[j, i, k]].into(),
c_aux[[i, j, k]].into(),
c_aux[[i, j - 1, k]].into(),
)?;
}
}
Expand All @@ -331,7 +353,7 @@ impl<E: ExtensionField> ProtocolBuilder<E> for KeccakLayout<E> {
system.require_left_rotation64(
|| format!("theta rotation_{i}"),
&c_aux
.slice(s![i, 4, ..])
.slice(s![i, 3, ..])
.iter()
.map(|e| e.expr())
.collect_vec(),
Expand All @@ -354,7 +376,7 @@ impl<E: ExtensionField> ProtocolBuilder<E> for KeccakLayout<E> {
for i in 0..5 {
for k in 0..8 {
system.lookup_xor_byte(
c_aux[[(i + 5 - 1) % 5, 4, k]].into(),
c_aux[[(i + 5 - 1) % 5, 3, k]].into(),
c_rot[[(i + 1) % 5, k]].into(),
d[[i, k]].into(),
)?;
Expand Down Expand Up @@ -390,28 +412,39 @@ impl<E: ExtensionField> ProtocolBuilder<E> for KeccakLayout<E> {
for i in 0..5 {
#[allow(clippy::needless_range_loop)]
for j in 0..5 {
let rot = ROTATION_CONSTANTS[j][i];
let arg = theta_output
.slice(s!(j, i, ..))
.iter()
.map(|e| e.expr())
.collect_vec();
let (sizes, _) = rotation_split(ROTATION_CONSTANTS[j][i]);
let many = sizes.len();
let rep_split = zip_eq(sizes, rotation_witness.by_ref().take(many))
.map(|(sz, wit)| (sz, wit.expr()))
.collect_vec();
let arg_rotated = rhopi_output
.slice(s!((2 * i + 3 * j) % 5, j, ..))
.iter()
.map(|e| e.expr())
.collect_vec();
system.require_left_rotation64(
|| format!("RHOPI {i}, {j}"),
&arg,
&rep_split,
&arg_rotated,
ROTATION_CONSTANTS[j][i],
)?;
if rot == 0 {
for (lhs, rhs) in izip!(arg.iter(), arg_rotated.iter()) {
system.require_equal(
|| format!("RHOPI identity {i}, {j}"),
lhs.expr(),
rhs.expr(),
)?;
}
} else {
let (sizes, _) = rotation_split(rot);
let many = sizes.len();
let rep_split = zip_eq(sizes, rotation_witness.by_ref().take(many))
.map(|(sz, wit)| (sz, wit.expr()))
.collect_vec();
system.require_left_rotation64(
|| format!("RHOPI {i}, {j}"),
&arg,
&rep_split,
&arg_rotated,
rot,
)?;
}
}
}
assert!(rotation_witness.next().is_none());
Expand All @@ -428,14 +461,11 @@ impl<E: ExtensionField> ProtocolBuilder<E> for KeccakLayout<E> {
for i in 0..5 {
for j in 0..5 {
for k in 0..8 {
system.lookup_and_byte(
not8_expr(rhopi_output[[j, (i + 1) % 5, k]].into()),
rhopi_output[[j, (i + 2) % 5, k]].into(),
nonlinear[[j, i, k]].into(),
)?;

system.lookup_xor_byte(
constrain_chi_byte(
system,
rhopi_output[[j, i, k]].into(),
rhopi_output[[j, (i + 1) % 5, k]].into(),
rhopi_output[[j, (i + 2) % 5, k]].into(),
nonlinear[[j, i, k]].into(),
chi_output[[j, i, k]].into(),
)?;
Expand Down Expand Up @@ -722,27 +752,35 @@ where
state8.into_iter().flatten().flatten(),
);

let mut c_aux64 = [[0u64; 5]; 5];
let mut c_aux8 = [[[0u64; 8]; 5]; 5];
let mut c_aux64 = [[0u64; 4]; 5];
let mut c_aux8 = [[[0u64; 8]; 4]; 5];

for i in 0..5 {
c_aux64[i][0] = state64[0][i];
c_aux8[i][0] = split_mask_to_array(c_aux64[i][0], &BYTE_SPLIT_SIZES);
for j in 1..5 {
c_aux64[i][j] = state64[j][i] ^ c_aux64[i][j - 1];
let prev64 = if j == 1 {
state64[0][i]
} else {
c_aux64[i][j - 2]
};
c_aux64[i][j - 1] = state64[j][i] ^ prev64;
for k in 0..8 {
lk_multiplicity
.lookup_xor_byte(c_aux8[i][j - 1][k], state8[j][i][k]);
let prev8 = if j == 1 {
state8[0][i][k]
} else {
c_aux8[i][j - 2][k]
};
lk_multiplicity.lookup_xor_byte(prev8, state8[j][i][k]);
}
c_aux8[i][j] = split_mask_to_array(c_aux64[i][j], &BYTE_SPLIT_SIZES);
c_aux8[i][j - 1] =
split_mask_to_array(c_aux64[i][j - 1], &BYTE_SPLIT_SIZES);
}
}

let mut c64 = [0u64; 5];
let mut c8 = [[0u64; 8]; 5];

for x in 0..5 {
c64[x] = c_aux64[x][4];
c64[x] = c_aux64[x][3];
c8[x] = split_mask_to_array(c64[x], &BYTE_SPLIT_SIZES);
}

Expand All @@ -768,7 +806,7 @@ where
d64[x] = c64[(x + 4) % 5] ^ c64[(x + 1) % 5].rotate_left(1);
for k in 0..8 {
lk_multiplicity.lookup_xor_byte(
c_aux8[(x + 5 - 1) % 5][4][k],
c_aux8[(x + 5 - 1) % 5][3][k],
crot8[(x + 1) % 5][k],
);
}
Expand All @@ -789,14 +827,18 @@ where
split_mask_to_array(theta_state64[y][x], &BYTE_SPLIT_SIZES);

let (sizes, _) = rotation_split(ROTATION_CONSTANTS[y][x]);
let rotation_chunks =
MaskRepresentation::from_mask(Mask::new(64, theta_state64[y][x]))
.convert(&sizes)
.values();
for (chunk, size) in rotation_chunks.iter().zip(sizes.iter()) {
lk_multiplicity.assert_const_range(*chunk, *size);
if ROTATION_CONSTANTS[y][x] != 0 {
let rotation_chunks = MaskRepresentation::from_mask(Mask::new(
64,
theta_state64[y][x],
))
.convert(&sizes)
.values();
for (chunk, size) in rotation_chunks.iter().zip(sizes.iter()) {
lk_multiplicity.assert_const_range(*chunk, *size);
}
rotation_witness.extend(rotation_chunks);
}
rotation_witness.extend(rotation_chunks);
}
}
assert_eq!(rotation_witness.len(), rotation_witness_witin.len());
Expand Down Expand Up @@ -827,8 +869,10 @@ where
nonlinear64[y][x] =
!rhopi_output64[y][(x + 1) % 5] & rhopi_output64[y][(x + 2) % 5];
for k in 0..8 {
lk_multiplicity.lookup_and_byte(
0xFF - rhopi_output8[y][(x + 1) % 5][k],
record_chi_byte_lookups(
&mut lk_multiplicity,
rhopi_output8[y][x][k],
rhopi_output8[y][(x + 1) % 5][k],
rhopi_output8[y][(x + 2) % 5][k],
);
}
Expand All @@ -842,10 +886,6 @@ where
for x in 0..5 {
for y in 0..5 {
chi_output64[y][x] = nonlinear64[y][x] ^ rhopi_output64[y][x];
for k in 0..8 {
lk_multiplicity
.lookup_xor_byte(rhopi_output8[y][x][k], nonlinear8[y][x][k]);
}
chi_output8[y][x] =
split_mask_to_array(chi_output64[y][x], &BYTE_SPLIT_SIZES);
}
Expand All @@ -856,9 +896,9 @@ where
let mut iota_output8 = [[[0u64; 8]; 5]; 5];
// TODO figure out how to deal with RC, since it's not a constant in rotation
iota_output64[0][0] ^= RC[round];
let rc8 = split_mask_to_array(RC[round], &BYTE_SPLIT_SIZES);

for k in 0..8 {
let rc8 = split_mask_to_array(RC[round], &BYTE_SPLIT_SIZES);
lk_multiplicity.lookup_xor_byte(chi_output8[0][0][k], rc8[k]);
}

Expand Down
Loading