Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ members = [
[dev-dependencies]
solana-program = "2.3"
serde_json = "1.0.143"

[features]
default = ["solana"]
solana = []
181 changes: 164 additions & 17 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,99 @@
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//
// SPDX-License-Identifier: AGPL-3.0-or-later

//! WOTS+ (Winternitz One-Time Signature Plus) implementation


enum SignatureBuffer {

#[cfg(feature = "solana")]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SignatureBuffer’s #[cfg(feature = "solana")] variant stores [u8; constants::SIGNATURE_SIZE] inline, which is a stack allocation and seems to contradict the PR description/goal of avoiding large Solana stack frames (and it replaces prior Vec usages with stack storage). Was the cfg intended to be inverted (heap-backed on solana, fixed-capacity on non-Solana)?

Fix This in Augment

🤖 Was this useful? React with 👍 or 👎

Vec(Vec<u8>),

#[cfg(not(feature = "solana"))]
Array {
buf: [u8; constants::SIGNATURE_SIZE],
len: usize,
},
}

impl SignatureBuffer {
fn new() -> Self {
#[cfg(feature = "solana")]
{
Self::Vec(Vec::with_capacity(constants::SIGNATURE_SIZE))
}

#[cfg(not(feature = "solana"))]
{
Self::Array {
buf: [0u8; constants::SIGNATURE_SIZE],
len: 0,
}
}
}

fn push_slice(&mut self, data: &[u8]) {
#[cfg(feature = "solana")]
{
let Self::Vec(v) = self;
let new_len = v.len() + data.len();
assert!(
new_len <= constants::SIGNATURE_SIZE,
"SignatureBuffer overflow"
);
v.extend_from_slice(data);
}

#[cfg(not(feature = "solana"))]
{
let Self::Array { buf, len } = self;
let end = *len + data.len();
buf[*len..end].copy_from_slice(data);
*len = end;
}
}

fn as_slice(&self) -> &[u8] {
#[cfg(feature = "solana")]
{
let Self::Vec(v) = self;
v.as_slice()
}

#[cfg(not(feature = "solana"))]
{
let Self::Array { buf, len } = self;
&buf[..*len]
}
}

/// Consume the buffer and return its contents as a Vec<u8>
fn as_signature_chunks(self) -> Vec<[u8; constants::HASH_LEN]> {
let slice = match self {
#[cfg(feature = "solana")]
SignatureBuffer::Vec(v) => v,

#[cfg(not(feature = "solana"))]
SignatureBuffer::Array { buf, len } => buf[..len].to_vec(),
};
// Enforce chunk alignment invariant
assert!(
slice.len() % constants::HASH_LEN == 0,
"SignatureBuffer length is not chunk-aligned"
);

slice
.chunks_exact(constants::HASH_LEN)
.map(|chunk| {
let mut arr = [0u8; constants::HASH_LEN];
arr.copy_from_slice(chunk);
arr
})
.collect()
}

}

/// Hash function type for WOTS+
type HashFn = fn(&[u8]) -> [u8; 32];

Expand Down Expand Up @@ -167,11 +257,13 @@ impl WOTSPlus {
&self,
public_seed: &[u8; constants::HASH_LEN]
) -> Vec<[u8; constants::HASH_LEN]> {
let mut elements = Vec::with_capacity(constants::NUM_SIGNATURE_CHUNKS);

let mut elements = SignatureBuffer::new();
for i in 0..constants::NUM_SIGNATURE_CHUNKS {
elements.push(self.prf(public_seed, i as u16));
elements.push_slice(&self.prf(public_seed, i as u16));
}
elements
elements.as_signature_chunks()

}

/// XOR two 32-byte arrays
Expand Down Expand Up @@ -252,7 +344,7 @@ impl WOTSPlus {
let randomization_elements = self.generate_randomization_elements(&public_seed);
let function_key = randomization_elements[0];

let mut public_key_segments = Vec::with_capacity(constants::SIGNATURE_SIZE);
let mut public_key_segments = SignatureBuffer::new();

for i in 0..constants::NUM_SIGNATURE_CHUNKS {
let mut to_hash = vec![0u8; constants::HASH_LEN * 2];
Expand All @@ -267,10 +359,10 @@ impl WOTSPlus {
(constants::CHAIN_LEN - 1) as u16,
);

public_key_segments.extend_from_slice(&segment);
public_key_segments.push_slice(&segment);
}

let public_key_hash = (self.hash_fn)(&public_key_segments);
let public_key_hash = (self.hash_fn)(public_key_segments.as_slice());

PublicKey {
public_seed: *public_seed,
Expand Down Expand Up @@ -312,7 +404,7 @@ impl WOTSPlus {
let function_key = randomization_elements[0];

let chain_segments = self.compute_message_hash_chain_indexes(message);
let mut signature = Vec::with_capacity(constants::NUM_SIGNATURE_CHUNKS);
let mut signature = SignatureBuffer::new();

for (i, &chain_idx) in chain_segments.iter().enumerate() {
let mut to_hash = vec![0u8; constants::HASH_LEN * 2];
Expand All @@ -326,10 +418,10 @@ impl WOTSPlus {
0,
chain_idx as u16,
);
signature.push(sig_segment);
signature.push_slice(&sig_segment);
}

signature
signature.as_signature_chunks()
}

/// Verify a WOTS+ signature
Expand All @@ -353,7 +445,7 @@ impl WOTSPlus {

let chain_segments = self.compute_message_hash_chain_indexes(message);

let mut public_key_segments = Vec::with_capacity(constants::SIGNATURE_SIZE);
let mut public_key_segments = SignatureBuffer::new();

// Compute each public key segment. These are done by taking the signature, which is prevChainOut at chainIdx,
// and completing the hash chain via the chain function to recompute the public key segment.
Expand All @@ -366,11 +458,11 @@ impl WOTSPlus {
num_iterations,
);

public_key_segments.extend_from_slice(&segment);
public_key_segments.push_slice(&segment);
}

// Hash all public key segments together to recreate the original public key
let computed_hash = (self.hash_fn)(&public_key_segments);
let computed_hash = (self.hash_fn)(public_key_segments.as_slice());

// Compare computed hash with stored public key hash
computed_hash == public_key.public_key_hash
Expand All @@ -397,7 +489,7 @@ impl WOTSPlus {
}

let chain_segments = self.compute_message_hash_chain_indexes(message);
let mut public_key_segments = [0u8; constants::SIGNATURE_SIZE];
let mut public_key_segments = SignatureBuffer::new();

// Compute each public key segment using the pre-computed randomization elements
for (i, &chain_idx) in chain_segments.iter().enumerate() {
Expand All @@ -409,12 +501,12 @@ impl WOTSPlus {
num_iterations,
);

let offset = i * constants::HASH_LEN;
public_key_segments[offset..offset + constants::HASH_LEN].copy_from_slice(&segment);
// let offset = i * constants::HASH_LEN;
public_key_segments.push_slice(&segment);
}

// Hash all public key segments together and compare with the provided hash
let computed_hash = (self.hash_fn)(&public_key_segments);
let computed_hash = (self.hash_fn)(public_key_segments.as_slice());
computed_hash == *public_key_hash
}
}
Expand Down Expand Up @@ -490,6 +582,61 @@ mod tests {
assert_eq!(recovered.public_key_hash, public_key.public_key_hash);
}

#[test]
fn signatures_are_deterministic() {
let wots = WOTSPlus::new(mock_hash);

let seed = [9u8; 32];
let (_, sk) = wots.generate_key_pair(&seed);
let msg = [1u8; constants::MESSAGE_LEN];

let sig1 = wots.sign(&sk, &msg);
let sig2 = wots.sign(&sk, &msg);

assert_eq!(sig1, sig2);
}


#[test]
fn sigbuf_appends_correctly() {
let mut buf = SignatureBuffer::new();

let a = [1u8; constants::HASH_LEN];
let b = [2u8; constants::HASH_LEN];

buf.push_slice(&a);
buf.push_slice(&b);

let out = buf.as_slice();
assert_eq!(out.len(), 2 * constants::HASH_LEN);
assert_eq!(&out[..constants::HASH_LEN], &a);
assert_eq!(&out[constants::HASH_LEN..], &b);
}

#[cfg(not(feature = "solana"))]
#[test]
#[should_panic]
fn sigbuf_panics_on_overflow_non_solana() {
let mut buf = SignatureBuffer::new();
let chunk = [0u8; constants::HASH_LEN];

for _ in 0..(constants::NUM_SIGNATURE_CHUNKS + 1) {
buf.push_slice(&chunk);
}
}

#[cfg(feature = "solana")]
#[test]
#[should_panic]
fn sigbuf_panics_on_overflow_solana() {
let mut buf = SignatureBuffer::new();
let chunk = [0u8; constants::HASH_LEN];

for _ in 0..(constants::NUM_SIGNATURE_CHUNKS + 1) {
buf.push_slice(&chunk);
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down