diff --git a/Cargo.toml b/Cargo.toml index ae2734d..e99b8fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,7 @@ members = [ [dev-dependencies] solana-program = "2.3" serde_json = "1.0.143" + +[features] +default = ["solana"] +solana = [] \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index c8d259e..d854f10 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,9 +14,99 @@ // along with this program. If not, see . // // SPDX-License-Identifier: AGPL-3.0-or-later - //! WOTS+ (Winternitz One-Time Signature Plus) implementation + +enum SignatureBuffer { + + #[cfg(feature = "solana")] + Vec(Vec), + + #[cfg(not(feature = "solana"))] + Array { + buf: [u8; constants::SIGNATURE_SIZE], + len: usize, + }, +} + +impl SignatureBuffer { + fn new() -> Self { + #[cfg(feature = "solana")] + { + Self::Vec(Vec::with_capacity(constants::SIGNATURE_SIZE)) + } + + #[cfg(not(feature = "solana"))] + { + Self::Array { + buf: [0u8; constants::SIGNATURE_SIZE], + len: 0, + } + } + } + + fn push_slice(&mut self, data: &[u8]) { + #[cfg(feature = "solana")] + { + let Self::Vec(v) = self; + let new_len = v.len() + data.len(); + assert!( + new_len <= constants::SIGNATURE_SIZE, + "SignatureBuffer overflow" + ); + v.extend_from_slice(data); + } + + #[cfg(not(feature = "solana"))] + { + let Self::Array { buf, len } = self; + let end = *len + data.len(); + buf[*len..end].copy_from_slice(data); + *len = end; + } + } + + fn as_slice(&self) -> &[u8] { + #[cfg(feature = "solana")] + { + let Self::Vec(v) = self; + v.as_slice() + } + + #[cfg(not(feature = "solana"))] + { + let Self::Array { buf, len } = self; + &buf[..*len] + } + } + + /// Consume the buffer and return its contents as a Vec + fn as_signature_chunks(self) -> Vec<[u8; constants::HASH_LEN]> { + let slice = match self { + #[cfg(feature = "solana")] + SignatureBuffer::Vec(v) => v, + + #[cfg(not(feature = "solana"))] + SignatureBuffer::Array { buf, len } => buf[..len].to_vec(), + }; + // Enforce chunk alignment invariant + assert!( + slice.len() % constants::HASH_LEN == 0, + "SignatureBuffer length is not chunk-aligned" + ); + + slice + .chunks_exact(constants::HASH_LEN) + .map(|chunk| { + let mut arr = [0u8; constants::HASH_LEN]; + arr.copy_from_slice(chunk); + arr + }) + .collect() + } + +} + /// Hash function type for WOTS+ type HashFn = fn(&[u8]) -> [u8; 32]; @@ -167,11 +257,13 @@ impl WOTSPlus { &self, public_seed: &[u8; constants::HASH_LEN] ) -> Vec<[u8; constants::HASH_LEN]> { - let mut elements = Vec::with_capacity(constants::NUM_SIGNATURE_CHUNKS); + + let mut elements = SignatureBuffer::new(); for i in 0..constants::NUM_SIGNATURE_CHUNKS { - elements.push(self.prf(public_seed, i as u16)); + elements.push_slice(&self.prf(public_seed, i as u16)); } - elements + elements.as_signature_chunks() + } /// XOR two 32-byte arrays @@ -252,7 +344,7 @@ impl WOTSPlus { let randomization_elements = self.generate_randomization_elements(&public_seed); let function_key = randomization_elements[0]; - let mut public_key_segments = Vec::with_capacity(constants::SIGNATURE_SIZE); + let mut public_key_segments = SignatureBuffer::new(); for i in 0..constants::NUM_SIGNATURE_CHUNKS { let mut to_hash = vec![0u8; constants::HASH_LEN * 2]; @@ -267,10 +359,10 @@ impl WOTSPlus { (constants::CHAIN_LEN - 1) as u16, ); - public_key_segments.extend_from_slice(&segment); + public_key_segments.push_slice(&segment); } - let public_key_hash = (self.hash_fn)(&public_key_segments); + let public_key_hash = (self.hash_fn)(public_key_segments.as_slice()); PublicKey { public_seed: *public_seed, @@ -312,7 +404,7 @@ impl WOTSPlus { let function_key = randomization_elements[0]; let chain_segments = self.compute_message_hash_chain_indexes(message); - let mut signature = Vec::with_capacity(constants::NUM_SIGNATURE_CHUNKS); + let mut signature = SignatureBuffer::new(); for (i, &chain_idx) in chain_segments.iter().enumerate() { let mut to_hash = vec![0u8; constants::HASH_LEN * 2]; @@ -326,10 +418,10 @@ impl WOTSPlus { 0, chain_idx as u16, ); - signature.push(sig_segment); + signature.push_slice(&sig_segment); } - signature + signature.as_signature_chunks() } /// Verify a WOTS+ signature @@ -353,7 +445,7 @@ impl WOTSPlus { let chain_segments = self.compute_message_hash_chain_indexes(message); - let mut public_key_segments = Vec::with_capacity(constants::SIGNATURE_SIZE); + let mut public_key_segments = SignatureBuffer::new(); // Compute each public key segment. These are done by taking the signature, which is prevChainOut at chainIdx, // and completing the hash chain via the chain function to recompute the public key segment. @@ -366,11 +458,11 @@ impl WOTSPlus { num_iterations, ); - public_key_segments.extend_from_slice(&segment); + public_key_segments.push_slice(&segment); } // Hash all public key segments together to recreate the original public key - let computed_hash = (self.hash_fn)(&public_key_segments); + let computed_hash = (self.hash_fn)(public_key_segments.as_slice()); // Compare computed hash with stored public key hash computed_hash == public_key.public_key_hash @@ -397,7 +489,7 @@ impl WOTSPlus { } let chain_segments = self.compute_message_hash_chain_indexes(message); - let mut public_key_segments = [0u8; constants::SIGNATURE_SIZE]; + let mut public_key_segments = SignatureBuffer::new(); // Compute each public key segment using the pre-computed randomization elements for (i, &chain_idx) in chain_segments.iter().enumerate() { @@ -409,12 +501,12 @@ impl WOTSPlus { num_iterations, ); - let offset = i * constants::HASH_LEN; - public_key_segments[offset..offset + constants::HASH_LEN].copy_from_slice(&segment); + // let offset = i * constants::HASH_LEN; + public_key_segments.push_slice(&segment); } // Hash all public key segments together and compare with the provided hash - let computed_hash = (self.hash_fn)(&public_key_segments); + let computed_hash = (self.hash_fn)(public_key_segments.as_slice()); computed_hash == *public_key_hash } } @@ -490,6 +582,61 @@ mod tests { assert_eq!(recovered.public_key_hash, public_key.public_key_hash); } + #[test] + fn signatures_are_deterministic() { + let wots = WOTSPlus::new(mock_hash); + + let seed = [9u8; 32]; + let (_, sk) = wots.generate_key_pair(&seed); + let msg = [1u8; constants::MESSAGE_LEN]; + + let sig1 = wots.sign(&sk, &msg); + let sig2 = wots.sign(&sk, &msg); + + assert_eq!(sig1, sig2); + } + + + #[test] + fn sigbuf_appends_correctly() { + let mut buf = SignatureBuffer::new(); + + let a = [1u8; constants::HASH_LEN]; + let b = [2u8; constants::HASH_LEN]; + + buf.push_slice(&a); + buf.push_slice(&b); + + let out = buf.as_slice(); + assert_eq!(out.len(), 2 * constants::HASH_LEN); + assert_eq!(&out[..constants::HASH_LEN], &a); + assert_eq!(&out[constants::HASH_LEN..], &b); + } + + #[cfg(not(feature = "solana"))] + #[test] + #[should_panic] + fn sigbuf_panics_on_overflow_non_solana() { + let mut buf = SignatureBuffer::new(); + let chunk = [0u8; constants::HASH_LEN]; + + for _ in 0..(constants::NUM_SIGNATURE_CHUNKS + 1) { + buf.push_slice(&chunk); + } + } + + #[cfg(feature = "solana")] + #[test] + #[should_panic] + fn sigbuf_panics_on_overflow_solana() { + let mut buf = SignatureBuffer::new(); + let chunk = [0u8; constants::HASH_LEN]; + + for _ in 0..(constants::NUM_SIGNATURE_CHUNKS + 1) { + buf.push_slice(&chunk); + } + } + #[cfg(test)] mod tests { use super::*;