diff --git a/crates/sdk/src/program/core.rs b/crates/sdk/src/program/core.rs index 45c1304..3625355 100644 --- a/crates/sdk/src/program/core.rs +++ b/crates/sdk/src/program/core.rs @@ -41,6 +41,8 @@ pub trait ProgramTrait: DynClone { input_index: usize, network: &SimplicityNetwork, ) -> Result>, ProgramError>; + + fn load(&self) -> Result; } #[derive(Clone)] @@ -142,6 +144,10 @@ impl ProgramTrait for Program { self.control_block()?.serialize(), ]) } + + fn load(&self) -> Result { + self.load() + } } impl Program { @@ -173,7 +179,7 @@ impl Program { hash_script(&self.get_script_pubkey(network)) } - fn load(&self) -> Result { + pub fn load(&self) -> Result { let compiled = CompiledProgram::new(self.source, self.arguments.build_arguments(), true) .map_err(ProgramError::Compilation)?; Ok(compiled) diff --git a/crates/sdk/src/signer/core.rs b/crates/sdk/src/signer/core.rs index 045fb15..741d7df 100644 --- a/crates/sdk/src/signer/core.rs +++ b/crates/sdk/src/signer/core.rs @@ -1,5 +1,6 @@ use std::collections::{HashMap, HashSet}; use std::str::FromStr; +use std::sync::Arc; use simplicityhl::Value; use simplicityhl::WitnessValues; @@ -31,6 +32,7 @@ use crate::constants::MIN_FEE; use crate::program::ProgramTrait; use crate::provider::ProviderTrait; use crate::provider::SimplicityNetwork; +use crate::signer::wtns_injector::WtnsInjector; use crate::transaction::{FinalTransaction, PartialInput, PartialOutput, RequiredSignature, UTXO}; use super::error::SignerError; @@ -418,18 +420,24 @@ impl Signer { for (index, input_i) in inputs.iter().enumerate() { // we need to prune the program if let Some(program_input) = &input_i.program_input { - let signed_witness: Result = match &input_i.required_sig { - // sign the program and insert the signature into the witness - RequiredSignature::Witness(witness_name) => Ok(self.get_signed_program_witness( + let signing_info: Option<(&String, &[String])> = match &input_i.required_sig { + RequiredSignature::Witness(wnts_name) => Some((wnts_name, &[])), + RequiredSignature::WitnessWithPath(wnts_name, sig_path) => Some((wnts_name, sig_path)), + _ => None, + }; + + let signed_witness: Result = match signing_info { + Some((witness_name, sig_path)) => Ok(self.get_signed_program_witness( &pst, program_input.program.as_ref(), &program_input.witness.build_witness(), witness_name, + sig_path, index, )?), - // just build the passed witness - _ => Ok(program_input.witness.build_witness()), + None => Ok(program_input.witness.build_witness()), }; + let pruned_witness = program_input .program .finalize(&pst, &signed_witness.unwrap(), index, &self.network) @@ -455,20 +463,42 @@ impl Signer { program: &dyn ProgramTrait, witness: &WitnessValues, witness_name: &str, + sig_path: &[String], index: usize, ) -> Result { let signature = self.sign_program(pst, program, index, &self.network)?; + // put signature right after wtns field name if path is not provided + let sig_val = if !sig_path.is_empty() { + let wtns_injector = WtnsInjector::new(sig_path)?; + + let compiled = program.load().map_err(SignerError::Program)?; + + let abi_meta = compiled.generate_abi_meta().map_err(SignerError::ProgramGenAbiMeta)?; + + let witness_types = abi_meta + .witness_types + .get(&WitnessName::from_str_unchecked(witness_name)) + .ok_or(SignerError::WtnsFieldNotFound(witness_name.to_string()))?; + + let local_wtns = Arc::new( + witness + .get(&WitnessName::from_str_unchecked(witness_name)) + .expect("checked above") + .clone(), + ); + + wtns_injector.inject_value(&local_wtns, witness_types, Value::byte_array(signature.serialize()))? + } else { + Value::byte_array(signature.serialize()) + }; let mut hm = HashMap::new(); witness.iter().for_each(|el| { hm.insert(el.0.clone(), el.1.clone()); }); - hm.insert( - WitnessName::from_str_unchecked(witness_name), - Value::byte_array(signature.serialize()), - ); + hm.insert(WitnessName::from_str_unchecked(witness_name), sig_val); Ok(WitnessValues::from(hm)) } diff --git a/crates/sdk/src/signer/error.rs b/crates/sdk/src/signer/error.rs index 7293936..d119780 100644 --- a/crates/sdk/src/signer/error.rs +++ b/crates/sdk/src/signer/error.rs @@ -53,4 +53,25 @@ pub enum SignerError { #[error("Failed to construct a wpkh address: {0}")] WpkhAddressConstruction(#[from] elements_miniscript::Error), + + #[error("Failed to obtain program witness types: {0}")] + ProgramGenAbiMeta(String), + + #[error("Missing such witness field: {0}")] + WtnsFieldNotFound(String), + + #[error(transparent)] + WtnsInjectError(#[from] WtnsWrappingError), +} + +#[derive(Debug, thiserror::Error)] +pub enum WtnsWrappingError { + #[error("Failed to parse path")] + ParsingError, + #[error("Unsupported path type: {0}")] + UnsupportedPathType(String), + #[error("Path index out of bounds: len is {0}, got {1}")] + IdxOutOfBounds(usize, usize), + #[error("Root type mismatch: expected {0}, got {1}")] + RootTypeMismatch(String, String), } diff --git a/crates/sdk/src/signer/mod.rs b/crates/sdk/src/signer/mod.rs index 97c19b1..9cdcfcd 100644 --- a/crates/sdk/src/signer/mod.rs +++ b/crates/sdk/src/signer/mod.rs @@ -1,5 +1,6 @@ pub mod core; pub mod error; +mod wtns_injector; pub use core::{Signer, SignerTrait}; pub use error::SignerError; diff --git a/crates/sdk/src/signer/wtns_injector.rs b/crates/sdk/src/signer/wtns_injector.rs new file mode 100644 index 0000000..aa619c6 --- /dev/null +++ b/crates/sdk/src/signer/wtns_injector.rs @@ -0,0 +1,212 @@ +use std::sync::Arc; + +use simplicityhl::{ + ResolvedType, Value, + types::TypeInner, + value::{ValueConstructible, ValueInner}, +}; + +use crate::signer::error::WtnsWrappingError; + +/// Struct for injecting specific value by given path into witness value +#[derive(Clone)] +pub struct WtnsInjector { + path: Vec, +} + +impl WtnsInjector { + /// ## Usage + /// ```rust,ignore + /// // .simf script + /// match witness::SOMETHING { + /// Left(x: u64) => ..., + /// Right([y, z]: [u64, u64]) => ... + /// } + /// // path for each variable + /// vec!["Left"] // for x + /// vec!["Right", "0"] // for y + /// vec!["Right", "1"] // for z + /// ``` + pub fn new(path: &[String]) -> Result { + let parsed_path = path + .iter() + .map(|route| match route.as_str() { + "Left" => Ok(WtnsPathRoute::Either(EitherRoute::Left)), + "Right" => Ok(WtnsPathRoute::Either(EitherRoute::Right)), + s => s + .parse::() + .map(|n| WtnsPathRoute::Enumerable(EnumerableRoute(n))) + .map_err(|_| WtnsWrappingError::ParsingError), + }) + .collect::, _>>()?; + + Ok(Self { path: parsed_path }) + } + + /// Constructs new value by intjecting given value into witness at the position described by `path`. + /// Consistency between `witness` and `witness_types` should be guaranteed by caller. + pub fn inject_value( + &self, + witness: &Arc, + witness_types: &ResolvedType, + value: Value, + ) -> Result { + enum StackItem { + Either(EitherRoute, Arc), + Array(EnumerableRoute, Arc, Arc<[Value]>), + Tuple(EnumerableRoute, Arc<[Value]>), + } + + // invocations of these functions below determined from types during traversal + // matches! guard at top of loop guarantees that types and routes are consistent + fn downcast_either(val: &Value, direction: EitherRoute) -> Arc { + match (direction, val.inner()) { + (EitherRoute::Left, ValueInner::Either(either)) => Arc::clone(either.as_ref().unwrap_left()), + (EitherRoute::Right, ValueInner::Either(either)) => Arc::clone(either.as_ref().unwrap_right()), + _ => unreachable!(), + } + } + + fn downcast_array(val: &Value) -> Arc<[Value]> { + match val.inner() { + ValueInner::Array(arr) => Arc::clone(arr), + _ => unreachable!(), + } + } + + fn downcast_tuple(val: &Value) -> Arc<[Value]> { + match val.inner() { + ValueInner::Tuple(arr) => Arc::clone(arr), + _ => unreachable!(), + } + } + + let mut stack = Vec::new(); + let mut current_val = Arc::clone(witness); + let mut current_ty = witness_types; + + for route in self.path.iter() { + if !matches!( + (route, current_ty.as_inner()), + (WtnsPathRoute::Enumerable(_), TypeInner::Array(_, _)) + | (WtnsPathRoute::Enumerable(_), TypeInner::Tuple(_)) + | (WtnsPathRoute::Either(_), TypeInner::Either(_, _)) + ) { + return Err(WtnsWrappingError::UnsupportedPathType(current_ty.to_string())); + } + + match current_ty.as_inner() { + TypeInner::Either(left_ty, right_ty) => { + let direction: EitherRoute = (*route).try_into().expect("Checked in matches! above"); + match direction { + EitherRoute::Left => { + stack.push(StackItem::Either(direction, Arc::clone(right_ty))); + current_ty = left_ty; + } + EitherRoute::Right => { + stack.push(StackItem::Either(direction, Arc::clone(left_ty))); + current_ty = right_ty; + } + } + current_val = downcast_either(¤t_val, direction); + } + TypeInner::Array(ty, len) => { + let idx: EnumerableRoute = (*route).try_into().expect("Checked in matches! above"); + + if idx.0 >= *len { + return Err(WtnsWrappingError::IdxOutOfBounds(*len, idx.0)); + } + + let arr_val = downcast_array(¤t_val); + + stack.push(StackItem::Array(idx, Arc::clone(ty), Arc::clone(&arr_val))); + + current_ty = ty; + current_val = Arc::new(arr_val[idx.0].clone()); + } + TypeInner::Tuple(tuple) => { + let idx: EnumerableRoute = (*route).try_into().expect("Checked in matches! above"); + + if idx.0 >= tuple.len() { + return Err(WtnsWrappingError::IdxOutOfBounds(tuple.len(), idx.0)); + } + + let tuple_val = downcast_tuple(¤t_val); + + stack.push(StackItem::Tuple(idx, Arc::clone(&tuple_val))); + + current_ty = &tuple[idx.0]; + current_val = Arc::new(tuple_val[idx.0].clone()); + } + _ => unreachable!("checked at the top of loop"), + } + } + + if value.ty() != current_ty { + return Err(WtnsWrappingError::RootTypeMismatch( + current_ty.to_string(), + value.ty().to_string(), + )); + } + + let mut value = value; + + for item in stack.into_iter().rev() { + value = match item { + StackItem::Either(direction, sibling_ty) => match direction { + EitherRoute::Left => Value::left(value, (*sibling_ty).clone()), + EitherRoute::Right => Value::right((*sibling_ty).clone(), value), + }, + StackItem::Array(idx, elem_ty, arr) => { + let mut elements = arr.to_vec(); + elements[idx.0] = value; + Value::array(elements, (*elem_ty).clone()) + } + StackItem::Tuple(idx, tuple_vals) => { + let mut elements = tuple_vals.to_vec(); + elements[idx.0] = value; + Value::tuple(elements) + } + }; + } + + Ok(value) + } +} + +#[derive(Clone, Copy, Debug)] +pub enum WtnsPathRoute { + Either(EitherRoute), + Enumerable(EnumerableRoute), +} + +impl TryInto for WtnsPathRoute { + type Error = WtnsPathRoute; + + fn try_into(self) -> Result { + match self { + Self::Either(direction) => Ok(direction), + _ => Err(self), + } + } +} + +impl TryInto for WtnsPathRoute { + type Error = WtnsPathRoute; + + fn try_into(self) -> Result { + match self { + Self::Enumerable(tuple) => Ok(tuple), + _ => Err(self), + } + } +} + +#[derive(Clone, Copy, Debug)] +pub enum EitherRoute { + Left, + Right, +} + +#[derive(Clone, Copy, Debug)] +pub struct EnumerableRoute(usize); diff --git a/crates/sdk/src/transaction/final_transaction.rs b/crates/sdk/src/transaction/final_transaction.rs index 83610ba..a21df4d 100644 --- a/crates/sdk/src/transaction/final_transaction.rs +++ b/crates/sdk/src/transaction/final_transaction.rs @@ -38,9 +38,12 @@ impl FinalTransaction { } pub fn add_input(&mut self, partial_input: PartialInput, required_sig: RequiredSignature) { - if let RequiredSignature::Witness(_) = required_sig { - panic!("Requested signature is not NativeEcdsa or None"); - } + match required_sig { + RequiredSignature::Witness(_) | RequiredSignature::WitnessWithPath(_, _) => { + panic!("Requested signature is not NativeEcdsa or None") + } + _ => {} + }; self.inputs.push(FinalInput { partial_input, @@ -74,9 +77,12 @@ impl FinalTransaction { issuance_input: IssuanceInput, required_sig: RequiredSignature, ) -> AssetId { - if let RequiredSignature::Witness(_) = required_sig { - panic!("Requested signature is not NativeEcdsa or None"); - } + match required_sig { + RequiredSignature::Witness(_) | RequiredSignature::WitnessWithPath(_, _) => { + panic!("Requested signature is not NativeEcdsa or None") + } + _ => {} + }; let asset_id = AssetId::from_entropy(asset_entropy(&partial_input.outpoint(), issuance_input.asset_entropy)); diff --git a/crates/sdk/src/transaction/partial_input.rs b/crates/sdk/src/transaction/partial_input.rs index e3122e3..08538ee 100644 --- a/crates/sdk/src/transaction/partial_input.rs +++ b/crates/sdk/src/transaction/partial_input.rs @@ -12,6 +12,7 @@ pub enum RequiredSignature { None, NativeEcdsa, Witness(String), + WitnessWithPath(String, Vec), } #[derive(Debug, Clone)] diff --git a/examples/basic/Cargo.lock b/examples/basic/Cargo.lock index 30cc944..54ff1a6 100644 --- a/examples/basic/Cargo.lock +++ b/examples/basic/Cargo.lock @@ -1209,7 +1209,7 @@ dependencies = [ [[package]] name = "smplx-build" -version = "0.0.2" +version = "0.0.3" dependencies = [ "glob", "globwalk", @@ -1226,7 +1226,7 @@ dependencies = [ [[package]] name = "smplx-macros" -version = "0.0.2" +version = "0.0.3" dependencies = [ "smplx-build", "smplx-test", @@ -1235,7 +1235,7 @@ dependencies = [ [[package]] name = "smplx-regtest" -version = "0.0.2" +version = "0.0.3" dependencies = [ "electrsd", "serde", @@ -1246,7 +1246,7 @@ dependencies = [ [[package]] name = "smplx-sdk" -version = "0.0.2" +version = "0.0.3" dependencies = [ "bip39", "bitcoin_hashes", @@ -1264,7 +1264,7 @@ dependencies = [ [[package]] name = "smplx-std" -version = "0.0.2" +version = "0.0.3" dependencies = [ "either", "serde", @@ -1276,7 +1276,7 @@ dependencies = [ [[package]] name = "smplx-test" -version = "0.0.2" +version = "0.0.3" dependencies = [ "electrsd", "proc-macro2", diff --git a/examples/basic/simf/nested_sig.simf b/examples/basic/simf/nested_sig.simf new file mode 100644 index 0000000..e70d467 --- /dev/null +++ b/examples/basic/simf/nested_sig.simf @@ -0,0 +1,28 @@ +fn checksig(pk: Pubkey, sig: Signature) { + let msg: u256 = jet::sig_all_hash(); + jet::bip_0340_verify((pk, msg), sig); +} + +fn inherit_spend(inherit_data: (Signature, u256)) { + let (inheritor_sig, nonce): (Signature, u256) = inherit_data; + checksig(param::PUBLIC_KEY, inheritor_sig); +} + +fn cold_spend(cold_sig: Signature) { + checksig(param::PUBLIC_KEY, cold_sig); +} + +fn hot_spend(hot_sigs: [Signature; 2]) { + let [sig1, sig2]: [Signature; 2] = hot_sigs; + checksig(param::PUBLIC_KEY, sig1); +} + +fn main() { + match witness::INHERIT_OR_NOT { + Left(inherit_data: (Signature, u256)) => inherit_spend(inherit_data), + Right(cold_or_hot: Either) => match cold_or_hot { + Left(cold_sig: Signature) => cold_spend(cold_sig), + Right(hot_sigs: [Signature; 2]) => hot_spend(hot_sigs), + }, + } +} \ No newline at end of file diff --git a/examples/basic/tests/nested_sig.rs b/examples/basic/tests/nested_sig.rs new file mode 100644 index 0000000..94fc76a --- /dev/null +++ b/examples/basic/tests/nested_sig.rs @@ -0,0 +1,116 @@ +use simplex::constants::DUMMY_SIGNATURE; +use simplex::simplicityhl::elements::{Script, Txid}; +use simplex::transaction::{FinalTransaction, PartialInput, ProgramInput, RequiredSignature}; +use simplex::utils::tr_unspendable_key; + +use simplex_example::artifacts::nested_sig::NestedSigProgram; +use simplex_example::artifacts::nested_sig::derived_nested_sig::{NestedSigArguments, NestedSigWitness}; + +fn get_nested_sig(context: &simplex::TestContext) -> (NestedSigProgram, Script) { + let signer = context.get_default_signer(); + + let arguments = NestedSigArguments { + public_key: signer.get_schnorr_public_key().serialize(), + }; + + let program = NestedSigProgram::new(tr_unspendable_key(), arguments); + let script = program.get_program().get_script_pubkey(context.get_network()); + + (program, script) +} +fn fund_nested_sig(context: &simplex::TestContext) -> anyhow::Result { + let signer = context.get_default_signer(); + let (_, script) = get_nested_sig(context); + + let txid = signer.send(script, 50_000)?; + println!("Funded: {}", txid); + + Ok(txid) +} + +fn spend_nested_sig( + context: &simplex::TestContext, + witness: NestedSigWitness, + sig_path: Vec<&str>, +) -> anyhow::Result { + let signer = context.get_default_signer(); + let provider = context.get_default_provider(); + + let (program, script) = get_nested_sig(context); + + let mut utxos = provider.fetch_scripthash_utxos(&script)?; + utxos.retain(|utxo| utxo.explicit_asset() == context.get_network().policy_asset()); + + let mut ft = FinalTransaction::new(); + + ft.add_program_input( + PartialInput::new(utxos[0].clone()), + ProgramInput::new(Box::new(program.get_program().clone()), Box::new(witness)), + RequiredSignature::WitnessWithPath( + "INHERIT_OR_NOT".to_string(), + sig_path.iter().map(ToString::to_string).collect(), + ), + ); + + let txid = signer.broadcast(&ft)?; + println!("Broadcast: {}", txid); + + Ok(txid) +} + +#[simplex::test] +fn test_inherit_spend(context: simplex::TestContext) -> anyhow::Result<()> { + let provider = context.get_default_provider(); + + let fund_tx = fund_nested_sig(&context)?; + provider.wait(&fund_tx)?; + + // Left — inheritor sig injected by signer at path L + let witness = NestedSigWitness { + inherit_or_not: simplex::either::Either::Left((DUMMY_SIGNATURE, [0; 32])), + }; + + let spend_tx = spend_nested_sig(&context, witness, vec!["Left", "0"])?; + provider.wait(&spend_tx)?; + println!("Inherit spend confirmed"); + + Ok(()) +} + +#[simplex::test] +fn test_cold_spend(context: simplex::TestContext) -> anyhow::Result<()> { + let provider = context.get_default_provider(); + + let fund_tx = fund_nested_sig(&context)?; + provider.wait(&fund_tx)?; + + // Right Left — cold sig injected by signer at path R L + let witness = NestedSigWitness { + inherit_or_not: simplex::either::Either::Right(simplex::either::Either::Left(DUMMY_SIGNATURE)), + }; + + let spend_tx = spend_nested_sig(&context, witness, vec!["Right", "Left"])?; + provider.wait(&spend_tx)?; + println!("Cold spend confirmed"); + + Ok(()) +} + +#[simplex::test] +fn test_hot_spend(context: simplex::TestContext) -> anyhow::Result<()> { + let provider = context.get_default_provider(); + + let fund_tx = fund_nested_sig(&context)?; + provider.wait(&fund_tx)?; + + // Right Right — hot sig injected by signer at path R R + let witness = NestedSigWitness { + inherit_or_not: simplex::either::Either::Right(simplex::either::Either::Right([DUMMY_SIGNATURE, [0; 64]])), + }; + + let spend_tx = spend_nested_sig(&context, witness, vec!["Right", "Right", "0"])?; + provider.wait(&spend_tx)?; + println!("Hot spend confirmed"); + + Ok(()) +}