diff --git a/Cargo.toml b/Cargo.toml
index ae2734d..e99b8fb 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -25,3 +25,7 @@ members = [
[dev-dependencies]
solana-program = "2.3"
serde_json = "1.0.143"
+
+[features]
+default = ["solana"]
+solana = []
\ No newline at end of file
diff --git a/src/lib.rs b/src/lib.rs
index c8d259e..d854f10 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -14,9 +14,99 @@
// along with this program. If not, see .
//
// SPDX-License-Identifier: AGPL-3.0-or-later
-
//! WOTS+ (Winternitz One-Time Signature Plus) implementation
+
+enum SignatureBuffer {
+
+ #[cfg(feature = "solana")]
+ Vec(Vec),
+
+ #[cfg(not(feature = "solana"))]
+ Array {
+ buf: [u8; constants::SIGNATURE_SIZE],
+ len: usize,
+ },
+}
+
+impl SignatureBuffer {
+ fn new() -> Self {
+ #[cfg(feature = "solana")]
+ {
+ Self::Vec(Vec::with_capacity(constants::SIGNATURE_SIZE))
+ }
+
+ #[cfg(not(feature = "solana"))]
+ {
+ Self::Array {
+ buf: [0u8; constants::SIGNATURE_SIZE],
+ len: 0,
+ }
+ }
+ }
+
+ fn push_slice(&mut self, data: &[u8]) {
+ #[cfg(feature = "solana")]
+ {
+ let Self::Vec(v) = self;
+ let new_len = v.len() + data.len();
+ assert!(
+ new_len <= constants::SIGNATURE_SIZE,
+ "SignatureBuffer overflow"
+ );
+ v.extend_from_slice(data);
+ }
+
+ #[cfg(not(feature = "solana"))]
+ {
+ let Self::Array { buf, len } = self;
+ let end = *len + data.len();
+ buf[*len..end].copy_from_slice(data);
+ *len = end;
+ }
+ }
+
+ fn as_slice(&self) -> &[u8] {
+ #[cfg(feature = "solana")]
+ {
+ let Self::Vec(v) = self;
+ v.as_slice()
+ }
+
+ #[cfg(not(feature = "solana"))]
+ {
+ let Self::Array { buf, len } = self;
+ &buf[..*len]
+ }
+ }
+
+ /// Consume the buffer and return its contents as a Vec
+ fn as_signature_chunks(self) -> Vec<[u8; constants::HASH_LEN]> {
+ let slice = match self {
+ #[cfg(feature = "solana")]
+ SignatureBuffer::Vec(v) => v,
+
+ #[cfg(not(feature = "solana"))]
+ SignatureBuffer::Array { buf, len } => buf[..len].to_vec(),
+ };
+ // Enforce chunk alignment invariant
+ assert!(
+ slice.len() % constants::HASH_LEN == 0,
+ "SignatureBuffer length is not chunk-aligned"
+ );
+
+ slice
+ .chunks_exact(constants::HASH_LEN)
+ .map(|chunk| {
+ let mut arr = [0u8; constants::HASH_LEN];
+ arr.copy_from_slice(chunk);
+ arr
+ })
+ .collect()
+ }
+
+}
+
/// Hash function type for WOTS+
type HashFn = fn(&[u8]) -> [u8; 32];
@@ -167,11 +257,13 @@ impl WOTSPlus {
&self,
public_seed: &[u8; constants::HASH_LEN]
) -> Vec<[u8; constants::HASH_LEN]> {
- let mut elements = Vec::with_capacity(constants::NUM_SIGNATURE_CHUNKS);
+
+ let mut elements = SignatureBuffer::new();
for i in 0..constants::NUM_SIGNATURE_CHUNKS {
- elements.push(self.prf(public_seed, i as u16));
+ elements.push_slice(&self.prf(public_seed, i as u16));
}
- elements
+ elements.as_signature_chunks()
+
}
/// XOR two 32-byte arrays
@@ -252,7 +344,7 @@ impl WOTSPlus {
let randomization_elements = self.generate_randomization_elements(&public_seed);
let function_key = randomization_elements[0];
- let mut public_key_segments = Vec::with_capacity(constants::SIGNATURE_SIZE);
+ let mut public_key_segments = SignatureBuffer::new();
for i in 0..constants::NUM_SIGNATURE_CHUNKS {
let mut to_hash = vec![0u8; constants::HASH_LEN * 2];
@@ -267,10 +359,10 @@ impl WOTSPlus {
(constants::CHAIN_LEN - 1) as u16,
);
- public_key_segments.extend_from_slice(&segment);
+ public_key_segments.push_slice(&segment);
}
- let public_key_hash = (self.hash_fn)(&public_key_segments);
+ let public_key_hash = (self.hash_fn)(public_key_segments.as_slice());
PublicKey {
public_seed: *public_seed,
@@ -312,7 +404,7 @@ impl WOTSPlus {
let function_key = randomization_elements[0];
let chain_segments = self.compute_message_hash_chain_indexes(message);
- let mut signature = Vec::with_capacity(constants::NUM_SIGNATURE_CHUNKS);
+ let mut signature = SignatureBuffer::new();
for (i, &chain_idx) in chain_segments.iter().enumerate() {
let mut to_hash = vec![0u8; constants::HASH_LEN * 2];
@@ -326,10 +418,10 @@ impl WOTSPlus {
0,
chain_idx as u16,
);
- signature.push(sig_segment);
+ signature.push_slice(&sig_segment);
}
- signature
+ signature.as_signature_chunks()
}
/// Verify a WOTS+ signature
@@ -353,7 +445,7 @@ impl WOTSPlus {
let chain_segments = self.compute_message_hash_chain_indexes(message);
- let mut public_key_segments = Vec::with_capacity(constants::SIGNATURE_SIZE);
+ let mut public_key_segments = SignatureBuffer::new();
// Compute each public key segment. These are done by taking the signature, which is prevChainOut at chainIdx,
// and completing the hash chain via the chain function to recompute the public key segment.
@@ -366,11 +458,11 @@ impl WOTSPlus {
num_iterations,
);
- public_key_segments.extend_from_slice(&segment);
+ public_key_segments.push_slice(&segment);
}
// Hash all public key segments together to recreate the original public key
- let computed_hash = (self.hash_fn)(&public_key_segments);
+ let computed_hash = (self.hash_fn)(public_key_segments.as_slice());
// Compare computed hash with stored public key hash
computed_hash == public_key.public_key_hash
@@ -397,7 +489,7 @@ impl WOTSPlus {
}
let chain_segments = self.compute_message_hash_chain_indexes(message);
- let mut public_key_segments = [0u8; constants::SIGNATURE_SIZE];
+ let mut public_key_segments = SignatureBuffer::new();
// Compute each public key segment using the pre-computed randomization elements
for (i, &chain_idx) in chain_segments.iter().enumerate() {
@@ -409,12 +501,12 @@ impl WOTSPlus {
num_iterations,
);
- let offset = i * constants::HASH_LEN;
- public_key_segments[offset..offset + constants::HASH_LEN].copy_from_slice(&segment);
+ // let offset = i * constants::HASH_LEN;
+ public_key_segments.push_slice(&segment);
}
// Hash all public key segments together and compare with the provided hash
- let computed_hash = (self.hash_fn)(&public_key_segments);
+ let computed_hash = (self.hash_fn)(public_key_segments.as_slice());
computed_hash == *public_key_hash
}
}
@@ -490,6 +582,61 @@ mod tests {
assert_eq!(recovered.public_key_hash, public_key.public_key_hash);
}
+ #[test]
+ fn signatures_are_deterministic() {
+ let wots = WOTSPlus::new(mock_hash);
+
+ let seed = [9u8; 32];
+ let (_, sk) = wots.generate_key_pair(&seed);
+ let msg = [1u8; constants::MESSAGE_LEN];
+
+ let sig1 = wots.sign(&sk, &msg);
+ let sig2 = wots.sign(&sk, &msg);
+
+ assert_eq!(sig1, sig2);
+ }
+
+
+ #[test]
+ fn sigbuf_appends_correctly() {
+ let mut buf = SignatureBuffer::new();
+
+ let a = [1u8; constants::HASH_LEN];
+ let b = [2u8; constants::HASH_LEN];
+
+ buf.push_slice(&a);
+ buf.push_slice(&b);
+
+ let out = buf.as_slice();
+ assert_eq!(out.len(), 2 * constants::HASH_LEN);
+ assert_eq!(&out[..constants::HASH_LEN], &a);
+ assert_eq!(&out[constants::HASH_LEN..], &b);
+ }
+
+ #[cfg(not(feature = "solana"))]
+ #[test]
+ #[should_panic]
+ fn sigbuf_panics_on_overflow_non_solana() {
+ let mut buf = SignatureBuffer::new();
+ let chunk = [0u8; constants::HASH_LEN];
+
+ for _ in 0..(constants::NUM_SIGNATURE_CHUNKS + 1) {
+ buf.push_slice(&chunk);
+ }
+ }
+
+ #[cfg(feature = "solana")]
+ #[test]
+ #[should_panic]
+ fn sigbuf_panics_on_overflow_solana() {
+ let mut buf = SignatureBuffer::new();
+ let chunk = [0u8; constants::HASH_LEN];
+
+ for _ in 0..(constants::NUM_SIGNATURE_CHUNKS + 1) {
+ buf.push_slice(&chunk);
+ }
+ }
+
#[cfg(test)]
mod tests {
use super::*;