diff --git a/external/photon b/external/photon index a52fd36570..52ca110cf8 160000 --- a/external/photon +++ b/external/photon @@ -1 +1 @@ -Subproject commit a52fd365706e235f538689c3afdf94c8371db80f +Subproject commit 52ca110cf8e3d5aca6e65e1ef8e98b7632d3a16f diff --git a/forester/CHANGELOG.md b/forester/CHANGELOG.md index 388f633a93..8b2af02dfc 100644 --- a/forester/CHANGELOG.md +++ b/forester/CHANGELOG.md @@ -2,6 +2,16 @@ ## [Unreleased] +### Added + +- **Graceful shutdown signaling** via `watch::channel`. Shutdown requests are now race-free regardless of when the run loop subscribes. +- **Panic isolation for `process_epoch`.** A panicking epoch no longer kills the run loop; the panic message is logged and processing continues. + +### Fixed + +- **`bigint_to_u8_32` now rejects negative `BigInt` inputs** (`light-prover-client`). Previously, negative inputs were silently converted to `[u8; 32]` using only the magnitude bytes, producing wrong-sign output that would cause silent proof-input corruption. +- **`pathIndex` widened from `u32` to `u64`** on both the Rust client and the Go prover server. The Gnark circuit already constrained by tree height (up to 40 bits for v2 address trees); only the JSON marshalling and runtime struct types were artificially narrow. This prevents proof generation failures once a v2 address tree exceeds ~4.3 billion entries. + ### Breaking Changes - **Removed `--photon-api-key` CLI arg and `PHOTON_API_KEY` env var.** The API key should now be included in `--indexer-url` as a query parameter: diff --git a/forester/src/cli.rs b/forester/src/cli.rs index 359ce9f0f0..0011e7d55b 100644 --- a/forester/src/cli.rs +++ b/forester/src/cli.rs @@ -306,7 +306,7 @@ pub struct StartArgs { #[arg( long, env = "WORK_ITEM_BATCH_SIZE", - value_parser = clap::value_parser!(usize).range(1..), + value_parser = parse_nonzero_usize, help = "Number of queue items to process per batch cycle. Smaller values reduce blockhash expiry risk, larger values reduce per-batch overhead." )] pub work_item_batch_size: Option, @@ -392,6 +392,16 @@ impl StartArgs { } } +fn parse_nonzero_usize(value: &str) -> Result { + let parsed = value + .parse::() + .map_err(|err| format!("invalid positive integer: {err}"))?; + if parsed == 0 { + return Err("value must be at least 1".to_string()); + } + Ok(parsed) +} + impl StatusArgs { pub fn enable_metrics(&self) -> bool { self.push_gateway_url.is_some() diff --git a/forester/src/compressible/ctoken/compressor.rs b/forester/src/compressible/ctoken/compressor.rs index 0296ac1747..867042bbcf 100644 --- a/forester/src/compressible/ctoken/compressor.rs +++ b/forester/src/compressible/ctoken/compressor.rs @@ -30,16 +30,16 @@ use crate::{ pub struct CTokenCompressor { rpc_pool: Arc>, tracker: Arc, - payer_keypair: Keypair, + payer_keypair: Arc, transaction_policy: TransactionPolicy, } impl Clone for CTokenCompressor { fn clone(&self) -> Self { Self { - rpc_pool: Arc::clone(&self.rpc_pool), - tracker: Arc::clone(&self.tracker), - payer_keypair: self.payer_keypair.insecure_clone(), + rpc_pool: self.rpc_pool.clone(), + tracker: self.tracker.clone(), + payer_keypair: self.payer_keypair.clone(), transaction_policy: self.transaction_policy, } } @@ -49,7 +49,7 @@ impl CTokenCompressor { pub fn new( rpc_pool: Arc>, tracker: Arc, - payer_keypair: Keypair, + payer_keypair: Arc, transaction_policy: TransactionPolicy, ) -> Self { Self { diff --git a/forester/src/compressible/ctoken/state.rs b/forester/src/compressible/ctoken/state.rs index dfcf269140..b2969fc61d 100644 --- a/forester/src/compressible/ctoken/state.rs +++ b/forester/src/compressible/ctoken/state.rs @@ -91,9 +91,13 @@ impl CTokenAccountTracker { /// Returns all tracked token accounts (not mints), ignoring compressible_slot. /// Use `get_ready_to_compress(current_slot)` to get only accounts ready for compression. pub fn get_all_token_accounts(&self) -> Vec { - self.get_ready_to_compress(u64::MAX) - .into_iter() - .filter(|state| state.account.is_token_account()) + let pending = self.pending(); + self.accounts() + .iter() + .filter(|entry| { + entry.value().account.is_token_account() && !pending.contains(entry.key()) + }) + .map(|entry| entry.value().clone()) .collect() } diff --git a/forester/src/compressible/mint/compressor.rs b/forester/src/compressible/mint/compressor.rs index 889a28d5e1..934f1aa46f 100644 --- a/forester/src/compressible/mint/compressor.rs +++ b/forester/src/compressible/mint/compressor.rs @@ -30,16 +30,16 @@ use crate::{ pub struct MintCompressor { rpc_pool: Arc>, tracker: Arc, - payer_keypair: Keypair, + payer_keypair: Arc, transaction_policy: TransactionPolicy, } impl Clone for MintCompressor { fn clone(&self) -> Self { Self { - rpc_pool: Arc::clone(&self.rpc_pool), - tracker: Arc::clone(&self.tracker), - payer_keypair: self.payer_keypair.insecure_clone(), + rpc_pool: self.rpc_pool.clone(), + tracker: self.tracker.clone(), + payer_keypair: self.payer_keypair.clone(), transaction_policy: self.transaction_policy, } } @@ -49,7 +49,7 @@ impl MintCompressor { pub fn new( rpc_pool: Arc>, tracker: Arc, - payer_keypair: Keypair, + payer_keypair: Arc, transaction_policy: TransactionPolicy, ) -> Self { Self { @@ -133,21 +133,20 @@ impl MintCompressor { /// Use this when you need fine-grained control over individual compressions. pub async fn compress_batch_concurrent( &self, - mint_states: &[MintAccountState], + pubkeys: &[Pubkey], max_concurrent: usize, cancelled: Arc, - ) -> CompressionOutcomes { - if mint_states.is_empty() { + ) -> CompressionOutcomes { + if pubkeys.is_empty() { return Vec::new(); } // Guard against max_concurrent == 0 to avoid buffer_unordered panic if max_concurrent == 0 { - return mint_states + return pubkeys .iter() - .cloned() - .map(|mint_state| CompressionOutcome::Failed { - state: mint_state, + .map(|&pubkey| CompressionOutcome::Failed { + pubkey, error: CompressionTaskError::Failed(anyhow::anyhow!( "max_concurrent must be > 0" )), @@ -156,30 +155,48 @@ impl MintCompressor { } // Mark all as pending upfront - let all_pubkeys: Vec = mint_states.iter().map(|s| s.pubkey).collect(); - self.tracker.mark_pending(&all_pubkeys); + self.tracker.mark_pending(pubkeys); // Create futures for each mint - let compression_futures = mint_states.iter().cloned().map(|mint_state| { + let compression_futures = pubkeys.iter().copied().map(|pubkey| { let compressor = self.clone(); let cancelled = cancelled.clone(); async move { // Check cancellation before processing if cancelled.load(Ordering::Relaxed) { - compressor.tracker.unmark_pending(&[mint_state.pubkey]); + compressor.tracker.unmark_pending(&[pubkey]); return CompressionOutcome::Failed { - state: mint_state, + pubkey, error: CompressionTaskError::Cancelled, }; } + let mint_state = match compressor + .tracker + .accounts() + .get(&pubkey) + .map(|r| r.clone()) + { + Some(state) => state, + None => { + compressor.tracker.unmark_pending(&[pubkey]); + return CompressionOutcome::Failed { + pubkey, + error: CompressionTaskError::Failed(anyhow::anyhow!( + "mint {} removed from tracker before compression", + pubkey + )), + }; + } + }; + match compressor.compress(&mint_state).await { Ok(sig) => CompressionOutcome::Compressed { signature: sig, - state: mint_state, + pubkey, }, Err(e) => CompressionOutcome::Failed { - state: mint_state, + pubkey, error: e.into(), }, } @@ -195,11 +212,11 @@ impl MintCompressor { // Remove successfully compressed mints; unmark failed ones for result in &results { match result { - CompressionOutcome::Compressed { state, .. } => { - self.tracker.remove_compressed(&state.pubkey); + CompressionOutcome::Compressed { pubkey, .. } => { + self.tracker.remove_compressed(pubkey); } - CompressionOutcome::Failed { state, .. } => { - self.tracker.unmark_pending(&[state.pubkey]); + CompressionOutcome::Failed { pubkey, .. } => { + self.tracker.unmark_pending(&[*pubkey]); } } } diff --git a/forester/src/compressible/pda/compressor.rs b/forester/src/compressible/pda/compressor.rs index 941f0d3e53..5465963837 100644 --- a/forester/src/compressible/pda/compressor.rs +++ b/forester/src/compressible/pda/compressor.rs @@ -55,16 +55,16 @@ pub struct CachedProgramConfig { pub struct PdaCompressor { rpc_pool: Arc>, tracker: Arc, - payer_keypair: Keypair, + payer_keypair: Arc, transaction_policy: TransactionPolicy, } impl Clone for PdaCompressor { fn clone(&self) -> Self { Self { - rpc_pool: Arc::clone(&self.rpc_pool), - tracker: Arc::clone(&self.tracker), - payer_keypair: self.payer_keypair.insecure_clone(), + rpc_pool: self.rpc_pool.clone(), + tracker: self.tracker.clone(), + payer_keypair: self.payer_keypair.clone(), transaction_policy: self.transaction_policy, } } @@ -74,7 +74,7 @@ impl PdaCompressor { pub fn new( rpc_pool: Arc>, tracker: Arc, - payer_keypair: Keypair, + payer_keypair: Arc, transaction_policy: TransactionPolicy, ) -> Self { Self { @@ -156,22 +156,21 @@ impl PdaCompressor { /// Successfully compressed accounts are removed from the tracker. pub async fn compress_batch_concurrent( &self, - account_states: &[PdaAccountState], + pubkeys: &[Pubkey], program_config: &PdaProgramConfig, cached_config: &CachedProgramConfig, max_concurrent: usize, cancelled: Arc, - ) -> CompressionOutcomes { - if account_states.is_empty() { + ) -> CompressionOutcomes { + if pubkeys.is_empty() { return Vec::new(); } // Mark all accounts as pending upfront so concurrent cycles skip them - let all_pubkeys: Vec = account_states.iter().map(|s| s.pubkey).collect(); - self.tracker.mark_pending(&all_pubkeys); + self.tracker.mark_pending(pubkeys); // Create futures for each account - let compression_futures = account_states.iter().cloned().map(|account_state| { + let compression_futures = pubkeys.iter().copied().map(|pubkey| { let compressor = self.clone(); let program_config = program_config.clone(); let cached_config = cached_config.clone(); @@ -180,24 +179,43 @@ impl PdaCompressor { async move { // Check cancellation before processing if cancelled.load(Ordering::Relaxed) { - // Unmark since we won't process this account - compressor.tracker.unmark_pending(&[account_state.pubkey]); + compressor.tracker.unmark_pending(&[pubkey]); return CompressionOutcome::Failed { - state: account_state, + pubkey, error: CompressionTaskError::Cancelled, }; } + // Look up account state from tracker; it may have been removed + let account_state = match compressor + .tracker + .accounts() + .get(&pubkey) + .map(|r| r.clone()) + { + Some(state) => state, + None => { + compressor.tracker.unmark_pending(&[pubkey]); + return CompressionOutcome::Failed { + pubkey, + error: CompressionTaskError::Failed(anyhow::anyhow!( + "account {} removed from tracker before compression", + pubkey + )), + }; + } + }; + match compressor .compress(&account_state, &program_config, &cached_config) .await { Ok(sig) => CompressionOutcome::Compressed { signature: sig, - state: account_state, + pubkey, }, Err(e) => CompressionOutcome::Failed { - state: account_state, + pubkey, error: e.into(), }, } @@ -213,11 +231,11 @@ impl PdaCompressor { // Remove successfully compressed PDAs; unmark failed ones for result in &results { match result { - CompressionOutcome::Compressed { state, .. } => { - self.tracker.remove_compressed(&state.pubkey); + CompressionOutcome::Compressed { pubkey, .. } => { + self.tracker.remove_compressed(pubkey); } - CompressionOutcome::Failed { state, .. } => { - self.tracker.unmark_pending(&[state.pubkey]); + CompressionOutcome::Failed { pubkey, .. } => { + self.tracker.unmark_pending(&[*pubkey]); } } } @@ -396,7 +414,7 @@ impl PdaCompressor { ); let payer_pubkey = self.payer_keypair.pubkey(); - let signers = [&self.payer_keypair]; + let signers = [self.payer_keypair.as_ref()]; let instructions = vec![ix]; let priority_fee_accounts = collect_priority_fee_accounts(payer_pubkey, &instructions); let signature = send_transaction_with_policy( diff --git a/forester/src/compressible/pda/state.rs b/forester/src/compressible/pda/state.rs index 92f9fe0a61..2c1d2e6711 100644 --- a/forester/src/compressible/pda/state.rs +++ b/forester/src/compressible/pda/state.rs @@ -13,7 +13,7 @@ use super::types::PdaAccountState; use crate::{ compressible::{ config::PdaProgramConfig, - traits::{CompressibleTracker, SubscriptionHandler}, + traits::{CompressibleState, CompressibleTracker, SubscriptionHandler}, }, Result, }; @@ -72,10 +72,33 @@ impl PdaAccountTracker { &self, program_id: &Pubkey, current_slot: u64, + ) -> Vec { + let pending = self.pending(); + self.accounts() + .iter() + .filter(|entry| { + entry.value().program_id == *program_id + && entry.value().is_ready_to_compress(current_slot) + && !pending.contains(entry.key()) + }) + .map(|entry| *entry.key()) + .collect() + } + + pub fn get_ready_states_for_program( + &self, + program_id: &Pubkey, + current_slot: u64, ) -> Vec { - self.get_ready_to_compress(current_slot) - .into_iter() - .filter(|state| state.program_id == *program_id) + let pending = self.pending(); + self.accounts() + .iter() + .filter(|entry| { + entry.value().program_id == *program_id + && entry.value().is_ready_to_compress(current_slot) + && !pending.contains(entry.key()) + }) + .map(|entry| entry.value().clone()) .collect() } diff --git a/forester/src/compressible/traits.rs b/forester/src/compressible/traits.rs index ccc7f43386..39c33e97d7 100644 --- a/forester/src/compressible/traits.rs +++ b/forester/src/compressible/traits.rs @@ -47,18 +47,18 @@ pub enum CompressionTaskError { } #[derive(Debug)] -pub enum CompressionOutcome { +pub enum CompressionOutcome { Compressed { signature: Signature, - state: S, + pubkey: Pubkey, }, Failed { - state: S, + pubkey: Pubkey, error: CompressionTaskError, }, } -pub type CompressionOutcomes = Vec>; +pub type CompressionOutcomes = Vec; pub trait CompressibleState: Clone + Send + Sync { fn pubkey(&self) -> &Pubkey; @@ -128,7 +128,20 @@ pub trait CompressibleTracker: Send + Sync { self.len() == 0 } - fn get_ready_to_compress(&self, current_slot: u64) -> Vec { + fn get_ready_to_compress(&self, current_slot: u64) -> Vec { + let pending = self.pending(); + self.accounts() + .iter() + .filter(|entry| { + entry.value().is_ready_to_compress(current_slot) && !pending.contains(entry.key()) + }) + .map(|entry| *entry.key()) + .collect() + } + + /// Clone of ready-to-compress states. Prefer `get_ready_to_compress` in + /// production; this helper exists for tests that need full state fields. + fn get_ready_states(&self, current_slot: u64) -> Vec { let pending = self.pending(); self.accounts() .iter() diff --git a/forester/src/epoch_manager.rs b/forester/src/epoch_manager.rs deleted file mode 100644 index a253b4bad3..0000000000 --- a/forester/src/epoch_manager.rs +++ /dev/null @@ -1,4896 +0,0 @@ -use std::{ - collections::HashMap, - sync::{ - atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}, - Arc, - }, - time::{Duration, SystemTime, UNIX_EPOCH}, -}; - -use anyhow::{anyhow, Context}; -use borsh::BorshSerialize; -use dashmap::DashMap; -use forester_utils::{ - forester_epoch::{get_epoch_phases, Epoch, ForesterSlot, TreeAccounts, TreeForesterSchedule}, - rpc_pool::SolanaRpcPool, -}; -use futures::future::join_all; -use light_client::{ - indexer::{Indexer, MerkleProof, NewAddressProofWithContext}, - rpc::{LightClient, LightClientConfig, RetryConfig, Rpc, RpcError}, -}; -use light_compressed_account::TreeType; -use light_registry::{ - account_compression_cpi::sdk::{ - create_batch_append_instruction, create_batch_nullify_instruction, - create_batch_update_address_tree_instruction, - }, - protocol_config::state::{EpochState, ProtocolConfig}, - sdk::{create_finalize_registration_instruction, create_report_work_instruction}, - utils::{get_epoch_pda_address, get_forester_epoch_pda_from_authority}, - EpochPda, ForesterEpochPda, -}; -use solana_program::{ - instruction::InstructionError, native_token::LAMPORTS_PER_SOL, pubkey::Pubkey, -}; -use solana_sdk::{ - address_lookup_table::AddressLookupTableAccount, - signature::{Keypair, Signer}, - transaction::TransactionError, -}; -use tokio::{ - sync::{mpsc, oneshot, Mutex}, - task::JoinHandle, - time::{sleep, Instant, MissedTickBehavior}, -}; -use tracing::{debug, error, info, info_span, instrument, trace, warn}; - -use crate::{ - compressible::{ - traits::{Cancelled, CompressibleTracker, CompressionOutcome, CompressionTaskError}, - CTokenAccountTracker, CTokenCompressor, CompressibleConfig, - }, - errors::{ - rpc_is_already_processed, ChannelError, ForesterError, InitializationError, - RegistrationError, WorkReportError, - }, - logging::{should_emit_rate_limited_warning, ServiceHeartbeat}, - metrics::{ - push_metrics, queue_metric_update, update_epoch_detected, update_epoch_registered, - update_forester_sol_balance, - }, - pagerduty::send_pagerduty_alert, - priority_fee::PriorityFeeConfig, - processor::{ - tx_cache::ProcessedHashCache, - v1::{ - config::{BuildTransactionBatchConfig, SendBatchedTransactionsConfig}, - send_transaction::send_batched_transactions, - tx_builder::EpochManagerTransactions, - }, - v2::{ - strategy::{AddressTreeStrategy, StateTreeStrategy}, - BatchContext, BatchInstruction, ProcessingResult, ProverConfig, QueueProcessor, - SharedProofCache, - }, - }, - queue_helpers::QueueItemData, - rollover::{ - is_tree_ready_for_rollover, perform_address_merkle_tree_rollover, - perform_state_merkle_tree_rollover_forester, - }, - slot_tracker::{slot_duration, wait_until_slot_reached, SlotTracker}, - smart_transaction::{ - send_smart_transaction, ComputeBudgetConfig, ConfirmationConfig, - SendSmartTransactionConfig, TransactionPolicy, - }, - transaction_timing::{scheduled_confirmation_deadline, scheduled_v1_batch_timeout}, - tree_data_sync::{fetch_protocol_group_authority, fetch_trees}, - ForesterConfig, ForesterEpochInfo, Result, -}; - -type StateBatchProcessorMap = - Arc>>)>>; -type AddressBatchProcessorMap = - Arc>>)>>; -type ProcessorInitLockMap = Arc>>>; -type TreeProcessingTask = JoinHandle>; - -/// Coordinates re-finalization across parallel `process_queue` tasks when new -/// foresters register mid-epoch. Only one task performs the on-chain -/// `finalize_registration` tx; others wait for it to complete. -#[derive(Debug)] -pub(crate) struct RegistrationTracker { - cached_registered_weight: AtomicU64, - refinalize_in_progress: AtomicBool, - refinalized: tokio::sync::Notify, -} - -impl RegistrationTracker { - fn new(weight: u64) -> Self { - Self { - cached_registered_weight: AtomicU64::new(weight), - refinalize_in_progress: AtomicBool::new(false), - refinalized: tokio::sync::Notify::new(), - } - } - - fn cached_weight(&self) -> u64 { - self.cached_registered_weight.load(Ordering::Acquire) - } - - /// Returns `true` if this caller won the race to perform re-finalization. - fn try_claim_refinalize(&self) -> bool { - self.refinalize_in_progress - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) - .is_ok() - } - - /// Called by the winner after the on-chain tx succeeds. - fn complete_refinalize(&self, new_weight: u64) { - self.cached_registered_weight - .store(new_weight, Ordering::Release); - self.refinalize_in_progress.store(false, Ordering::Release); - self.refinalized.notify_waiters(); - } - - /// Called by non-winners to block until re-finalization is done. - async fn wait_for_refinalize(&self) { - if !self.refinalize_in_progress.load(Ordering::Acquire) { - return; - } - let fut = self.refinalized.notified(); - if !self.refinalize_in_progress.load(Ordering::Acquire) { - return; - } - fut.await; - } -} - -/// Timing for a single circuit type (circuit inputs + proof generation) -#[derive(Copy, Clone, Debug, Default)] -pub struct CircuitMetrics { - /// Time spent building circuit inputs - pub circuit_inputs_duration: std::time::Duration, - /// Time spent generating ZK proofs (pure prover server time) - pub proof_generation_duration: std::time::Duration, - /// Total round-trip time (submit to result, includes queue wait) - pub round_trip_duration: std::time::Duration, -} - -impl CircuitMetrics { - pub fn total(&self) -> std::time::Duration { - self.circuit_inputs_duration + self.proof_generation_duration - } -} - -impl std::ops::AddAssign for CircuitMetrics { - fn add_assign(&mut self, rhs: Self) { - self.circuit_inputs_duration += rhs.circuit_inputs_duration; - self.proof_generation_duration += rhs.proof_generation_duration; - self.round_trip_duration += rhs.round_trip_duration; - } -} - -/// Timing breakdown by circuit type -#[derive(Copy, Clone, Debug, Default)] -pub struct ProcessingMetrics { - /// State append circuit (output queue processing) - pub append: CircuitMetrics, - /// State nullify circuit (input queue processing) - pub nullify: CircuitMetrics, - /// Address append circuit - pub address_append: CircuitMetrics, - /// Time spent sending transactions (overlapped with proof gen) - pub tx_sending_duration: std::time::Duration, -} - -impl ProcessingMetrics { - pub fn total(&self) -> std::time::Duration { - self.append.total() - + self.nullify.total() - + self.address_append.total() - + self.tx_sending_duration - } - - pub fn total_circuit_inputs(&self) -> std::time::Duration { - self.append.circuit_inputs_duration - + self.nullify.circuit_inputs_duration - + self.address_append.circuit_inputs_duration - } - - pub fn total_proof_generation(&self) -> std::time::Duration { - self.append.proof_generation_duration - + self.nullify.proof_generation_duration - + self.address_append.proof_generation_duration - } - - pub fn total_round_trip(&self) -> std::time::Duration { - self.append.round_trip_duration - + self.nullify.round_trip_duration - + self.address_append.round_trip_duration - } -} - -impl std::ops::AddAssign for ProcessingMetrics { - fn add_assign(&mut self, rhs: Self) { - self.append += rhs.append; - self.nullify += rhs.nullify; - self.address_append += rhs.address_append; - self.tx_sending_duration += rhs.tx_sending_duration; - } -} - -#[derive(Copy, Clone, Debug)] -pub struct WorkReport { - pub epoch: u64, - pub processed_items: usize, - pub metrics: ProcessingMetrics, -} - -#[derive(Debug, Clone)] -pub struct WorkItem { - pub tree_account: TreeAccounts, - pub queue_item_data: QueueItemData, -} - -impl WorkItem { - pub fn is_address_tree(&self) -> bool { - self.tree_account.tree_type == TreeType::AddressV1 - } - pub fn is_state_tree(&self) -> bool { - self.tree_account.tree_type == TreeType::StateV1 - } -} - -#[allow(clippy::large_enum_variant)] -#[derive(Debug, Clone)] -pub enum MerkleProofType { - AddressProof(NewAddressProofWithContext), - StateProof(MerkleProof), -} - -#[derive(Debug)] -pub struct EpochManager { - config: Arc, - protocol_config: Arc, - rpc_pool: Arc>, - authority: Arc, - work_report_sender: mpsc::Sender, - processed_items_per_epoch_count: Arc>>, - processing_metrics_per_epoch: Arc>>, - trees: Arc>>, - slot_tracker: Arc, - processing_epochs: Arc>>, - tx_cache: Arc>, - ops_cache: Arc>, - /// Proof caches for pre-warming during idle slots - proof_caches: Arc>>, - state_processors: StateBatchProcessorMap, - address_processors: AddressBatchProcessorMap, - state_processor_init_locks: ProcessorInitLockMap, - address_processor_init_locks: ProcessorInitLockMap, - compressible_tracker: Option>, - pda_tracker: Option>, - mint_tracker: Option>, - /// Cached zkp_batch_size per tree to filter queue updates below threshold - zkp_batch_sizes: Arc>, - address_lookup_tables: Arc>, - heartbeat: Arc, - run_id: Arc, - /// Per-epoch registration trackers to coordinate re-finalization when new foresters register mid-epoch - registration_trackers: Arc>>, -} - -impl Clone for EpochManager { - fn clone(&self) -> Self { - Self { - config: self.config.clone(), - protocol_config: self.protocol_config.clone(), - rpc_pool: self.rpc_pool.clone(), - authority: self.authority.clone(), - work_report_sender: self.work_report_sender.clone(), - processed_items_per_epoch_count: self.processed_items_per_epoch_count.clone(), - processing_metrics_per_epoch: self.processing_metrics_per_epoch.clone(), - trees: self.trees.clone(), - slot_tracker: self.slot_tracker.clone(), - processing_epochs: self.processing_epochs.clone(), - tx_cache: self.tx_cache.clone(), - ops_cache: self.ops_cache.clone(), - proof_caches: self.proof_caches.clone(), - state_processors: self.state_processors.clone(), - address_processors: self.address_processors.clone(), - state_processor_init_locks: self.state_processor_init_locks.clone(), - address_processor_init_locks: self.address_processor_init_locks.clone(), - compressible_tracker: self.compressible_tracker.clone(), - pda_tracker: self.pda_tracker.clone(), - mint_tracker: self.mint_tracker.clone(), - zkp_batch_sizes: self.zkp_batch_sizes.clone(), - address_lookup_tables: self.address_lookup_tables.clone(), - heartbeat: self.heartbeat.clone(), - run_id: self.run_id.clone(), - registration_trackers: self.registration_trackers.clone(), - } - } -} - -impl EpochManager { - #[allow(clippy::too_many_arguments)] - pub async fn new( - config: Arc, - protocol_config: Arc, - rpc_pool: Arc>, - work_report_sender: mpsc::Sender, - trees: Vec, - slot_tracker: Arc, - tx_cache: Arc>, - ops_cache: Arc>, - compressible_tracker: Option>, - pda_tracker: Option>, - mint_tracker: Option>, - address_lookup_tables: Arc>, - heartbeat: Arc, - run_id: String, - ) -> Result { - let authority = Arc::new(config.payer_keypair.insecure_clone()); - Ok(Self { - config, - protocol_config, - rpc_pool, - authority, - work_report_sender, - processed_items_per_epoch_count: Arc::new(Mutex::new(HashMap::new())), - processing_metrics_per_epoch: Arc::new(Mutex::new(HashMap::new())), - trees: Arc::new(Mutex::new(trees)), - slot_tracker, - processing_epochs: Arc::new(DashMap::new()), - tx_cache, - ops_cache, - proof_caches: Arc::new(DashMap::new()), - state_processors: Arc::new(DashMap::new()), - address_processors: Arc::new(DashMap::new()), - state_processor_init_locks: Arc::new(DashMap::new()), - address_processor_init_locks: Arc::new(DashMap::new()), - compressible_tracker, - pda_tracker, - mint_tracker, - zkp_batch_sizes: Arc::new(DashMap::new()), - address_lookup_tables, - heartbeat, - run_id: Arc::::from(run_id), - registration_trackers: Arc::new(DashMap::new()), - }) - } - - pub async fn run(self: Arc) -> Result<()> { - let (tx, mut rx) = mpsc::channel(100); - let tx = Arc::new(tx); - - let mut monitor_handle = tokio::spawn({ - let self_clone = Arc::clone(&self); - let tx_clone = Arc::clone(&tx); - async move { self_clone.monitor_epochs(tx_clone).await } - }); - - // Process current and previous epochs - let current_previous_handle = tokio::spawn({ - let self_clone = Arc::clone(&self); - let tx_clone = Arc::clone(&tx); - async move { - self_clone - .process_current_and_previous_epochs(tx_clone) - .await - } - }); - - let tree_discovery_handle = tokio::spawn({ - let self_clone = Arc::clone(&self); - async move { self_clone.discover_trees_periodically().await } - }); - - let balance_check_handle = tokio::spawn({ - let self_clone = Arc::clone(&self); - async move { self_clone.check_sol_balance_periodically().await } - }); - - let queue_metrics_handle = tokio::spawn({ - let self_clone = Arc::clone(&self); - async move { self_clone.update_queue_metrics_periodically().await } - }); - - let _guard = scopeguard::guard( - ( - current_previous_handle, - tree_discovery_handle, - balance_check_handle, - queue_metrics_handle, - ), - |(h2, h3, h4, h5)| { - info!( - event = "background_tasks_aborting", - run_id = %self.run_id, - "Aborting EpochManager background tasks" - ); - h2.abort(); - h3.abort(); - h4.abort(); - h5.abort(); - }, - ); - - let result = loop { - tokio::select! { - epoch_opt = rx.recv() => { - match epoch_opt { - Some(epoch) => { - debug!( - event = "epoch_queued_for_processing", - run_id = %self.run_id, - epoch, - "Received epoch from monitor" - ); - let self_clone = Arc::clone(&self); - tokio::spawn(async move { - if let Err(e) = self_clone.process_epoch(epoch).await { - error!( - event = "epoch_processing_failed", - run_id = %self_clone.run_id, - epoch, - error = ?e, - "Error processing epoch" - ); - } - }); - } - None => { - error!( - event = "epoch_monitor_channel_closed", - run_id = %self.run_id, - "Epoch monitor channel closed unexpectedly" - ); - break Err(anyhow!( - "Epoch monitor channel closed - forester cannot function without it" - )); - } - } - } - result = &mut monitor_handle => { - match result { - Ok(Ok(())) => { - error!( - event = "epoch_monitor_exited_unexpected_ok", - run_id = %self.run_id, - "Epoch monitor exited unexpectedly with Ok(())" - ); - } - Ok(Err(e)) => { - error!( - event = "epoch_monitor_exited_with_error", - run_id = %self.run_id, - error = ?e, - "Epoch monitor exited with error" - ); - } - Err(e) => { - error!( - event = "epoch_monitor_task_failed", - run_id = %self.run_id, - error = ?e, - "Epoch monitor task panicked or was cancelled" - ); - } - } - if let Some(pagerduty_key) = &self.config.external_services.pagerduty_routing_key { - let _ = send_pagerduty_alert( - pagerduty_key, - &format!("Forester epoch monitor died unexpectedly on {}", self.config.payer_keypair.pubkey()), - "critical", - "epoch_monitor_dead", - ).await; - } - break Err(anyhow!("Epoch monitor exited unexpectedly - forester cannot function without it")); - } - } - }; - - // Abort monitor_handle on exit - monitor_handle.abort(); - result - } - - /// Periodically updates queue_length and queue_capacity Prometheus gauges - /// so Grafana dashboards can show queue trends over time. - async fn update_queue_metrics_periodically(self: Arc) -> Result<()> { - let interval_secs = self.config.general_config.tree_discovery_interval_seconds; - if interval_secs == 0 { - return Ok(()); - } - // Use same interval as tree discovery (default 30s) - let mut interval = tokio::time::interval(Duration::from_secs(interval_secs)); - // Skip first tick — let tree discovery populate the tree list first - interval.tick().await; - - loop { - interval.tick().await; - - let trees = self.trees.lock().await; - let trees_snapshot: Vec<_> = trees.clone(); - drop(trees); - - if trees_snapshot.is_empty() { - continue; - } - - for tree_type in [ - TreeType::StateV1, - TreeType::AddressV1, - TreeType::StateV2, - TreeType::AddressV2, - ] { - if let Err(e) = - crate::run_queue_info(self.config.clone(), &trees_snapshot, tree_type).await - { - debug!( - event = "queue_metrics_update_failed", - run_id = %self.run_id, - tree_type = ?tree_type, - error = ?e, - "Failed to update queue metrics" - ); - } - } - } - } - - async fn check_sol_balance_periodically(self: Arc) -> Result<()> { - let interval_duration = Duration::from_secs(300); - let mut interval = tokio::time::interval(interval_duration); - - loop { - interval.tick().await; - match self.rpc_pool.get_connection().await { - Ok(rpc) => match rpc.get_balance(&self.config.payer_keypair.pubkey()).await { - Ok(balance) => { - let balance_in_sol = balance as f64 / (LAMPORTS_PER_SOL as f64); - update_forester_sol_balance( - &self.config.payer_keypair.pubkey().to_string(), - balance_in_sol, - ); - debug!( - event = "forester_balance_updated", - run_id = %self.run_id, - balance_sol = balance_in_sol, - "Current SOL balance updated" - ); - } - Err(e) => error!( - event = "forester_balance_fetch_failed", - run_id = %self.run_id, - error = ?e, - "Failed to get balance" - ), - }, - Err(e) => error!( - event = "forester_balance_rpc_connection_failed", - run_id = %self.run_id, - error = ?e, - "Failed to get RPC connection for balance check" - ), - } - } - } - - /// Periodically fetches trees from on-chain and adds newly discovered ones. - async fn discover_trees_periodically(self: Arc) -> Result<()> { - let interval_secs = self.config.general_config.tree_discovery_interval_seconds; - if interval_secs == 0 { - info!(event = "tree_discovery_disabled", run_id = %self.run_id, "Tree discovery disabled (interval=0)"); - return Ok(()); - } - let mut interval = tokio::time::interval(Duration::from_secs(interval_secs)); - // Skip the first immediate tick — initial trees are already loaded at startup - interval.tick().await; - - info!( - event = "tree_discovery_started", - run_id = %self.run_id, - interval_secs, - "Starting periodic tree discovery" - ); - - let mut group_authority: Option = self.config.general_config.group_authority; - - loop { - interval.tick().await; - - let rpc = match self.rpc_pool.get_connection().await { - Ok(rpc) => rpc, - Err(e) => { - warn!(event = "tree_discovery_rpc_failed", run_id = %self.run_id, error = ?e, "Tree discovery: failed to get RPC connection"); - continue; - } - }; - - // Lazily resolve group authority (retry each tick until successful) - if group_authority.is_none() { - if let Ok(ga) = fetch_protocol_group_authority(&*rpc, &self.run_id).await { - group_authority = Some(ga); - // Retroactively filter already-tracked trees that were added - // before group_authority was resolved. - let mut trees = self.trees.lock().await; - let before = trees.len(); - trees.retain(|t| t.owner == ga); - if !self.config.general_config.tree_ids.is_empty() { - let tree_ids = &self.config.general_config.tree_ids; - trees.retain(|t| tree_ids.contains(&t.merkle_tree)); - } - if trees.len() < before { - info!( - event = "tree_discovery_retroactive_filter", - run_id = %self.run_id, - group_authority = %ga, - trees_before = before, - trees_after = trees.len(), - "Filtered existing trees after resolving group authority" - ); - } - } - } - - let mut fetched_trees = match fetch_trees(&*rpc).await { - Ok(trees) => trees, - Err(e) => { - warn!(event = "tree_discovery_fetch_failed", run_id = %self.run_id, error = ?e, "Tree discovery: failed to fetch trees"); - continue; - } - }; - - if let Some(ga) = group_authority { - fetched_trees.retain(|tree| tree.owner == ga); - } - if !self.config.general_config.tree_ids.is_empty() { - let tree_ids = &self.config.general_config.tree_ids; - fetched_trees.retain(|tree| tree_ids.contains(&tree.merkle_tree)); - } - - let known_trees = self.trees.lock().await; - let known_pubkeys: std::collections::HashSet = - known_trees.iter().map(|t| t.merkle_tree).collect(); - drop(known_trees); - - for tree in fetched_trees { - if known_pubkeys.contains(&tree.merkle_tree) { - continue; - } - if should_skip_tree(&self.config, &tree.tree_type) { - debug!( - event = "tree_discovery_skipped", - run_id = %self.run_id, - tree = %tree.merkle_tree, - tree_type = ?tree.tree_type, - "Skipping tree due to fee filter config" - ); - continue; - } - info!( - event = "tree_discovery_new_tree", - run_id = %self.run_id, - tree = %tree.merkle_tree, - tree_type = ?tree.tree_type, - queue = %tree.queue, - "Discovered new tree" - ); - if let Err(e) = self.add_new_tree(tree).await { - error!( - event = "tree_discovery_add_failed", - run_id = %self.run_id, - error = ?e, - "Failed to add discovered tree" - ); - } - } - } - } - - async fn add_new_tree(&self, new_tree: TreeAccounts) -> Result<()> { - info!( - event = "new_tree_add_started", - run_id = %self.run_id, - tree = %new_tree.merkle_tree, - tree_type = ?new_tree.tree_type, - "Adding new tree" - ); - let mut trees = self.trees.lock().await; - trees.push(new_tree); - drop(trees); - - info!( - event = "new_tree_added", - run_id = %self.run_id, - tree = %new_tree.merkle_tree, - "New tree added to tracked list" - ); - - let (current_slot, current_epoch) = self.get_current_slot_and_epoch().await?; - let phases = get_epoch_phases(&self.protocol_config, current_epoch); - - // Check if we're currently in the active phase - if current_slot >= phases.active.start && current_slot < phases.active.end { - info!( - event = "new_tree_active_phase_injection", - run_id = %self.run_id, - tree = %new_tree.merkle_tree, - current_slot, - active_phase_start_slot = phases.active.start, - active_phase_end_slot = phases.active.end, - "In active phase; attempting immediate processing for new tree" - ); - info!( - event = "new_tree_recover_registration_started", - run_id = %self.run_id, - tree = %new_tree.merkle_tree, - epoch = current_epoch, - "Recovering registration info for new tree" - ); - match self - .recover_registration_info_if_exists(current_epoch) - .await - { - Ok(Some(mut epoch_info)) => { - info!( - event = "new_tree_recover_registration_succeeded", - run_id = %self.run_id, - tree = %new_tree.merkle_tree, - epoch = current_epoch, - "Recovered registration info for current epoch" - ); - let tree_schedule = TreeForesterSchedule::new_with_schedule( - &new_tree, - current_slot, - &epoch_info.forester_epoch_pda, - &epoch_info.epoch_pda, - )?; - epoch_info.trees.push(tree_schedule.clone()); - - let self_clone = Arc::new(self.clone()); - let tracker = self - .registration_trackers - .entry(current_epoch) - .or_insert_with(|| { - Arc::new(RegistrationTracker::new( - epoch_info.epoch_pda.registered_weight, - )) - }) - .value() - .clone(); - - info!( - event = "new_tree_processing_task_spawned", - run_id = %self.run_id, - tree = %new_tree.merkle_tree, - epoch = current_epoch, - "Spawning task to process new tree in current epoch" - ); - tokio::spawn(async move { - let tree_pubkey = tree_schedule.tree_accounts.merkle_tree; - if let Err(e) = self_clone - .process_queue( - &epoch_info.epoch, - epoch_info.forester_epoch_pda.clone(), - tree_schedule, - tracker, - ) - .await - { - error!( - event = "new_tree_process_queue_failed", - run_id = %self_clone.run_id, - tree = %tree_pubkey, - error = ?e, - "Error processing queue for new tree" - ); - } else { - info!( - event = "new_tree_process_queue_succeeded", - run_id = %self_clone.run_id, - tree = %tree_pubkey, - "Successfully processed new tree in current epoch" - ); - } - }); - } - Ok(None) => { - debug!( - "Not registered for current epoch yet, new tree will be picked up during next registration" - ); - } - Err(e) => { - warn!( - event = "new_tree_recover_registration_failed", - run_id = %self.run_id, - tree = %new_tree.merkle_tree, - epoch = current_epoch, - error = ?e, - "Failed to recover registration info for new tree" - ); - } - } - - info!( - event = "new_tree_injected_into_current_epoch", - run_id = %self.run_id, - tree = %new_tree.merkle_tree, - epoch = current_epoch, - "Injected new tree into current epoch" - ); - } else { - info!( - event = "new_tree_queued_for_next_registration", - run_id = %self.run_id, - tree = %new_tree.merkle_tree, - current_slot, - active_phase_start_slot = phases.active.start, - "Not in active phase; new tree will be picked up in next registration" - ); - } - - Ok(()) - } - - #[instrument(level = "debug", skip(self, tx))] - async fn monitor_epochs(&self, tx: Arc>) -> Result<()> { - let mut last_epoch: Option = None; - let mut consecutive_failures = 0u32; - const MAX_BACKOFF_SECS: u64 = 60; - - info!( - event = "epoch_monitor_started", - run_id = %self.run_id, - "Starting epoch monitor" - ); - - loop { - let (slot, current_epoch) = match self.get_current_slot_and_epoch().await { - Ok(result) => { - if consecutive_failures > 0 { - info!( - event = "epoch_monitor_recovered", - run_id = %self.run_id, - consecutive_failures, "Epoch monitor recovered after failures" - ); - } - consecutive_failures = 0; - result - } - Err(e) => { - consecutive_failures += 1; - let backoff_secs = 2u64.pow(consecutive_failures.min(6)).min(MAX_BACKOFF_SECS); - let backoff = Duration::from_secs(backoff_secs); - - if consecutive_failures == 1 { - warn!( - event = "epoch_monitor_slot_epoch_failed", - run_id = %self.run_id, - consecutive_failures, - error = ?e, - backoff_ms = backoff.as_millis() as u64, - "Epoch monitor failed to get slot/epoch; retrying" - ); - } else if consecutive_failures.is_multiple_of(10) { - error!( - event = "epoch_monitor_slot_epoch_failed_repeated", - run_id = %self.run_id, - consecutive_failures, - error = ?e, - backoff_ms = backoff.as_millis() as u64, - "Epoch monitor still failing repeatedly" - ); - } - - tokio::time::sleep(backoff).await; - continue; - } - }; - - debug!( - event = "epoch_monitor_tick", - run_id = %self.run_id, - last_epoch = ?last_epoch, - current_epoch, - slot, - "Epoch monitor tick" - ); - - if last_epoch.is_none_or(|last| current_epoch > last) { - debug!( - event = "epoch_monitor_new_epoch_detected", - run_id = %self.run_id, - epoch = current_epoch, - "New epoch detected; sending for processing" - ); - if let Err(e) = tx.send(current_epoch).await { - error!( - event = "epoch_monitor_send_current_epoch_failed", - run_id = %self.run_id, - epoch = current_epoch, - error = ?e, - "Failed to send current epoch for processing; channel closed" - ); - return Err(anyhow!("Epoch channel closed: {}", e)); - } - last_epoch = Some(current_epoch); - } - - // Find the next epoch to process - let target_epoch = current_epoch + 1; - if last_epoch.is_none_or(|last| target_epoch > last) { - let target_phases = get_epoch_phases(&self.protocol_config, target_epoch); - - // If registration hasn't started yet, wait for it - if slot < target_phases.registration.start { - let mut rpc = match self.rpc_pool.get_connection().await { - Ok(rpc) => rpc, - Err(e) => { - warn!( - event = "epoch_monitor_wait_rpc_connection_failed", - run_id = %self.run_id, - target_epoch, - error = ?e, - "Failed to get RPC connection while waiting for registration slot" - ); - tokio::time::sleep(Duration::from_secs(1)).await; - continue; - } - }; - - const REGISTRATION_BUFFER_SLOTS: u64 = 30; - let wait_target = target_phases - .registration - .start - .saturating_sub(REGISTRATION_BUFFER_SLOTS); - let slots_to_wait = wait_target.saturating_sub(slot); - - debug!( - event = "epoch_monitor_wait_for_registration", - run_id = %self.run_id, - target_epoch, - current_slot = slot, - wait_target_slot = wait_target, - registration_start_slot = target_phases.registration.start, - slots_to_wait, - "Waiting for target epoch registration phase" - ); - - if let Err(e) = - wait_until_slot_reached(&mut *rpc, &self.slot_tracker, wait_target).await - { - error!( - event = "epoch_monitor_wait_for_registration_failed", - run_id = %self.run_id, - target_epoch, - error = ?e, - "Error waiting for registration phase" - ); - continue; - } - } - - debug!( - event = "epoch_monitor_send_target_epoch", - run_id = %self.run_id, - target_epoch, - "Sending target epoch for processing" - ); - if let Err(e) = tx.send(target_epoch).await { - error!( - event = "epoch_monitor_send_target_epoch_failed", - run_id = %self.run_id, - target_epoch, - error = ?e, - "Failed to send target epoch for processing; channel closed" - ); - return Err(anyhow!("Epoch channel closed: {}", e)); - } - last_epoch = Some(target_epoch); - continue; // Re-check state after processing - } else { - // we've already sent the next epoch, wait a bit before checking again - tokio::time::sleep(Duration::from_secs(10)).await; - } - } - } - - async fn get_processed_items_count(&self, epoch: u64) -> usize { - let counts = self.processed_items_per_epoch_count.lock().await; - counts - .get(&epoch) - .map_or(0, |count| count.load(Ordering::Relaxed)) - } - - async fn increment_processed_items_count(&self, epoch: u64, increment_by: usize) { - let mut counts = self.processed_items_per_epoch_count.lock().await; - counts - .entry(epoch) - .or_insert_with(|| AtomicUsize::new(0)) - .fetch_add(increment_by, Ordering::Relaxed); - } - - async fn get_processing_metrics(&self, epoch: u64) -> ProcessingMetrics { - let metrics = self.processing_metrics_per_epoch.lock().await; - metrics.get(&epoch).copied().unwrap_or_default() - } - - async fn add_processing_metrics(&self, epoch: u64, new_metrics: ProcessingMetrics) { - let mut metrics = self.processing_metrics_per_epoch.lock().await; - *metrics.entry(epoch).or_default() += new_metrics; - } - - async fn recover_registration_info_if_exists( - &self, - epoch: u64, - ) -> std::result::Result, ForesterError> { - debug!("Recovering registration info for epoch {}", epoch); - - let forester_epoch_pda_pubkey = - get_forester_epoch_pda_from_authority(&self.config.derivation_pubkey, epoch).0; - - let existing_pda = { - let rpc = self.rpc_pool.get_connection().await?; - rpc.get_anchor_account::(&forester_epoch_pda_pubkey) - .await? - }; - - match existing_pda { - Some(pda) => self - .recover_registration_info_internal(epoch, forester_epoch_pda_pubkey, pda) - .await - .map(Some) - .map_err(ForesterError::from), - None => Ok(None), - } - } - - async fn process_current_and_previous_epochs(&self, tx: Arc>) -> Result<()> { - let (slot, current_epoch) = self.get_current_slot_and_epoch().await?; - let current_phases = get_epoch_phases(&self.protocol_config, current_epoch); - let previous_epoch = current_epoch.saturating_sub(1); - - // Process the previous epoch if still in active or later phase - if slot > current_phases.registration.start { - debug!("Processing previous epoch: {}", previous_epoch); - if let Err(e) = tx.send(previous_epoch).await { - error!( - event = "initial_epoch_send_previous_failed", - run_id = %self.run_id, - epoch = previous_epoch, - error = ?e, - "Failed to send previous epoch for processing" - ); - return Ok(()); - } - } - - // Always process the current epoch (registration is allowed at any time) - debug!("Processing current epoch: {}", current_epoch); - if let Err(e) = tx.send(current_epoch).await { - error!( - event = "initial_epoch_send_current_failed", - run_id = %self.run_id, - epoch = current_epoch, - error = ?e, - "Failed to send current epoch for processing" - ); - return Ok(()); // Channel closed, exit gracefully - } - - debug!("Finished processing current and previous epochs"); - Ok(()) - } - - #[instrument(level = "debug", skip(self), fields(forester = %self.config.payer_keypair.pubkey(), epoch = epoch))] - async fn process_epoch(&self, epoch: u64) -> Result<()> { - // Clone the Arc immediately to release the DashMap shard lock. - // Without .clone(), the RefMut guard would be held across async operations, - // blocking other epochs from accessing the DashMap if they hash to the same shard. - let processing_flag = self - .processing_epochs - .entry(epoch) - .or_insert_with(|| Arc::new(AtomicBool::new(false))) - .clone(); - - if processing_flag - .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) - .is_err() - { - // Another task is already processing this epoch - debug!("Epoch {} is already being processed, skipping", epoch); - return Ok(()); - } - - // Ensure we reset the processing flag when this scope exits - // (whether by normal return, early return, or panic). - let _reset_guard = scopeguard::guard((), |_| { - processing_flag.store(false, Ordering::SeqCst); - }); - - let phases = get_epoch_phases(&self.protocol_config, epoch); - update_epoch_detected(epoch); - - // Attempt to recover registration info - debug!("Recovering registration info for epoch {}", epoch); - let mut registration_info = match self.recover_registration_info_if_exists(epoch).await { - Ok(Some(info)) => info, - Ok(None) => { - debug!( - "No existing registration found for epoch {}, will register fresh", - epoch - ); - match self - .register_for_epoch_with_retry(epoch, 100, Duration::from_millis(1000)) - .await - { - Ok(info) => info, - Err(e) => return Err(e.into()), - } - } - Err(e) => { - warn!( - event = "recover_registration_info_failed", - run_id = %self.run_id, - epoch, - error = ?e, - "Failed to recover registration info" - ); - return Err(e.into()); - } - }; - debug!("Recovered registration info for epoch {}", epoch); - update_epoch_registered(epoch); - - // Wait for the active phase - registration_info = match self.wait_for_active_phase(®istration_info).await? { - Some(info) => info, - None => { - let current_slot = self.slot_tracker.estimated_current_slot(); - debug!( - event = "epoch_processing_skipped_finalize_registration_phase_ended", - run_id = %self.run_id, - epoch, - current_slot, - active_phase_end_slot = registration_info.epoch.phases.active.end, - "Skipping epoch processing because FinalizeRegistration is no longer possible" - ); - return Ok(()); - } - }; - - // Perform work - if self.sync_slot().await? < phases.active.end { - self.perform_active_work(®istration_info).await?; - } - // Wait for report work phase - if self.sync_slot().await? < phases.report_work.start { - self.wait_for_report_work_phase(®istration_info).await?; - } - - // Always send metrics report to channel for monitoring/testing - // This ensures metrics are captured even if we missed the report_work phase - self.send_work_report(®istration_info).await?; - - // Report work on-chain only if within the report_work phase - if self.sync_slot().await? < phases.report_work.end { - self.report_work_onchain(®istration_info).await?; - } else { - let current_slot = self.slot_tracker.estimated_current_slot(); - info!( - event = "skip_onchain_work_report_phase_ended", - run_id = %self.run_id, - epoch = registration_info.epoch.epoch, - current_slot, - report_work_end_slot = phases.report_work.end, - "Skipping on-chain work report because report_work phase has ended" - ); - } - - // TODO: implement - // self.claim(®istration_info).await?; - - // Clean up per-epoch state now that this epoch is complete. - // In-flight tasks still hold their own Arc clones, so removal is safe. - self.registration_trackers.remove(&epoch); - self.processing_epochs.remove(&epoch); - self.processed_items_per_epoch_count - .lock() - .await - .remove(&epoch); - self.processing_metrics_per_epoch - .lock() - .await - .remove(&epoch); - - info!( - event = "process_epoch_completed", - run_id = %self.run_id, - epoch, "Exiting process_epoch" - ); - Ok(()) - } - - async fn get_current_slot_and_epoch(&self) -> Result<(u64, u64)> { - let slot = self.slot_tracker.estimated_current_slot(); - Ok((slot, self.protocol_config.get_current_epoch(slot))) - } - - #[instrument(level = "debug", skip(self), fields(forester = %self.config.payer_keypair.pubkey(), epoch = epoch - ))] - async fn register_for_epoch_with_retry( - &self, - epoch: u64, - max_retries: u32, - retry_delay: Duration, - ) -> std::result::Result { - let rpc = LightClient::new(LightClientConfig { - url: self.config.external_services.rpc_url.to_string(), - photon_url: self.config.external_services.indexer_url.clone(), - commitment_config: Some(solana_sdk::commitment_config::CommitmentConfig::confirmed()), - fetch_active_tree: false, - }) - .await - .map_err(ForesterError::Rpc)?; - let slot = rpc.get_slot().await.map_err(ForesterError::Rpc)?; - let phases = get_epoch_phases(&self.protocol_config, epoch); - - if slot < phases.registration.start { - let slots_to_wait = phases.registration.start.saturating_sub(slot); - info!( - event = "registration_wait_for_window", - run_id = %self.run_id, - epoch, - current_slot = slot, - registration_start_slot = phases.registration.start, - slots_to_wait, - "Registration window not open yet; waiting" - ); - let wait_duration = slot_duration() * slots_to_wait as u32; - sleep(wait_duration).await; - } - - for attempt in 0..max_retries { - match self.recover_registration_info_if_exists(epoch).await { - Ok(Some(registration_info)) => return Ok(registration_info), - Ok(None) => {} - Err(e) => return Err(e), - } - - match self.register_for_epoch(epoch).await { - Ok(registration_info) => return Ok(registration_info), - Err(e) => { - warn!( - event = "registration_attempt_failed", - run_id = %self.run_id, - epoch, - attempt = attempt + 1, - max_attempts = max_retries, - error = ?e, - "Failed to register for epoch; retrying" - ); - if attempt < max_retries - 1 { - sleep(retry_delay).await; - } else { - if let Some(pagerduty_key) = - self.config.external_services.pagerduty_routing_key.clone() - { - if let Err(alert_err) = send_pagerduty_alert( - &pagerduty_key, - &format!( - "Forester failed to register for epoch {} after {} attempts", - epoch, max_retries - ), - "critical", - &format!("Forester {}", self.config.payer_keypair.pubkey()), - ) - .await - { - error!( - event = "pagerduty_alert_failed", - run_id = %self.run_id, - epoch, - error = ?alert_err, - "Failed to send PagerDuty alert" - ); - } - } - return Err(ForesterError::Other(e)); - } - } - } - } - Err(RegistrationError::MaxRetriesExceeded { - epoch, - attempts: max_retries, - } - .into()) - } - - #[instrument(level = "debug", skip(self), fields(forester = %self.config.payer_keypair.pubkey(), epoch = epoch - ))] - async fn register_for_epoch(&self, epoch: u64) -> Result { - info!( - event = "registration_attempt_started", - run_id = %self.run_id, - epoch, "Registering for epoch" - ); - let mut rpc = LightClient::new(LightClientConfig { - url: self.config.external_services.rpc_url.to_string(), - photon_url: self.config.external_services.indexer_url.clone(), - commitment_config: Some(solana_sdk::commitment_config::CommitmentConfig::processed()), - fetch_active_tree: false, - }) - .await?; - let slot = rpc.get_slot().await?; - let phases = get_epoch_phases(&self.protocol_config, epoch); - - if slot >= phases.registration.start { - let forester_epoch_pda_pubkey = - get_forester_epoch_pda_from_authority(&self.config.derivation_pubkey, epoch).0; - let existing_registration = rpc - .get_anchor_account::(&forester_epoch_pda_pubkey) - .await?; - - if let Some(existing_pda) = existing_registration { - info!( - event = "registration_already_exists", - run_id = %self.run_id, - epoch, "Already registered for epoch; recovering registration info" - ); - let registration_info = self - .recover_registration_info_internal( - epoch, - forester_epoch_pda_pubkey, - existing_pda, - ) - .await?; - return Ok(registration_info); - } - - let registration_info = { - debug!("Registering epoch {}", epoch); - let registered_epoch = match Epoch::register( - &mut rpc, - &self.protocol_config, - &self.config.payer_keypair, - &self.config.derivation_pubkey, - Some(epoch), - ) - .await - .with_context(|| { - format!("Failed to execute epoch registration for epoch {}", epoch) - })? { - Some(epoch) => { - debug!("Registered epoch: {:?}", epoch); - epoch - } - None => { - return Err(RegistrationError::EmptyRegistration.into()); - } - }; - - let forester_epoch_pda = rpc - .get_anchor_account::(®istered_epoch.forester_epoch_pda) - .await - .with_context(|| { - format!( - "Failed to fetch ForesterEpochPda from RPC for address {}", - registered_epoch.forester_epoch_pda - ) - })? - .ok_or(RegistrationError::ForesterEpochPdaNotFound { - epoch, - pda_address: registered_epoch.forester_epoch_pda, - })?; - - let epoch_pda_address = get_epoch_pda_address(epoch); - let epoch_pda = rpc - .get_anchor_account::(&epoch_pda_address) - .await - .with_context(|| { - format!( - "Failed to fetch EpochPda from RPC for address {}", - epoch_pda_address - ) - })? - .ok_or(RegistrationError::EpochPdaNotFound { - epoch, - pda_address: epoch_pda_address, - })?; - - ForesterEpochInfo { - epoch: registered_epoch, - epoch_pda, - forester_epoch_pda, - trees: Vec::new(), - } - }; - debug!("Registered: {:?}", registration_info); - Ok(registration_info) - } else { - warn!( - event = "registration_too_early", - run_id = %self.run_id, - epoch, - current_slot = slot, - registration_start_slot = phases.registration.start, - "Too early to register for epoch" - ); - Err(RegistrationError::RegistrationPhaseNotStarted { - epoch, - current_slot: slot, - registration_start: phases.registration.start, - } - .into()) - } - } - - async fn recover_registration_info_internal( - &self, - epoch: u64, - forester_epoch_pda_address: Pubkey, - forester_epoch_pda: ForesterEpochPda, - ) -> Result { - let rpc = self.rpc_pool.get_connection().await?; - - let phases = get_epoch_phases(&self.protocol_config, epoch); - let slot = rpc.get_slot().await?; - let state = phases.get_current_epoch_state(slot); - - let epoch_pda_address = get_epoch_pda_address(epoch); - let epoch_pda = rpc - .get_anchor_account::(&epoch_pda_address) - .await - .with_context(|| format!("Failed to fetch EpochPda for epoch {}", epoch))? - .ok_or(RegistrationError::EpochPdaNotFound { - epoch, - pda_address: epoch_pda_address, - })?; - - let epoch_info = Epoch { - epoch, - epoch_pda: epoch_pda_address, - forester_epoch_pda: forester_epoch_pda_address, - phases, - state, - merkle_trees: Vec::new(), - }; - - let forester_epoch_info = ForesterEpochInfo { - epoch: epoch_info, - epoch_pda, - forester_epoch_pda, - trees: Vec::new(), - }; - - Ok(forester_epoch_info) - } - - #[instrument(level = "debug", skip(self, epoch_info), fields(forester = %self.config.payer_keypair.pubkey(), epoch = epoch_info.epoch.epoch - ))] - async fn wait_for_active_phase( - &self, - epoch_info: &ForesterEpochInfo, - ) -> std::result::Result, ForesterError> { - let mut rpc = self.rpc_pool.get_connection().await?; - let active_phase_start_slot = epoch_info.epoch.phases.active.start; - let active_phase_end_slot = epoch_info.epoch.phases.active.end; - let current_slot = self.slot_tracker.estimated_current_slot(); - - if current_slot >= active_phase_start_slot { - info!( - event = "active_phase_already_started", - run_id = %self.run_id, - current_slot, - active_phase_start_slot, - active_phase_end_slot, - slots_left = active_phase_end_slot.saturating_sub(current_slot), - "Active phase has already started" - ); - } else { - let waiting_slots = active_phase_start_slot - current_slot; - let waiting_secs = waiting_slots / 2; - info!( - event = "wait_for_active_phase", - run_id = %self.run_id, - current_slot, - active_phase_start_slot, - waiting_slots, - approx_wait_seconds = waiting_secs, - "Waiting for active phase to start" - ); - } - - self.prewarm_all_trees_during_wait(epoch_info, active_phase_start_slot) - .await; - - wait_until_slot_reached(&mut *rpc, &self.slot_tracker, active_phase_start_slot).await?; - - let forester_epoch_pda_pubkey = get_forester_epoch_pda_from_authority( - &self.config.derivation_pubkey, - epoch_info.epoch.epoch, - ) - .0; - let existing_registration = rpc - .get_anchor_account::(&forester_epoch_pda_pubkey) - .await?; - - if let Some(registration) = existing_registration { - if registration.total_epoch_weight.is_none() { - let current_slot = rpc.get_slot().await?; - if current_slot > epoch_info.epoch.phases.active.end { - info!( - event = "skip_finalize_registration_phase_ended", - run_id = %self.run_id, - epoch = epoch_info.epoch.epoch, - current_slot, - active_phase_end_slot = epoch_info.epoch.phases.active.end, - "Skipping FinalizeRegistration because active phase ended" - ); - return Ok(None); - } - - // TODO: we can put this ix into every tx of the first batch of the current active phase - let ix = create_finalize_registration_instruction( - &self.config.payer_keypair.pubkey(), - &self.config.derivation_pubkey, - epoch_info.epoch.epoch, - ); - let priority_fee = self - .resolve_epoch_priority_fee(&*rpc, epoch_info.epoch.epoch) - .await?; - let Some(confirmation_deadline) = scheduled_confirmation_deadline( - epoch_info - .epoch - .phases - .active - .end - .saturating_sub(current_slot), - ) else { - info!( - event = "skip_finalize_registration_confirmation_budget_exhausted", - run_id = %self.run_id, - epoch = epoch_info.epoch.epoch, - current_slot, - active_phase_end_slot = epoch_info.epoch.phases.active.end, - "Skipping FinalizeRegistration because not enough active-phase time remains for confirmation" - ); - return Ok(None); - }; - let payer = self.config.payer_keypair.pubkey(); - let signers = [&self.config.payer_keypair]; - send_smart_transaction( - &mut *rpc, - SendSmartTransactionConfig { - instructions: vec![ix], - payer: &payer, - signers: &signers, - address_lookup_tables: &self.address_lookup_tables, - compute_budget: ComputeBudgetConfig { - compute_unit_price: priority_fee, - compute_unit_limit: Some(self.config.transaction_config.cu_limit), - }, - confirmation: Some(self.confirmation_config()), - confirmation_deadline: Some(confirmation_deadline), - }, - ) - .await - .map_err(RpcError::from)?; - } - } - - let mut epoch_info = (*epoch_info).clone(); - epoch_info.forester_epoch_pda = rpc - .get_anchor_account::(&epoch_info.epoch.forester_epoch_pda) - .await - .with_context(|| { - format!( - "Failed to fetch ForesterEpochPda for epoch {} at address {}", - epoch_info.epoch.epoch, epoch_info.epoch.forester_epoch_pda - ) - })? - .ok_or(RegistrationError::ForesterEpochPdaNotFound { - epoch: epoch_info.epoch.epoch, - pda_address: epoch_info.epoch.forester_epoch_pda, - })?; - - let slot = rpc.get_slot().await?; - let trees = self.trees.lock().await; - trace!("Adding schedule for trees: {:?}", *trees); - epoch_info.add_trees_with_schedule(&trees, slot)?; - - if self.compressible_tracker.is_some() && self.config.compressible_config.is_some() { - let compression_tree_accounts = TreeAccounts { - merkle_tree: solana_sdk::pubkey::Pubkey::default(), - queue: solana_sdk::pubkey::Pubkey::default(), - tree_type: TreeType::Unknown, - is_rolledover: false, - owner: solana_sdk::pubkey::Pubkey::default(), - }; - let tree_schedule = TreeForesterSchedule::new_with_schedule( - &compression_tree_accounts, - slot, - &epoch_info.forester_epoch_pda, - &epoch_info.epoch_pda, - ) - .map_err(anyhow::Error::from)?; - epoch_info.trees.insert(0, tree_schedule); - debug!("Added compression tree to epoch {}", epoch_info.epoch.epoch); - } - - info!( - event = "active_phase_ready", - run_id = %self.run_id, - epoch = epoch_info.epoch.epoch, - "Finished waiting for active phase" - ); - Ok(Some(epoch_info)) - } - - // TODO: add receiver for new tree discovered -> spawn new task to process this tree derive schedule etc. - // TODO: optimize active phase startup time - #[instrument( - level = "debug", - skip(self, epoch_info), - fields(forester = %self.config.payer_keypair.pubkey(), epoch = epoch_info.epoch.epoch - ))] - async fn perform_active_work(&self, epoch_info: &ForesterEpochInfo) -> Result<()> { - self.heartbeat.increment_active_cycle(); - - let current_slot = self.slot_tracker.estimated_current_slot(); - let active_phase_end = epoch_info.epoch.phases.active.end; - - if !self.is_in_active_phase(current_slot, epoch_info)? { - info!( - event = "active_work_skipped_not_in_phase", - run_id = %self.run_id, - current_slot, - active_phase_end, - "No longer in active phase. Skipping work." - ); - return Ok(()); - } - - self.sync_slot().await?; - - let trees_to_process: Vec<_> = epoch_info - .trees - .iter() - .filter(|tree| !should_skip_tree(&self.config, &tree.tree_accounts.tree_type)) - .cloned() - .collect(); - - info!( - event = "active_work_cycle_started", - run_id = %self.run_id, - current_slot, - active_phase_end, - tree_count = trees_to_process.len(), - "Starting active work cycle" - ); - - let self_arc = Arc::new(self.clone()); - let registration_tracker = self - .registration_trackers - .entry(epoch_info.epoch.epoch) - .or_insert_with(|| { - Arc::new(RegistrationTracker::new( - epoch_info.epoch_pda.registered_weight, - )) - }) - .value() - .clone(); - - let mut handles: Vec = Vec::with_capacity(trees_to_process.len()); - - for tree in trees_to_process { - debug!( - event = "tree_processing_task_spawned", - run_id = %self.run_id, - tree = %tree.tree_accounts.merkle_tree, - tree_type = ?tree.tree_accounts.tree_type, - "Spawning tree processing task" - ); - self.heartbeat.add_tree_tasks_spawned(1); - - let self_clone = self_arc.clone(); - let epoch_clone = epoch_info.epoch.clone(); - let forester_epoch_pda = epoch_info.forester_epoch_pda.clone(); - let tracker = registration_tracker.clone(); - - let handle = tokio::spawn(async move { - self_clone - .process_queue(&epoch_clone, forester_epoch_pda, tree, tracker) - .await - }); - - handles.push(handle); - } - - debug!("Waiting for {} tree processing tasks", handles.len()); - let results = join_all(handles).await; - let mut success_count = 0usize; - let mut error_count = 0usize; - let mut panic_count = 0usize; - for result in results { - match result { - Ok(Ok(())) => success_count += 1, - Ok(Err(e)) => { - error_count += 1; - error!( - event = "tree_processing_task_failed", - run_id = %self.run_id, - error = ?e, - "Error processing queue" - ); - } - Err(e) => { - panic_count += 1; - error!( - event = "tree_processing_task_panicked", - run_id = %self.run_id, - error = ?e, - "Tree processing task panicked" - ); - } - } - } - info!( - event = "active_work_cycle_completed", - run_id = %self.run_id, - tree_tasks = success_count + error_count + panic_count, - succeeded = success_count, - failed = error_count, - panicked = panic_count, - "Active work cycle completed" - ); - - debug!("Waiting for active phase to end"); - let mut rpc = self.rpc_pool.get_connection().await?; - wait_until_slot_reached(&mut *rpc, &self.slot_tracker, active_phase_end).await?; - Ok(()) - } - - async fn sync_slot(&self) -> Result { - let rpc = self.rpc_pool.get_connection().await?; - let current_slot = rpc.get_slot().await?; - self.slot_tracker.update(current_slot); - Ok(current_slot) - } - - #[instrument( - level = "debug", - skip(self, epoch_info, forester_epoch_pda, tree_schedule, registration_tracker), - fields(forester = %self.config.payer_keypair.pubkey(), epoch = epoch_info.epoch, - tree = %tree_schedule.tree_accounts.merkle_tree) - )] - pub(crate) async fn process_queue( - &self, - epoch_info: &Epoch, - mut forester_epoch_pda: ForesterEpochPda, - mut tree_schedule: TreeForesterSchedule, - registration_tracker: Arc, - ) -> Result<()> { - self.heartbeat.increment_queue_started(); - let mut current_slot = self.slot_tracker.estimated_current_slot(); - - let total_slots = tree_schedule.slots.len(); - let eligible_slots = tree_schedule.slots.iter().filter(|s| s.is_some()).count(); - let tree_type = tree_schedule.tree_accounts.tree_type; - - debug!( - event = "process_queue_started", - run_id = %self.run_id, - tree = %tree_schedule.tree_accounts.merkle_tree, - tree_type = ?tree_type, - total_slots, - eligible_slots, - current_slot, - active_phase_end = epoch_info.phases.active.end, - "Processing queue for tree" - ); - - let mut last_weight_check = Instant::now(); - const WEIGHT_CHECK_INTERVAL: Duration = Duration::from_secs(30); - - 'outer_slot_loop: while current_slot < epoch_info.phases.active.end { - let next_slot_to_process = tree_schedule - .slots - .iter_mut() - .enumerate() - .find_map(|(idx, opt_slot)| opt_slot.as_ref().map(|s| (idx, s.clone()))); - - if let Some((slot_idx, light_slot_details)) = next_slot_to_process { - let result = match tree_type { - TreeType::StateV1 | TreeType::AddressV1 | TreeType::Unknown => { - self.process_light_slot( - epoch_info, - &forester_epoch_pda, - &tree_schedule.tree_accounts, - &light_slot_details, - ) - .await - } - TreeType::StateV2 | TreeType::AddressV2 => { - let consecutive_end = tree_schedule - .get_consecutive_eligibility_end(slot_idx) - .unwrap_or(light_slot_details.end_solana_slot); - self.process_light_slot_v2( - epoch_info, - &forester_epoch_pda, - &tree_schedule.tree_accounts, - &light_slot_details, - consecutive_end, - ) - .await - } - }; - - let mut force_refinalize = false; - match result { - Ok(_) => { - trace!( - "Successfully processed light slot {:?}", - light_slot_details.slot - ); - } - Err(e) => { - force_refinalize = e.is_forester_not_eligible(); - if force_refinalize { - warn!( - event = "light_slot_processing_stale_eligibility", - run_id = %self.run_id, - tree = %tree_schedule.tree_accounts.merkle_tree, - light_slot = light_slot_details.slot, - "Detected ForesterNotEligible; forcing immediate re-finalization" - ); - } - error!( - event = "light_slot_processing_error", - run_id = %self.run_id, - light_slot = light_slot_details.slot, - error = ?e, - "Error processing light slot" - ); - } - } - tree_schedule.slots[slot_idx] = None; - - // Check if re-finalization is needed: either forced (after - // ForesterNotEligible) or periodic (every WEIGHT_CHECK_INTERVAL). - // force=true bypasses the weight-change check to handle the case - // where cached_weight is correct but schedule was never recomputed. - if force_refinalize || last_weight_check.elapsed() >= WEIGHT_CHECK_INTERVAL { - last_weight_check = Instant::now(); - if let Err(e) = self - .maybe_refinalize( - epoch_info, - &mut forester_epoch_pda, - &mut tree_schedule, - ®istration_tracker, - force_refinalize, - ) - .await - { - warn!( - event = "refinalize_check_failed", - run_id = %self.run_id, - forced = force_refinalize, - error = ?e, - "Failed to check/perform re-finalization" - ); - } - } - } else { - debug!( - event = "process_queue_no_eligible_slots", - run_id = %self.run_id, - tree = %tree_schedule.tree_accounts.merkle_tree, - "No further eligible slots in schedule" - ); - break 'outer_slot_loop; - } - - current_slot = self.slot_tracker.estimated_current_slot(); - } - - self.heartbeat.increment_queue_finished(); - debug!( - event = "process_queue_finished", - run_id = %self.run_id, - tree = %tree_schedule.tree_accounts.merkle_tree, - "Exiting process_queue" - ); - Ok(()) - } - - /// Check if `EpochPda.registered_weight` changed on-chain. If so, - /// one task sends a `finalize_registration` tx while others wait, - /// then all tasks refresh their `ForesterEpochPda` and recompute schedules. - /// - /// When `force` is true (e.g. after a ForesterNotEligible error), skips - /// the weight-change check and unconditionally refreshes the schedule. - async fn maybe_refinalize( - &self, - epoch_info: &Epoch, - forester_epoch_pda: &mut ForesterEpochPda, - tree_schedule: &mut TreeForesterSchedule, - registration_tracker: &RegistrationTracker, - force: bool, - ) -> Result<()> { - let mut rpc = self.rpc_pool.get_connection().await?; - let epoch_pda_address = get_epoch_pda_address(epoch_info.epoch); - let on_chain_epoch_pda: EpochPda = rpc - .get_anchor_account::(&epoch_pda_address) - .await? - .ok_or_else(|| anyhow!("EpochPda not found for epoch {}", epoch_info.epoch))?; - - let on_chain_weight = on_chain_epoch_pda.registered_weight; - let cached_weight = registration_tracker.cached_weight(); - let weight_changed = on_chain_weight != cached_weight; - - if !weight_changed && !force { - return Ok(()); - } - - if weight_changed { - info!( - event = "registered_weight_changed", - run_id = %self.run_id, - epoch = epoch_info.epoch, - old_weight = cached_weight, - new_weight = on_chain_weight, - "Detected new forester registration, re-finalizing" - ); - - if registration_tracker.try_claim_refinalize() { - // This task sends the finalize_registration tx - let ix = create_finalize_registration_instruction( - &self.config.payer_keypair.pubkey(), - &self.config.derivation_pubkey, - epoch_info.epoch, - ); - let priority_fee = self - .resolve_epoch_priority_fee(&*rpc, epoch_info.epoch) - .await?; - let current_slot = rpc.get_slot().await?; - let Some(confirmation_deadline) = scheduled_confirmation_deadline( - epoch_info.phases.active.end.saturating_sub(current_slot), - ) else { - info!( - event = "refinalize_registration_skipped_confirmation_budget_exhausted", - run_id = %self.run_id, - epoch = epoch_info.epoch, - current_slot, - active_phase_end_slot = epoch_info.phases.active.end, - "Skipping re-finalization because not enough active-phase time remains for confirmation" - ); - registration_tracker.complete_refinalize(cached_weight); - return Ok(()); - }; - let payer = self.config.payer_keypair.pubkey(); - let signers = [&self.config.payer_keypair]; - match send_smart_transaction( - &mut *rpc, - SendSmartTransactionConfig { - instructions: vec![ix], - payer: &payer, - signers: &signers, - address_lookup_tables: &self.address_lookup_tables, - compute_budget: ComputeBudgetConfig { - compute_unit_price: priority_fee, - compute_unit_limit: Some(self.config.transaction_config.cu_limit), - }, - confirmation: Some(self.confirmation_config()), - confirmation_deadline: Some(confirmation_deadline), - }, - ) - .await - .map_err(RpcError::from) - { - Ok(_) => { - // Re-fetch EpochPda after finalize to get authoritative - // post-finalize weight (another forester may have registered - // between our initial read and the finalize tx). - // Fallback to on_chain_weight if re-fetch fails to avoid - // deadlocking the RegistrationTracker. - let post_finalize_weight = - match rpc.get_anchor_account::(&epoch_pda_address).await { - Ok(Some(pda)) => pda.registered_weight, - _ => on_chain_weight, - }; - info!( - event = "refinalize_registration_success", - run_id = %self.run_id, - epoch = epoch_info.epoch, - new_weight = post_finalize_weight, - "Re-finalized registration on-chain" - ); - registration_tracker.complete_refinalize(post_finalize_weight); - } - Err(e) => { - // Release the claim so a future check can retry - registration_tracker.complete_refinalize(cached_weight); - return Err(e.into()); - } - } - } else { - // Another task is already re-finalizing; wait for it - registration_tracker.wait_for_refinalize().await; - } - } - - // All tasks: re-fetch both PDAs to get latest on-chain state and - // recompute schedule (after finalize or forced refresh). - let refreshed_epoch_pda: EpochPda = rpc - .get_anchor_account::(&epoch_pda_address) - .await? - .ok_or_else(|| anyhow!("EpochPda not found for epoch {}", epoch_info.epoch))?; - let updated_pda: ForesterEpochPda = rpc - .get_anchor_account::(&epoch_info.forester_epoch_pda) - .await? - .ok_or_else(|| { - anyhow!( - "ForesterEpochPda not found at {} after re-finalization", - epoch_info.forester_epoch_pda - ) - })?; - - let current_slot = self.slot_tracker.estimated_current_slot(); - let new_schedule = TreeForesterSchedule::new_with_schedule( - &tree_schedule.tree_accounts, - current_slot, - &updated_pda, - &refreshed_epoch_pda, - )?; - - *forester_epoch_pda = updated_pda; - *tree_schedule = new_schedule; - - info!( - event = "schedule_recomputed_after_refinalize", - run_id = %self.run_id, - epoch = epoch_info.epoch, - tree = %tree_schedule.tree_accounts.merkle_tree, - new_eligible_slots = tree_schedule.slots.iter().filter(|s| s.is_some()).count(), - "Recomputed schedule after re-finalization" - ); - - Ok(()) - } - - #[instrument( - level = "debug", - skip(self, epoch_info, epoch_pda, tree_accounts, forester_slot_details), - fields(forester = %self.config.payer_keypair.pubkey(), epoch = epoch_info.epoch, - tree = %tree_accounts.merkle_tree) - )] - async fn process_light_slot( - &self, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - tree_accounts: &TreeAccounts, - forester_slot_details: &ForesterSlot, - ) -> std::result::Result<(), ForesterError> { - debug!( - event = "light_slot_processing_started", - run_id = %self.run_id, - tree = %tree_accounts.merkle_tree, - epoch = epoch_info.epoch, - light_slot = forester_slot_details.slot, - slot_start = forester_slot_details.start_solana_slot, - slot_end = forester_slot_details.end_solana_slot, - "Processing light slot" - ); - let mut rpc = self.rpc_pool.get_connection().await?; - wait_until_slot_reached( - &mut *rpc, - &self.slot_tracker, - forester_slot_details.start_solana_slot, - ) - .await?; - let mut estimated_slot = self.slot_tracker.estimated_current_slot(); - - 'inner_processing_loop: loop { - if estimated_slot >= forester_slot_details.end_solana_slot { - trace!( - "Ending processing for slot {:?} due to time limit.", - forester_slot_details.slot - ); - break 'inner_processing_loop; - } - - let current_light_slot = (estimated_slot - epoch_info.phases.active.start) - / epoch_pda.protocol_config.slot_length; - if current_light_slot != forester_slot_details.slot { - warn!( - event = "light_slot_mismatch", - run_id = %self.run_id, - tree = %tree_accounts.merkle_tree, - expected_light_slot = forester_slot_details.slot, - actual_light_slot = current_light_slot, - estimated_slot, - "Light slot mismatch; exiting processing for this slot" - ); - break 'inner_processing_loop; - } - - if !self - .check_forester_eligibility( - epoch_pda, - current_light_slot, - &tree_accounts.queue, - epoch_info.epoch, - epoch_info, - ) - .await? - { - break 'inner_processing_loop; - } - - let processing_start_time = Instant::now(); - let items_processed_this_iteration = match self - .dispatch_tree_processing( - epoch_info, - epoch_pda, - tree_accounts, - forester_slot_details, - forester_slot_details.end_solana_slot, - estimated_slot, - ) - .await - { - Ok(count) => count, - Err(e) => { - if e.is_forester_not_eligible() { - return Err(e); - } - error!( - event = "light_slot_processing_failed", - run_id = %self.run_id, - tree = %tree_accounts.merkle_tree, - light_slot = forester_slot_details.slot, - error = ?e, - "Failed processing in light slot" - ); - break 'inner_processing_loop; - } - }; - if items_processed_this_iteration > 0 { - debug!( - event = "light_slot_items_processed", - run_id = %self.run_id, - light_slot = forester_slot_details.slot, - items = items_processed_this_iteration, - "Processed items in light slot" - ); - } - - self.update_metrics_and_counts( - epoch_info.epoch, - items_processed_this_iteration, - processing_start_time.elapsed(), - ) - .await; - - if let Err(e) = push_metrics(&self.config.external_services.pushgateway_url).await { - if should_emit_rate_limited_warning("push_metrics_v1", Duration::from_secs(30)) { - warn!( - event = "metrics_push_failed", - run_id = %self.run_id, - error = ?e, - "Failed to push metrics" - ); - } else { - debug!( - event = "metrics_push_failed_suppressed", - run_id = %self.run_id, - error = ?e, - "Suppressing repeated metrics push failure" - ); - } - } - estimated_slot = self.slot_tracker.estimated_current_slot(); - - if items_processed_this_iteration == 0 { - // No items processed. Short sleep before re-checking — the queue - // may grow above min_queue_items within this light slot. - tokio::time::sleep(Duration::from_secs(5)).await; - } - // When items were processed, loop immediately to fetch the next batch. - } - Ok(()) - } - - #[instrument( - level = "debug", - skip(self, epoch_info, epoch_pda, tree_accounts, forester_slot_details, consecutive_eligibility_end), - fields(tree = %tree_accounts.merkle_tree) - )] - async fn process_light_slot_v2( - &self, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - tree_accounts: &TreeAccounts, - forester_slot_details: &ForesterSlot, - consecutive_eligibility_end: u64, - ) -> std::result::Result<(), ForesterError> { - debug!( - event = "v2_light_slot_processing_started", - run_id = %self.run_id, - tree = %tree_accounts.merkle_tree, - light_slot = forester_slot_details.slot, - slot_start = forester_slot_details.start_solana_slot, - slot_end = forester_slot_details.end_solana_slot, - consecutive_eligibility_end_slot = consecutive_eligibility_end, - "Processing V2 light slot" - ); - - let tree_pubkey = tree_accounts.merkle_tree; - - let mut rpc = self.rpc_pool.get_connection().await?; - wait_until_slot_reached( - &mut *rpc, - &self.slot_tracker, - forester_slot_details.start_solana_slot, - ) - .await?; - - // Try to send any cached proofs first - let cached_send_start = Instant::now(); - if let Some(items_sent) = self - .try_send_cached_proofs(epoch_info, tree_accounts, consecutive_eligibility_end) - .await? - { - if items_sent > 0 { - let cached_send_duration = cached_send_start.elapsed(); - info!( - event = "cached_proofs_sent", - run_id = %self.run_id, - tree = %tree_pubkey, - items = items_sent, - duration_ms = cached_send_duration.as_millis() as u64, - "Sent items from proof cache" - ); - self.update_metrics_and_counts(epoch_info.epoch, items_sent, cached_send_duration) - .await; - } - } - - let mut estimated_slot = self.slot_tracker.estimated_current_slot(); - - // Polling interval for checking queue - const POLL_INTERVAL: Duration = Duration::from_millis(200); - - 'inner_processing_loop: loop { - if estimated_slot >= forester_slot_details.end_solana_slot { - trace!( - "Ending V2 processing for slot {:?}", - forester_slot_details.slot - ); - break 'inner_processing_loop; - } - - let current_light_slot = (estimated_slot - epoch_info.phases.active.start) - / epoch_pda.protocol_config.slot_length; - if current_light_slot != forester_slot_details.slot { - warn!( - event = "v2_light_slot_mismatch", - run_id = %self.run_id, - tree = %tree_pubkey, - expected_light_slot = forester_slot_details.slot, - actual_light_slot = current_light_slot, - estimated_slot, - "V2 slot mismatch; exiting processing" - ); - break 'inner_processing_loop; - } - - if !self - .check_forester_eligibility( - epoch_pda, - current_light_slot, - &tree_accounts.merkle_tree, - epoch_info.epoch, - epoch_info, - ) - .await? - { - break 'inner_processing_loop; - } - - // Process directly - the processor fetches queue data from the indexer - let processing_start_time = Instant::now(); - match self - .dispatch_tree_processing( - epoch_info, - epoch_pda, - tree_accounts, - forester_slot_details, - consecutive_eligibility_end, - estimated_slot, - ) - .await - { - Ok(count) => { - if count > 0 { - info!( - event = "v2_tree_processed_items", - run_id = %self.run_id, - tree = %tree_pubkey, - items = count, - epoch = epoch_info.epoch, - "V2 processed items for tree" - ); - self.update_metrics_and_counts( - epoch_info.epoch, - count, - processing_start_time.elapsed(), - ) - .await; - } else { - // No items to process, wait before polling again - tokio::time::sleep(POLL_INTERVAL).await; - } - } - Err(e) => { - if e.is_forester_not_eligible() { - return Err(e); - } - error!( - event = "v2_tree_processing_failed", - run_id = %self.run_id, - tree = %tree_pubkey, - error = ?e, - "V2 processing failed for tree" - ); - tokio::time::sleep(POLL_INTERVAL).await; - } - } - - if let Err(e) = push_metrics(&self.config.external_services.pushgateway_url).await { - if should_emit_rate_limited_warning("push_metrics_v2", Duration::from_secs(30)) { - warn!( - event = "metrics_push_failed", - run_id = %self.run_id, - error = ?e, - "Failed to push metrics" - ); - } else { - debug!( - event = "metrics_push_failed_suppressed", - run_id = %self.run_id, - error = ?e, - "Suppressing repeated metrics push failure" - ); - } - } - estimated_slot = self.slot_tracker.estimated_current_slot(); - } - - Ok(()) - } - - async fn check_forester_eligibility( - &self, - epoch_pda: &ForesterEpochPda, - current_light_slot: u64, - queue_pubkey: &Pubkey, - current_epoch_num: u64, - epoch_info: &Epoch, - ) -> Result { - let current_slot = self.slot_tracker.estimated_current_slot(); - let current_phase_state = epoch_info.phases.get_current_epoch_state(current_slot); - - if current_phase_state != EpochState::Active { - trace!( - "Skipping processing: not in active phase (current phase: {:?}, slot: {})", - current_phase_state, - current_slot - ); - return Ok(false); - } - - let total_epoch_weight = epoch_pda.total_epoch_weight.ok_or_else(|| { - anyhow::anyhow!( - "Total epoch weight not available in ForesterEpochPda for epoch {}", - current_epoch_num - ) - })?; - - let eligible_forester_slot_index = ForesterEpochPda::get_eligible_forester_index( - current_light_slot, - queue_pubkey, - total_epoch_weight, - current_epoch_num, - ) - .map_err(|e| { - error!( - event = "eligibility_index_calculation_failed", - run_id = %self.run_id, - queue = %queue_pubkey, - epoch = current_epoch_num, - light_slot = current_light_slot, - error = ?e, - "Failed to calculate eligible forester index" - ); - anyhow::anyhow!("Eligibility calculation failed: {}", e) - })?; - - if !epoch_pda.is_eligible(eligible_forester_slot_index) { - warn!( - event = "forester_not_eligible_for_slot", - run_id = %self.run_id, - forester = %self.config.payer_keypair.pubkey(), - queue = %queue_pubkey, - light_slot = current_light_slot, - "Forester is no longer eligible to process this queue in current light slot" - ); - return Ok(false); - } - Ok(true) - } - - #[allow(clippy::too_many_arguments)] - async fn dispatch_tree_processing( - &self, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - tree_accounts: &TreeAccounts, - forester_slot_details: &ForesterSlot, - consecutive_eligibility_end: u64, - current_solana_slot: u64, - ) -> std::result::Result { - match tree_accounts.tree_type { - TreeType::Unknown => self - .dispatch_compression( - epoch_info, - epoch_pda, - forester_slot_details, - consecutive_eligibility_end, - ) - .await - .map_err(ForesterError::from), - TreeType::StateV1 | TreeType::AddressV1 => { - self.process_v1( - epoch_info, - epoch_pda, - tree_accounts, - forester_slot_details, - current_solana_slot, - ) - .await - } - TreeType::StateV2 | TreeType::AddressV2 => { - let result = self - .process_v2(epoch_info, tree_accounts, consecutive_eligibility_end) - .await?; - // Accumulate processing metrics for this epoch - self.add_processing_metrics(epoch_info.epoch, result.metrics) - .await; - Ok(result.items_processed) - } - } - } - - async fn dispatch_compression( - &self, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - forester_slot_details: &ForesterSlot, - consecutive_eligibility_end: u64, - ) -> Result { - let current_slot = self.slot_tracker.estimated_current_slot(); - if current_slot >= consecutive_eligibility_end { - debug!( - "Skipping compression: forester no longer eligible (current_slot={}, eligibility_end={})", - current_slot, consecutive_eligibility_end - ); - return Ok(0); - } - - if current_slot >= forester_slot_details.end_solana_slot { - debug!( - "Skipping compression: forester slot ended (current_slot={}, slot_end={})", - current_slot, forester_slot_details.end_solana_slot - ); - return Ok(0); - } - - let current_light_slot = current_slot.saturating_sub(epoch_info.phases.active.start) - / epoch_pda.protocol_config.slot_length; - if !self - .check_forester_eligibility( - epoch_pda, - current_light_slot, - &Pubkey::default(), - epoch_info.epoch, - epoch_info, - ) - .await? - { - debug!( - "Skipping compression: forester not eligible for current light slot {}", - current_light_slot - ); - return Ok(0); - } - - debug!("Dispatching compression for epoch {}", epoch_info.epoch); - - let tracker = self - .compressible_tracker - .as_ref() - .ok_or_else(|| anyhow!("Compressible tracker not initialized"))?; - - let config = self - .config - .compressible_config - .as_ref() - .ok_or_else(|| anyhow!("Compressible config not set"))?; - let accounts = tracker.get_ready_to_compress(current_slot); - - if accounts.is_empty() { - trace!("No compressible accounts ready for compression"); - return Ok(0); - } - - let num_batches = accounts.len().div_ceil(config.batch_size); - info!( - event = "compression_ctoken_started", - run_id = %self.run_id, - accounts = accounts.len(), - batches = num_batches, - batch_size = config.batch_size, - "Starting ctoken compression batches" - ); - - let compressor = CTokenCompressor::new( - self.rpc_pool.clone(), - tracker.clone(), - self.config.payer_keypair.insecure_clone(), - self.transaction_policy(), - ); - - // Derive registered forester PDA once for all batches - let (registered_forester_pda, _) = - light_registry::utils::get_forester_epoch_pda_from_authority( - &self.config.derivation_pubkey, - epoch_info.epoch, - ); - - // Create parallel compression futures - use futures::stream::StreamExt; - - // Collect chunks into owned vectors to avoid lifetime issues - let batches: Vec<(usize, Vec<_>)> = accounts - .chunks(config.batch_size) - .enumerate() - .map(|(idx, chunk)| (idx, chunk.to_vec())) - .collect(); - - let slot_tracker = self.slot_tracker.clone(); - // Shared cancellation flag - when set, all pending futures should skip processing - let cancelled = Arc::new(AtomicBool::new(false)); - - let compression_futures = batches.into_iter().map(|(batch_idx, batch)| { - let compressor = compressor.clone(); - let slot_tracker = slot_tracker.clone(); - let cancelled = cancelled.clone(); - async move { - // Check if already cancelled by another future - if cancelled.load(Ordering::Relaxed) { - debug!( - "Skipping compression batch {}/{}: cancelled", - batch_idx + 1, - num_batches - ); - return Err((batch_idx, batch.len(), Cancelled.into())); - } - - // Check forester is still eligible before processing this batch - let current_slot = slot_tracker.estimated_current_slot(); - if current_slot >= consecutive_eligibility_end { - // Signal cancellation to all other futures - cancelled.store(true, Ordering::Relaxed); - warn!( - event = "compression_ctoken_cancelled_not_eligible", - run_id = %self.run_id, - current_slot, - eligibility_end_slot = consecutive_eligibility_end, - "Cancelling compression because forester is no longer eligible" - ); - return Err(( - batch_idx, - batch.len(), - anyhow!("Forester no longer eligible"), - )); - } - - debug!( - "Processing compression batch {}/{} with {} accounts", - batch_idx + 1, - num_batches, - batch.len() - ); - - match compressor - .compress_batch(&batch, registered_forester_pda) - .await - { - Ok(sig) => { - debug!( - "Compression batch {}/{} succeeded: {}", - batch_idx + 1, - num_batches, - sig - ); - Ok((batch_idx, batch.len(), sig)) - } - Err(e) => { - error!( - event = "compression_ctoken_batch_failed", - run_id = %self.run_id, - batch = batch_idx + 1, - total_batches = num_batches, - error = ?e, - "Compression batch failed" - ); - Err((batch_idx, batch.len(), e)) - } - } - } - }); - - // Execute batches in parallel with concurrency limit - let results = futures::stream::iter(compression_futures) - .buffer_unordered(config.max_concurrent_batches) - .collect::>() - .await; - - // Aggregate results - let mut total_compressed = 0; - for result in results { - match result { - Ok((batch_idx, count, sig)) => { - info!( - event = "compression_ctoken_batch_succeeded", - run_id = %self.run_id, - batch = batch_idx + 1, - total_batches = num_batches, - accounts = count, - signature = %sig, - "Compression batch succeeded" - ); - total_compressed += count; - } - Err((batch_idx, count, e)) => { - error!( - event = "compression_ctoken_batch_failed_final", - run_id = %self.run_id, - batch = batch_idx + 1, - total_batches = num_batches, - accounts = count, - error = ?e, - "Compression batch failed" - ); - } - } - } - - info!( - event = "compression_ctoken_completed", - run_id = %self.run_id, - epoch = epoch_info.epoch, - compressed_accounts = total_compressed, - "Completed ctoken compression" - ); - - // Process PDA compression if configured - let pda_compressed = self - .dispatch_pda_compression(epoch_info, epoch_pda, consecutive_eligibility_end) - .await - .unwrap_or_else(|e| { - error!( - event = "compression_pda_dispatch_failed", - run_id = %self.run_id, - error = ?e, - "PDA compression failed" - ); - 0 - }); - - // Process Mint compression - let mint_compressed = self - .dispatch_mint_compression(epoch_info, epoch_pda, consecutive_eligibility_end) - .await - .unwrap_or_else(|e| { - error!( - event = "compression_mint_dispatch_failed", - run_id = %self.run_id, - error = ?e, - "Mint compression failed" - ); - 0 - }); - - let total = total_compressed + pda_compressed + mint_compressed; - info!( - event = "compression_all_completed", - run_id = %self.run_id, - epoch = epoch_info.epoch, - ctoken_compressed = total_compressed, - pda_compressed, - mint_compressed, - total_compressed = total, - "Completed all compression" - ); - Ok(total) - } - - async fn dispatch_pda_compression( - &self, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - consecutive_eligibility_end: u64, - ) -> Result { - let Some((pda_tracker, config, current_slot)) = self - .prepare_compression_dispatch( - self.pda_tracker.as_ref(), - "PDA", - epoch_info, - epoch_pda, - consecutive_eligibility_end, - ) - .await? - else { - return Ok(0); - }; - - if config.pda_programs.is_empty() { - return Ok(0); - } - - let mut total_compressed = 0; - - // Shared cancellation flag across all programs - let cancelled = Arc::new(AtomicBool::new(false)); - - // Process each configured PDA program - for program_config in &config.pda_programs { - // Check cancellation at program level - if cancelled.load(Ordering::Relaxed) { - break; - } - - let accounts = pda_tracker - .get_ready_to_compress_for_program(&program_config.program_id, current_slot); - - if accounts.is_empty() { - trace!( - "No compressible PDA accounts ready for program {}", - program_config.program_id - ); - continue; - } - - info!( - event = "compression_pda_program_started", - run_id = %self.run_id, - program = %program_config.program_id, - accounts = accounts.len(), - "Processing compressible PDA accounts for program" - ); - - let pda_compressor = crate::compressible::pda::PdaCompressor::new( - self.rpc_pool.clone(), - pda_tracker.clone(), - self.config.payer_keypair.insecure_clone(), - self.transaction_policy(), - ); - - // Fetch and cache config once per program - let cached_config = match pda_compressor.fetch_program_config(program_config).await { - Ok(cfg) => cfg, - Err(e) => { - error!( - event = "compression_pda_program_config_failed", - run_id = %self.run_id, - program = %program_config.program_id, - error = ?e, - "Failed to fetch config for PDA program" - ); - continue; - } - }; - - // Check eligibility before processing - let current_slot = self.slot_tracker.estimated_current_slot(); - if current_slot >= consecutive_eligibility_end { - cancelled.store(true, Ordering::Relaxed); - warn!( - event = "compression_pda_cancelled_not_eligible", - run_id = %self.run_id, - current_slot, - eligibility_end_slot = consecutive_eligibility_end, - "Stopping PDA compression because forester is no longer eligible" - ); - break; - } - - // Process all accounts for this program concurrently - let results = pda_compressor - .compress_batch_concurrent( - &accounts, - program_config, - &cached_config, - config.max_concurrent_batches, - cancelled.clone(), - ) - .await; - - // Process results (tracker cleanup already done by compressor) - for result in results { - match result { - CompressionOutcome::Compressed { - signature: sig, - state: account_state, - } => { - debug!( - "Compressed PDA {} for program {}: {}", - account_state.pubkey, program_config.program_id, sig - ); - total_compressed += 1; - } - CompressionOutcome::Failed { - state: _account_state, - error: CompressionTaskError::Cancelled, - } => {} - CompressionOutcome::Failed { - state: account_state, - error: CompressionTaskError::Failed(e), - } => { - error!( - event = "compression_pda_account_failed", - run_id = %self.run_id, - account = %account_state.pubkey, - program = %program_config.program_id, - error = ?e, - "Failed to compress PDA account" - ); - } - } - } - } - - info!( - event = "compression_pda_completed", - run_id = %self.run_id, - compressed_accounts = total_compressed, - "Completed PDA compression" - ); - Ok(total_compressed) - } - - async fn dispatch_mint_compression( - &self, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - consecutive_eligibility_end: u64, - ) -> Result { - let Some((mint_tracker, config, current_slot)) = self - .prepare_compression_dispatch( - self.mint_tracker.as_ref(), - "Mint", - epoch_info, - epoch_pda, - consecutive_eligibility_end, - ) - .await? - else { - return Ok(0); - }; - - let accounts = mint_tracker.get_ready_to_compress(current_slot); - - if accounts.is_empty() { - trace!("No compressible Mint accounts ready"); - return Ok(0); - } - - info!( - event = "compression_mint_started", - run_id = %self.run_id, - accounts = accounts.len(), - max_concurrent = config.max_concurrent_batches, - "Processing compressible Mint accounts" - ); - - let mint_compressor = crate::compressible::mint::MintCompressor::new( - self.rpc_pool.clone(), - mint_tracker.clone(), - self.config.payer_keypair.insecure_clone(), - self.transaction_policy(), - ); - - // Shared cancellation flag - let cancelled = Arc::new(AtomicBool::new(false)); - - // Process all mints concurrently - let results = mint_compressor - .compress_batch_concurrent(&accounts, config.max_concurrent_batches, cancelled) - .await; - - // Process results (tracker cleanup already done by compressor) - let mut total_compressed = 0; - for result in results { - match result { - CompressionOutcome::Compressed { - signature: sig, - state: mint_state, - } => { - debug!("Compressed Mint {}: {}", mint_state.pubkey, sig); - total_compressed += 1; - } - CompressionOutcome::Failed { - state: _mint_state, - error: CompressionTaskError::Cancelled, - } => {} - CompressionOutcome::Failed { - state: mint_state, - error: CompressionTaskError::Failed(e), - } => { - error!( - event = "compression_mint_account_failed", - run_id = %self.run_id, - mint = %mint_state.pubkey, - error = ?e, - "Failed to compress mint account" - ); - } - } - } - - info!( - event = "compression_mint_completed", - run_id = %self.run_id, - compressed_accounts = total_compressed, - "Completed Mint compression" - ); - Ok(total_compressed) - } - - async fn prepare_compression_dispatch<'a, T>( - &'a self, - tracker: Option<&'a T>, - label: &'static str, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - consecutive_eligibility_end: u64, - ) -> Result> { - let Some(tracker) = tracker else { - return Ok(None); - }; - - let Some(config) = self.config.compressible_config.as_ref() else { - return Ok(None); - }; - - let current_slot = self.slot_tracker.estimated_current_slot(); - if current_slot >= consecutive_eligibility_end { - debug!( - "Skipping {} compression: forester no longer eligible (current_slot={}, eligibility_end={})", - label, current_slot, consecutive_eligibility_end - ); - return Ok(None); - } - - let current_light_slot = current_slot.saturating_sub(epoch_info.phases.active.start) - / epoch_pda.protocol_config.slot_length; - if !self - .check_forester_eligibility( - epoch_pda, - current_light_slot, - &Pubkey::default(), - epoch_info.epoch, - epoch_info, - ) - .await? - { - debug!( - "Skipping {} compression: forester not eligible for current light slot {}", - label, current_light_slot - ); - return Ok(None); - } - - Ok(Some((tracker, config, current_slot))) - } - - async fn process_v1( - &self, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - tree_accounts: &TreeAccounts, - forester_slot_details: &ForesterSlot, - current_solana_slot: u64, - ) -> std::result::Result { - let slots_remaining = forester_slot_details - .end_solana_slot - .saturating_sub(current_solana_slot); - let Some(remaining_time_timeout) = scheduled_v1_batch_timeout(slots_remaining) else { - debug!( - event = "v1_tree_skipped_low_slot_budget", - run_id = %self.run_id, - tree = %tree_accounts.merkle_tree, - slots_remaining, - "Skipping V1 tree: not enough scheduled slot budget left to confirm a transaction" - ); - return Ok(0); - }; - - let batched_tx_config = SendBatchedTransactionsConfig { - num_batches: 1, - build_transaction_batch_config: BuildTransactionBatchConfig { - batch_size: self.config.transaction_config.legacy_ixs_per_tx as u64, - compute_unit_price: self.config.transaction_config.priority_fee_microlamports, - compute_unit_limit: Some(self.config.transaction_config.cu_limit), - enable_priority_fees: self.config.transaction_config.enable_priority_fees, - max_concurrent_sends: Some(self.config.transaction_config.max_concurrent_sends), - }, - queue_config: self.config.queue_config, - retry_config: RetryConfig { - timeout: remaining_time_timeout, - ..self.config.retry_config - }, - light_slot_length: epoch_pda.protocol_config.slot_length, - confirmation_poll_interval: Duration::from_millis( - self.config.transaction_config.confirmation_poll_interval_ms, - ), - confirmation_max_attempts: self.config.transaction_config.confirmation_max_attempts - as usize, - min_queue_items: if self.config.enable_v1_multi_nullify - && !self.address_lookup_tables.is_empty() - { - self.config.min_queue_items - } else { - None - }, - enable_presort: self.config.enable_v1_multi_nullify - && !self.address_lookup_tables.is_empty(), - work_item_batch_size: self.config.work_item_batch_size, - }; - - let alt_snapshot = (*self.address_lookup_tables).clone(); - let transaction_builder = Arc::new(EpochManagerTransactions::new( - self.rpc_pool.clone(), - epoch_info.epoch, - self.tx_cache.clone(), - alt_snapshot, - self.config.enable_v1_multi_nullify, - )); - - let num_sent = send_batched_transactions( - &self.config.payer_keypair, - &self.config.derivation_pubkey, - self.rpc_pool.clone(), - &batched_tx_config, - *tree_accounts, - transaction_builder, - ) - .await?; - - if num_sent > 0 { - debug!( - event = "v1_tree_items_processed", - run_id = %self.run_id, - tree = %tree_accounts.merkle_tree, - items = num_sent, - "Processed items for V1 tree" - ); - } - - match self.rollover_if_needed(tree_accounts).await { - Ok(_) => Ok(num_sent), - Err(e) => { - error!( - event = "tree_rollover_failed", - run_id = %self.run_id, - tree = %tree_accounts.merkle_tree, - tree_type = ?tree_accounts.tree_type, - error = ?e, - "Failed to rollover tree" - ); - Err(e.into()) - } - } - } - - fn build_batch_context( - &self, - epoch_info: &Epoch, - tree_accounts: &TreeAccounts, - input_queue_hint: Option, - output_queue_hint: Option, - eligibility_end: Option, - address_lookup_tables: Arc>, - ) -> BatchContext { - let default_prover_url = "http://127.0.0.1:3001".to_string(); - let eligibility_end = eligibility_end.unwrap_or(0); - BatchContext { - rpc_pool: self.rpc_pool.clone(), - authority: self.authority.clone(), - run_id: self.run_id.clone(), - derivation: self.config.derivation_pubkey, - epoch: epoch_info.epoch, - merkle_tree: tree_accounts.merkle_tree, - output_queue: tree_accounts.queue, - prover_config: Arc::new(ProverConfig { - append_url: self - .config - .external_services - .prover_append_url - .clone() - .unwrap_or_else(|| default_prover_url.clone()), - update_url: self - .config - .external_services - .prover_update_url - .clone() - .unwrap_or_else(|| default_prover_url.clone()), - address_append_url: self - .config - .external_services - .prover_address_append_url - .clone() - .unwrap_or_else(|| default_prover_url.clone()), - api_key: self.config.external_services.prover_api_key.clone(), - polling_interval: self - .config - .external_services - .prover_polling_interval - .unwrap_or(Duration::from_secs(1)), - max_wait_time: self - .config - .external_services - .prover_max_wait_time - .unwrap_or(Duration::from_secs(600)), - }), - ops_cache: self.ops_cache.clone(), - epoch_phases: epoch_info.phases.clone(), - slot_tracker: self.slot_tracker.clone(), - input_queue_hint, - output_queue_hint, - num_proof_workers: self.config.transaction_config.max_concurrent_batches, - forester_eligibility_end_slot: Arc::new(AtomicU64::new(eligibility_end)), - address_lookup_tables, - transaction_policy: self.transaction_policy(), - max_batches_per_tree: self.config.transaction_config.max_batches_per_tree, - } - } - - fn confirmation_config(&self) -> ConfirmationConfig { - ConfirmationConfig { - max_attempts: self.config.transaction_config.confirmation_max_attempts, - poll_interval: Duration::from_millis( - self.config.transaction_config.confirmation_poll_interval_ms, - ), - } - } - - fn transaction_priority_fee_config(&self) -> PriorityFeeConfig { - PriorityFeeConfig { - compute_unit_price: self.config.transaction_config.priority_fee_microlamports, - enable_priority_fees: self.config.transaction_config.enable_priority_fees, - } - } - - fn transaction_policy(&self) -> TransactionPolicy { - TransactionPolicy { - priority_fee_config: self.transaction_priority_fee_config(), - compute_unit_limit: Some(self.config.transaction_config.cu_limit), - confirmation: Some(self.confirmation_config()), - } - } - - async fn resolve_epoch_priority_fee( - &self, - rpc: &RpcT, - epoch: u64, - ) -> Result> { - self.transaction_priority_fee_config() - .resolve( - rpc, - vec![ - self.config.payer_keypair.pubkey(), - get_forester_epoch_pda_from_authority(&self.config.derivation_pubkey, epoch).0, - ], - ) - .await - } - - async fn resolve_tree_priority_fee( - &self, - rpc: &RpcT, - epoch: u64, - tree_accounts: &TreeAccounts, - ) -> Result> { - self.transaction_priority_fee_config() - .resolve( - rpc, - vec![ - self.config.payer_keypair.pubkey(), - get_forester_epoch_pda_from_authority(&self.config.derivation_pubkey, epoch).0, - tree_accounts.queue, - tree_accounts.merkle_tree, - ], - ) - .await - } - - async fn get_or_create_state_processor( - &self, - epoch_info: &Epoch, - tree_accounts: &TreeAccounts, - ) -> Result>>> { - // Serialize initialization per tree to avoid duplicate expensive processor construction. - let init_lock = self - .state_processor_init_locks - .entry(tree_accounts.merkle_tree) - .or_insert_with(|| Arc::new(Mutex::new(()))) - .clone(); - let _init_guard = init_lock.lock().await; - - // First check if we already have a processor for this tree - // We REUSE processors across epochs to preserve cached state for optimistic processing - if let Some(entry) = self.state_processors.get(&tree_accounts.merkle_tree) { - let (stored_epoch, processor_ref) = entry.value(); - let processor_clone = processor_ref.clone(); - let old_epoch = *stored_epoch; - drop(entry); // Release read lock before any async operation - - if old_epoch != epoch_info.epoch { - // Update epoch in the map (processor is reused with its cached state) - debug!( - "Reusing StateBatchProcessor for tree {} across epoch transition ({} -> {})", - tree_accounts.merkle_tree, old_epoch, epoch_info.epoch - ); - self.state_processors.insert( - tree_accounts.merkle_tree, - (epoch_info.epoch, processor_clone.clone()), - ); - // Update the processor's epoch context and phases - processor_clone - .lock() - .await - .update_epoch(epoch_info.epoch, epoch_info.phases.clone()); - } - return Ok(processor_clone); - } - - // No existing processor - create new one - let batch_context = self.build_batch_context( - epoch_info, - tree_accounts, - None, - None, - None, - self.address_lookup_tables.clone(), - ); - let processor = Arc::new(Mutex::new( - QueueProcessor::new(batch_context, StateTreeStrategy).await?, - )); - - // Cache the zkp_batch_size for early filtering of queue updates - let batch_size = processor.lock().await.zkp_batch_size(); - self.zkp_batch_sizes - .insert(tree_accounts.merkle_tree, batch_size); - - self.state_processors.insert( - tree_accounts.merkle_tree, - (epoch_info.epoch, processor.clone()), - ); - Ok(processor) - } - - async fn get_or_create_address_processor( - &self, - epoch_info: &Epoch, - tree_accounts: &TreeAccounts, - ) -> Result>>> { - // Serialize initialization per tree to avoid duplicate expensive processor construction. - let init_lock = self - .address_processor_init_locks - .entry(tree_accounts.merkle_tree) - .or_insert_with(|| Arc::new(Mutex::new(()))) - .clone(); - let _init_guard = init_lock.lock().await; - - if let Some(entry) = self.address_processors.get(&tree_accounts.merkle_tree) { - let (stored_epoch, processor_ref) = entry.value(); - let processor_clone = processor_ref.clone(); - let old_epoch = *stored_epoch; - drop(entry); - - if old_epoch != epoch_info.epoch { - debug!( - "Reusing AddressBatchProcessor for tree {} across epoch transition ({} -> {})", - tree_accounts.merkle_tree, old_epoch, epoch_info.epoch - ); - self.address_processors.insert( - tree_accounts.merkle_tree, - (epoch_info.epoch, processor_clone.clone()), - ); - processor_clone - .lock() - .await - .update_epoch(epoch_info.epoch, epoch_info.phases.clone()); - } - return Ok(processor_clone); - } - - // No existing processor - create new one - let batch_context = self.build_batch_context( - epoch_info, - tree_accounts, - None, - None, - None, - self.address_lookup_tables.clone(), - ); - let processor = Arc::new(Mutex::new( - QueueProcessor::new(batch_context, AddressTreeStrategy).await?, - )); - - // Cache the zkp_batch_size for early filtering of queue updates - let batch_size = processor.lock().await.zkp_batch_size(); - self.zkp_batch_sizes - .insert(tree_accounts.merkle_tree, batch_size); - - self.address_processors.insert( - tree_accounts.merkle_tree, - (epoch_info.epoch, processor.clone()), - ); - Ok(processor) - } - - async fn process_v2( - &self, - epoch_info: &Epoch, - tree_accounts: &TreeAccounts, - consecutive_eligibility_end: u64, - ) -> std::result::Result { - match tree_accounts.tree_type { - TreeType::StateV2 => { - let processor = self - .get_or_create_state_processor(epoch_info, tree_accounts) - .await?; - - let cache = self - .proof_caches - .entry(tree_accounts.merkle_tree) - .or_insert_with(|| Arc::new(SharedProofCache::new(tree_accounts.merkle_tree))) - .clone(); - - { - let mut proc = processor.lock().await; - proc.update_eligibility(consecutive_eligibility_end); - proc.set_proof_cache(cache); - } - - let mut proc = processor.lock().await; - match proc.process().await { - Ok(res) => Ok(res), - Err(error) if matches!(&error, ForesterError::V2(v2_error) if v2_error.is_constraint()) => - { - warn!( - event = "v2_state_constraint_error", - run_id = %self.run_id, - tree = %tree_accounts.merkle_tree, - error = %error, - "State processing hit constraint error. Dropping processor to flush cache." - ); - drop(proc); // Release lock before removing - self.state_processors.remove(&tree_accounts.merkle_tree); - self.proof_caches.remove(&tree_accounts.merkle_tree); - Err(error) - } - Err(ForesterError::V2(v2_error)) if v2_error.is_hashchain_mismatch() => { - let warning_key = - format!("v2_state_hashchain_mismatch:{}", tree_accounts.merkle_tree); - if should_emit_rate_limited_warning(warning_key, Duration::from_secs(15)) { - warn!( - event = "v2_state_hashchain_mismatch", - run_id = %self.run_id, - tree = %tree_accounts.merkle_tree, - error = %v2_error, - "State processing hit hashchain mismatch. Clearing cache and retrying." - ); - } - self.heartbeat.increment_v2_recoverable_error(); - proc.clear_cache().await; - Ok(ProcessingResult::default()) - } - Err(e) => { - let warning_key = - format!("v2_state_process_failed:{}", tree_accounts.merkle_tree); - if should_emit_rate_limited_warning(warning_key, Duration::from_secs(10)) { - warn!( - event = "v2_state_process_failed_retrying", - run_id = %self.run_id, - tree = %tree_accounts.merkle_tree, - error = %e, - "Failed to process state queue. Will retry next tick without dropping processor." - ); - } - self.heartbeat.increment_v2_recoverable_error(); - Ok(ProcessingResult::default()) - } - } - } - TreeType::AddressV2 => { - let processor = self - .get_or_create_address_processor(epoch_info, tree_accounts) - .await?; - - let cache = self - .proof_caches - .entry(tree_accounts.merkle_tree) - .or_insert_with(|| Arc::new(SharedProofCache::new(tree_accounts.merkle_tree))) - .clone(); - - { - let mut proc = processor.lock().await; - proc.update_eligibility(consecutive_eligibility_end); - proc.set_proof_cache(cache); - } - - let mut proc = processor.lock().await; - match proc.process().await { - Ok(res) => Ok(res), - Err(error) if matches!(&error, ForesterError::V2(v2_error) if v2_error.is_constraint()) => - { - warn!( - event = "v2_address_constraint_error", - run_id = %self.run_id, - tree = %tree_accounts.merkle_tree, - error = %error, - "Address processing hit constraint error. Dropping processor to flush cache." - ); - drop(proc); - self.address_processors.remove(&tree_accounts.merkle_tree); - self.proof_caches.remove(&tree_accounts.merkle_tree); - Err(error) - } - Err(ForesterError::V2(v2_error)) if v2_error.is_hashchain_mismatch() => { - let warning_key = format!( - "v2_address_hashchain_mismatch:{}", - tree_accounts.merkle_tree - ); - if should_emit_rate_limited_warning(warning_key, Duration::from_secs(15)) { - warn!( - event = "v2_address_hashchain_mismatch", - run_id = %self.run_id, - tree = %tree_accounts.merkle_tree, - error = %v2_error, - "Address processing hit hashchain mismatch. Clearing cache and retrying." - ); - } - self.heartbeat.increment_v2_recoverable_error(); - proc.clear_cache().await; - Ok(ProcessingResult::default()) - } - Err(e) => { - let warning_key = - format!("v2_address_process_failed:{}", tree_accounts.merkle_tree); - if should_emit_rate_limited_warning(warning_key, Duration::from_secs(10)) { - warn!( - event = "v2_address_process_failed_retrying", - run_id = %self.run_id, - tree = %tree_accounts.merkle_tree, - error = %e, - "Failed to process address queue. Will retry next tick without dropping processor." - ); - } - self.heartbeat.increment_v2_recoverable_error(); - Ok(ProcessingResult::default()) - } - } - } - _ => { - warn!( - event = "v2_unsupported_tree_type", - run_id = %self.run_id, - tree_type = ?tree_accounts.tree_type, - "Unsupported tree type for V2 processing" - ); - Ok(ProcessingResult::default()) - } - } - } - - async fn update_metrics_and_counts( - &self, - epoch_num: u64, - items_processed: usize, - duration: Duration, - ) { - if items_processed > 0 { - trace!( - "{} items processed in this iteration, duration: {:?}", - items_processed, - duration - ); - queue_metric_update(epoch_num, items_processed, duration); - self.increment_processed_items_count(epoch_num, items_processed) - .await; - self.heartbeat.add_items_processed(items_processed); - } - } - - async fn prewarm_all_trees_during_wait( - &self, - epoch_info: &ForesterEpochInfo, - deadline_slot: u64, - ) { - let current_slot = self.slot_tracker.estimated_current_slot(); - let slots_until_active = deadline_slot.saturating_sub(current_slot); - - let trees = self.trees.lock().await; - let total_v2_state = trees - .iter() - .filter(|t| matches!(t.tree_type, TreeType::StateV2)) - .count(); - let v2_state_trees: Vec<_> = trees - .iter() - .filter(|t| { - matches!(t.tree_type, TreeType::StateV2) - && !should_skip_tree(&self.config, &t.tree_type) - }) - .cloned() - .collect(); - let skipped_count = total_v2_state - v2_state_trees.len(); - drop(trees); - - if v2_state_trees.is_empty() { - if skipped_count > 0 { - info!( - event = "prewarm_skipped_all_trees_filtered", - run_id = %self.run_id, - skipped_trees = skipped_count, - "No trees to pre-warm; all StateV2 trees skipped by config" - ); - } - return; - } - - if slots_until_active < 15 { - info!( - event = "prewarm_skipped_not_enough_time", - run_id = %self.run_id, - slots_until_active, - min_required_slots = 15, - "Skipping pre-warming; not enough slots until active phase" - ); - return; - } - - let prewarm_futures: Vec<_> = v2_state_trees - .iter() - .map(|tree_accounts| { - let tree_pubkey = tree_accounts.merkle_tree; - let epoch_info = epoch_info.clone(); - let tree_accounts = *tree_accounts; - let self_clone = self.clone(); - - async move { - let cache = self_clone - .proof_caches - .entry(tree_pubkey) - .or_insert_with(|| Arc::new(SharedProofCache::new(tree_pubkey))) - .clone(); - - let cache_len = cache.len().await; - if cache_len > 0 && !cache.is_warming().await { - let mut rpc = match self_clone.rpc_pool.get_connection().await { - Ok(r) => r, - Err(e) => { - warn!( - event = "prewarm_cache_validation_rpc_failed", - run_id = %self_clone.run_id, - tree = %tree_pubkey, - error = ?e, - "Failed to get RPC for cache validation" - ); - return; - } - }; - if let Ok(current_root) = - self_clone.fetch_current_root(&mut *rpc, &tree_accounts).await - { - info!( - event = "prewarm_skipped_cache_already_warm", - run_id = %self_clone.run_id, - tree = %tree_pubkey, - cached_proofs = cache_len, - root_prefix = ?¤t_root[..4], - "Tree already has cached proofs from previous epoch; skipping pre-warm" - ); - return; - } - } - - let processor = match self_clone - .get_or_create_state_processor(&epoch_info.epoch, &tree_accounts) - .await - { - Ok(p) => p, - Err(e) => { - warn!( - event = "prewarm_processor_create_failed", - run_id = %self_clone.run_id, - tree = %tree_pubkey, - error = ?e, - "Failed to create processor for pre-warming tree" - ); - return; - } - }; - - const PREWARM_MAX_BATCHES: usize = 4; - let mut p = processor.lock().await; - match p - .prewarm_from_indexer( - cache.clone(), - light_compressed_account::QueueType::OutputStateV2, - PREWARM_MAX_BATCHES, - ) - .await - { - Ok(result) => { - if result.items_processed > 0 { - info!( - event = "prewarm_tree_completed", - run_id = %self_clone.run_id, - tree = %tree_pubkey, - items = result.items_processed, - "Pre-warmed items for tree during wait" - ); - self_clone - .add_processing_metrics(epoch_info.epoch.epoch, result.metrics) - .await; - } - } - Err(e) => { - debug!( - "Pre-warming from indexer failed for tree {}: {:?}", - tree_pubkey, e - ); - cache.clear().await; - } - } - } - }) - .collect(); - - let timeout_slots = slots_until_active.saturating_sub(5); - let timeout_duration = - (slot_duration() * timeout_slots as u32).min(Duration::from_secs(30)); - - info!( - event = "prewarm_started", - run_id = %self.run_id, - trees = v2_state_trees.len(), - skipped_trees = skipped_count, - timeout_ms = timeout_duration.as_millis() as u64, - "Starting pre-warming" - ); - - match tokio::time::timeout(timeout_duration, futures::future::join_all(prewarm_futures)) - .await - { - Ok(_) => { - info!( - event = "prewarm_completed", - run_id = %self.run_id, - trees = v2_state_trees.len(), - "Completed pre-warming for all trees" - ); - } - Err(_) => { - info!( - event = "prewarm_timed_out", - run_id = %self.run_id, - timeout_ms = timeout_duration.as_millis() as u64, - "Pre-warming timed out" - ); - } - } - } - - async fn try_send_cached_proofs( - &self, - epoch_info: &Epoch, - tree_accounts: &TreeAccounts, - consecutive_eligibility_end: u64, - ) -> Result> { - let tree_pubkey = tree_accounts.merkle_tree; - - // Check eligibility window before attempting to send cached proofs - let current_slot = self.slot_tracker.estimated_current_slot(); - if current_slot >= consecutive_eligibility_end { - debug!( - event = "cached_proofs_skipped_outside_eligibility", - run_id = %self.run_id, - tree = %tree_pubkey, - current_slot, - eligibility_end_slot = consecutive_eligibility_end, - "Skipping cached proof send because eligibility window has ended" - ); - return Ok(None); - } - - let Some(confirmation_deadline) = scheduled_confirmation_deadline( - consecutive_eligibility_end.saturating_sub(current_slot), - ) else { - debug!( - event = "cached_proofs_skipped_confirmation_budget_exhausted", - run_id = %self.run_id, - tree = %tree_pubkey, - current_slot, - eligibility_end_slot = consecutive_eligibility_end, - "Skipping cached proofs because not enough eligible slots remain for confirmation" - ); - return Ok(None); - }; - - let cache = match self.proof_caches.get(&tree_pubkey) { - Some(c) => c.clone(), - None => return Ok(None), - }; - - if cache.is_warming().await { - debug!( - event = "cached_proofs_skipped_cache_warming", - run_id = %self.run_id, - tree = %tree_pubkey, - "Skipping cached proofs because cache is still warming" - ); - return Ok(None); - } - - let mut rpc = self.rpc_pool.get_connection().await?; - let current_root = match self.fetch_current_root(&mut *rpc, tree_accounts).await { - Ok(root) => root, - Err(e) => { - warn!( - event = "cached_proofs_root_fetch_failed", - run_id = %self.run_id, - tree = %tree_pubkey, - error = ?e, - "Failed to fetch current root for tree" - ); - return Ok(None); - } - }; - - let cached_proofs = match cache.take_if_valid(¤t_root).await { - Some(proofs) => proofs, - None => { - debug!( - event = "cached_proofs_not_available", - run_id = %self.run_id, - tree = %tree_pubkey, - root_prefix = ?¤t_root[..4], - "No valid cached proofs for tree" - ); - return Ok(None); - } - }; - - if cached_proofs.is_empty() { - return Ok(Some(0)); - } - - info!( - event = "cached_proofs_send_started", - run_id = %self.run_id, - tree = %tree_pubkey, - proofs = cached_proofs.len(), - root_prefix = ?¤t_root[..4], - "Sending cached proofs for tree" - ); - - let items_sent = self - .send_cached_proofs_as_transactions( - epoch_info, - tree_accounts, - cached_proofs, - confirmation_deadline, - ) - .await?; - - Ok(Some(items_sent)) - } - - async fn fetch_current_root( - &self, - rpc: &mut impl Rpc, - tree_accounts: &TreeAccounts, - ) -> Result<[u8; 32]> { - use light_batched_merkle_tree::merkle_tree::BatchedMerkleTreeAccount; - - let mut account = rpc - .get_account(tree_accounts.merkle_tree) - .await? - .ok_or_else(|| anyhow!("Tree account not found: {}", tree_accounts.merkle_tree))?; - - let tree = match tree_accounts.tree_type { - TreeType::StateV2 => BatchedMerkleTreeAccount::state_from_bytes( - &mut account.data, - &tree_accounts.merkle_tree.into(), - )?, - TreeType::AddressV2 => BatchedMerkleTreeAccount::address_from_bytes( - &mut account.data, - &tree_accounts.merkle_tree.into(), - )?, - _ => return Err(anyhow!("Unsupported tree type for root fetch")), - }; - - let root = tree.root_history.last().copied().unwrap_or([0u8; 32]); - Ok(root) - } - - async fn send_cached_proofs_as_transactions( - &self, - epoch_info: &Epoch, - tree_accounts: &TreeAccounts, - cached_proofs: Vec, - confirmation_deadline: Instant, - ) -> Result { - let mut total_items = 0; - let authority = self.config.payer_keypair.pubkey(); - let derivation = self.config.derivation_pubkey; - - const PROOFS_PER_TX: usize = 4; - for chunk in cached_proofs.chunks(PROOFS_PER_TX) { - let mut instructions = Vec::new(); - let mut chunk_items = 0; - - for proof in chunk { - match &proof.instruction { - BatchInstruction::Append(data) => { - for d in data { - let serialized = d - .try_to_vec() - .with_context(|| "Failed to serialize batch append payload")?; - instructions.push(create_batch_append_instruction( - authority, - derivation, - tree_accounts.merkle_tree, - tree_accounts.queue, - epoch_info.epoch, - serialized, - )); - } - } - BatchInstruction::Nullify(data) => { - for d in data { - let serialized = d - .try_to_vec() - .with_context(|| "Failed to serialize batch nullify payload")?; - instructions.push(create_batch_nullify_instruction( - authority, - derivation, - tree_accounts.merkle_tree, - epoch_info.epoch, - serialized, - )); - } - } - BatchInstruction::AddressAppend(data) => { - for d in data { - let serialized = d.try_to_vec().with_context(|| { - "Failed to serialize batch address append payload" - })?; - instructions.push(create_batch_update_address_tree_instruction( - authority, - derivation, - tree_accounts.merkle_tree, - epoch_info.epoch, - serialized, - )); - } - } - } - chunk_items += proof.items; - } - - if !instructions.is_empty() { - let mut rpc = self.rpc_pool.get_connection().await?; - let priority_fee = self - .resolve_tree_priority_fee(&*rpc, epoch_info.epoch, tree_accounts) - .await?; - let instruction_count = instructions.len(); - let payer = self.config.payer_keypair.pubkey(); - let signers = [&self.config.payer_keypair]; - match send_smart_transaction( - &mut *rpc, - SendSmartTransactionConfig { - instructions, - payer: &payer, - signers: &signers, - address_lookup_tables: &self.address_lookup_tables, - compute_budget: ComputeBudgetConfig { - compute_unit_price: priority_fee, - compute_unit_limit: Some(self.config.transaction_config.cu_limit), - }, - confirmation: Some(self.confirmation_config()), - confirmation_deadline: Some(confirmation_deadline), - }, - ) - .await - .map_err(RpcError::from) - { - Ok(sig) => { - info!( - event = "cached_proofs_tx_sent", - run_id = %self.run_id, - signature = %sig, - instruction_count, - "Sent cached proofs transaction" - ); - total_items += chunk_items; - } - Err(e) => { - warn!( - event = "cached_proofs_tx_send_failed", - run_id = %self.run_id, - error = ?e, - "Failed to send cached proofs transaction" - ); - } - } - } - } - - Ok(total_items) - } - - async fn rollover_if_needed(&self, tree_account: &TreeAccounts) -> Result<()> { - let mut rpc = self.rpc_pool.get_connection().await?; - if is_tree_ready_for_rollover(&mut *rpc, tree_account.merkle_tree, tree_account.tree_type) - .await? - { - info!( - event = "tree_rollover_started", - run_id = %self.run_id, - tree = %tree_account.merkle_tree, - tree_type = ?tree_account.tree_type, - "Starting tree rollover" - ); - self.perform_rollover(tree_account).await?; - } - Ok(()) - } - - fn is_in_active_phase(&self, slot: u64, epoch_info: &ForesterEpochInfo) -> Result { - let current_epoch = self.protocol_config.get_current_active_epoch(slot)?; - if current_epoch != epoch_info.epoch.epoch { - return Ok(false); - } - - Ok(self - .protocol_config - .is_active_phase(slot, epoch_info.epoch.epoch) - .is_ok()) - } - - #[instrument(level = "debug", skip(self, epoch_info), fields(forester = %self.config.payer_keypair.pubkey(), epoch = epoch_info.epoch.epoch - ))] - async fn wait_for_report_work_phase(&self, epoch_info: &ForesterEpochInfo) -> Result<()> { - info!( - event = "wait_for_report_work_phase", - run_id = %self.run_id, - epoch = epoch_info.epoch.epoch, - report_work_start_slot = epoch_info.epoch.phases.report_work.start, - "Waiting for report work phase" - ); - let mut rpc = self.rpc_pool.get_connection().await?; - let report_work_start_slot = epoch_info.epoch.phases.report_work.start; - wait_until_slot_reached(&mut *rpc, &self.slot_tracker, report_work_start_slot).await?; - - info!( - event = "report_work_phase_ready", - run_id = %self.run_id, - epoch = epoch_info.epoch.epoch, - "Finished waiting for report work phase" - ); - Ok(()) - } - - #[instrument(level = "debug", skip(self, epoch_info), fields(forester = %self.config.payer_keypair.pubkey(), epoch = epoch_info.epoch.epoch - ))] - async fn send_work_report(&self, epoch_info: &ForesterEpochInfo) -> Result<()> { - let report = WorkReport { - epoch: epoch_info.epoch.epoch, - processed_items: self.get_processed_items_count(epoch_info.epoch.epoch).await, - metrics: self.get_processing_metrics(epoch_info.epoch.epoch).await, - }; - - info!( - event = "work_report_sent_to_channel", - run_id = %self.run_id, - epoch = report.epoch, - items = report.processed_items, - total_circuit_inputs_ms = report.metrics.total_circuit_inputs().as_millis() as u64, - total_proof_generation_ms = report.metrics.total_proof_generation().as_millis() as u64, - total_round_trip_ms = report.metrics.total_round_trip().as_millis() as u64, - tx_sending_ms = report.metrics.tx_sending_duration.as_millis() as u64, - "Sending work report to channel" - ); - - self.work_report_sender - .send(report) - .await - .map_err(|e| ChannelError::WorkReportSend { - epoch: report.epoch, - error: e.to_string(), - })?; - self.heartbeat.increment_work_report_sent(); - - Ok(()) - } - - #[instrument(level = "debug", skip(self, epoch_info), fields(forester = %self.config.payer_keypair.pubkey(), epoch = epoch_info.epoch.epoch - ))] - async fn report_work_onchain(&self, epoch_info: &ForesterEpochInfo) -> Result<()> { - info!( - event = "work_report_onchain_started", - run_id = %self.run_id, - epoch = epoch_info.epoch.epoch, - "Reporting work on-chain" - ); - let mut rpc = LightClient::new(LightClientConfig { - url: self.config.external_services.rpc_url.to_string(), - photon_url: self.config.external_services.indexer_url.clone(), - commitment_config: Some(solana_sdk::commitment_config::CommitmentConfig::processed()), - fetch_active_tree: false, - }) - .await?; - - let forester_epoch_pda_pubkey = get_forester_epoch_pda_from_authority( - &self.config.derivation_pubkey, - epoch_info.epoch.epoch, - ) - .0; - if let Some(forester_epoch_pda) = rpc - .get_anchor_account::(&forester_epoch_pda_pubkey) - .await? - { - if forester_epoch_pda.has_reported_work { - return Ok(()); - } - } - - let forester_epoch_pda = &epoch_info.forester_epoch_pda; - if forester_epoch_pda.has_reported_work { - return Ok(()); - } - - let ix = create_report_work_instruction( - &self.config.payer_keypair.pubkey(), - &self.config.derivation_pubkey, - epoch_info.epoch.epoch, - ); - - let priority_fee = self - .resolve_epoch_priority_fee(&rpc, epoch_info.epoch.epoch) - .await?; - let payer = self.config.payer_keypair.pubkey(); - let signers = [&self.config.payer_keypair]; - match send_smart_transaction( - &mut rpc, - SendSmartTransactionConfig { - instructions: vec![ix], - payer: &payer, - signers: &signers, - address_lookup_tables: &self.address_lookup_tables, - compute_budget: ComputeBudgetConfig { - compute_unit_price: priority_fee, - compute_unit_limit: Some(self.config.transaction_config.cu_limit), - }, - confirmation: Some(self.confirmation_config()), - confirmation_deadline: None, - }, - ) - .await - .map_err(RpcError::from) - { - Ok(_) => { - info!( - event = "work_report_onchain_succeeded", - run_id = %self.run_id, - epoch = epoch_info.epoch.epoch, - "Work reported on-chain" - ); - } - Err(e) => { - if rpc_is_already_processed(&e) { - info!( - event = "work_report_onchain_already_reported", - run_id = %self.run_id, - epoch = epoch_info.epoch.epoch, - "Work already reported on-chain for epoch" - ); - return Ok(()); - } - if let RpcError::ClientError(client_error) = &e { - if let Some(TransactionError::InstructionError( - _, - InstructionError::Custom(error_code), - )) = client_error.get_transaction_error() - { - return WorkReportError::from_registry_error( - error_code, - epoch_info.epoch.epoch, - ) - .map_err(|e| anyhow::Error::from(ForesterError::from(e))); - } - } - return Err(anyhow::Error::from(WorkReportError::Transaction(Box::new( - e, - )))); - } - } - - Ok(()) - } - - async fn perform_rollover(&self, tree_account: &TreeAccounts) -> Result<()> { - let mut rpc = self.rpc_pool.get_connection().await?; - let (_, current_epoch) = self.get_current_slot_and_epoch().await?; - - let result = match tree_account.tree_type { - TreeType::AddressV1 => { - let new_nullifier_queue_keypair = Keypair::new(); - let new_merkle_tree_keypair = Keypair::new(); - - let rollover_signature = perform_address_merkle_tree_rollover( - &self.config.payer_keypair, - &self.config.derivation_pubkey, - &mut *rpc, - &new_nullifier_queue_keypair, - &new_merkle_tree_keypair, - &tree_account.merkle_tree, - &tree_account.queue, - current_epoch, - ) - .await?; - - info!( - event = "address_tree_rollover_succeeded", - run_id = %self.run_id, - tree = %tree_account.merkle_tree, - signature = %rollover_signature, - "Address tree rollover succeeded" - ); - Ok(()) - } - TreeType::StateV1 => { - let new_nullifier_queue_keypair = Keypair::new(); - let new_merkle_tree_keypair = Keypair::new(); - let new_cpi_signature_keypair = Keypair::new(); - - let rollover_signature = perform_state_merkle_tree_rollover_forester( - &self.config.payer_keypair, - &self.config.derivation_pubkey, - &mut *rpc, - &new_nullifier_queue_keypair, - &new_merkle_tree_keypair, - &new_cpi_signature_keypair, - &tree_account.merkle_tree, - &tree_account.queue, - &Pubkey::default(), - current_epoch, - ) - .await?; - - info!( - event = "state_tree_rollover_succeeded", - run_id = %self.run_id, - tree = %tree_account.merkle_tree, - signature = %rollover_signature, - "State tree rollover succeeded" - ); - - Ok(()) - } - _ => Err(ForesterError::InvalidTreeType(tree_account.tree_type)), - }; - - match result { - Ok(_) => debug!( - "{:?} tree rollover completed successfully", - tree_account.tree_type - ), - Err(e) => warn!("{:?} tree rollover failed: {:?}", tree_account.tree_type, e), - } - Ok(()) - } -} - -fn should_skip_tree(config: &ForesterConfig, tree_type: &TreeType) -> bool { - match tree_type { - TreeType::AddressV1 => config.general_config.skip_v1_address_trees, - TreeType::AddressV2 => config.general_config.skip_v2_address_trees, - TreeType::StateV1 => config.general_config.skip_v1_state_trees, - TreeType::StateV2 => config.general_config.skip_v2_state_trees, - TreeType::Unknown => false, // Never skip compression tree - } -} - -pub fn generate_run_id() -> String { - let epoch_ms = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_millis(); - format!("{}-{}", std::process::id(), epoch_ms) -} - -fn spawn_heartbeat_task( - heartbeat: Arc, - slot_tracker: Arc, - protocol_config: Arc, - run_id: String, -) -> JoinHandle<()> { - tokio::spawn(async move { - let mut interval = tokio::time::interval(Duration::from_secs(20)); - interval.set_missed_tick_behavior(MissedTickBehavior::Skip); - let mut previous = heartbeat.snapshot(); - - loop { - interval.tick().await; - - let slot = slot_tracker.estimated_current_slot(); - let epoch = protocol_config.get_current_active_epoch(slot).ok(); - let epoch_known = epoch.is_some(); - let epoch_value = epoch.unwrap_or_default(); - let current = heartbeat.snapshot(); - let delta = current.delta_since(&previous); - previous = current; - - info!( - event = "service_heartbeat", - run_id = %run_id, - slot, - epoch = epoch_value, - epoch_known, - cycle_delta = delta.active_cycles, - tree_tasks_delta = delta.tree_tasks_spawned, - queues_started_delta = delta.queues_started, - queues_finished_delta = delta.queues_finished, - items_processed_delta = delta.items_processed, - work_reports_delta = delta.work_reports_sent, - recoverable_v2_errors_delta = delta.v2_recoverable_errors, - cycle_total = current.active_cycles, - items_processed_total = current.items_processed, - "Forester heartbeat" - ); - } - }) -} - -#[instrument( - level = "info", - skip( - config, - protocol_config, - rpc_pool, - shutdown, - work_report_sender, - slot_tracker, - tx_cache, - ops_cache, - compressible_tracker, - pda_tracker, - mint_tracker, - run_id - ), - fields(forester = %config.payer_keypair.pubkey()) -)] -#[allow(clippy::too_many_arguments)] -pub async fn run_service( - config: Arc, - protocol_config: Arc, - rpc_pool: Arc>, - mut shutdown: oneshot::Receiver<()>, - work_report_sender: mpsc::Sender, - slot_tracker: Arc, - tx_cache: Arc>, - ops_cache: Arc>, - compressible_tracker: Option>, - pda_tracker: Option>, - mint_tracker: Option>, - run_id: String, -) -> Result<()> { - let heartbeat = Arc::new(ServiceHeartbeat::default()); - let heartbeat_handle = spawn_heartbeat_task( - heartbeat.clone(), - slot_tracker.clone(), - protocol_config.clone(), - run_id.clone(), - ); - - let run_id_for_logs = run_id.clone(); - let result = info_span!( - "run_service", - forester = %config.payer_keypair.pubkey() - ) - .in_scope(|| async move { - let processor_mode_str = match ( - config.general_config.skip_v1_state_trees - && config.general_config.skip_v1_address_trees, - config.general_config.skip_v2_state_trees - && config.general_config.skip_v2_address_trees, - ) { - (true, false) => "v2", - (false, true) => "v1", - (false, false) => "all", - _ => "unknown", - }; - info!( - event = "forester_starting", - run_id = %run_id_for_logs, - processor_mode = processor_mode_str, - "Starting forester" - ); - - const INITIAL_RETRY_DELAY: Duration = Duration::from_secs(1); - const MAX_RETRY_DELAY: Duration = Duration::from_secs(30); - - let mut retry_count = 0; - let mut retry_delay = INITIAL_RETRY_DELAY; - let start_time = Instant::now(); - - let trees = { - let max_attempts = 10; - let mut attempts = 0; - let mut delay = Duration::from_secs(2); - - loop { - tokio::select! { - biased; - _ = &mut shutdown => { - info!( - event = "shutdown_received", - run_id = %run_id_for_logs, - phase = "tree_fetch", - "Received shutdown signal during tree fetch. Stopping." - ); - return Ok(()); - } - result = rpc_pool.get_connection() => { - match result { - Ok(rpc) => { - tokio::select! { - biased; - _ = &mut shutdown => { - info!( - event = "shutdown_received", - run_id = %run_id_for_logs, - phase = "tree_fetch", - "Received shutdown signal during tree fetch. Stopping." - ); - return Ok(()); - } - fetch_result = fetch_trees(&*rpc) => { - match fetch_result { - Ok(mut fetched_trees) => { - let group_authority = match config.general_config.group_authority { - Some(ga) => Some(ga), - None => { - match fetch_protocol_group_authority(&*rpc, run_id_for_logs.as_str()).await { - Ok(ga) => { - info!( - event = "group_authority_default_fetched", - run_id = %run_id_for_logs, - group_authority = %ga, - "Using protocol default group authority" - ); - Some(ga) - } - Err(e) => { - warn!( - event = "group_authority_fetch_failed", - run_id = %run_id_for_logs, - error = ?e, - "Failed to fetch protocol group authority; processing all trees" - ); - None - } - } - } - }; - - if let Some(group_authority) = group_authority { - let before_count = fetched_trees.len(); - fetched_trees.retain(|tree| tree.owner == group_authority); - info!( - event = "trees_filtered_by_group_authority", - run_id = %run_id_for_logs, - group_authority = %group_authority, - trees_before = before_count, - trees_after = fetched_trees.len(), - "Filtered trees by group authority" - ); - } - - if !config.general_config.tree_ids.is_empty() { - let tree_ids = &config.general_config.tree_ids; - fetched_trees.retain(|tree| tree_ids.contains(&tree.merkle_tree)); - if fetched_trees.is_empty() { - error!( - event = "trees_filter_explicit_ids_empty", - run_id = %run_id_for_logs, - requested_tree_count = tree_ids.len(), - requested_trees = ?tree_ids, - "None of the specified trees were found" - ); - return Err(anyhow::anyhow!( - "None of the specified trees found: {:?}", - tree_ids - )); - } - info!( - event = "trees_filter_explicit_ids", - run_id = %run_id_for_logs, - tree_count = tree_ids.len(), - "Processing only explicitly requested trees" - ); - } - break fetched_trees; - } - Err(e) => { - attempts += 1; - if attempts >= max_attempts { - return Err(anyhow::anyhow!( - "Failed to fetch trees after {} attempts: {:?}", - max_attempts, - e - )); - } - warn!( - event = "fetch_trees_failed_retrying", - run_id = %run_id_for_logs, - attempt = attempts, - max_attempts, - retry_delay_ms = delay.as_millis() as u64, - error = ?e, - "Failed to fetch trees; retrying" - ); - } - } - } - } - } - Err(e) => { - attempts += 1; - if attempts >= max_attempts { - return Err(anyhow::anyhow!( - "Failed to get RPC connection for trees after {} attempts: {:?}", - max_attempts, - e - )); - } - warn!( - event = "rpc_connection_failed_retrying", - run_id = %run_id_for_logs, - attempt = attempts, - max_attempts, - retry_delay_ms = delay.as_millis() as u64, - error = ?e, - "Failed to get RPC connection; retrying" - ); - } - } - } - } - - tokio::select! { - biased; - _ = &mut shutdown => { - info!( - event = "shutdown_received", - run_id = %run_id_for_logs, - phase = "tree_fetch_retry_wait", - "Received shutdown signal during retry wait. Stopping." - ); - return Ok(()); - } - _ = sleep(delay) => { - delay = std::cmp::min(delay * 2, Duration::from_secs(30)); - } - } - } - }; - trace!("Fetched initial trees: {:?}", trees); - - if !config.general_config.tree_ids.is_empty() { - info!( - event = "tree_discovery_limited_to_explicit_ids", - run_id = %run_id_for_logs, - tree_count = config.general_config.tree_ids.len(), - "Processing specific trees; tree discovery will be limited" - ); - } - - while retry_count < config.retry_config.max_retries { - debug!("Creating EpochManager (attempt {})", retry_count + 1); - - let address_lookup_tables = { - if let Some(lut_address) = config.lookup_table_address { - let rpc = rpc_pool.get_connection().await?; - let lut = load_lookup_table_async(&*rpc, lut_address).await - .map_err(|e| { - error!( - event = "lookup_table_load_failed", - run_id = %run_id_for_logs, - lookup_table = %lut_address, - error = %e, - "Failed to load lookup table" - ); - e - })?; - info!( - event = "lookup_table_loaded", - run_id = %run_id_for_logs, - lookup_table = %lut_address, - address_count = lut.addresses.len(), - "Loaded lookup table" - ); - Arc::new(vec![lut]) - } else { - debug!("No lookup table address configured. Using v1 state single nullify transactions."); - Arc::new(Vec::new()) - } - }; - - match EpochManager::new( - config.clone(), - protocol_config.clone(), - rpc_pool.clone(), - work_report_sender.clone(), - trees.clone(), - slot_tracker.clone(), - tx_cache.clone(), - ops_cache.clone(), - compressible_tracker.clone(), - pda_tracker.clone(), - mint_tracker.clone(), - address_lookup_tables, - heartbeat.clone(), - run_id.clone(), - ) - .await - { - Ok(epoch_manager) => { - let epoch_manager = Arc::new(epoch_manager); - debug!( - "Successfully created EpochManager after {} attempts", - retry_count + 1 - ); - - let result = tokio::select! { - result = epoch_manager.run() => result, - _ = shutdown => { - info!( - event = "shutdown_received", - run_id = %run_id_for_logs, - phase = "service_run", - "Received shutdown signal. Stopping the service." - ); - Ok(()) - } - }; - - return result; - } - Err(e) => { - warn!( - event = "epoch_manager_create_failed", - run_id = %run_id_for_logs, - attempt = retry_count + 1, - error = ?e, - "Failed to create EpochManager" - ); - retry_count += 1; - if retry_count < config.retry_config.max_retries { - debug!("Retrying in {:?}", retry_delay); - sleep(retry_delay).await; - retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY); - } else { - error!( - event = "forester_start_failed_max_retries", - run_id = %run_id_for_logs, - attempts = config.retry_config.max_retries, - elapsed_ms = start_time.elapsed().as_millis() as u64, - error = ?e, - "Failed to start forester after max retries" - ); - return Err(InitializationError::MaxRetriesExceeded { - attempts: config.retry_config.max_retries, - error: e.to_string(), - } - .into()); - } - } - } - } - - Err( - InitializationError::Unexpected("Retry loop exited without returning".to_string()) - .into(), - ) - }) - .await; - - heartbeat_handle.abort(); - result -} - -/// Async version of load_lookup_table that works with the Rpc trait -async fn load_lookup_table_async( - rpc: &R, - lookup_table_address: Pubkey, -) -> anyhow::Result { - use light_client::rpc::lut::AddressLookupTable; - - let account = rpc - .get_account(lookup_table_address) - .await? - .ok_or_else(|| { - anyhow::anyhow!("Lookup table account not found: {}", lookup_table_address) - })?; - - let address_lookup_table = AddressLookupTable::deserialize(&account.data) - .map_err(|e| anyhow::anyhow!("Failed to deserialize AddressLookupTable: {:?}", e))?; - - Ok(AddressLookupTableAccount { - key: lookup_table_address, - addresses: address_lookup_table.addresses.to_vec(), - }) -} - -#[cfg(test)] -mod tests { - use light_client::rpc::RetryConfig; - use solana_sdk::{pubkey::Pubkey, signature::Keypair}; - - use super::*; - use crate::{ - config::{ExternalServicesConfig, GeneralConfig}, - ForesterConfig, - }; - - fn create_test_config_with_skip_flags( - skip_v1_state: bool, - skip_v1_address: bool, - skip_v2_state: bool, - skip_v2_address: bool, - ) -> ForesterConfig { - ForesterConfig { - external_services: ExternalServicesConfig { - rpc_url: "http://localhost:8899".to_string(), - ws_rpc_url: None, - indexer_url: None, - prover_url: None, - prover_append_url: None, - prover_update_url: None, - prover_address_append_url: None, - prover_api_key: None, - photon_grpc_url: None, - pushgateway_url: None, - pagerduty_routing_key: None, - rpc_rate_limit: None, - photon_rate_limit: None, - send_tx_rate_limit: None, - prover_polling_interval: None, - prover_max_wait_time: None, - fallback_rpc_url: None, - fallback_indexer_url: None, - }, - retry_config: RetryConfig::default(), - queue_config: Default::default(), - indexer_config: Default::default(), - transaction_config: Default::default(), - general_config: GeneralConfig { - enable_metrics: false, - skip_v1_state_trees: skip_v1_state, - skip_v1_address_trees: skip_v1_address, - skip_v2_state_trees: skip_v2_state, - skip_v2_address_trees: skip_v2_address, - sleep_after_processing_ms: 50, - sleep_when_idle_ms: 100, - ..Default::default() - }, - rpc_pool_config: Default::default(), - registry_pubkey: Pubkey::default(), - payer_keypair: Keypair::new(), - derivation_pubkey: Pubkey::default(), - address_tree_data: vec![], - state_tree_data: vec![], - compressible_config: None, - lookup_table_address: None, - min_queue_items: None, - enable_v1_multi_nullify: false, - work_item_batch_size: 50, - } - } - - #[test] - fn test_should_skip_tree_none_skipped() { - let config = create_test_config_with_skip_flags(false, false, false, false); - - assert!(!should_skip_tree(&config, &TreeType::StateV1)); - assert!(!should_skip_tree(&config, &TreeType::StateV2)); - assert!(!should_skip_tree(&config, &TreeType::AddressV1)); - assert!(!should_skip_tree(&config, &TreeType::AddressV2)); - } - - #[test] - fn test_should_skip_tree_all_v1_skipped() { - let config = create_test_config_with_skip_flags(true, true, false, false); - - assert!(should_skip_tree(&config, &TreeType::StateV1)); - assert!(should_skip_tree(&config, &TreeType::AddressV1)); - assert!(!should_skip_tree(&config, &TreeType::StateV2)); - assert!(!should_skip_tree(&config, &TreeType::AddressV2)); - } - - #[test] - fn test_should_skip_tree_all_v2_skipped() { - let config = create_test_config_with_skip_flags(false, false, true, true); - - assert!(!should_skip_tree(&config, &TreeType::StateV1)); - assert!(!should_skip_tree(&config, &TreeType::AddressV1)); - assert!(should_skip_tree(&config, &TreeType::StateV2)); - assert!(should_skip_tree(&config, &TreeType::AddressV2)); - } - - #[test] - fn test_should_skip_tree_only_state_trees() { - let config = create_test_config_with_skip_flags(true, false, true, false); - - assert!(should_skip_tree(&config, &TreeType::StateV1)); - assert!(should_skip_tree(&config, &TreeType::StateV2)); - assert!(!should_skip_tree(&config, &TreeType::AddressV1)); - assert!(!should_skip_tree(&config, &TreeType::AddressV2)); - } - - #[test] - fn test_should_skip_tree_only_address_trees() { - let config = create_test_config_with_skip_flags(false, true, false, true); - - assert!(!should_skip_tree(&config, &TreeType::StateV1)); - assert!(!should_skip_tree(&config, &TreeType::StateV2)); - assert!(should_skip_tree(&config, &TreeType::AddressV1)); - assert!(should_skip_tree(&config, &TreeType::AddressV2)); - } - - #[test] - fn test_should_skip_tree_mixed_config() { - // Skip V1 state and V2 address - let config = create_test_config_with_skip_flags(true, false, false, true); - - assert!(should_skip_tree(&config, &TreeType::StateV1)); - assert!(!should_skip_tree(&config, &TreeType::StateV2)); - assert!(!should_skip_tree(&config, &TreeType::AddressV1)); - assert!(should_skip_tree(&config, &TreeType::AddressV2)); - } - - #[test] - fn test_general_config_test_address_v2() { - let config = GeneralConfig::test_address_v2(); - - assert!(config.skip_v1_state_trees); - assert!(config.skip_v1_address_trees); - assert!(config.skip_v2_state_trees); - assert!(!config.skip_v2_address_trees); - } - - #[test] - fn test_general_config_test_state_v2() { - let config = GeneralConfig::test_state_v2(); - - assert!(config.skip_v1_state_trees); - assert!(config.skip_v1_address_trees); - assert!(!config.skip_v2_state_trees); - assert!(config.skip_v2_address_trees); - } - - #[test] - fn test_work_item_is_address_tree() { - let tree_account = TreeAccounts { - merkle_tree: Pubkey::new_unique(), - queue: Pubkey::new_unique(), - is_rolledover: false, - tree_type: TreeType::AddressV1, - owner: Default::default(), - }; - - let work_item = WorkItem { - tree_account, - queue_item_data: QueueItemData { - hash: [0u8; 32], - index: 0, - leaf_index: None, - }, - }; - - assert!(work_item.is_address_tree()); - assert!(!work_item.is_state_tree()); - } - - #[test] - fn test_work_item_is_state_tree() { - let tree_account = TreeAccounts { - merkle_tree: Pubkey::new_unique(), - queue: Pubkey::new_unique(), - is_rolledover: false, - tree_type: TreeType::StateV1, - owner: Default::default(), - }; - - let work_item = WorkItem { - tree_account, - queue_item_data: QueueItemData { - hash: [0u8; 32], - index: 0, - leaf_index: None, - }, - }; - - assert!(!work_item.is_address_tree()); - assert!(work_item.is_state_tree()); - } - - #[test] - fn test_work_report_creation() { - let report = WorkReport { - epoch: 42, - processed_items: 100, - metrics: ProcessingMetrics { - append: CircuitMetrics { - circuit_inputs_duration: std::time::Duration::from_secs(1), - proof_generation_duration: std::time::Duration::from_secs(3), - round_trip_duration: std::time::Duration::from_secs(10), - }, - nullify: CircuitMetrics { - circuit_inputs_duration: std::time::Duration::from_secs(1), - proof_generation_duration: std::time::Duration::from_secs(2), - round_trip_duration: std::time::Duration::from_secs(8), - }, - address_append: CircuitMetrics { - circuit_inputs_duration: std::time::Duration::from_secs(1), - proof_generation_duration: std::time::Duration::from_secs(2), - round_trip_duration: std::time::Duration::from_secs(9), - }, - tx_sending_duration: std::time::Duration::ZERO, - }, - }; - - assert_eq!(report.epoch, 42); - assert_eq!(report.processed_items, 100); - assert_eq!(report.metrics.total().as_secs(), 10); - assert_eq!(report.metrics.total_circuit_inputs().as_secs(), 3); - assert_eq!(report.metrics.total_proof_generation().as_secs(), 7); - assert_eq!(report.metrics.total_round_trip().as_secs(), 27); - } -} diff --git a/forester/src/epoch_manager/compression.rs b/forester/src/epoch_manager/compression.rs new file mode 100644 index 0000000000..ce51b01f80 --- /dev/null +++ b/forester/src/epoch_manager/compression.rs @@ -0,0 +1,547 @@ +//! Compression dispatch: ctoken, PDA, and mint compression during active phase. + +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; + +use anyhow::anyhow; +use forester_utils::forester_epoch::{Epoch, ForesterSlot}; +use light_client::{indexer::Indexer, rpc::Rpc}; +use light_registry::ForesterEpochPda; +use solana_program::pubkey::Pubkey; +use tracing::{debug, error, info, trace, warn}; + +use super::EpochManager; +use crate::compressible::{ + traits::{ + Cancelled, CompressibleState, CompressibleTracker, CompressionOutcome, CompressionTaskError, + }, + CTokenCompressor, CompressibleConfig, +}; + +impl EpochManager { + pub(crate) async fn dispatch_compression( + &self, + epoch_info: &Epoch, + epoch_pda: &ForesterEpochPda, + forester_slot_details: &ForesterSlot, + consecutive_eligibility_end: u64, + ) -> crate::Result { + let current_slot = self.ctx.slot_tracker.estimated_current_slot(); + if current_slot >= consecutive_eligibility_end { + debug!( + "Skipping compression: forester no longer eligible (current_slot={}, eligibility_end={})", + current_slot, consecutive_eligibility_end + ); + return Ok(0); + } + + if current_slot >= forester_slot_details.end_solana_slot { + debug!( + "Skipping compression: forester slot ended (current_slot={}, slot_end={})", + current_slot, forester_slot_details.end_solana_slot + ); + return Ok(0); + } + + let current_light_slot = current_slot.saturating_sub(epoch_info.phases.active.start) + / epoch_pda.protocol_config.slot_length; + if !self + .check_forester_eligibility( + epoch_pda, + current_light_slot, + &Pubkey::default(), + epoch_info.epoch, + epoch_info, + ) + .await? + { + debug!( + "Skipping compression: forester not eligible for current light slot {}", + current_light_slot + ); + return Ok(0); + } + + debug!("Dispatching compression for epoch {}", epoch_info.epoch); + + let tracker = self + .compressible_tracker + .as_ref() + .ok_or_else(|| anyhow!("Compressible tracker not initialized"))?; + + let config = self + .ctx + .config + .compressible_config + .as_ref() + .ok_or_else(|| anyhow!("Compressible config not set"))?; + + let pending = tracker.pending(); + let accounts: Vec<_> = tracker + .accounts() + .iter() + .filter(|entry| { + entry.value().is_ready_to_compress(current_slot) && !pending.contains(entry.key()) + }) + .map(|entry| entry.value().clone()) + .collect(); + let _ = pending; + + if accounts.is_empty() { + trace!("No compressible accounts ready for compression"); + return Ok(0); + } + + let num_batches = accounts.len().div_ceil(config.batch_size); + info!( + event = "compression_ctoken_started", + run_id = %self.ctx.run_id, + accounts = accounts.len(), + batches = num_batches, + batch_size = config.batch_size, + "Starting ctoken compression batches" + ); + + let compressor = CTokenCompressor::new( + self.ctx.rpc_pool.clone(), + tracker.clone(), + self.ctx.authority.clone(), + self.ctx.transaction_policy(), + ); + + let (registered_forester_pda, _) = + light_registry::utils::get_forester_epoch_pda_from_authority( + &self.ctx.config.derivation_pubkey, + epoch_info.epoch, + ); + + use futures::stream::StreamExt; + + let batches: Vec<(usize, Vec<_>)> = accounts + .chunks(config.batch_size) + .enumerate() + .map(|(idx, chunk)| (idx, chunk.to_vec())) + .collect(); + + let slot_tracker = self.ctx.slot_tracker.clone(); + let cancelled = Arc::new(AtomicBool::new(false)); + + let compression_futures = batches.into_iter().map(|(batch_idx, batch)| { + let compressor = compressor.clone(); + let slot_tracker = slot_tracker.clone(); + let cancelled = cancelled.clone(); + let run_id = self.ctx.run_id.clone(); + async move { + if cancelled.load(Ordering::Relaxed) { + debug!( + "Skipping compression batch {}/{}: cancelled", + batch_idx + 1, + num_batches + ); + return Err((batch_idx, batch.len(), Cancelled.into())); + } + + let current_slot = slot_tracker.estimated_current_slot(); + if current_slot >= consecutive_eligibility_end { + cancelled.store(true, Ordering::Relaxed); + warn!( + event = "compression_ctoken_cancelled_not_eligible", + run_id = %run_id, + current_slot, + eligibility_end_slot = consecutive_eligibility_end, + "Cancelling compression because forester is no longer eligible" + ); + return Err(( + batch_idx, + batch.len(), + anyhow!("Forester no longer eligible"), + )); + } + + debug!( + "Processing compression batch {}/{} with {} accounts", + batch_idx + 1, + num_batches, + batch.len() + ); + + match compressor + .compress_batch(&batch, registered_forester_pda) + .await + { + Ok(sig) => { + debug!( + "Compression batch {}/{} succeeded: {}", + batch_idx + 1, + num_batches, + sig + ); + Ok((batch_idx, batch.len(), sig)) + } + Err(e) => { + error!( + event = "compression_ctoken_batch_failed", + run_id = %run_id, + batch = batch_idx + 1, + total_batches = num_batches, + error = ?e, + "Compression batch failed" + ); + Err((batch_idx, batch.len(), e)) + } + } + } + }); + + let results = futures::stream::iter(compression_futures) + .buffer_unordered(config.max_concurrent_batches) + .collect::>() + .await; + + let mut total_compressed = 0; + for result in results { + match result { + Ok((batch_idx, count, sig)) => { + info!( + event = "compression_ctoken_batch_succeeded", + run_id = %self.ctx.run_id, + batch = batch_idx + 1, + total_batches = num_batches, + accounts = count, + signature = %sig, + "Compression batch succeeded" + ); + total_compressed += count; + } + Err((batch_idx, count, e)) => { + error!( + event = "compression_ctoken_batch_failed_final", + run_id = %self.ctx.run_id, + batch = batch_idx + 1, + total_batches = num_batches, + accounts = count, + error = ?e, + "Compression batch failed" + ); + } + } + } + + info!( + event = "compression_ctoken_completed", + run_id = %self.ctx.run_id, + epoch = epoch_info.epoch, + compressed_accounts = total_compressed, + "Completed ctoken compression" + ); + + let pda_compressed = self + .dispatch_pda_compression(epoch_info, epoch_pda, consecutive_eligibility_end) + .await + .unwrap_or_else(|e| { + error!( + event = "compression_pda_dispatch_failed", + run_id = %self.ctx.run_id, + error = ?e, + "PDA compression failed" + ); + 0 + }); + + let mint_compressed = self + .dispatch_mint_compression(epoch_info, epoch_pda, consecutive_eligibility_end) + .await + .unwrap_or_else(|e| { + error!( + event = "compression_mint_dispatch_failed", + run_id = %self.ctx.run_id, + error = ?e, + "Mint compression failed" + ); + 0 + }); + + let total = total_compressed + pda_compressed + mint_compressed; + info!( + event = "compression_all_completed", + run_id = %self.ctx.run_id, + epoch = epoch_info.epoch, + ctoken_compressed = total_compressed, + pda_compressed, + mint_compressed, + total_compressed = total, + "Completed all compression" + ); + Ok(total) + } + + async fn dispatch_pda_compression( + &self, + epoch_info: &Epoch, + epoch_pda: &ForesterEpochPda, + consecutive_eligibility_end: u64, + ) -> crate::Result { + let Some((pda_tracker, config, current_slot)) = self + .prepare_compression_dispatch( + self.pda_tracker.as_ref(), + "PDA", + epoch_info, + epoch_pda, + consecutive_eligibility_end, + ) + .await? + else { + return Ok(0); + }; + + if config.pda_programs.is_empty() { + return Ok(0); + } + + let mut total_compressed = 0; + let cancelled = Arc::new(AtomicBool::new(false)); + + for program_config in &config.pda_programs { + if cancelled.load(Ordering::Relaxed) { + break; + } + + let accounts = pda_tracker + .get_ready_to_compress_for_program(&program_config.program_id, current_slot); + + if accounts.is_empty() { + trace!( + "No compressible PDA accounts ready for program {}", + program_config.program_id + ); + continue; + } + + info!( + event = "compression_pda_program_started", + run_id = %self.ctx.run_id, + program = %program_config.program_id, + accounts = accounts.len(), + "Processing compressible PDA accounts for program" + ); + + let pda_compressor = crate::compressible::pda::PdaCompressor::new( + self.ctx.rpc_pool.clone(), + pda_tracker.clone(), + self.ctx.authority.clone(), + self.ctx.transaction_policy(), + ); + + let cached_config = match pda_compressor.fetch_program_config(program_config).await { + Ok(cfg) => cfg, + Err(e) => { + error!( + event = "compression_pda_program_config_failed", + run_id = %self.ctx.run_id, + program = %program_config.program_id, + error = ?e, + "Failed to fetch config for PDA program" + ); + continue; + } + }; + + let current_slot = self.ctx.slot_tracker.estimated_current_slot(); + if current_slot >= consecutive_eligibility_end { + cancelled.store(true, Ordering::Relaxed); + warn!( + event = "compression_pda_cancelled_not_eligible", + run_id = %self.ctx.run_id, + current_slot, + eligibility_end_slot = consecutive_eligibility_end, + "Stopping PDA compression because forester is no longer eligible" + ); + break; + } + + let results = pda_compressor + .compress_batch_concurrent( + &accounts, + program_config, + &cached_config, + config.max_concurrent_batches, + cancelled.clone(), + ) + .await; + + for result in results { + match result { + CompressionOutcome::Compressed { + signature: sig, + pubkey, + } => { + debug!( + "Compressed PDA {} for program {}: {}", + pubkey, program_config.program_id, sig + ); + total_compressed += 1; + } + CompressionOutcome::Failed { + error: CompressionTaskError::Cancelled, + .. + } => {} + CompressionOutcome::Failed { + pubkey, + error: CompressionTaskError::Failed(e), + } => { + error!( + event = "compression_pda_account_failed", + run_id = %self.ctx.run_id, + account = %pubkey, + program = %program_config.program_id, + error = ?e, + "Failed to compress PDA account" + ); + } + } + } + } + + info!( + event = "compression_pda_completed", + run_id = %self.ctx.run_id, + compressed_accounts = total_compressed, + "Completed PDA compression" + ); + Ok(total_compressed) + } + + async fn dispatch_mint_compression( + &self, + epoch_info: &Epoch, + epoch_pda: &ForesterEpochPda, + consecutive_eligibility_end: u64, + ) -> crate::Result { + let Some((mint_tracker, config, current_slot)) = self + .prepare_compression_dispatch( + self.mint_tracker.as_ref(), + "Mint", + epoch_info, + epoch_pda, + consecutive_eligibility_end, + ) + .await? + else { + return Ok(0); + }; + + let accounts = mint_tracker.get_ready_to_compress(current_slot); + + if accounts.is_empty() { + trace!("No compressible Mint accounts ready"); + return Ok(0); + } + + info!( + event = "compression_mint_started", + run_id = %self.ctx.run_id, + accounts = accounts.len(), + max_concurrent = config.max_concurrent_batches, + "Processing compressible Mint accounts" + ); + + let mint_compressor = crate::compressible::mint::MintCompressor::new( + self.ctx.rpc_pool.clone(), + mint_tracker.clone(), + self.ctx.authority.clone(), + self.ctx.transaction_policy(), + ); + + let cancelled = Arc::new(AtomicBool::new(false)); + + let results = mint_compressor + .compress_batch_concurrent(&accounts, config.max_concurrent_batches, cancelled) + .await; + + let mut total_compressed = 0; + for result in results { + match result { + CompressionOutcome::Compressed { + signature: sig, + pubkey, + } => { + debug!("Compressed Mint {}: {}", pubkey, sig); + total_compressed += 1; + } + CompressionOutcome::Failed { + error: CompressionTaskError::Cancelled, + .. + } => {} + CompressionOutcome::Failed { + pubkey, + error: CompressionTaskError::Failed(e), + } => { + error!( + event = "compression_mint_account_failed", + run_id = %self.ctx.run_id, + mint = %pubkey, + error = ?e, + "Failed to compress mint account" + ); + } + } + } + + info!( + event = "compression_mint_completed", + run_id = %self.ctx.run_id, + compressed_accounts = total_compressed, + "Completed Mint compression" + ); + Ok(total_compressed) + } + + async fn prepare_compression_dispatch<'a, T>( + &'a self, + tracker: Option<&'a T>, + label: &'static str, + epoch_info: &Epoch, + epoch_pda: &ForesterEpochPda, + consecutive_eligibility_end: u64, + ) -> crate::Result> { + let Some(tracker) = tracker else { + return Ok(None); + }; + + let Some(config) = self.ctx.config.compressible_config.as_ref() else { + return Ok(None); + }; + + let current_slot = self.ctx.slot_tracker.estimated_current_slot(); + if current_slot >= consecutive_eligibility_end { + debug!( + "Skipping {} compression: forester no longer eligible (current_slot={}, eligibility_end={})", + label, current_slot, consecutive_eligibility_end + ); + return Ok(None); + } + + let current_light_slot = current_slot.saturating_sub(epoch_info.phases.active.start) + / epoch_pda.protocol_config.slot_length; + if !self + .check_forester_eligibility( + epoch_pda, + current_light_slot, + &Pubkey::default(), + epoch_info.epoch, + epoch_info, + ) + .await? + { + debug!( + "Skipping {} compression: forester not eligible for current light slot {}", + label, current_light_slot + ); + return Ok(None); + } + + Ok(Some((tracker, config, current_slot))) + } +} diff --git a/forester/src/epoch_manager/context.rs b/forester/src/epoch_manager/context.rs new file mode 100644 index 0000000000..40814ef322 --- /dev/null +++ b/forester/src/epoch_manager/context.rs @@ -0,0 +1,210 @@ +use std::{ + sync::{atomic::AtomicU64, Arc}, + time::Duration, +}; + +use forester_utils::{ + forester_epoch::{Epoch, TreeAccounts}, + rpc_pool::SolanaRpcPool, +}; +use light_client::{indexer::Indexer, rpc::Rpc}; +use light_compressed_account::TreeType; +use light_registry::{ + protocol_config::state::ProtocolConfig, utils::get_forester_epoch_pda_from_authority, +}; +use solana_sdk::{ + address_lookup_table::AddressLookupTableAccount, + signature::{Keypair, Signer}, +}; +use tokio::sync::Mutex; + +use crate::{ + logging::ServiceHeartbeat, + priority_fee::PriorityFeeConfig, + processor::{ + tx_cache::ProcessedHashCache, + v2::{BatchContext, ProverConfig}, + }, + slot_tracker::SlotTracker, + smart_transaction::{ConfirmationConfig, TransactionPolicy}, + ForesterConfig, Result, +}; + +/// Shared immutable infrastructure, cheap to clone (all Arc). +/// +/// This is the "context" that nearly every method needs: config, RPC access, +/// slot tracking, and transaction helpers. Extracted from EpochManager to +/// make dependencies explicit and enable isolated testing. +#[derive(Debug)] +pub(crate) struct ForesterContext { + pub config: Arc, + pub protocol_config: Arc, + pub rpc_pool: Arc>, + pub authority: Arc, + pub slot_tracker: Arc, + pub address_lookup_tables: Arc>, + pub heartbeat: Arc, + pub run_id: Arc, +} + +impl Clone for ForesterContext { + fn clone(&self) -> Self { + Self { + config: self.config.clone(), + protocol_config: self.protocol_config.clone(), + rpc_pool: self.rpc_pool.clone(), + authority: self.authority.clone(), + slot_tracker: self.slot_tracker.clone(), + address_lookup_tables: self.address_lookup_tables.clone(), + heartbeat: self.heartbeat.clone(), + run_id: self.run_id.clone(), + } + } +} + +impl ForesterContext { + pub fn confirmation_config(&self) -> ConfirmationConfig { + ConfirmationConfig { + max_attempts: self.config.transaction_config.confirmation_max_attempts, + poll_interval: Duration::from_millis( + self.config.transaction_config.confirmation_poll_interval_ms, + ), + } + } + + pub fn transaction_priority_fee_config(&self) -> PriorityFeeConfig { + PriorityFeeConfig { + compute_unit_price: self.config.transaction_config.priority_fee_microlamports, + enable_priority_fees: self.config.transaction_config.enable_priority_fees, + } + } + + pub fn transaction_policy(&self) -> TransactionPolicy { + TransactionPolicy { + priority_fee_config: self.transaction_priority_fee_config(), + compute_unit_limit: Some(self.config.transaction_config.cu_limit), + confirmation: Some(self.confirmation_config()), + } + } + + pub async fn resolve_epoch_priority_fee( + &self, + rpc: &RpcT, + epoch: u64, + ) -> Result> { + self.transaction_priority_fee_config() + .resolve( + rpc, + vec![ + self.config.payer_keypair.pubkey(), + get_forester_epoch_pda_from_authority(&self.config.derivation_pubkey, epoch).0, + ], + ) + .await + } + + pub async fn resolve_tree_priority_fee( + &self, + rpc: &RpcT, + epoch: u64, + tree_accounts: &TreeAccounts, + ) -> Result> { + self.transaction_priority_fee_config() + .resolve( + rpc, + vec![ + self.config.payer_keypair.pubkey(), + get_forester_epoch_pda_from_authority(&self.config.derivation_pubkey, epoch).0, + tree_accounts.queue, + tree_accounts.merkle_tree, + ], + ) + .await + } + + pub async fn sync_slot(&self) -> Result { + let rpc = self.rpc_pool.get_connection().await?; + let current_slot = rpc.get_slot().await?; + self.slot_tracker.update(current_slot); + Ok(current_slot) + } + + pub async fn get_current_slot_and_epoch(&self) -> Result<(u64, u64)> { + let slot = self.slot_tracker.estimated_current_slot(); + let epoch = self.protocol_config.get_current_active_epoch(slot)?; + Ok((slot, epoch)) + } + + pub fn build_batch_context( + &self, + epoch_info: &Epoch, + tree_accounts: &TreeAccounts, + input_queue_hint: Option, + output_queue_hint: Option, + eligibility_end: Option, + ops_cache: Arc>, + ) -> BatchContext { + let default_prover_url = "http://127.0.0.1:3001".to_string(); + let eligibility_end = eligibility_end.unwrap_or(0); + BatchContext { + rpc_pool: self.rpc_pool.clone(), + authority: self.authority.clone(), + run_id: self.run_id.clone(), + derivation: self.config.derivation_pubkey, + epoch: epoch_info.epoch, + merkle_tree: tree_accounts.merkle_tree, + output_queue: tree_accounts.queue, + prover_config: Arc::new(ProverConfig { + append_url: self + .config + .external_services + .prover_append_url + .clone() + .unwrap_or_else(|| default_prover_url.clone()), + update_url: self + .config + .external_services + .prover_update_url + .clone() + .unwrap_or_else(|| default_prover_url.clone()), + address_append_url: self + .config + .external_services + .prover_address_append_url + .clone() + .unwrap_or_else(|| default_prover_url.clone()), + api_key: self.config.external_services.prover_api_key.clone(), + polling_interval: self + .config + .external_services + .prover_polling_interval + .unwrap_or(Duration::from_secs(1)), + max_wait_time: self + .config + .external_services + .prover_max_wait_time + .unwrap_or(Duration::from_secs(600)), + }), + ops_cache, + epoch_phases: epoch_info.phases.clone(), + slot_tracker: self.slot_tracker.clone(), + input_queue_hint, + output_queue_hint, + num_proof_workers: self.config.transaction_config.max_concurrent_batches, + forester_eligibility_end_slot: Arc::new(AtomicU64::new(eligibility_end)), + address_lookup_tables: self.address_lookup_tables.clone(), + transaction_policy: self.transaction_policy(), + max_batches_per_tree: self.config.transaction_config.max_batches_per_tree, + } + } +} + +pub(crate) fn should_skip_tree(config: &ForesterConfig, tree_type: &TreeType) -> bool { + match tree_type { + TreeType::AddressV1 => config.general_config.skip_v1_address_trees, + TreeType::AddressV2 => config.general_config.skip_v2_address_trees, + TreeType::StateV1 => config.general_config.skip_v1_state_trees, + TreeType::StateV2 => config.general_config.skip_v2_state_trees, + TreeType::Unknown => false, // Never skip compression tree + } +} diff --git a/forester/src/epoch_manager/mod.rs b/forester/src/epoch_manager/mod.rs new file mode 100644 index 0000000000..fc8be00b47 --- /dev/null +++ b/forester/src/epoch_manager/mod.rs @@ -0,0 +1,1068 @@ +mod compression; +pub(crate) mod context; +mod monitor; +mod pipeline; +pub(crate) mod processor_pool; +mod registration; +mod reporting; +pub(crate) mod tracker; +mod v1; +mod v2; + +use std::{ + sync::Arc, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; + +use anyhow::anyhow; +use forester_utils::{forester_epoch::TreeAccounts, rpc_pool::SolanaRpcPool}; +use light_client::{indexer::Indexer, rpc::Rpc}; +use light_compressed_account::TreeType; +use light_registry::protocol_config::state::ProtocolConfig; +use solana_sdk::{address_lookup_table::AddressLookupTableAccount, signature::Signer}; +use tokio::{ + sync::{mpsc, oneshot, watch, Mutex}, + task::JoinHandle, + time::{sleep, Instant, MissedTickBehavior}, +}; +use tracing::{debug, error, info, info_span}; + +use self::{context::ForesterContext, processor_pool::ProcessorPool, tracker::EpochTracker}; +use crate::{ + compressible::CTokenAccountTracker, + errors::InitializationError, + logging::ServiceHeartbeat, + processor::tx_cache::ProcessedHashCache, + queue_helpers::QueueItemData, + slot_tracker::SlotTracker, + tree_data_sync::{fetch_protocol_group_authority, fetch_trees}, + ForesterConfig, Result, +}; + +// ── Public re-exports (preserve existing public API) ───────────────────── + +/// Timing for a single circuit type (circuit inputs + proof generation) +#[derive(Copy, Clone, Debug, Default)] +pub struct CircuitMetrics { + /// Time spent building circuit inputs + pub circuit_inputs_duration: std::time::Duration, + /// Time spent generating ZK proofs (pure prover server time) + pub proof_generation_duration: std::time::Duration, + /// Total round-trip time (submit to result, includes queue wait) + pub round_trip_duration: std::time::Duration, +} + +impl CircuitMetrics { + pub fn total(&self) -> std::time::Duration { + self.circuit_inputs_duration + self.proof_generation_duration + } +} + +impl std::ops::AddAssign for CircuitMetrics { + fn add_assign(&mut self, rhs: Self) { + self.circuit_inputs_duration += rhs.circuit_inputs_duration; + self.proof_generation_duration += rhs.proof_generation_duration; + self.round_trip_duration += rhs.round_trip_duration; + } +} + +/// Timing breakdown by circuit type +#[derive(Copy, Clone, Debug, Default)] +pub struct ProcessingMetrics { + /// State append circuit (output queue processing) + pub append: CircuitMetrics, + /// State nullify circuit (input queue processing) + pub nullify: CircuitMetrics, + /// Address append circuit + pub address_append: CircuitMetrics, + /// Time spent sending transactions (overlapped with proof gen) + pub tx_sending_duration: std::time::Duration, +} + +impl ProcessingMetrics { + pub fn total(&self) -> std::time::Duration { + self.append.total() + + self.nullify.total() + + self.address_append.total() + + self.tx_sending_duration + } + + pub fn total_circuit_inputs(&self) -> std::time::Duration { + self.append.circuit_inputs_duration + + self.nullify.circuit_inputs_duration + + self.address_append.circuit_inputs_duration + } + + pub fn total_proof_generation(&self) -> std::time::Duration { + self.append.proof_generation_duration + + self.nullify.proof_generation_duration + + self.address_append.proof_generation_duration + } + + pub fn total_round_trip(&self) -> std::time::Duration { + self.append.round_trip_duration + + self.nullify.round_trip_duration + + self.address_append.round_trip_duration + } +} + +impl std::ops::AddAssign for ProcessingMetrics { + fn add_assign(&mut self, rhs: Self) { + self.append += rhs.append; + self.nullify += rhs.nullify; + self.address_append += rhs.address_append; + self.tx_sending_duration += rhs.tx_sending_duration; + } +} + +#[derive(Copy, Clone, Debug)] +pub struct WorkReport { + pub epoch: u64, + pub processed_items: usize, + pub metrics: ProcessingMetrics, +} + +#[derive(Debug, Clone)] +pub struct WorkItem { + pub tree_account: TreeAccounts, + pub queue_item_data: QueueItemData, +} + +impl WorkItem { + pub fn is_address_tree(&self) -> bool { + self.tree_account.tree_type == TreeType::AddressV1 + } + pub fn is_state_tree(&self) -> bool { + self.tree_account.tree_type == TreeType::StateV1 + } +} + +#[allow(clippy::large_enum_variant)] +#[derive(Debug, Clone)] +pub enum MerkleProofType { + AddressProof(light_client::indexer::NewAddressProofWithContext), + StateProof(light_client::indexer::MerkleProof), +} + +// ── EpochManager ───────────────────────────────────────────────────────── + +#[derive(Debug)] +pub struct EpochManager { + pub(crate) ctx: ForesterContext, + pub(crate) epoch_tracker: EpochTracker, + pub(crate) processor_pool: ProcessorPool, + pub(crate) trees: Arc>>, + // External integrations + pub(crate) work_report_sender: mpsc::Sender, + pub(crate) tx_cache: Arc>, + pub(crate) ops_cache: Arc>, + pub(crate) compressible_tracker: Option>, + pub(crate) pda_tracker: Option>, + pub(crate) mint_tracker: Option>, + pub(crate) shutdown_tx: watch::Sender, +} + +impl Clone for EpochManager { + fn clone(&self) -> Self { + Self { + ctx: self.ctx.clone(), + epoch_tracker: self.epoch_tracker.clone(), + processor_pool: self.processor_pool.clone(), + trees: self.trees.clone(), + work_report_sender: self.work_report_sender.clone(), + tx_cache: self.tx_cache.clone(), + ops_cache: self.ops_cache.clone(), + compressible_tracker: self.compressible_tracker.clone(), + pda_tracker: self.pda_tracker.clone(), + mint_tracker: self.mint_tracker.clone(), + shutdown_tx: self.shutdown_tx.clone(), + } + } +} + +impl EpochManager { + #[allow(clippy::too_many_arguments)] + pub async fn new( + config: Arc, + protocol_config: Arc, + rpc_pool: Arc>, + work_report_sender: mpsc::Sender, + trees: Vec, + slot_tracker: Arc, + tx_cache: Arc>, + ops_cache: Arc>, + compressible_tracker: Option>, + pda_tracker: Option>, + mint_tracker: Option>, + address_lookup_tables: Arc>, + heartbeat: Arc, + run_id: String, + ) -> Result { + let authority = Arc::new(config.payer_keypair.insecure_clone()); + let ctx = ForesterContext { + config, + protocol_config, + rpc_pool, + authority, + slot_tracker, + address_lookup_tables, + heartbeat, + run_id: Arc::::from(run_id), + }; + Ok(Self { + ctx, + epoch_tracker: EpochTracker::new(), + processor_pool: ProcessorPool::new(), + trees: Arc::new(Mutex::new(trees)), + work_report_sender, + tx_cache, + ops_cache, + compressible_tracker, + pda_tracker, + mint_tracker, + shutdown_tx: watch::channel(false).0, + }) + } + + pub(crate) fn request_shutdown(&self) { + let _ = self.shutdown_tx.send(true); + } + + pub async fn run(self: Arc) -> Result<()> { + let (tx, mut rx) = mpsc::channel(100); + let tx = Arc::new(tx); + + let mut monitor_handle = tokio::spawn({ + let self_clone = self.clone(); + let tx_clone = tx.clone(); + async move { self_clone.monitor_epochs(tx_clone).await } + }); + + let current_previous_handle = tokio::spawn({ + let self_clone = self.clone(); + let tx_clone = tx.clone(); + async move { + self_clone + .process_current_and_previous_epochs(tx_clone) + .await + } + }); + + let tree_discovery_handle = tokio::spawn({ + let self_clone = self.clone(); + async move { self_clone.discover_trees_periodically().await } + }); + + let balance_check_handle = tokio::spawn({ + let self_clone = self.clone(); + async move { self_clone.check_sol_balance_periodically().await } + }); + + let _guard = scopeguard::guard( + ( + monitor_handle.abort_handle(), + current_previous_handle, + tree_discovery_handle, + balance_check_handle, + ), + |(monitor, h2, h3, h4)| { + info!( + event = "background_tasks_aborting", + run_id = %self.ctx.run_id, + "Aborting EpochManager background tasks" + ); + monitor.abort(); + h2.abort(); + h3.abort(); + h4.abort(); + }, + ); + + let mut shutdown_rx = self.shutdown_tx.subscribe(); + let mut epoch_tasks = tokio::task::JoinSet::new(); + let result = loop { + if *shutdown_rx.borrow_and_update() { + info!( + event = "epoch_manager_shutdown_requested", + run_id = %self.ctx.run_id, + "Stopping EpochManager after shutdown request" + ); + break Ok(()); + } + + tokio::select! { + _ = shutdown_rx.changed() => {} + Some(join_result) = epoch_tasks.join_next() => { + match join_result { + Ok(Ok(())) => debug!( + event = "epoch_processing_completed", + run_id = %self.ctx.run_id, + "Epoch processed successfully" + ), + Ok(Err(e)) => error!( + event = "epoch_processing_failed", + run_id = %self.ctx.run_id, + error = ?e, + "Error processing epoch" + ), + Err(join_error) => { + if join_error.is_panic() { + error!( + event = "epoch_processing_panicked", + run_id = %self.ctx.run_id, + error = %join_error, + "Epoch processing panicked" + ); + } + } + } + } + epoch_opt = rx.recv() => { + match epoch_opt { + Some(epoch) => { + debug!( + event = "epoch_queued_for_processing", + run_id = %self.ctx.run_id, + epoch, + "Received epoch from monitor" + ); + let self_clone = self.clone(); + epoch_tasks.spawn(async move { + self_clone.process_epoch(epoch).await + }); + } + None => { + error!( + event = "epoch_monitor_channel_closed", + run_id = %self.ctx.run_id, + "Epoch monitor channel closed unexpectedly" + ); + break Err(anyhow!( + "Epoch monitor channel closed - forester cannot function without it" + )); + } + } + } + result = &mut monitor_handle => { + match result { + Ok(Ok(())) => { + error!( + event = "epoch_monitor_exited_unexpected_ok", + run_id = %self.ctx.run_id, + "Epoch monitor exited unexpectedly with Ok(())" + ); + } + Ok(Err(e)) => { + error!( + event = "epoch_monitor_exited_with_error", + run_id = %self.ctx.run_id, + error = ?e, + "Epoch monitor exited with error" + ); + } + Err(e) => { + error!( + event = "epoch_monitor_task_failed", + run_id = %self.ctx.run_id, + error = ?e, + "Epoch monitor task panicked or was cancelled" + ); + } + } + if let Some(pagerduty_key) = &self.ctx.config.external_services.pagerduty_routing_key { + let _ = crate::pagerduty::send_pagerduty_alert( + pagerduty_key, + &format!("Forester epoch monitor died unexpectedly on {}", self.ctx.config.payer_keypair.pubkey()), + "critical", + "epoch_monitor_dead", + ).await; + } + break Err(anyhow!("Epoch monitor exited unexpectedly - forester cannot function without it")); + } + } + }; + + result + } +} + +// ── run_service (top-level entry point) ────────────────────────────────── + +pub fn generate_run_id() -> String { + let epoch_ms = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + format!("{}-{}", std::process::id(), epoch_ms) +} + +fn spawn_heartbeat_task( + heartbeat: Arc, + slot_tracker: Arc, + protocol_config: Arc, + run_id: String, +) -> JoinHandle<()> { + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(20)); + interval.set_missed_tick_behavior(MissedTickBehavior::Skip); + let mut previous = heartbeat.snapshot(); + + loop { + interval.tick().await; + + let slot = slot_tracker.estimated_current_slot(); + let epoch = protocol_config.get_current_active_epoch(slot).ok(); + let epoch_known = epoch.is_some(); + let epoch_value = epoch.unwrap_or_default(); + let current = heartbeat.snapshot(); + + let delta_active = current.active_cycles.saturating_sub(previous.active_cycles); + let delta_queues_started = current + .queues_started + .saturating_sub(previous.queues_started); + let delta_queues_finished = current + .queues_finished + .saturating_sub(previous.queues_finished); + let delta_items = current + .items_processed + .saturating_sub(previous.items_processed); + let delta_work_reports = current + .work_reports_sent + .saturating_sub(previous.work_reports_sent); + let delta_v2_recoverable = current + .v2_recoverable_errors + .saturating_sub(previous.v2_recoverable_errors); + let delta_tree_tasks = current + .tree_tasks_spawned + .saturating_sub(previous.tree_tasks_spawned); + + info!( + event = "heartbeat", + run_id = %run_id, + slot, + epoch_known, + epoch = epoch_value, + delta_active_cycles = delta_active, + delta_queues_started, + delta_queues_finished, + delta_items_processed = delta_items, + delta_work_reports = delta_work_reports, + delta_v2_recoverable_errors = delta_v2_recoverable, + delta_tree_tasks_spawned = delta_tree_tasks, + cumulative_active_cycles = current.active_cycles, + cumulative_items_processed = current.items_processed, + cumulative_work_reports = current.work_reports_sent, + cumulative_tree_tasks = current.tree_tasks_spawned, + "Service heartbeat" + ); + + previous = current; + } + }) +} + +#[allow(clippy::too_many_arguments)] +#[tracing::instrument( + level = "debug", + skip( + config, + protocol_config, + rpc_pool, + shutdown, + work_report_sender, + slot_tracker, + tx_cache, + ops_cache, + compressible_tracker, + pda_tracker, + mint_tracker, + run_id + ), + fields(forester = %config.payer_keypair.pubkey()) +)] +pub async fn run_service( + config: Arc, + protocol_config: Arc, + rpc_pool: Arc>, + mut shutdown: oneshot::Receiver<()>, + work_report_sender: mpsc::Sender, + slot_tracker: Arc, + tx_cache: Arc>, + ops_cache: Arc>, + compressible_tracker: Option>, + pda_tracker: Option>, + mint_tracker: Option>, + run_id: String, +) -> Result<()> { + let heartbeat = Arc::new(ServiceHeartbeat::default()); + let heartbeat_handle = spawn_heartbeat_task( + heartbeat.clone(), + slot_tracker.clone(), + protocol_config.clone(), + run_id.clone(), + ); + + let run_id_for_logs = run_id.clone(); + let result = info_span!( + "run_service", + forester = %config.payer_keypair.pubkey() + ) + .in_scope(|| async move { + let processor_mode_str = match ( + config.general_config.skip_v1_state_trees + && config.general_config.skip_v1_address_trees, + config.general_config.skip_v2_state_trees + && config.general_config.skip_v2_address_trees, + ) { + (true, false) => "v2", + (false, true) => "v1", + (false, false) => "all", + _ => "unknown", + }; + info!( + event = "forester_starting", + run_id = %run_id_for_logs, + processor_mode = processor_mode_str, + "Starting forester" + ); + + const INITIAL_RETRY_DELAY: Duration = Duration::from_secs(1); + const MAX_RETRY_DELAY: Duration = Duration::from_secs(30); + + let mut retry_count = 0; + let mut retry_delay = INITIAL_RETRY_DELAY; + let start_time = Instant::now(); + + let trees = { + let max_attempts = 10; + let mut attempts = 0; + let mut delay = Duration::from_secs(2); + + loop { + tokio::select! { + biased; + _ = &mut shutdown => { + info!( + event = "shutdown_received", + run_id = %run_id_for_logs, + phase = "tree_fetch", + "Received shutdown signal during tree fetch. Stopping." + ); + return Ok(()); + } + result = rpc_pool.get_connection() => { + match result { + Ok(rpc) => { + tokio::select! { + biased; + _ = &mut shutdown => { + info!( + event = "shutdown_received", + run_id = %run_id_for_logs, + phase = "tree_fetch", + "Received shutdown signal during tree fetch. Stopping." + ); + return Ok(()); + } + fetch_result = fetch_trees(&*rpc) => { + match fetch_result { + Ok(mut fetched_trees) => { + let group_authority = match config.general_config.group_authority { + Some(ga) => Some(ga), + None => { + match fetch_protocol_group_authority(&*rpc, run_id_for_logs.as_str()).await { + Ok(ga) => { + info!( + event = "group_authority_default_fetched", + run_id = %run_id_for_logs, + group_authority = %ga, + "Using protocol default group authority" + ); + Some(ga) + } + Err(e) => { + tracing::warn!( + event = "group_authority_fetch_failed", + run_id = %run_id_for_logs, + error = ?e, + "Failed to fetch protocol group authority; processing all trees" + ); + None + } + } + } + }; + + if let Some(group_authority) = group_authority { + let before_count = fetched_trees.len(); + fetched_trees.retain(|tree| tree.owner == group_authority); + info!( + event = "trees_filtered_by_group_authority", + run_id = %run_id_for_logs, + group_authority = %group_authority, + trees_before = before_count, + trees_after = fetched_trees.len(), + "Filtered trees by group authority" + ); + } + + if !config.general_config.tree_ids.is_empty() { + let tree_ids = &config.general_config.tree_ids; + fetched_trees.retain(|tree| tree_ids.contains(&tree.merkle_tree)); + if fetched_trees.is_empty() { + error!( + event = "trees_filter_explicit_ids_empty", + run_id = %run_id_for_logs, + requested_tree_count = tree_ids.len(), + requested_trees = ?tree_ids, + "None of the specified trees were found" + ); + return Err(anyhow::anyhow!( + "None of the specified trees found: {:?}", + tree_ids + )); + } + info!( + event = "trees_filter_explicit_ids", + run_id = %run_id_for_logs, + tree_count = tree_ids.len(), + "Processing only explicitly requested trees" + ); + } + break fetched_trees; + } + Err(e) => { + attempts += 1; + if attempts >= max_attempts { + return Err(anyhow::anyhow!( + "Failed to fetch trees after {} attempts: {:?}", + max_attempts, + e + )); + } + tracing::warn!( + event = "fetch_trees_failed_retrying", + run_id = %run_id_for_logs, + attempt = attempts, + max_attempts, + retry_delay_ms = delay.as_millis() as u64, + error = ?e, + "Failed to fetch trees; retrying" + ); + } + } + } + } + } + Err(e) => { + attempts += 1; + if attempts >= max_attempts { + return Err(anyhow::anyhow!( + "Failed to get RPC connection for trees after {} attempts: {:?}", + max_attempts, + e + )); + } + tracing::warn!( + event = "rpc_connection_failed_retrying", + run_id = %run_id_for_logs, + attempt = attempts, + max_attempts, + retry_delay_ms = delay.as_millis() as u64, + error = ?e, + "Failed to get RPC connection; retrying" + ); + } + } + } + } + + tokio::select! { + biased; + _ = &mut shutdown => { + info!( + event = "shutdown_received", + run_id = %run_id_for_logs, + phase = "tree_fetch_retry_wait", + "Received shutdown signal during retry wait. Stopping." + ); + return Ok(()); + } + _ = sleep(delay) => { + delay = std::cmp::min(delay * 2, Duration::from_secs(30)); + } + } + } + }; + tracing::trace!("Fetched initial trees: {:?}", trees); + + if !config.general_config.tree_ids.is_empty() { + info!( + event = "tree_discovery_limited_to_explicit_ids", + run_id = %run_id_for_logs, + tree_count = config.general_config.tree_ids.len(), + "Processing specific trees; tree discovery will be limited" + ); + } + + while retry_count < config.retry_config.max_retries { + debug!("Creating EpochManager (attempt {})", retry_count + 1); + + let address_lookup_tables = { + if let Some(lut_address) = config.lookup_table_address { + let rpc = rpc_pool.get_connection().await?; + match load_lookup_table_async(&*rpc, lut_address).await { + Ok(lut) => { + info!( + event = "lookup_table_loaded", + run_id = %run_id_for_logs, + lookup_table = %lut_address, + address_count = lut.addresses.len(), + "Loaded lookup table" + ); + Arc::new(vec![lut]) + } + Err(e) => { + debug!( + "Lookup table {} not available: {}. Using legacy transactions.", + lut_address, e + ); + Arc::new(Vec::new()) + } + } + } else { + debug!("No lookup table address configured. Using legacy transactions."); + Arc::new(Vec::new()) + } + }; + + match EpochManager::new( + config.clone(), + protocol_config.clone(), + rpc_pool.clone(), + work_report_sender.clone(), + trees.clone(), + slot_tracker.clone(), + tx_cache.clone(), + ops_cache.clone(), + compressible_tracker.clone(), + pda_tracker.clone(), + mint_tracker.clone(), + address_lookup_tables, + heartbeat.clone(), + run_id.clone(), + ) + .await + { + Ok(epoch_manager) => { + let epoch_manager = Arc::new(epoch_manager); + debug!( + "Successfully created EpochManager after {} attempts", + retry_count + 1 + ); + + let run_future = epoch_manager.clone().run(); + tokio::pin!(run_future); + + let result = tokio::select! { + result = &mut run_future => result, + _ = &mut shutdown => { + info!( + event = "shutdown_received", + run_id = %run_id_for_logs, + phase = "service_run", + "Received shutdown signal. Stopping the service." + ); + epoch_manager.request_shutdown(); + run_future.await + } + }; + + return result; + } + Err(e) => { + tracing::warn!( + event = "epoch_manager_create_failed", + run_id = %run_id_for_logs, + attempt = retry_count + 1, + error = ?e, + "Failed to create EpochManager" + ); + retry_count += 1; + if retry_count < config.retry_config.max_retries { + debug!("Retrying in {:?}", retry_delay); + sleep(retry_delay).await; + retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY); + } else { + error!( + event = "forester_start_failed_max_retries", + run_id = %run_id_for_logs, + attempts = config.retry_config.max_retries, + elapsed_ms = start_time.elapsed().as_millis() as u64, + error = ?e, + "Failed to start forester after max retries" + ); + return Err(InitializationError::MaxRetriesExceeded { + attempts: config.retry_config.max_retries, + error: e.to_string(), + } + .into()); + } + } + } + } + + Err( + InitializationError::Unexpected("Retry loop exited without returning".to_string()) + .into(), + ) + }) + .await; + + heartbeat_handle.abort(); + result +} + +/// Async version of load_lookup_table that works with the Rpc trait +async fn load_lookup_table_async( + rpc: &R, + lookup_table_address: solana_program::pubkey::Pubkey, +) -> anyhow::Result { + use light_client::rpc::lut::AddressLookupTable; + + let account = rpc + .get_account(lookup_table_address) + .await? + .ok_or_else(|| { + anyhow::anyhow!("Lookup table account not found: {}", lookup_table_address) + })?; + + let address_lookup_table = AddressLookupTable::deserialize(&account.data) + .map_err(|e| anyhow::anyhow!("Failed to deserialize AddressLookupTable: {:?}", e))?; + + Ok(AddressLookupTableAccount { + key: lookup_table_address, + addresses: address_lookup_table.addresses.to_vec(), + }) +} + +#[cfg(test)] +mod tests { + use light_client::rpc::RetryConfig; + use solana_sdk::{pubkey::Pubkey, signature::Keypair}; + + use super::{context::should_skip_tree, *}; + use crate::{ + config::{ExternalServicesConfig, GeneralConfig}, + ForesterConfig, + }; + + fn create_test_config_with_skip_flags( + skip_v1_state: bool, + skip_v1_address: bool, + skip_v2_state: bool, + skip_v2_address: bool, + ) -> ForesterConfig { + ForesterConfig { + external_services: ExternalServicesConfig { + rpc_url: "http://localhost:8899".to_string(), + ws_rpc_url: None, + indexer_url: None, + prover_url: None, + prover_append_url: None, + prover_update_url: None, + prover_address_append_url: None, + prover_api_key: None, + photon_grpc_url: None, + pushgateway_url: None, + pagerduty_routing_key: None, + rpc_rate_limit: None, + photon_rate_limit: None, + send_tx_rate_limit: None, + prover_polling_interval: None, + prover_max_wait_time: None, + fallback_rpc_url: None, + fallback_indexer_url: None, + }, + retry_config: RetryConfig::default(), + queue_config: Default::default(), + indexer_config: Default::default(), + transaction_config: Default::default(), + general_config: GeneralConfig { + enable_metrics: false, + skip_v1_state_trees: skip_v1_state, + skip_v1_address_trees: skip_v1_address, + skip_v2_state_trees: skip_v2_state, + skip_v2_address_trees: skip_v2_address, + sleep_after_processing_ms: 50, + sleep_when_idle_ms: 100, + ..Default::default() + }, + rpc_pool_config: Default::default(), + registry_pubkey: Pubkey::default(), + payer_keypair: Keypair::new(), + derivation_pubkey: Pubkey::default(), + address_tree_data: vec![], + state_tree_data: vec![], + compressible_config: None, + lookup_table_address: None, + min_queue_items: None, + enable_v1_multi_nullify: false, + work_item_batch_size: 50, + } + } + + #[test] + fn test_should_skip_tree_none_skipped() { + let config = create_test_config_with_skip_flags(false, false, false, false); + assert!(!should_skip_tree(&config, &TreeType::StateV1)); + assert!(!should_skip_tree(&config, &TreeType::StateV2)); + assert!(!should_skip_tree(&config, &TreeType::AddressV1)); + assert!(!should_skip_tree(&config, &TreeType::AddressV2)); + } + + #[test] + fn test_should_skip_tree_all_v1_skipped() { + let config = create_test_config_with_skip_flags(true, true, false, false); + assert!(should_skip_tree(&config, &TreeType::StateV1)); + assert!(should_skip_tree(&config, &TreeType::AddressV1)); + assert!(!should_skip_tree(&config, &TreeType::StateV2)); + assert!(!should_skip_tree(&config, &TreeType::AddressV2)); + } + + #[test] + fn test_should_skip_tree_all_v2_skipped() { + let config = create_test_config_with_skip_flags(false, false, true, true); + assert!(!should_skip_tree(&config, &TreeType::StateV1)); + assert!(!should_skip_tree(&config, &TreeType::AddressV1)); + assert!(should_skip_tree(&config, &TreeType::StateV2)); + assert!(should_skip_tree(&config, &TreeType::AddressV2)); + } + + #[test] + fn test_should_skip_tree_only_state_trees() { + let config = create_test_config_with_skip_flags(true, false, true, false); + assert!(should_skip_tree(&config, &TreeType::StateV1)); + assert!(should_skip_tree(&config, &TreeType::StateV2)); + assert!(!should_skip_tree(&config, &TreeType::AddressV1)); + assert!(!should_skip_tree(&config, &TreeType::AddressV2)); + } + + #[test] + fn test_should_skip_tree_only_address_trees() { + let config = create_test_config_with_skip_flags(false, true, false, true); + assert!(!should_skip_tree(&config, &TreeType::StateV1)); + assert!(!should_skip_tree(&config, &TreeType::StateV2)); + assert!(should_skip_tree(&config, &TreeType::AddressV1)); + assert!(should_skip_tree(&config, &TreeType::AddressV2)); + } + + #[test] + fn test_should_skip_tree_mixed_config() { + let config = create_test_config_with_skip_flags(true, false, false, true); + assert!(should_skip_tree(&config, &TreeType::StateV1)); + assert!(!should_skip_tree(&config, &TreeType::StateV2)); + assert!(!should_skip_tree(&config, &TreeType::AddressV1)); + assert!(should_skip_tree(&config, &TreeType::AddressV2)); + } + + #[test] + fn test_general_config_test_address_v2() { + let config = GeneralConfig::test_address_v2(); + assert!(config.skip_v1_state_trees); + assert!(config.skip_v1_address_trees); + assert!(config.skip_v2_state_trees); + assert!(!config.skip_v2_address_trees); + } + + #[test] + fn test_general_config_test_state_v2() { + let config = GeneralConfig::test_state_v2(); + assert!(config.skip_v1_state_trees); + assert!(config.skip_v1_address_trees); + assert!(!config.skip_v2_state_trees); + assert!(config.skip_v2_address_trees); + } + + #[test] + fn test_work_item_is_address_tree() { + let tree_account = TreeAccounts { + merkle_tree: Pubkey::new_unique(), + queue: Pubkey::new_unique(), + is_rolledover: false, + tree_type: TreeType::AddressV1, + owner: Default::default(), + }; + let work_item = WorkItem { + tree_account, + queue_item_data: QueueItemData { + hash: [0u8; 32], + index: 0, + leaf_index: None, + }, + }; + assert!(work_item.is_address_tree()); + assert!(!work_item.is_state_tree()); + } + + #[test] + fn test_work_item_is_state_tree() { + let tree_account = TreeAccounts { + merkle_tree: Pubkey::new_unique(), + queue: Pubkey::new_unique(), + is_rolledover: false, + tree_type: TreeType::StateV1, + owner: Default::default(), + }; + let work_item = WorkItem { + tree_account, + queue_item_data: QueueItemData { + hash: [0u8; 32], + index: 0, + leaf_index: None, + }, + }; + assert!(!work_item.is_address_tree()); + assert!(work_item.is_state_tree()); + } + + #[test] + fn test_work_report_creation() { + let report = WorkReport { + epoch: 42, + processed_items: 100, + metrics: ProcessingMetrics { + append: CircuitMetrics { + circuit_inputs_duration: std::time::Duration::from_secs(1), + proof_generation_duration: std::time::Duration::from_secs(3), + round_trip_duration: std::time::Duration::from_secs(10), + }, + nullify: CircuitMetrics { + circuit_inputs_duration: std::time::Duration::from_secs(1), + proof_generation_duration: std::time::Duration::from_secs(2), + round_trip_duration: std::time::Duration::from_secs(8), + }, + address_append: CircuitMetrics { + circuit_inputs_duration: std::time::Duration::from_secs(1), + proof_generation_duration: std::time::Duration::from_secs(2), + round_trip_duration: std::time::Duration::from_secs(9), + }, + tx_sending_duration: std::time::Duration::ZERO, + }, + }; + assert_eq!(report.epoch, 42); + assert_eq!(report.processed_items, 100); + assert_eq!(report.metrics.total().as_secs(), 10); + assert_eq!(report.metrics.total_circuit_inputs().as_secs(), 3); + assert_eq!(report.metrics.total_proof_generation().as_secs(), 7); + assert_eq!(report.metrics.total_round_trip().as_secs(), 27); + } + + #[tokio::test] + async fn watch_shutdown_observed_after_request() { + let (tx, _initial_rx) = watch::channel(false); + tx.send(true).expect("send"); + let mut rx = tx.subscribe(); + assert!(*rx.borrow_and_update()); + } +} diff --git a/forester/src/epoch_manager/monitor.rs b/forester/src/epoch_manager/monitor.rs new file mode 100644 index 0000000000..417b536679 --- /dev/null +++ b/forester/src/epoch_manager/monitor.rs @@ -0,0 +1,508 @@ +//! Background monitoring tasks: epoch detection, tree discovery, balance checks. + +use std::{sync::Arc, time::Duration}; + +use anyhow::anyhow; +use forester_utils::forester_epoch::{get_epoch_phases, TreeAccounts, TreeForesterSchedule}; +use light_client::{indexer::Indexer, rpc::Rpc}; +use solana_program::{native_token::LAMPORTS_PER_SOL, pubkey::Pubkey}; +use solana_sdk::signature::Signer; +use tokio::sync::mpsc; +use tracing::{debug, error, info, warn}; + +use super::{context::should_skip_tree, EpochManager}; +use crate::{ + metrics::update_forester_sol_balance, + slot_tracker::wait_until_slot_reached, + tree_data_sync::{fetch_protocol_group_authority, fetch_trees}, +}; + +impl EpochManager { + pub(super) async fn check_sol_balance_periodically(self: Arc) -> crate::Result<()> { + let interval_duration = Duration::from_secs(300); + let mut interval = tokio::time::interval(interval_duration); + + loop { + interval.tick().await; + match self.ctx.rpc_pool.get_connection().await { + Ok(rpc) => match rpc + .get_balance(&self.ctx.config.payer_keypair.pubkey()) + .await + { + Ok(balance) => { + let balance_in_sol = balance as f64 / (LAMPORTS_PER_SOL as f64); + update_forester_sol_balance( + &self.ctx.config.payer_keypair.pubkey().to_string(), + balance_in_sol, + ); + debug!( + event = "forester_balance_updated", + run_id = %self.ctx.run_id, + balance_sol = balance_in_sol, + "Current SOL balance updated" + ); + } + Err(e) => error!( + event = "forester_balance_fetch_failed", + run_id = %self.ctx.run_id, + error = ?e, + "Failed to get balance" + ), + }, + Err(e) => error!( + event = "forester_balance_rpc_connection_failed", + run_id = %self.ctx.run_id, + error = ?e, + "Failed to get RPC connection for balance check" + ), + } + } + } + + pub(super) async fn discover_trees_periodically(self: Arc) -> crate::Result<()> { + let interval_secs = self + .ctx + .config + .general_config + .tree_discovery_interval_seconds; + if interval_secs == 0 { + info!(event = "tree_discovery_disabled", run_id = %self.ctx.run_id, "Tree discovery disabled (interval=0)"); + return Ok(()); + } + let mut interval = tokio::time::interval(Duration::from_secs(interval_secs)); + interval.tick().await; + + info!( + event = "tree_discovery_started", + run_id = %self.ctx.run_id, + interval_secs, + "Starting periodic tree discovery" + ); + + let mut group_authority: Option = self.ctx.config.general_config.group_authority; + + loop { + interval.tick().await; + + let rpc = match self.ctx.rpc_pool.get_connection().await { + Ok(rpc) => rpc, + Err(e) => { + warn!(event = "tree_discovery_rpc_failed", run_id = %self.ctx.run_id, error = ?e, "Tree discovery: failed to get RPC connection"); + continue; + } + }; + + if group_authority.is_none() { + if let Ok(ga) = fetch_protocol_group_authority(&*rpc, &self.ctx.run_id).await { + group_authority = Some(ga); + let mut trees = self.trees.lock().await; + let before = trees.len(); + trees.retain(|t| t.owner == ga); + if !self.ctx.config.general_config.tree_ids.is_empty() { + let tree_ids = &self.ctx.config.general_config.tree_ids; + trees.retain(|t| tree_ids.contains(&t.merkle_tree)); + } + if trees.len() < before { + info!( + event = "tree_discovery_retroactive_filter", + run_id = %self.ctx.run_id, + group_authority = %ga, + trees_before = before, + trees_after = trees.len(), + "Filtered existing trees after resolving group authority" + ); + } + } + } + + let mut fetched_trees = match fetch_trees(&*rpc).await { + Ok(trees) => trees, + Err(e) => { + warn!(event = "tree_discovery_fetch_failed", run_id = %self.ctx.run_id, error = ?e, "Tree discovery: failed to fetch trees"); + continue; + } + }; + + if let Some(ga) = group_authority { + fetched_trees.retain(|tree| tree.owner == ga); + } + if !self.ctx.config.general_config.tree_ids.is_empty() { + let tree_ids = &self.ctx.config.general_config.tree_ids; + fetched_trees.retain(|tree| tree_ids.contains(&tree.merkle_tree)); + } + + let known_trees = self.trees.lock().await; + let known_pubkeys: std::collections::HashSet = + known_trees.iter().map(|t| t.merkle_tree).collect(); + drop(known_trees); + + for tree in fetched_trees { + if known_pubkeys.contains(&tree.merkle_tree) { + continue; + } + if should_skip_tree(&self.ctx.config, &tree.tree_type) { + debug!( + event = "tree_discovery_skipped", + run_id = %self.ctx.run_id, + tree = %tree.merkle_tree, + tree_type = ?tree.tree_type, + "Skipping tree due to fee filter config" + ); + continue; + } + info!( + event = "tree_discovery_new_tree", + run_id = %self.ctx.run_id, + tree = %tree.merkle_tree, + tree_type = ?tree.tree_type, + queue = %tree.queue, + "Discovered new tree" + ); + if let Err(e) = self.add_new_tree(tree).await { + error!( + event = "tree_discovery_add_failed", + run_id = %self.ctx.run_id, + error = ?e, + "Failed to add discovered tree" + ); + } + } + } + } + + async fn add_new_tree(self: &Arc, new_tree: TreeAccounts) -> crate::Result<()> { + info!( + event = "new_tree_add_started", + run_id = %self.ctx.run_id, + tree = %new_tree.merkle_tree, + tree_type = ?new_tree.tree_type, + "Adding new tree" + ); + let mut trees = self.trees.lock().await; + trees.push(new_tree); + drop(trees); + + info!( + event = "new_tree_added", + run_id = %self.ctx.run_id, + tree = %new_tree.merkle_tree, + "New tree added to tracked list" + ); + + let (current_slot, current_epoch) = self.ctx.get_current_slot_and_epoch().await?; + let phases = get_epoch_phases(&self.ctx.protocol_config, current_epoch); + + if current_slot >= phases.active.start && current_slot < phases.active.end { + info!( + event = "new_tree_active_phase_injection", + run_id = %self.ctx.run_id, + tree = %new_tree.merkle_tree, + current_slot, + active_phase_start_slot = phases.active.start, + active_phase_end_slot = phases.active.end, + "In active phase; attempting immediate processing for new tree" + ); + info!( + event = "new_tree_recover_registration_started", + run_id = %self.ctx.run_id, + tree = %new_tree.merkle_tree, + epoch = current_epoch, + "Recovering registration info for new tree" + ); + match self + .recover_registration_info_if_exists(current_epoch) + .await + { + Ok(Some(mut epoch_info)) => { + info!( + event = "new_tree_recover_registration_succeeded", + run_id = %self.ctx.run_id, + tree = %new_tree.merkle_tree, + epoch = current_epoch, + "Recovered registration info for current epoch" + ); + let tree_schedule = TreeForesterSchedule::new_with_schedule( + &new_tree, + current_slot, + &epoch_info.forester_epoch_pda, + &epoch_info.epoch_pda, + )?; + epoch_info.trees.push(tree_schedule.clone()); + + let self_clone = self.clone(); + let tracker = self.epoch_tracker.get_or_create_tracker( + current_epoch, + epoch_info.epoch_pda.registered_weight, + ); + + info!( + event = "new_tree_processing_task_spawned", + run_id = %self.ctx.run_id, + tree = %new_tree.merkle_tree, + epoch = current_epoch, + "Spawning task to process new tree in current epoch" + ); + tokio::spawn(async move { + let tree_pubkey = tree_schedule.tree_accounts.merkle_tree; + if let Err(e) = self_clone + .process_queue( + &epoch_info.epoch, + epoch_info.forester_epoch_pda.clone(), + tree_schedule, + tracker, + ) + .await + { + error!( + event = "new_tree_process_queue_failed", + run_id = %self_clone.ctx.run_id, + tree = %tree_pubkey, + error = ?e, + "Error processing queue for new tree" + ); + } else { + info!( + event = "new_tree_process_queue_succeeded", + run_id = %self_clone.ctx.run_id, + tree = %tree_pubkey, + "Successfully processed new tree in current epoch" + ); + } + }); + } + Ok(None) => { + debug!( + "Not registered for current epoch yet, new tree will be picked up during next registration" + ); + } + Err(e) => { + warn!( + event = "new_tree_recover_registration_failed", + run_id = %self.ctx.run_id, + tree = %new_tree.merkle_tree, + epoch = current_epoch, + error = ?e, + "Failed to recover registration info for new tree" + ); + } + } + + info!( + event = "new_tree_injected_into_current_epoch", + run_id = %self.ctx.run_id, + tree = %new_tree.merkle_tree, + epoch = current_epoch, + "Injected new tree into current epoch" + ); + } else { + info!( + event = "new_tree_queued_for_next_registration", + run_id = %self.ctx.run_id, + tree = %new_tree.merkle_tree, + current_slot, + active_phase_start_slot = phases.active.start, + "Not in active phase; new tree will be picked up in next registration" + ); + } + + Ok(()) + } + + pub(super) async fn monitor_epochs(&self, tx: Arc>) -> crate::Result<()> { + let mut last_epoch: Option = None; + let mut consecutive_failures = 0u32; + const MAX_BACKOFF_SECS: u64 = 60; + + info!( + event = "epoch_monitor_started", + run_id = %self.ctx.run_id, + "Starting epoch monitor" + ); + + loop { + let (slot, current_epoch) = match self.ctx.get_current_slot_and_epoch().await { + Ok(result) => { + if consecutive_failures > 0 { + info!( + event = "epoch_monitor_recovered", + run_id = %self.ctx.run_id, + consecutive_failures, "Epoch monitor recovered after failures" + ); + } + consecutive_failures = 0; + result + } + Err(e) => { + consecutive_failures += 1; + let backoff_secs = 2u64.pow(consecutive_failures.min(6)).min(MAX_BACKOFF_SECS); + let backoff = Duration::from_secs(backoff_secs); + + if consecutive_failures == 1 { + warn!( + event = "epoch_monitor_slot_epoch_failed", + run_id = %self.ctx.run_id, + consecutive_failures, + error = ?e, + backoff_ms = backoff.as_millis() as u64, + "Epoch monitor failed to get slot/epoch; retrying" + ); + } else if consecutive_failures.is_multiple_of(10) { + error!( + event = "epoch_monitor_slot_epoch_failed_repeated", + run_id = %self.ctx.run_id, + consecutive_failures, + error = ?e, + backoff_ms = backoff.as_millis() as u64, + "Epoch monitor still failing repeatedly" + ); + } + + tokio::time::sleep(backoff).await; + continue; + } + }; + + debug!( + event = "epoch_monitor_tick", + run_id = %self.ctx.run_id, + last_epoch = ?last_epoch, + current_epoch, + slot, + "Epoch monitor tick" + ); + + if last_epoch.is_none_or(|last| current_epoch > last) { + debug!( + event = "epoch_monitor_new_epoch_detected", + run_id = %self.ctx.run_id, + epoch = current_epoch, + "New epoch detected; sending for processing" + ); + if let Err(e) = tx.send(current_epoch).await { + error!( + event = "epoch_monitor_send_current_epoch_failed", + run_id = %self.ctx.run_id, + epoch = current_epoch, + error = ?e, + "Failed to send current epoch for processing; channel closed" + ); + return Err(anyhow!("Epoch channel closed: {}", e)); + } + last_epoch = Some(current_epoch); + } + + let target_epoch = current_epoch + 1; + if last_epoch.is_none_or(|last| target_epoch > last) { + let target_phases = get_epoch_phases(&self.ctx.protocol_config, target_epoch); + + if slot < target_phases.registration.start { + let mut rpc = match self.ctx.rpc_pool.get_connection().await { + Ok(rpc) => rpc, + Err(e) => { + warn!( + event = "epoch_monitor_wait_rpc_connection_failed", + run_id = %self.ctx.run_id, + target_epoch, + error = ?e, + "Failed to get RPC connection while waiting for registration slot" + ); + tokio::time::sleep(Duration::from_secs(1)).await; + continue; + } + }; + + const REGISTRATION_BUFFER_SLOTS: u64 = 30; + let wait_target = target_phases + .registration + .start + .saturating_sub(REGISTRATION_BUFFER_SLOTS); + let slots_to_wait = wait_target.saturating_sub(slot); + + debug!( + event = "epoch_monitor_wait_for_registration", + run_id = %self.ctx.run_id, + target_epoch, + current_slot = slot, + wait_target_slot = wait_target, + registration_start_slot = target_phases.registration.start, + slots_to_wait, + "Waiting for target epoch registration phase" + ); + + if let Err(e) = + wait_until_slot_reached(&mut *rpc, &self.ctx.slot_tracker, wait_target) + .await + { + error!( + event = "epoch_monitor_wait_for_registration_failed", + run_id = %self.ctx.run_id, + target_epoch, + error = ?e, + "Error waiting for registration phase" + ); + continue; + } + } + + debug!( + event = "epoch_monitor_send_target_epoch", + run_id = %self.ctx.run_id, + target_epoch, + "Sending target epoch for processing" + ); + if let Err(e) = tx.send(target_epoch).await { + error!( + event = "epoch_monitor_send_target_epoch_failed", + run_id = %self.ctx.run_id, + target_epoch, + error = ?e, + "Failed to send target epoch for processing; channel closed" + ); + return Err(anyhow!("Epoch channel closed: {}", e)); + } + last_epoch = Some(target_epoch); + continue; + } else { + tokio::time::sleep(Duration::from_secs(10)).await; + } + } + } + + pub(super) async fn process_current_and_previous_epochs( + &self, + tx: Arc>, + ) -> crate::Result<()> { + let (slot, current_epoch) = self.ctx.get_current_slot_and_epoch().await?; + let current_phases = get_epoch_phases(&self.ctx.protocol_config, current_epoch); + let previous_epoch = current_epoch.saturating_sub(1); + + if slot > current_phases.registration.start { + debug!("Processing previous epoch: {}", previous_epoch); + if let Err(e) = tx.send(previous_epoch).await { + error!( + event = "initial_epoch_send_previous_failed", + run_id = %self.ctx.run_id, + epoch = previous_epoch, + error = ?e, + "Failed to send previous epoch for processing" + ); + return Ok(()); + } + } + + debug!("Processing current epoch: {}", current_epoch); + if let Err(e) = tx.send(current_epoch).await { + error!( + event = "initial_epoch_send_current_failed", + run_id = %self.ctx.run_id, + epoch = current_epoch, + error = ?e, + "Failed to send current epoch for processing" + ); + return Ok(()); + } + + debug!("Finished processing current and previous epochs"); + Ok(()) + } +} diff --git a/forester/src/epoch_manager/pipeline.rs b/forester/src/epoch_manager/pipeline.rs new file mode 100644 index 0000000000..9b707d2e20 --- /dev/null +++ b/forester/src/epoch_manager/pipeline.rs @@ -0,0 +1,832 @@ +//! Epoch processing pipeline: process_epoch → perform_active_work → process_queue → slot processing. + +use std::{sync::Arc, time::Duration}; + +use forester_utils::forester_epoch::{ + get_epoch_phases, Epoch, ForesterSlot, TreeAccounts, TreeForesterSchedule, +}; +use light_client::{indexer::Indexer, rpc::Rpc}; +use light_compressed_account::TreeType; +use light_registry::{protocol_config::state::EpochState, ForesterEpochPda}; +use solana_sdk::signature::Signer; +use tokio::time::Instant; +use tracing::{debug, error, info, instrument, trace, warn}; + +use super::{context::should_skip_tree, tracker::RegistrationTracker, EpochManager}; +use crate::{ + errors::ForesterError, + logging::should_emit_rate_limited_warning, + metrics::{push_metrics, queue_metric_update, update_epoch_detected, update_epoch_registered}, + slot_tracker::wait_until_slot_reached, + ForesterEpochInfo, +}; + +impl EpochManager { + #[instrument(level = "debug", skip(self), fields(forester = %self.ctx.config.payer_keypair.pubkey(), epoch = epoch))] + pub(super) async fn process_epoch(self: Arc, epoch: u64) -> crate::Result<()> { + let _epoch_guard = match self.epoch_tracker.try_claim_epoch(epoch) { + Some(guard) => guard, + None => { + debug!("Epoch {} is already being processed, skipping", epoch); + return Ok(()); + } + }; + + let phases = get_epoch_phases(&self.ctx.protocol_config, epoch); + update_epoch_detected(epoch); + + debug!("Recovering registration info for epoch {}", epoch); + let mut registration_info = match self.recover_registration_info_if_exists(epoch).await { + Ok(Some(info)) => info, + Ok(None) => { + debug!( + "No existing registration found for epoch {}, will register fresh", + epoch + ); + match self + .register_for_epoch_with_retry(epoch, 100, Duration::from_millis(1000)) + .await + { + Ok(info) => info, + Err(e) => return Err(e.into()), + } + } + Err(e) => { + warn!( + event = "recover_registration_info_failed", + run_id = %self.ctx.run_id, + epoch, + error = ?e, + "Failed to recover registration info" + ); + return Err(e.into()); + } + }; + debug!("Recovered registration info for epoch {}", epoch); + update_epoch_registered(epoch); + + registration_info = match self.wait_for_active_phase(®istration_info).await? { + Some(info) => info, + None => { + let current_slot = self.ctx.slot_tracker.estimated_current_slot(); + debug!( + event = "epoch_processing_skipped_finalize_registration_phase_ended", + run_id = %self.ctx.run_id, + epoch, + current_slot, + active_phase_end_slot = registration_info.epoch.phases.active.end, + "Skipping epoch processing because FinalizeRegistration is no longer possible" + ); + return Ok(()); + } + }; + + if self.ctx.sync_slot().await? < phases.active.end { + self.clone().perform_active_work(®istration_info).await?; + } + if self.ctx.sync_slot().await? < phases.report_work.start { + self.wait_for_report_work_phase(®istration_info).await?; + } + + self.send_work_report(®istration_info).await?; + + if self.ctx.sync_slot().await? < phases.report_work.end { + self.report_work_onchain(®istration_info).await?; + } else { + let current_slot = self.ctx.slot_tracker.estimated_current_slot(); + info!( + event = "skip_onchain_work_report_phase_ended", + run_id = %self.ctx.run_id, + epoch = registration_info.epoch.epoch, + current_slot, + report_work_end_slot = phases.report_work.end, + "Skipping on-chain work report because report_work phase has ended" + ); + } + + self.epoch_tracker.cleanup(epoch).await; + + info!( + event = "process_epoch_completed", + run_id = %self.ctx.run_id, + epoch, "Exiting process_epoch" + ); + Ok(()) + } + + #[instrument( + level = "debug", + skip(self, epoch_info), + fields(forester = %self.ctx.config.payer_keypair.pubkey(), epoch = epoch_info.epoch.epoch) + )] + pub(super) async fn perform_active_work( + self: Arc, + epoch_info: &ForesterEpochInfo, + ) -> crate::Result<()> { + self.ctx.heartbeat.increment_active_cycle(); + + let current_slot = self.ctx.slot_tracker.estimated_current_slot(); + let active_phase_end = epoch_info.epoch.phases.active.end; + + if !self.is_in_active_phase(current_slot, epoch_info)? { + info!( + event = "active_work_skipped_not_in_phase", + run_id = %self.ctx.run_id, + current_slot, + active_phase_end, + "No longer in active phase. Skipping work." + ); + return Ok(()); + } + + self.ctx.sync_slot().await?; + + let trees_to_process: Vec<_> = epoch_info + .trees + .iter() + .filter(|tree| !should_skip_tree(&self.ctx.config, &tree.tree_accounts.tree_type)) + .cloned() + .collect(); + + if trees_to_process.is_empty() { + debug!( + event = "active_work_cycle_no_trees", + run_id = %self.ctx.run_id, + "No trees to process this cycle" + ); + let mut rpc = self.ctx.rpc_pool.get_connection().await?; + wait_until_slot_reached(&mut *rpc, &self.ctx.slot_tracker, active_phase_end).await?; + return Ok(()); + } + + info!( + event = "active_work_cycle_started", + run_id = %self.ctx.run_id, + current_slot, + active_phase_end, + tree_count = trees_to_process.len(), + "Starting active work cycle" + ); + + let registration_tracker = self.epoch_tracker.get_or_create_tracker( + epoch_info.epoch.epoch, + epoch_info.epoch_pda.registered_weight, + ); + + let mut tree_tasks = tokio::task::JoinSet::new(); + + for tree in trees_to_process { + debug!( + event = "tree_processing_task_spawned", + run_id = %self.ctx.run_id, + tree = %tree.tree_accounts.merkle_tree, + tree_type = ?tree.tree_accounts.tree_type, + "Spawning tree processing task" + ); + self.ctx.heartbeat.add_tree_tasks_spawned(1); + + let self_clone = self.clone(); + let epoch_clone = epoch_info.epoch.clone(); + let forester_epoch_pda = epoch_info.forester_epoch_pda.clone(); + let tracker = registration_tracker.clone(); + tree_tasks.spawn(async move { + self_clone + .process_queue(&epoch_clone, forester_epoch_pda, tree, tracker) + .await + }); + } + + let mut success_count = 0usize; + let mut error_count = 0usize; + let mut panic_count = 0usize; + while let Some(join_result) = tree_tasks.join_next().await { + match join_result { + Ok(Ok(())) => success_count += 1, + Ok(Err(e)) => { + error_count += 1; + error!( + event = "tree_processing_task_failed", + run_id = %self.ctx.run_id, + error = ?e, + "Error processing queue" + ); + } + Err(join_error) => { + if join_error.is_panic() { + panic_count += 1; + error!( + event = "tree_processing_task_panicked", + run_id = %self.ctx.run_id, + error = %join_error, + "Tree processing task panicked" + ); + } + } + } + } + info!( + event = "active_work_cycle_completed", + run_id = %self.ctx.run_id, + tree_tasks = success_count + error_count + panic_count, + succeeded = success_count, + failed = error_count, + panicked = panic_count, + "Active work cycle completed" + ); + + debug!("Waiting for active phase to end"); + let mut rpc = self.ctx.rpc_pool.get_connection().await?; + wait_until_slot_reached(&mut *rpc, &self.ctx.slot_tracker, active_phase_end).await?; + Ok(()) + } + + #[instrument( + level = "debug", + skip(self, epoch_info, forester_epoch_pda, tree_schedule, registration_tracker), + fields(forester = %self.ctx.config.payer_keypair.pubkey(), epoch = epoch_info.epoch, + tree = %tree_schedule.tree_accounts.merkle_tree) + )] + pub(crate) async fn process_queue( + &self, + epoch_info: &Epoch, + mut forester_epoch_pda: ForesterEpochPda, + mut tree_schedule: TreeForesterSchedule, + registration_tracker: Arc, + ) -> crate::Result<()> { + self.ctx.heartbeat.increment_queue_started(); + let mut current_slot = self.ctx.slot_tracker.estimated_current_slot(); + + let total_slots = tree_schedule.slots.len(); + let eligible_slots = tree_schedule.slots.iter().filter(|s| s.is_some()).count(); + let tree_type = tree_schedule.tree_accounts.tree_type; + + debug!( + event = "process_queue_started", + run_id = %self.ctx.run_id, + tree = %tree_schedule.tree_accounts.merkle_tree, + tree_type = ?tree_type, + total_slots, + eligible_slots, + current_slot, + active_phase_end = epoch_info.phases.active.end, + "Processing queue for tree" + ); + + let mut last_weight_check = Instant::now(); + const WEIGHT_CHECK_INTERVAL: Duration = Duration::from_secs(30); + + 'outer_slot_loop: while current_slot < epoch_info.phases.active.end { + let next_slot_to_process = tree_schedule + .slots + .iter_mut() + .enumerate() + .find_map(|(idx, opt_slot)| opt_slot.as_ref().map(|s| (idx, s.clone()))); + + if let Some((slot_idx, light_slot_details)) = next_slot_to_process { + let result = match tree_type { + TreeType::StateV1 | TreeType::AddressV1 | TreeType::Unknown => { + self.process_light_slot( + epoch_info, + &forester_epoch_pda, + &tree_schedule.tree_accounts, + &light_slot_details, + ) + .await + } + TreeType::StateV2 | TreeType::AddressV2 => { + let consecutive_end = tree_schedule + .get_consecutive_eligibility_end(slot_idx) + .unwrap_or(light_slot_details.end_solana_slot); + self.process_light_slot_v2( + epoch_info, + &forester_epoch_pda, + &tree_schedule.tree_accounts, + &light_slot_details, + consecutive_end, + ) + .await + } + }; + + let mut force_refinalize = false; + match result { + Ok(_) => { + trace!( + "Successfully processed light slot {:?}", + light_slot_details.slot + ); + } + Err(e) => { + force_refinalize = e.is_forester_not_eligible(); + if force_refinalize { + warn!( + event = "light_slot_processing_stale_eligibility", + run_id = %self.ctx.run_id, + tree = %tree_schedule.tree_accounts.merkle_tree, + light_slot = light_slot_details.slot, + "Detected ForesterNotEligible; forcing immediate re-finalization" + ); + } + error!( + event = "light_slot_processing_error", + run_id = %self.ctx.run_id, + light_slot = light_slot_details.slot, + error = ?e, + "Error processing light slot" + ); + } + } + tree_schedule.slots[slot_idx] = None; + + if force_refinalize || last_weight_check.elapsed() >= WEIGHT_CHECK_INTERVAL { + last_weight_check = Instant::now(); + if let Err(e) = self + .maybe_refinalize( + epoch_info, + &mut forester_epoch_pda, + &mut tree_schedule, + ®istration_tracker, + force_refinalize, + ) + .await + { + warn!( + event = "refinalize_check_failed", + run_id = %self.ctx.run_id, + forced = force_refinalize, + error = ?e, + "Failed to check/perform re-finalization" + ); + } + } + } else { + debug!( + event = "process_queue_no_eligible_slots", + run_id = %self.ctx.run_id, + tree = %tree_schedule.tree_accounts.merkle_tree, + "No further eligible slots in schedule" + ); + break 'outer_slot_loop; + } + + current_slot = self.ctx.slot_tracker.estimated_current_slot(); + } + + self.ctx.heartbeat.increment_queue_finished(); + debug!( + event = "process_queue_finished", + run_id = %self.ctx.run_id, + tree = %tree_schedule.tree_accounts.merkle_tree, + "Exiting process_queue" + ); + Ok(()) + } + + #[instrument( + level = "debug", + skip(self, epoch_info, epoch_pda, tree_accounts, forester_slot_details), + fields(forester = %self.ctx.config.payer_keypair.pubkey(), epoch = epoch_info.epoch, + tree = %tree_accounts.merkle_tree) + )] + async fn process_light_slot( + &self, + epoch_info: &Epoch, + epoch_pda: &ForesterEpochPda, + tree_accounts: &TreeAccounts, + forester_slot_details: &ForesterSlot, + ) -> std::result::Result<(), ForesterError> { + debug!( + event = "light_slot_processing_started", + run_id = %self.ctx.run_id, + tree = %tree_accounts.merkle_tree, + epoch = epoch_info.epoch, + light_slot = forester_slot_details.slot, + slot_start = forester_slot_details.start_solana_slot, + slot_end = forester_slot_details.end_solana_slot, + "Processing light slot" + ); + let mut rpc = self.ctx.rpc_pool.get_connection().await?; + wait_until_slot_reached( + &mut *rpc, + &self.ctx.slot_tracker, + forester_slot_details.start_solana_slot, + ) + .await?; + let mut estimated_slot = self.ctx.slot_tracker.estimated_current_slot(); + + 'inner_processing_loop: loop { + if estimated_slot >= forester_slot_details.end_solana_slot { + trace!( + "Ending processing for slot {:?} due to time limit.", + forester_slot_details.slot + ); + break 'inner_processing_loop; + } + + let current_light_slot = (estimated_slot - epoch_info.phases.active.start) + / epoch_pda.protocol_config.slot_length; + if current_light_slot != forester_slot_details.slot { + warn!( + event = "light_slot_mismatch", + run_id = %self.ctx.run_id, + tree = %tree_accounts.merkle_tree, + expected_light_slot = forester_slot_details.slot, + actual_light_slot = current_light_slot, + estimated_slot, + "Light slot mismatch; exiting processing for this slot" + ); + break 'inner_processing_loop; + } + + if !self + .check_forester_eligibility( + epoch_pda, + current_light_slot, + &tree_accounts.queue, + epoch_info.epoch, + epoch_info, + ) + .await? + { + break 'inner_processing_loop; + } + + let processing_start_time = Instant::now(); + let items_processed_this_iteration = match self + .dispatch_tree_processing( + epoch_info, + epoch_pda, + tree_accounts, + forester_slot_details, + forester_slot_details.end_solana_slot, + estimated_slot, + ) + .await + { + Ok(count) => count, + Err(e) => { + if e.is_forester_not_eligible() { + return Err(e); + } + error!( + event = "light_slot_processing_failed", + run_id = %self.ctx.run_id, + tree = %tree_accounts.merkle_tree, + light_slot = forester_slot_details.slot, + error = ?e, + "Failed processing in light slot" + ); + break 'inner_processing_loop; + } + }; + if items_processed_this_iteration > 0 { + debug!( + event = "light_slot_items_processed", + run_id = %self.ctx.run_id, + light_slot = forester_slot_details.slot, + items = items_processed_this_iteration, + "Processed items in light slot" + ); + } + + self.update_metrics_and_counts( + epoch_info.epoch, + items_processed_this_iteration, + processing_start_time.elapsed(), + ) + .await; + + if let Err(e) = push_metrics(&self.ctx.config.external_services.pushgateway_url).await { + if should_emit_rate_limited_warning("push_metrics_v1", Duration::from_secs(30)) { + warn!( + event = "metrics_push_failed", + run_id = %self.ctx.run_id, + error = ?e, + "Failed to push metrics" + ); + } else { + debug!( + event = "metrics_push_failed_suppressed", + run_id = %self.ctx.run_id, + error = ?e, + "Suppressing repeated metrics push failure" + ); + } + } + estimated_slot = self.ctx.slot_tracker.estimated_current_slot(); + + let sleep_duration_ms = if items_processed_this_iteration > 0 { + self.ctx.config.general_config.sleep_after_processing_ms + } else { + self.ctx.config.general_config.sleep_when_idle_ms + }; + + tokio::time::sleep(Duration::from_millis(sleep_duration_ms)).await; + } + Ok(()) + } + + #[instrument( + level = "debug", + skip(self, epoch_info, epoch_pda, tree_accounts, forester_slot_details, consecutive_eligibility_end), + fields(tree = %tree_accounts.merkle_tree) + )] + async fn process_light_slot_v2( + &self, + epoch_info: &Epoch, + epoch_pda: &ForesterEpochPda, + tree_accounts: &TreeAccounts, + forester_slot_details: &ForesterSlot, + consecutive_eligibility_end: u64, + ) -> std::result::Result<(), ForesterError> { + debug!( + event = "v2_light_slot_processing_started", + run_id = %self.ctx.run_id, + tree = %tree_accounts.merkle_tree, + light_slot = forester_slot_details.slot, + slot_start = forester_slot_details.start_solana_slot, + slot_end = forester_slot_details.end_solana_slot, + consecutive_eligibility_end_slot = consecutive_eligibility_end, + "Processing V2 light slot" + ); + + let tree_pubkey = tree_accounts.merkle_tree; + + let mut rpc = self.ctx.rpc_pool.get_connection().await?; + wait_until_slot_reached( + &mut *rpc, + &self.ctx.slot_tracker, + forester_slot_details.start_solana_slot, + ) + .await?; + + let cached_send_start = Instant::now(); + if let Some(items_sent) = self + .try_send_cached_proofs(epoch_info, tree_accounts, consecutive_eligibility_end) + .await? + { + if items_sent > 0 { + let cached_send_duration = cached_send_start.elapsed(); + info!( + event = "cached_proofs_sent", + run_id = %self.ctx.run_id, + tree = %tree_pubkey, + items = items_sent, + duration_ms = cached_send_duration.as_millis() as u64, + "Sent items from proof cache" + ); + self.update_metrics_and_counts(epoch_info.epoch, items_sent, cached_send_duration) + .await; + } + } + + let mut estimated_slot = self.ctx.slot_tracker.estimated_current_slot(); + + const POLL_INTERVAL: Duration = Duration::from_millis(200); + + 'inner_processing_loop: loop { + if estimated_slot >= forester_slot_details.end_solana_slot { + trace!( + "Ending V2 processing for slot {:?}", + forester_slot_details.slot + ); + break 'inner_processing_loop; + } + + let current_light_slot = (estimated_slot - epoch_info.phases.active.start) + / epoch_pda.protocol_config.slot_length; + if current_light_slot != forester_slot_details.slot { + warn!( + event = "v2_light_slot_mismatch", + run_id = %self.ctx.run_id, + tree = %tree_pubkey, + expected_light_slot = forester_slot_details.slot, + actual_light_slot = current_light_slot, + estimated_slot, + "V2 slot mismatch; exiting processing" + ); + break 'inner_processing_loop; + } + + if !self + .check_forester_eligibility( + epoch_pda, + current_light_slot, + &tree_accounts.merkle_tree, + epoch_info.epoch, + epoch_info, + ) + .await? + { + break 'inner_processing_loop; + } + + let processing_start_time = Instant::now(); + match self + .dispatch_tree_processing( + epoch_info, + epoch_pda, + tree_accounts, + forester_slot_details, + consecutive_eligibility_end, + estimated_slot, + ) + .await + { + Ok(count) => { + if count > 0 { + info!( + event = "v2_tree_processed_items", + run_id = %self.ctx.run_id, + tree = %tree_pubkey, + items = count, + epoch = epoch_info.epoch, + "V2 processed items for tree" + ); + self.update_metrics_and_counts( + epoch_info.epoch, + count, + processing_start_time.elapsed(), + ) + .await; + } else { + tokio::time::sleep(POLL_INTERVAL).await; + } + } + Err(e) => { + if e.is_forester_not_eligible() { + return Err(e); + } + error!( + event = "v2_tree_processing_failed", + run_id = %self.ctx.run_id, + tree = %tree_pubkey, + error = ?e, + "V2 processing failed for tree" + ); + tokio::time::sleep(POLL_INTERVAL).await; + } + } + + if let Err(e) = push_metrics(&self.ctx.config.external_services.pushgateway_url).await { + if should_emit_rate_limited_warning("push_metrics_v2", Duration::from_secs(30)) { + warn!( + event = "metrics_push_failed", + run_id = %self.ctx.run_id, + error = ?e, + "Failed to push metrics" + ); + } else { + debug!( + event = "metrics_push_failed_suppressed", + run_id = %self.ctx.run_id, + error = ?e, + "Suppressing repeated metrics push failure" + ); + } + } + estimated_slot = self.ctx.slot_tracker.estimated_current_slot(); + } + + Ok(()) + } + + #[allow(clippy::too_many_arguments)] + pub(crate) async fn dispatch_tree_processing( + &self, + epoch_info: &Epoch, + epoch_pda: &ForesterEpochPda, + tree_accounts: &TreeAccounts, + forester_slot_details: &ForesterSlot, + consecutive_eligibility_end: u64, + current_solana_slot: u64, + ) -> std::result::Result { + match tree_accounts.tree_type { + TreeType::Unknown => self + .dispatch_compression( + epoch_info, + epoch_pda, + forester_slot_details, + consecutive_eligibility_end, + ) + .await + .map_err(ForesterError::from), + TreeType::StateV1 | TreeType::AddressV1 => { + self.process_v1( + epoch_info, + epoch_pda, + tree_accounts, + forester_slot_details, + current_solana_slot, + ) + .await + } + TreeType::StateV2 | TreeType::AddressV2 => { + let result = self + .process_v2(epoch_info, tree_accounts, consecutive_eligibility_end) + .await?; + self.epoch_tracker + .add_processing_metrics(epoch_info.epoch, result.metrics) + .await; + Ok(result.items_processed) + } + } + } + + pub(crate) async fn check_forester_eligibility( + &self, + epoch_pda: &ForesterEpochPda, + current_light_slot: u64, + queue_pubkey: &solana_program::pubkey::Pubkey, + current_epoch_num: u64, + epoch_info: &Epoch, + ) -> crate::Result { + let current_slot = self.ctx.slot_tracker.estimated_current_slot(); + let current_phase_state = epoch_info.phases.get_current_epoch_state(current_slot); + + if current_phase_state != EpochState::Active { + trace!( + "Skipping processing: not in active phase (current phase: {:?}, slot: {})", + current_phase_state, + current_slot + ); + return Ok(false); + } + + let total_epoch_weight = epoch_pda.total_epoch_weight.ok_or_else(|| { + anyhow::anyhow!( + "Total epoch weight not available in ForesterEpochPda for epoch {}", + current_epoch_num + ) + })?; + + let eligible_forester_slot_index = ForesterEpochPda::get_eligible_forester_index( + current_light_slot, + queue_pubkey, + total_epoch_weight, + current_epoch_num, + ) + .map_err(|e| { + error!( + event = "eligibility_index_calculation_failed", + run_id = %self.ctx.run_id, + queue = %queue_pubkey, + epoch = current_epoch_num, + light_slot = current_light_slot, + error = ?e, + "Failed to calculate eligible forester index" + ); + anyhow::anyhow!("Eligibility calculation failed: {}", e) + })?; + + if !epoch_pda.is_eligible(eligible_forester_slot_index) { + warn!( + event = "forester_not_eligible_for_slot", + run_id = %self.ctx.run_id, + forester = %self.ctx.config.payer_keypair.pubkey(), + queue = %queue_pubkey, + light_slot = current_light_slot, + "Forester is no longer eligible to process this queue in current light slot" + ); + return Ok(false); + } + Ok(true) + } + + pub(crate) async fn update_metrics_and_counts( + &self, + epoch_num: u64, + items_processed: usize, + duration: Duration, + ) { + if items_processed > 0 { + trace!( + "{} items processed in this iteration, duration: {:?}", + items_processed, + duration + ); + queue_metric_update(epoch_num, items_processed, duration); + self.epoch_tracker + .increment_processed_items_count(epoch_num, items_processed) + .await; + self.ctx.heartbeat.add_items_processed(items_processed); + } + } + + pub(crate) fn is_in_active_phase( + &self, + slot: u64, + epoch_info: &ForesterEpochInfo, + ) -> crate::Result { + let current_epoch = self.ctx.protocol_config.get_current_active_epoch(slot)?; + if current_epoch != epoch_info.epoch.epoch { + return Ok(false); + } + + Ok(self + .ctx + .protocol_config + .is_active_phase(slot, epoch_info.epoch.epoch) + .is_ok()) + } +} diff --git a/forester/src/epoch_manager/processor_pool.rs b/forester/src/epoch_manager/processor_pool.rs new file mode 100644 index 0000000000..c19402e204 --- /dev/null +++ b/forester/src/epoch_manager/processor_pool.rs @@ -0,0 +1,192 @@ +use std::sync::Arc; + +use dashmap::DashMap; +use forester_utils::forester_epoch::{Epoch, TreeAccounts}; +use light_client::{indexer::Indexer, rpc::Rpc}; +use solana_program::pubkey::Pubkey; +use tokio::sync::Mutex; +use tracing::debug; + +use super::context::ForesterContext; +use crate::{ + processor::v2::{ + strategy::{AddressTreeStrategy, StateTreeStrategy}, + QueueProcessor, SharedProofCache, + }, + Result, +}; + +pub(crate) type StateBatchProcessorMap = + Arc>>)>>; +pub(crate) type AddressBatchProcessorMap = + Arc>>)>>; +type ProcessorInitLockMap = Arc>>>; + +/// Cached V2 processors and proof caches, keyed by tree pubkey. +/// Processors are reused across epochs to preserve cached state. +#[derive(Debug)] +pub(crate) struct ProcessorPool { + state_processors: StateBatchProcessorMap, + address_processors: AddressBatchProcessorMap, + state_init_locks: ProcessorInitLockMap, + address_init_locks: ProcessorInitLockMap, + pub proof_caches: Arc>>, + pub zkp_batch_sizes: Arc>, +} + +impl Clone for ProcessorPool { + fn clone(&self) -> Self { + Self { + state_processors: self.state_processors.clone(), + address_processors: self.address_processors.clone(), + state_init_locks: self.state_init_locks.clone(), + address_init_locks: self.address_init_locks.clone(), + proof_caches: self.proof_caches.clone(), + zkp_batch_sizes: self.zkp_batch_sizes.clone(), + } + } +} + +impl ProcessorPool { + pub fn new() -> Self { + Self { + state_processors: Arc::new(DashMap::new()), + address_processors: Arc::new(DashMap::new()), + state_init_locks: Arc::new(DashMap::new()), + address_init_locks: Arc::new(DashMap::new()), + proof_caches: Arc::new(DashMap::new()), + zkp_batch_sizes: Arc::new(DashMap::new()), + } + } + + pub async fn get_or_create_state_processor( + &self, + ctx: &ForesterContext, + epoch_info: &Epoch, + tree_accounts: &TreeAccounts, + ops_cache: Arc>, + ) -> Result>>> { + let init_lock = self + .state_init_locks + .entry(tree_accounts.merkle_tree) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone(); + let _init_guard = init_lock.lock().await; + + if let Some(entry) = self.state_processors.get(&tree_accounts.merkle_tree) { + let (stored_epoch, processor_ref) = entry.value(); + let processor_clone = processor_ref.clone(); + let old_epoch = *stored_epoch; + drop(entry); + + if old_epoch != epoch_info.epoch { + debug!( + "Reusing StateBatchProcessor for tree {} across epoch transition ({} -> {})", + tree_accounts.merkle_tree, old_epoch, epoch_info.epoch + ); + self.state_processors.insert( + tree_accounts.merkle_tree, + (epoch_info.epoch, processor_clone.clone()), + ); + processor_clone + .lock() + .await + .update_epoch(epoch_info.epoch, epoch_info.phases.clone()); + } + return Ok(processor_clone); + } + + let batch_context = + ctx.build_batch_context(epoch_info, tree_accounts, None, None, None, ops_cache); + let processor = Arc::new(Mutex::new( + QueueProcessor::new(batch_context, StateTreeStrategy).await?, + )); + + let batch_size = processor.lock().await.zkp_batch_size(); + self.zkp_batch_sizes + .insert(tree_accounts.merkle_tree, batch_size); + + self.state_processors.insert( + tree_accounts.merkle_tree, + (epoch_info.epoch, processor.clone()), + ); + Ok(processor) + } + + pub async fn get_or_create_address_processor( + &self, + ctx: &ForesterContext, + epoch_info: &Epoch, + tree_accounts: &TreeAccounts, + ops_cache: Arc>, + ) -> Result>>> { + let init_lock = self + .address_init_locks + .entry(tree_accounts.merkle_tree) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone(); + let _init_guard = init_lock.lock().await; + + if let Some(entry) = self.address_processors.get(&tree_accounts.merkle_tree) { + let (stored_epoch, processor_ref) = entry.value(); + let processor_clone = processor_ref.clone(); + let old_epoch = *stored_epoch; + drop(entry); + + if old_epoch != epoch_info.epoch { + debug!( + "Reusing AddressBatchProcessor for tree {} across epoch transition ({} -> {})", + tree_accounts.merkle_tree, old_epoch, epoch_info.epoch + ); + self.address_processors.insert( + tree_accounts.merkle_tree, + (epoch_info.epoch, processor_clone.clone()), + ); + processor_clone + .lock() + .await + .update_epoch(epoch_info.epoch, epoch_info.phases.clone()); + } + return Ok(processor_clone); + } + + let batch_context = + ctx.build_batch_context(epoch_info, tree_accounts, None, None, None, ops_cache); + let processor = Arc::new(Mutex::new( + QueueProcessor::new(batch_context, AddressTreeStrategy).await?, + )); + + let batch_size = processor.lock().await.zkp_batch_size(); + self.zkp_batch_sizes + .insert(tree_accounts.merkle_tree, batch_size); + + self.address_processors.insert( + tree_accounts.merkle_tree, + (epoch_info.epoch, processor.clone()), + ); + Ok(processor) + } + + pub fn remove_state_processor(&self, tree: &Pubkey) { + self.state_processors.remove(tree); + } + + pub fn remove_address_processor(&self, tree: &Pubkey) { + self.address_processors.remove(tree); + } + + pub fn remove_proof_cache(&self, tree: &Pubkey) { + self.proof_caches.remove(tree); + } + + pub fn get_or_create_proof_cache(&self, tree: Pubkey) -> Arc { + self.proof_caches + .entry(tree) + .or_insert_with(|| Arc::new(SharedProofCache::new(tree))) + .clone() + } + + pub fn get_proof_cache(&self, tree: &Pubkey) -> Option> { + self.proof_caches.get(tree).map(|c| c.clone()) + } +} diff --git a/forester/src/epoch_manager/registration.rs b/forester/src/epoch_manager/registration.rs new file mode 100644 index 0000000000..355a398301 --- /dev/null +++ b/forester/src/epoch_manager/registration.rs @@ -0,0 +1,615 @@ +//! Epoch registration, recovery, wait-for-active-phase, and re-finalization. + +use std::time::Duration; + +use anyhow::Context; +use forester_utils::forester_epoch::{get_epoch_phases, Epoch, TreeAccounts, TreeForesterSchedule}; +use light_client::{ + indexer::Indexer, + rpc::{LightClient, LightClientConfig, Rpc, RpcError}, +}; +use light_compressed_account::TreeType; +use light_registry::{ + sdk::create_finalize_registration_instruction, + utils::{get_epoch_pda_address, get_forester_epoch_pda_from_authority}, + EpochPda, ForesterEpochPda, +}; +use solana_sdk::signature::Signer; +use tokio::time::sleep; +use tracing::{debug, error, info, instrument, warn}; + +use super::{tracker::RegistrationTracker, EpochManager}; +use crate::{ + errors::{ForesterError, RegistrationError}, + pagerduty::send_pagerduty_alert, + slot_tracker::{slot_duration, wait_until_slot_reached}, + smart_transaction::{send_smart_transaction, ComputeBudgetConfig, SendSmartTransactionConfig}, + transaction_timing::scheduled_confirmation_deadline, + ForesterEpochInfo, +}; + +impl EpochManager { + pub(crate) async fn recover_registration_info_if_exists( + &self, + epoch: u64, + ) -> std::result::Result, ForesterError> { + debug!("Recovering registration info for epoch {}", epoch); + + let forester_epoch_pda_pubkey = + get_forester_epoch_pda_from_authority(&self.ctx.config.derivation_pubkey, epoch).0; + + let existing_pda = { + let rpc = self.ctx.rpc_pool.get_connection().await?; + rpc.get_anchor_account::(&forester_epoch_pda_pubkey) + .await? + }; + + match existing_pda { + Some(pda) => self + .recover_registration_info_internal(epoch, forester_epoch_pda_pubkey, pda) + .await + .map(Some) + .map_err(ForesterError::from), + None => Ok(None), + } + } + + async fn recover_registration_info_internal( + &self, + epoch: u64, + forester_epoch_pda_address: solana_program::pubkey::Pubkey, + forester_epoch_pda: ForesterEpochPda, + ) -> crate::Result { + let rpc = self.ctx.rpc_pool.get_connection().await?; + + let phases = get_epoch_phases(&self.ctx.protocol_config, epoch); + let slot = rpc.get_slot().await?; + let state = phases.get_current_epoch_state(slot); + + let epoch_pda_address = get_epoch_pda_address(epoch); + let epoch_pda = rpc + .get_anchor_account::(&epoch_pda_address) + .await + .with_context(|| format!("Failed to fetch EpochPda for epoch {}", epoch))? + .ok_or(RegistrationError::EpochPdaNotFound { + epoch, + pda_address: epoch_pda_address, + })?; + + let epoch_info = Epoch { + epoch, + epoch_pda: epoch_pda_address, + forester_epoch_pda: forester_epoch_pda_address, + phases, + state, + merkle_trees: Vec::new(), + }; + + let forester_epoch_info = ForesterEpochInfo { + epoch: epoch_info, + epoch_pda, + forester_epoch_pda, + trees: Vec::new(), + }; + + Ok(forester_epoch_info) + } + + pub(crate) async fn register_for_epoch_with_retry( + &self, + epoch: u64, + max_retries: u32, + retry_delay: Duration, + ) -> std::result::Result { + let rpc = LightClient::new(LightClientConfig { + url: self.ctx.config.external_services.rpc_url.to_string(), + photon_url: self.ctx.config.external_services.indexer_url.clone(), + commitment_config: Some(solana_sdk::commitment_config::CommitmentConfig::confirmed()), + fetch_active_tree: false, + }) + .await + .map_err(ForesterError::Rpc)?; + let slot = rpc.get_slot().await.map_err(ForesterError::Rpc)?; + let phases = get_epoch_phases(&self.ctx.protocol_config, epoch); + + if slot < phases.registration.start { + let slots_to_wait = phases.registration.start.saturating_sub(slot); + info!( + event = "registration_wait_for_window", + run_id = %self.ctx.run_id, + epoch, + current_slot = slot, + registration_start_slot = phases.registration.start, + slots_to_wait, + "Registration window not open yet; waiting" + ); + let wait_duration = slot_duration() * slots_to_wait as u32; + sleep(wait_duration).await; + } + + for attempt in 0..max_retries { + match self.recover_registration_info_if_exists(epoch).await { + Ok(Some(registration_info)) => return Ok(registration_info), + Ok(None) => {} + Err(e) => return Err(e), + } + + match self.register_for_epoch(epoch).await { + Ok(registration_info) => return Ok(registration_info), + Err(e) => { + warn!( + event = "registration_attempt_failed", + run_id = %self.ctx.run_id, + epoch, + attempt = attempt + 1, + max_attempts = max_retries, + error = ?e, + "Failed to register for epoch; retrying" + ); + if attempt < max_retries - 1 { + sleep(retry_delay).await; + } else { + if let Some(pagerduty_key) = self + .ctx + .config + .external_services + .pagerduty_routing_key + .clone() + { + if let Err(alert_err) = send_pagerduty_alert( + &pagerduty_key, + &format!( + "Forester failed to register for epoch {} after {} attempts", + epoch, max_retries + ), + "critical", + &format!("Forester {}", self.ctx.config.payer_keypair.pubkey()), + ) + .await + { + error!( + event = "pagerduty_alert_failed", + run_id = %self.ctx.run_id, + epoch, + error = ?alert_err, + "Failed to send PagerDuty alert" + ); + } + } + return Err(ForesterError::Other(e)); + } + } + } + } + Err(RegistrationError::MaxRetriesExceeded { + epoch, + attempts: max_retries, + } + .into()) + } + + #[instrument(level = "debug", skip(self), fields(forester = %self.ctx.config.payer_keypair.pubkey(), epoch = epoch))] + async fn register_for_epoch(&self, epoch: u64) -> crate::Result { + info!( + event = "registration_attempt_started", + run_id = %self.ctx.run_id, + epoch, "Registering for epoch" + ); + let mut rpc = LightClient::new(LightClientConfig { + url: self.ctx.config.external_services.rpc_url.to_string(), + photon_url: self.ctx.config.external_services.indexer_url.clone(), + commitment_config: Some(solana_sdk::commitment_config::CommitmentConfig::processed()), + fetch_active_tree: false, + }) + .await?; + let slot = rpc.get_slot().await?; + let phases = get_epoch_phases(&self.ctx.protocol_config, epoch); + + if slot >= phases.registration.start { + let forester_epoch_pda_pubkey = + get_forester_epoch_pda_from_authority(&self.ctx.config.derivation_pubkey, epoch).0; + let existing_registration = rpc + .get_anchor_account::(&forester_epoch_pda_pubkey) + .await?; + + if let Some(existing_pda) = existing_registration { + info!( + event = "registration_already_exists", + run_id = %self.ctx.run_id, + epoch, "Already registered for epoch; recovering registration info" + ); + let registration_info = self + .recover_registration_info_internal( + epoch, + forester_epoch_pda_pubkey, + existing_pda, + ) + .await?; + return Ok(registration_info); + } + + let registration_info = { + debug!("Registering epoch {}", epoch); + let registered_epoch = match Epoch::register( + &mut rpc, + &self.ctx.protocol_config, + &self.ctx.config.payer_keypair, + &self.ctx.config.derivation_pubkey, + Some(epoch), + ) + .await + .with_context(|| { + format!("Failed to execute epoch registration for epoch {}", epoch) + })? { + Some(epoch) => { + debug!("Registered epoch: {:?}", epoch); + epoch + } + None => { + return Err(RegistrationError::EmptyRegistration.into()); + } + }; + + let forester_epoch_pda = rpc + .get_anchor_account::(®istered_epoch.forester_epoch_pda) + .await + .with_context(|| { + format!( + "Failed to fetch ForesterEpochPda from RPC for address {}", + registered_epoch.forester_epoch_pda + ) + })? + .ok_or(RegistrationError::ForesterEpochPdaNotFound { + epoch, + pda_address: registered_epoch.forester_epoch_pda, + })?; + + let epoch_pda_address = get_epoch_pda_address(epoch); + let epoch_pda = rpc + .get_anchor_account::(&epoch_pda_address) + .await + .with_context(|| { + format!( + "Failed to fetch EpochPda from RPC for address {}", + epoch_pda_address + ) + })? + .ok_or(RegistrationError::EpochPdaNotFound { + epoch, + pda_address: epoch_pda_address, + })?; + + ForesterEpochInfo { + epoch: registered_epoch, + epoch_pda, + forester_epoch_pda, + trees: Vec::new(), + } + }; + debug!("Registered: {:?}", registration_info); + Ok(registration_info) + } else { + warn!( + event = "registration_too_early", + run_id = %self.ctx.run_id, + epoch, + current_slot = slot, + registration_start_slot = phases.registration.start, + "Too early to register for epoch" + ); + Err(RegistrationError::RegistrationPhaseNotStarted { + epoch, + current_slot: slot, + registration_start: phases.registration.start, + } + .into()) + } + } + + #[instrument(level = "debug", skip(self, epoch_info), fields(forester = %self.ctx.config.payer_keypair.pubkey(), epoch = epoch_info.epoch.epoch))] + pub(crate) async fn wait_for_active_phase( + &self, + epoch_info: &ForesterEpochInfo, + ) -> std::result::Result, ForesterError> { + let mut rpc = self.ctx.rpc_pool.get_connection().await?; + let active_phase_start_slot = epoch_info.epoch.phases.active.start; + let active_phase_end_slot = epoch_info.epoch.phases.active.end; + let current_slot = self.ctx.slot_tracker.estimated_current_slot(); + + if current_slot >= active_phase_start_slot { + info!( + event = "active_phase_already_started", + run_id = %self.ctx.run_id, + current_slot, + active_phase_start_slot, + active_phase_end_slot, + slots_left = active_phase_end_slot.saturating_sub(current_slot), + "Active phase has already started" + ); + } else { + let waiting_slots = active_phase_start_slot - current_slot; + let waiting_secs = waiting_slots / 2; + info!( + event = "wait_for_active_phase", + run_id = %self.ctx.run_id, + current_slot, + active_phase_start_slot, + waiting_slots, + approx_wait_seconds = waiting_secs, + "Waiting for active phase to start" + ); + } + + self.prewarm_all_trees_during_wait(epoch_info, active_phase_start_slot) + .await; + + wait_until_slot_reached(&mut *rpc, &self.ctx.slot_tracker, active_phase_start_slot).await?; + + let forester_epoch_pda_pubkey = get_forester_epoch_pda_from_authority( + &self.ctx.config.derivation_pubkey, + epoch_info.epoch.epoch, + ) + .0; + let existing_registration = rpc + .get_anchor_account::(&forester_epoch_pda_pubkey) + .await?; + + if let Some(registration) = existing_registration { + if registration.total_epoch_weight.is_none() { + let current_slot = rpc.get_slot().await?; + if current_slot > epoch_info.epoch.phases.active.end { + info!( + event = "skip_finalize_registration_phase_ended", + run_id = %self.ctx.run_id, + epoch = epoch_info.epoch.epoch, + current_slot, + active_phase_end_slot = epoch_info.epoch.phases.active.end, + "Skipping FinalizeRegistration because active phase ended" + ); + return Ok(None); + } + + let ix = create_finalize_registration_instruction( + &self.ctx.config.payer_keypair.pubkey(), + &self.ctx.config.derivation_pubkey, + epoch_info.epoch.epoch, + ); + let priority_fee = self + .ctx + .resolve_epoch_priority_fee(&*rpc, epoch_info.epoch.epoch) + .await?; + let Some(confirmation_deadline) = scheduled_confirmation_deadline( + epoch_info + .epoch + .phases + .active + .end + .saturating_sub(current_slot), + ) else { + info!( + event = "skip_finalize_registration_confirmation_budget_exhausted", + run_id = %self.ctx.run_id, + epoch = epoch_info.epoch.epoch, + current_slot, + active_phase_end_slot = epoch_info.epoch.phases.active.end, + "Skipping FinalizeRegistration because not enough active-phase time remains for confirmation" + ); + return Ok(None); + }; + let payer = self.ctx.config.payer_keypair.pubkey(); + let signers = [&self.ctx.config.payer_keypair]; + send_smart_transaction( + &mut *rpc, + SendSmartTransactionConfig { + instructions: vec![ix], + payer: &payer, + signers: &signers, + address_lookup_tables: &self.ctx.address_lookup_tables, + compute_budget: ComputeBudgetConfig { + compute_unit_price: priority_fee, + compute_unit_limit: Some(self.ctx.config.transaction_config.cu_limit), + }, + confirmation: Some(self.ctx.confirmation_config()), + confirmation_deadline: Some(confirmation_deadline), + }, + ) + .await + .map_err(RpcError::from)?; + } + } + + let mut epoch_info = (*epoch_info).clone(); + epoch_info.forester_epoch_pda = rpc + .get_anchor_account::(&epoch_info.epoch.forester_epoch_pda) + .await + .with_context(|| { + format!( + "Failed to fetch ForesterEpochPda for epoch {} at address {}", + epoch_info.epoch.epoch, epoch_info.epoch.forester_epoch_pda + ) + })? + .ok_or(RegistrationError::ForesterEpochPdaNotFound { + epoch: epoch_info.epoch.epoch, + pda_address: epoch_info.epoch.forester_epoch_pda, + })?; + + let slot = rpc.get_slot().await?; + let trees = self.trees.lock().await; + tracing::trace!("Adding schedule for trees: {:?}", *trees); + epoch_info.add_trees_with_schedule(&trees, slot)?; + + if self.compressible_tracker.is_some() && self.ctx.config.compressible_config.is_some() { + let compression_tree_accounts = TreeAccounts { + merkle_tree: solana_sdk::pubkey::Pubkey::default(), + queue: solana_sdk::pubkey::Pubkey::default(), + tree_type: TreeType::Unknown, + is_rolledover: false, + owner: solana_sdk::pubkey::Pubkey::default(), + }; + let tree_schedule = TreeForesterSchedule::new_with_schedule( + &compression_tree_accounts, + slot, + &epoch_info.forester_epoch_pda, + &epoch_info.epoch_pda, + ) + .map_err(anyhow::Error::from)?; + epoch_info.trees.insert(0, tree_schedule); + debug!("Added compression tree to epoch {}", epoch_info.epoch.epoch); + } + + info!( + event = "active_phase_ready", + run_id = %self.ctx.run_id, + epoch = epoch_info.epoch.epoch, + "Finished waiting for active phase" + ); + Ok(Some(epoch_info)) + } + + /// Check if `EpochPda.registered_weight` changed on-chain. If so, + /// one task sends a `finalize_registration` tx while others wait, + /// then all tasks refresh their `ForesterEpochPda` and recompute schedules. + pub(crate) async fn maybe_refinalize( + &self, + epoch_info: &Epoch, + forester_epoch_pda: &mut ForesterEpochPda, + tree_schedule: &mut TreeForesterSchedule, + registration_tracker: &RegistrationTracker, + force: bool, + ) -> crate::Result<()> { + let mut rpc = self.ctx.rpc_pool.get_connection().await?; + let epoch_pda_address = get_epoch_pda_address(epoch_info.epoch); + let on_chain_epoch_pda: EpochPda = rpc + .get_anchor_account::(&epoch_pda_address) + .await? + .ok_or(RegistrationError::EpochPdaNotFound { + epoch: epoch_info.epoch, + pda_address: epoch_pda_address, + })?; + + let on_chain_weight = on_chain_epoch_pda.registered_weight; + let cached_weight = registration_tracker.cached_weight(); + let weight_changed = on_chain_weight != cached_weight; + + if !weight_changed && !force { + return Ok(()); + } + + if weight_changed { + info!( + event = "registered_weight_changed", + run_id = %self.ctx.run_id, + epoch = epoch_info.epoch, + old_weight = cached_weight, + new_weight = on_chain_weight, + "Detected new forester registration, re-finalizing" + ); + + if registration_tracker.try_claim_refinalize() { + let ix = create_finalize_registration_instruction( + &self.ctx.config.payer_keypair.pubkey(), + &self.ctx.config.derivation_pubkey, + epoch_info.epoch, + ); + let priority_fee = self + .ctx + .resolve_epoch_priority_fee(&*rpc, epoch_info.epoch) + .await?; + let current_slot = rpc.get_slot().await?; + let Some(confirmation_deadline) = scheduled_confirmation_deadline( + epoch_info.phases.active.end.saturating_sub(current_slot), + ) else { + info!( + event = "refinalize_registration_skipped_confirmation_budget_exhausted", + run_id = %self.ctx.run_id, + epoch = epoch_info.epoch, + current_slot, + active_phase_end_slot = epoch_info.phases.active.end, + "Skipping re-finalization because not enough active-phase time remains for confirmation" + ); + registration_tracker.complete_refinalize(cached_weight); + return Ok(()); + }; + let payer = self.ctx.config.payer_keypair.pubkey(); + let signers = [&self.ctx.config.payer_keypair]; + match send_smart_transaction( + &mut *rpc, + SendSmartTransactionConfig { + instructions: vec![ix], + payer: &payer, + signers: &signers, + address_lookup_tables: &self.ctx.address_lookup_tables, + compute_budget: ComputeBudgetConfig { + compute_unit_price: priority_fee, + compute_unit_limit: Some(self.ctx.config.transaction_config.cu_limit), + }, + confirmation: Some(self.ctx.confirmation_config()), + confirmation_deadline: Some(confirmation_deadline), + }, + ) + .await + .map_err(RpcError::from) + { + Ok(_) => { + let post_finalize_weight = + match rpc.get_anchor_account::(&epoch_pda_address).await { + Ok(Some(pda)) => pda.registered_weight, + _ => on_chain_weight, + }; + info!( + event = "refinalize_registration_success", + run_id = %self.ctx.run_id, + epoch = epoch_info.epoch, + new_weight = post_finalize_weight, + "Re-finalized registration on-chain" + ); + registration_tracker.complete_refinalize(post_finalize_weight); + } + Err(e) => { + registration_tracker.complete_refinalize(cached_weight); + return Err(e.into()); + } + } + } else { + registration_tracker.wait_for_refinalize().await; + } + } + + let refreshed_epoch_pda: EpochPda = rpc + .get_anchor_account::(&epoch_pda_address) + .await? + .ok_or(RegistrationError::EpochPdaNotFound { + epoch: epoch_info.epoch, + pda_address: epoch_pda_address, + })?; + let updated_pda: ForesterEpochPda = rpc + .get_anchor_account::(&epoch_info.forester_epoch_pda) + .await? + .ok_or(RegistrationError::ForesterEpochPdaNotFound { + epoch: epoch_info.epoch, + pda_address: epoch_info.forester_epoch_pda, + })?; + + let current_slot = self.ctx.slot_tracker.estimated_current_slot(); + let new_schedule = TreeForesterSchedule::new_with_schedule( + &tree_schedule.tree_accounts, + current_slot, + &updated_pda, + &refreshed_epoch_pda, + )?; + + *forester_epoch_pda = updated_pda; + *tree_schedule = new_schedule; + + info!( + event = "schedule_recomputed_after_refinalize", + run_id = %self.ctx.run_id, + epoch = epoch_info.epoch, + tree = %tree_schedule.tree_accounts.merkle_tree, + new_eligible_slots = tree_schedule.slots.iter().filter(|s| s.is_some()).count(), + "Recomputed schedule after re-finalization" + ); + + Ok(()) + } +} diff --git a/forester/src/epoch_manager/reporting.rs b/forester/src/epoch_manager/reporting.rs new file mode 100644 index 0000000000..6352854db3 --- /dev/null +++ b/forester/src/epoch_manager/reporting.rs @@ -0,0 +1,197 @@ +//! Work reporting: send metrics to channel and report on-chain. + +use light_client::{ + indexer::Indexer, + rpc::{LightClient, LightClientConfig, Rpc, RpcError}, +}; +use light_registry::{ + sdk::create_report_work_instruction, utils::get_forester_epoch_pda_from_authority, + ForesterEpochPda, +}; +use solana_program::instruction::InstructionError; +use solana_sdk::{signature::Signer, transaction::TransactionError}; +use tracing::{info, instrument}; + +use super::{EpochManager, WorkReport}; +use crate::{ + errors::{rpc_is_already_processed, ChannelError, ForesterError, WorkReportError}, + slot_tracker::wait_until_slot_reached, + smart_transaction::{send_smart_transaction, ComputeBudgetConfig, SendSmartTransactionConfig}, + ForesterEpochInfo, +}; + +impl EpochManager { + #[instrument(level = "debug", skip(self, epoch_info), fields(forester = %self.ctx.config.payer_keypair.pubkey(), epoch = epoch_info.epoch.epoch))] + pub(crate) async fn wait_for_report_work_phase( + &self, + epoch_info: &ForesterEpochInfo, + ) -> crate::Result<()> { + info!( + event = "wait_for_report_work_phase", + run_id = %self.ctx.run_id, + epoch = epoch_info.epoch.epoch, + report_work_start_slot = epoch_info.epoch.phases.report_work.start, + "Waiting for report work phase" + ); + let mut rpc = self.ctx.rpc_pool.get_connection().await?; + let report_work_start_slot = epoch_info.epoch.phases.report_work.start; + wait_until_slot_reached(&mut *rpc, &self.ctx.slot_tracker, report_work_start_slot).await?; + + info!( + event = "report_work_phase_ready", + run_id = %self.ctx.run_id, + epoch = epoch_info.epoch.epoch, + "Finished waiting for report work phase" + ); + Ok(()) + } + + #[instrument(level = "debug", skip(self, epoch_info), fields(forester = %self.ctx.config.payer_keypair.pubkey(), epoch = epoch_info.epoch.epoch))] + pub(crate) async fn send_work_report( + &self, + epoch_info: &ForesterEpochInfo, + ) -> crate::Result<()> { + let report = WorkReport { + epoch: epoch_info.epoch.epoch, + processed_items: self + .epoch_tracker + .get_processed_items_count(epoch_info.epoch.epoch) + .await, + metrics: self + .epoch_tracker + .get_processing_metrics(epoch_info.epoch.epoch) + .await, + }; + + info!( + event = "work_report_sent_to_channel", + run_id = %self.ctx.run_id, + epoch = report.epoch, + items = report.processed_items, + total_circuit_inputs_ms = report.metrics.total_circuit_inputs().as_millis() as u64, + total_proof_generation_ms = report.metrics.total_proof_generation().as_millis() as u64, + total_round_trip_ms = report.metrics.total_round_trip().as_millis() as u64, + tx_sending_ms = report.metrics.tx_sending_duration.as_millis() as u64, + "Sending work report to channel" + ); + + self.work_report_sender + .send(report) + .await + .map_err(|e| ChannelError::WorkReportSend { + epoch: report.epoch, + error: e.to_string(), + })?; + self.ctx.heartbeat.increment_work_report_sent(); + + Ok(()) + } + + #[instrument(level = "debug", skip(self, epoch_info), fields(forester = %self.ctx.config.payer_keypair.pubkey(), epoch = epoch_info.epoch.epoch))] + pub(crate) async fn report_work_onchain( + &self, + epoch_info: &ForesterEpochInfo, + ) -> crate::Result<()> { + info!( + event = "work_report_onchain_started", + run_id = %self.ctx.run_id, + epoch = epoch_info.epoch.epoch, + "Reporting work on-chain" + ); + let mut rpc = LightClient::new(LightClientConfig { + url: self.ctx.config.external_services.rpc_url.to_string(), + photon_url: self.ctx.config.external_services.indexer_url.clone(), + commitment_config: Some(solana_sdk::commitment_config::CommitmentConfig::processed()), + fetch_active_tree: false, + }) + .await?; + + let forester_epoch_pda_pubkey = get_forester_epoch_pda_from_authority( + &self.ctx.config.derivation_pubkey, + epoch_info.epoch.epoch, + ) + .0; + if let Some(forester_epoch_pda) = rpc + .get_anchor_account::(&forester_epoch_pda_pubkey) + .await? + { + if forester_epoch_pda.has_reported_work { + return Ok(()); + } + } + + let forester_epoch_pda = &epoch_info.forester_epoch_pda; + if forester_epoch_pda.has_reported_work { + return Ok(()); + } + + let ix = create_report_work_instruction( + &self.ctx.config.payer_keypair.pubkey(), + &self.ctx.config.derivation_pubkey, + epoch_info.epoch.epoch, + ); + + let priority_fee = self + .ctx + .resolve_epoch_priority_fee(&rpc, epoch_info.epoch.epoch) + .await?; + let payer = self.ctx.config.payer_keypair.pubkey(); + let signers = [&self.ctx.config.payer_keypair]; + match send_smart_transaction( + &mut rpc, + SendSmartTransactionConfig { + instructions: vec![ix], + payer: &payer, + signers: &signers, + address_lookup_tables: &self.ctx.address_lookup_tables, + compute_budget: ComputeBudgetConfig { + compute_unit_price: priority_fee, + compute_unit_limit: Some(self.ctx.config.transaction_config.cu_limit), + }, + confirmation: Some(self.ctx.confirmation_config()), + confirmation_deadline: None, + }, + ) + .await + .map_err(RpcError::from) + { + Ok(_) => { + info!( + event = "work_report_onchain_succeeded", + run_id = %self.ctx.run_id, + epoch = epoch_info.epoch.epoch, + "Work reported on-chain" + ); + } + Err(e) => { + if rpc_is_already_processed(&e) { + info!( + event = "work_report_onchain_already_reported", + run_id = %self.ctx.run_id, + epoch = epoch_info.epoch.epoch, + "Work already reported on-chain for epoch" + ); + return Ok(()); + } + if let RpcError::ClientError(client_error) = &e { + if let Some(TransactionError::InstructionError( + _, + InstructionError::Custom(error_code), + )) = client_error.get_transaction_error() + { + return WorkReportError::from_registry_error( + error_code, + epoch_info.epoch.epoch, + ) + .map_err(|e| anyhow::Error::from(ForesterError::from(e))); + } + } + return Err(anyhow::Error::from(WorkReportError::Transaction(Box::new( + e, + )))); + } + } + + Ok(()) + } +} diff --git a/forester/src/epoch_manager/tracker.rs b/forester/src/epoch_manager/tracker.rs new file mode 100644 index 0000000000..349ced4e98 --- /dev/null +++ b/forester/src/epoch_manager/tracker.rs @@ -0,0 +1,276 @@ +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}, + Arc, + }, +}; + +use dashmap::DashMap; +use tokio::sync::Mutex; + +use super::ProcessingMetrics; + +/// Coordinates re-finalization across parallel `process_queue` tasks when new +/// foresters register mid-epoch. Only one task performs the on-chain +/// `finalize_registration` tx; others wait for it to complete. +#[derive(Debug)] +pub(crate) struct RegistrationTracker { + cached_registered_weight: AtomicU64, + refinalize_in_progress: AtomicBool, + refinalized: tokio::sync::Notify, +} + +impl RegistrationTracker { + pub fn new(weight: u64) -> Self { + Self { + cached_registered_weight: AtomicU64::new(weight), + refinalize_in_progress: AtomicBool::new(false), + refinalized: tokio::sync::Notify::new(), + } + } + + pub fn cached_weight(&self) -> u64 { + self.cached_registered_weight.load(Ordering::Acquire) + } + + /// Returns `true` if this caller won the race to perform re-finalization. + pub fn try_claim_refinalize(&self) -> bool { + self.refinalize_in_progress + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) + .is_ok() + } + + /// Called by the winner after the on-chain tx succeeds. + pub fn complete_refinalize(&self, new_weight: u64) { + self.cached_registered_weight + .store(new_weight, Ordering::Release); + self.refinalize_in_progress.store(false, Ordering::Release); + self.refinalized.notify_waiters(); + } + + /// Called by non-winners to block until re-finalization is done. + pub async fn wait_for_refinalize(&self) { + if !self.refinalize_in_progress.load(Ordering::Acquire) { + return; + } + let fut = self.refinalized.notified(); + if !self.refinalize_in_progress.load(Ordering::Acquire) { + return; + } + fut.await; + } +} + +/// Per-epoch coordination state: dedup flags, item counts, metrics, and +/// registration trackers. Shared across parallel tree tasks within an epoch. +#[derive(Debug, Clone)] +pub(crate) struct EpochTracker { + processing_epochs: Arc>>, + processed_items: Arc>>, + processing_metrics: Arc>>, + registration_trackers: Arc>>, +} + +impl EpochTracker { + pub fn new() -> Self { + Self { + processing_epochs: Arc::new(DashMap::new()), + processed_items: Arc::new(Mutex::new(HashMap::new())), + processing_metrics: Arc::new(Mutex::new(HashMap::new())), + registration_trackers: Arc::new(DashMap::new()), + } + } + + /// Attempts to claim exclusive processing for an epoch. + /// Returns `Some(EpochGuard)` if this caller won, `None` if already claimed. + /// The guard resets the flag on drop. + pub fn try_claim_epoch(&self, epoch: u64) -> Option { + let flag = self + .processing_epochs + .entry(epoch) + .or_insert_with(|| Arc::new(AtomicBool::new(false))) + .clone(); + + if flag + .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) + .is_err() + { + return None; + } + + Some(EpochGuard { flag }) + } + + pub async fn get_processed_items_count(&self, epoch: u64) -> usize { + let counts = self.processed_items.lock().await; + counts + .get(&epoch) + .map_or(0, |count| count.load(Ordering::Relaxed)) + } + + pub async fn increment_processed_items_count(&self, epoch: u64, increment_by: usize) { + let mut counts = self.processed_items.lock().await; + counts + .entry(epoch) + .or_insert_with(|| AtomicUsize::new(0)) + .fetch_add(increment_by, Ordering::Relaxed); + } + + pub async fn get_processing_metrics(&self, epoch: u64) -> ProcessingMetrics { + let metrics = self.processing_metrics.lock().await; + metrics.get(&epoch).copied().unwrap_or_default() + } + + pub async fn add_processing_metrics(&self, epoch: u64, new_metrics: ProcessingMetrics) { + let mut metrics = self.processing_metrics.lock().await; + *metrics.entry(epoch).or_default() += new_metrics; + } + + pub fn get_or_create_tracker(&self, epoch: u64, weight: u64) -> Arc { + self.registration_trackers + .entry(epoch) + .or_insert_with(|| Arc::new(RegistrationTracker::new(weight))) + .value() + .clone() + } + + /// Removes all per-epoch state for the given epoch. + pub async fn cleanup(&self, epoch: u64) { + self.registration_trackers.remove(&epoch); + self.processing_epochs.remove(&epoch); + self.processed_items.lock().await.remove(&epoch); + self.processing_metrics.lock().await.remove(&epoch); + } +} + +/// RAII guard that resets the epoch processing flag on drop. +pub(crate) struct EpochGuard { + flag: Arc, +} + +impl Drop for EpochGuard { + fn drop(&mut self) { + self.flag.store(false, Ordering::SeqCst); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn registration_tracker_initial_weight() { + let tracker = RegistrationTracker::new(1000); + assert_eq!(tracker.cached_weight(), 1000); + } + + #[test] + fn registration_tracker_claim_returns_true_once() { + let tracker = RegistrationTracker::new(100); + assert!(tracker.try_claim_refinalize()); + assert!(!tracker.try_claim_refinalize()); + } + + #[test] + fn registration_tracker_complete_releases_claim() { + let tracker = RegistrationTracker::new(100); + assert!(tracker.try_claim_refinalize()); + tracker.complete_refinalize(200); + assert_eq!(tracker.cached_weight(), 200); + // Can claim again after complete + assert!(tracker.try_claim_refinalize()); + } + + #[tokio::test] + async fn registration_tracker_wait_returns_immediately_when_not_in_progress() { + let tracker = RegistrationTracker::new(100); + // Should return immediately — no one is re-finalizing + tracker.wait_for_refinalize().await; + } + + #[tokio::test] + async fn registration_tracker_wait_blocks_until_complete() { + let tracker = Arc::new(RegistrationTracker::new(100)); + assert!(tracker.try_claim_refinalize()); + + let tracker_clone = tracker.clone(); + let waiter = tokio::spawn(async move { + tracker_clone.wait_for_refinalize().await; + tracker_clone.cached_weight() + }); + + // Give the waiter a moment to start waiting + tokio::task::yield_now().await; + + tracker.complete_refinalize(500); + + let result = waiter.await.unwrap(); + assert_eq!(result, 500); + } + + #[test] + fn epoch_tracker_claim_returns_none_on_duplicate() { + let tracker = EpochTracker::new(); + let guard = tracker.try_claim_epoch(1); + assert!(guard.is_some()); + assert!(tracker.try_claim_epoch(1).is_none()); + // Different epoch is fine + assert!(tracker.try_claim_epoch(2).is_some()); + } + + #[test] + fn epoch_guard_resets_on_drop() { + let tracker = EpochTracker::new(); + { + let _guard = tracker.try_claim_epoch(1); + assert!(tracker.try_claim_epoch(1).is_none()); + } + // After guard dropped, can claim again + assert!(tracker.try_claim_epoch(1).is_some()); + } + + #[tokio::test] + async fn epoch_tracker_items_and_metrics() { + let tracker = EpochTracker::new(); + assert_eq!(tracker.get_processed_items_count(1).await, 0); + + tracker.increment_processed_items_count(1, 10).await; + tracker.increment_processed_items_count(1, 5).await; + assert_eq!(tracker.get_processed_items_count(1).await, 15); + + // Different epoch is independent + assert_eq!(tracker.get_processed_items_count(2).await, 0); + + let metrics = tracker.get_processing_metrics(1).await; + assert_eq!(metrics.total(), std::time::Duration::ZERO); + + let new_metrics = ProcessingMetrics { + tx_sending_duration: std::time::Duration::from_secs(5), + ..Default::default() + }; + tracker.add_processing_metrics(1, new_metrics).await; + let metrics = tracker.get_processing_metrics(1).await; + assert_eq!( + metrics.tx_sending_duration, + std::time::Duration::from_secs(5) + ); + } + + #[tokio::test] + async fn epoch_tracker_cleanup() { + let tracker = EpochTracker::new(); + let _guard = tracker.try_claim_epoch(1); + tracker.increment_processed_items_count(1, 10).await; + tracker.get_or_create_tracker(1, 100); + + tracker.cleanup(1).await; + + // After cleanup, items and metrics are gone + assert_eq!(tracker.get_processed_items_count(1).await, 0); + assert_eq!( + tracker.get_processing_metrics(1).await.total(), + std::time::Duration::ZERO + ); + } +} diff --git a/forester/src/epoch_manager/v1.rs b/forester/src/epoch_manager/v1.rs new file mode 100644 index 0000000000..2e2020d876 --- /dev/null +++ b/forester/src/epoch_manager/v1.rs @@ -0,0 +1,221 @@ +//! V1 tree processing and rollover. + +use std::{sync::Arc, time::Duration}; + +use forester_utils::forester_epoch::{Epoch, ForesterSlot, TreeAccounts}; +use light_client::{indexer::Indexer, rpc::Rpc}; +use light_compressed_account::TreeType; +use light_registry::ForesterEpochPda; +use solana_sdk::signature::Keypair; +use tracing::{debug, error, info, warn}; + +use super::EpochManager; +use crate::{ + errors::ForesterError, + processor::v1::{ + config::{BuildTransactionBatchConfig, SendBatchedTransactionsConfig}, + send_transaction::send_batched_transactions, + tx_builder::EpochManagerTransactions, + }, + rollover::{ + is_tree_ready_for_rollover, perform_address_merkle_tree_rollover, + perform_state_merkle_tree_rollover_forester, + }, + transaction_timing::scheduled_v1_batch_timeout, +}; + +impl EpochManager { + pub(crate) async fn process_v1( + &self, + epoch_info: &Epoch, + epoch_pda: &ForesterEpochPda, + tree_accounts: &TreeAccounts, + forester_slot_details: &ForesterSlot, + current_solana_slot: u64, + ) -> std::result::Result { + let slots_remaining = forester_slot_details + .end_solana_slot + .saturating_sub(current_solana_slot); + let Some(remaining_time_timeout) = scheduled_v1_batch_timeout(slots_remaining) else { + debug!( + event = "v1_tree_skipped_low_slot_budget", + run_id = %self.ctx.run_id, + tree = %tree_accounts.merkle_tree, + slots_remaining, + "Skipping V1 tree: not enough scheduled slot budget left to confirm a transaction" + ); + return Ok(0); + }; + + let use_lookup_tables = !self.ctx.address_lookup_tables.is_empty(); + let enable_v1_multi_nullify = self.ctx.config.enable_v1_multi_nullify && use_lookup_tables; + + let batched_tx_config = SendBatchedTransactionsConfig { + num_batches: 1, + build_transaction_batch_config: BuildTransactionBatchConfig { + batch_size: self.ctx.config.transaction_config.legacy_ixs_per_tx as u64, + compute_unit_price: self + .ctx + .config + .transaction_config + .priority_fee_microlamports, + compute_unit_limit: Some(self.ctx.config.transaction_config.cu_limit), + enable_priority_fees: self.ctx.config.transaction_config.enable_priority_fees, + max_concurrent_sends: Some(self.ctx.config.transaction_config.max_concurrent_sends), + }, + queue_config: self.ctx.config.queue_config, + retry_config: light_client::rpc::RetryConfig { + timeout: remaining_time_timeout, + ..self.ctx.config.retry_config + }, + light_slot_length: epoch_pda.protocol_config.slot_length, + confirmation_poll_interval: Duration::from_millis( + self.ctx + .config + .transaction_config + .confirmation_poll_interval_ms, + ), + confirmation_max_attempts: self.ctx.config.transaction_config.confirmation_max_attempts + as usize, + min_queue_items: if use_lookup_tables { + self.ctx.config.min_queue_items + } else { + None + }, + enable_presort: enable_v1_multi_nullify, + work_item_batch_size: self.ctx.config.work_item_batch_size, + }; + + let transaction_builder = Arc::new(EpochManagerTransactions::new( + self.ctx.rpc_pool.clone(), + epoch_info.epoch, + self.tx_cache.clone(), + self.ctx.address_lookup_tables.as_ref().clone(), + enable_v1_multi_nullify, + )); + + let num_sent = send_batched_transactions( + &self.ctx.config.payer_keypair, + &self.ctx.config.derivation_pubkey, + self.ctx.rpc_pool.clone(), + &batched_tx_config, + *tree_accounts, + transaction_builder, + ) + .await?; + + if num_sent > 0 { + debug!( + event = "v1_tree_items_processed", + run_id = %self.ctx.run_id, + tree = %tree_accounts.merkle_tree, + items = num_sent, + "Processed items for V1 tree" + ); + } + + match self.rollover_if_needed(tree_accounts).await { + Ok(_) => Ok(num_sent), + Err(e) => { + error!( + event = "tree_rollover_failed", + run_id = %self.ctx.run_id, + tree = %tree_accounts.merkle_tree, + tree_type = ?tree_accounts.tree_type, + error = ?e, + "Failed to rollover tree" + ); + Err(e.into()) + } + } + } + + async fn rollover_if_needed(&self, tree_account: &TreeAccounts) -> crate::Result<()> { + let mut rpc = self.ctx.rpc_pool.get_connection().await?; + if is_tree_ready_for_rollover(&mut *rpc, tree_account.merkle_tree, tree_account.tree_type) + .await? + { + info!( + event = "tree_rollover_started", + run_id = %self.ctx.run_id, + tree = %tree_account.merkle_tree, + tree_type = ?tree_account.tree_type, + "Starting tree rollover" + ); + self.perform_rollover(tree_account).await?; + } + Ok(()) + } + + async fn perform_rollover(&self, tree_account: &TreeAccounts) -> crate::Result<()> { + let mut rpc = self.ctx.rpc_pool.get_connection().await?; + let (_, current_epoch) = self.ctx.get_current_slot_and_epoch().await?; + + let result = match tree_account.tree_type { + TreeType::AddressV1 => { + let new_nullifier_queue_keypair = Keypair::new(); + let new_merkle_tree_keypair = Keypair::new(); + + let rollover_signature = perform_address_merkle_tree_rollover( + &self.ctx.config.payer_keypair, + &self.ctx.config.derivation_pubkey, + &mut *rpc, + &new_nullifier_queue_keypair, + &new_merkle_tree_keypair, + &tree_account.merkle_tree, + &tree_account.queue, + current_epoch, + ) + .await?; + + info!( + event = "address_tree_rollover_succeeded", + run_id = %self.ctx.run_id, + tree = %tree_account.merkle_tree, + signature = %rollover_signature, + "Address tree rollover succeeded" + ); + Ok(()) + } + TreeType::StateV1 => { + let new_nullifier_queue_keypair = Keypair::new(); + let new_merkle_tree_keypair = Keypair::new(); + let new_cpi_signature_keypair = Keypair::new(); + + let rollover_signature = perform_state_merkle_tree_rollover_forester( + &self.ctx.config.payer_keypair, + &self.ctx.config.derivation_pubkey, + &mut *rpc, + &new_nullifier_queue_keypair, + &new_merkle_tree_keypair, + &new_cpi_signature_keypair, + &tree_account.merkle_tree, + &tree_account.queue, + &solana_program::pubkey::Pubkey::default(), + current_epoch, + ) + .await?; + + info!( + event = "state_tree_rollover_succeeded", + run_id = %self.ctx.run_id, + tree = %tree_account.merkle_tree, + signature = %rollover_signature, + "State tree rollover succeeded" + ); + + Ok(()) + } + _ => Err(ForesterError::InvalidTreeType(tree_account.tree_type)), + }; + + match result { + Ok(_) => debug!( + "{:?} tree rollover completed successfully", + tree_account.tree_type + ), + Err(e) => warn!("{:?} tree rollover failed: {:?}", tree_account.tree_type, e), + } + Ok(()) + } +} diff --git a/forester/src/epoch_manager/v2.rs b/forester/src/epoch_manager/v2.rs new file mode 100644 index 0000000000..42fb9c5927 --- /dev/null +++ b/forester/src/epoch_manager/v2.rs @@ -0,0 +1,630 @@ +//! V2 tree processing, proof cache prewarming, and cached proof sending. + +use std::time::Duration; + +use anyhow::{anyhow, Context}; +use borsh::BorshSerialize; +use forester_utils::forester_epoch::{Epoch, TreeAccounts}; +use light_client::{indexer::Indexer, rpc::Rpc}; +use light_compressed_account::TreeType; +use light_registry::account_compression_cpi::sdk::{ + create_batch_append_instruction, create_batch_nullify_instruction, + create_batch_update_address_tree_instruction, +}; +use solana_sdk::signature::Signer; +use tokio::time::Instant; +use tracing::{debug, info, warn}; + +use super::{context::should_skip_tree, EpochManager}; +use crate::{ + errors::ForesterError, + logging::should_emit_rate_limited_warning, + processor::v2::{BatchInstruction, ProcessingResult}, + slot_tracker::slot_duration, + smart_transaction::{send_smart_transaction, ComputeBudgetConfig, SendSmartTransactionConfig}, + transaction_timing::scheduled_confirmation_deadline, + ForesterEpochInfo, +}; + +impl EpochManager { + pub(crate) async fn process_v2( + &self, + epoch_info: &Epoch, + tree_accounts: &TreeAccounts, + consecutive_eligibility_end: u64, + ) -> std::result::Result { + match tree_accounts.tree_type { + TreeType::StateV2 => { + let processor = self + .processor_pool + .get_or_create_state_processor( + &self.ctx, + epoch_info, + tree_accounts, + self.ops_cache.clone(), + ) + .await?; + + let cache = self + .processor_pool + .get_or_create_proof_cache(tree_accounts.merkle_tree); + + { + let mut proc = processor.lock().await; + proc.update_eligibility(consecutive_eligibility_end); + proc.set_proof_cache(cache); + } + + let mut proc = processor.lock().await; + match proc.process().await { + Ok(res) => Ok(res), + Err(error) if matches!(&error, ForesterError::V2(v2_error) if v2_error.is_constraint()) => + { + warn!( + event = "v2_state_constraint_error", + run_id = %self.ctx.run_id, + tree = %tree_accounts.merkle_tree, + error = %error, + "State processing hit constraint error. Dropping processor to flush cache." + ); + drop(proc); + self.processor_pool + .remove_state_processor(&tree_accounts.merkle_tree); + self.processor_pool + .remove_proof_cache(&tree_accounts.merkle_tree); + Err(error) + } + Err(ForesterError::V2(v2_error)) if v2_error.is_hashchain_mismatch() => { + let warning_key = + format!("v2_state_hashchain_mismatch:{}", tree_accounts.merkle_tree); + if should_emit_rate_limited_warning(warning_key, Duration::from_secs(15)) { + warn!( + event = "v2_state_hashchain_mismatch", + run_id = %self.ctx.run_id, + tree = %tree_accounts.merkle_tree, + error = %v2_error, + "State processing hit hashchain mismatch. Clearing cache and retrying." + ); + } + self.ctx.heartbeat.increment_v2_recoverable_error(); + proc.clear_cache().await; + Ok(ProcessingResult::default()) + } + Err(e) => { + let warning_key = + format!("v2_state_process_failed:{}", tree_accounts.merkle_tree); + if should_emit_rate_limited_warning(warning_key, Duration::from_secs(10)) { + warn!( + event = "v2_state_process_failed_retrying", + run_id = %self.ctx.run_id, + tree = %tree_accounts.merkle_tree, + error = %e, + "Failed to process state queue. Will retry next tick without dropping processor." + ); + } + self.ctx.heartbeat.increment_v2_recoverable_error(); + Ok(ProcessingResult::default()) + } + } + } + TreeType::AddressV2 => { + let processor = self + .processor_pool + .get_or_create_address_processor( + &self.ctx, + epoch_info, + tree_accounts, + self.ops_cache.clone(), + ) + .await?; + + let cache = self + .processor_pool + .get_or_create_proof_cache(tree_accounts.merkle_tree); + + { + let mut proc = processor.lock().await; + proc.update_eligibility(consecutive_eligibility_end); + proc.set_proof_cache(cache); + } + + let mut proc = processor.lock().await; + match proc.process().await { + Ok(res) => Ok(res), + Err(error) if matches!(&error, ForesterError::V2(v2_error) if v2_error.is_constraint()) => + { + warn!( + event = "v2_address_constraint_error", + run_id = %self.ctx.run_id, + tree = %tree_accounts.merkle_tree, + error = %error, + "Address processing hit constraint error. Dropping processor to flush cache." + ); + drop(proc); + self.processor_pool + .remove_address_processor(&tree_accounts.merkle_tree); + self.processor_pool + .remove_proof_cache(&tree_accounts.merkle_tree); + Err(error) + } + Err(ForesterError::V2(v2_error)) if v2_error.is_hashchain_mismatch() => { + let warning_key = format!( + "v2_address_hashchain_mismatch:{}", + tree_accounts.merkle_tree + ); + if should_emit_rate_limited_warning(warning_key, Duration::from_secs(15)) { + warn!( + event = "v2_address_hashchain_mismatch", + run_id = %self.ctx.run_id, + tree = %tree_accounts.merkle_tree, + error = %v2_error, + "Address processing hit hashchain mismatch. Clearing cache and retrying." + ); + } + self.ctx.heartbeat.increment_v2_recoverable_error(); + proc.clear_cache().await; + Ok(ProcessingResult::default()) + } + Err(e) => { + let warning_key = + format!("v2_address_process_failed:{}", tree_accounts.merkle_tree); + if should_emit_rate_limited_warning(warning_key, Duration::from_secs(10)) { + warn!( + event = "v2_address_process_failed_retrying", + run_id = %self.ctx.run_id, + tree = %tree_accounts.merkle_tree, + error = %e, + "Failed to process address queue. Will retry next tick without dropping processor." + ); + } + self.ctx.heartbeat.increment_v2_recoverable_error(); + Ok(ProcessingResult::default()) + } + } + } + _ => { + warn!( + event = "v2_unsupported_tree_type", + run_id = %self.ctx.run_id, + tree_type = ?tree_accounts.tree_type, + "Unsupported tree type for V2 processing" + ); + Ok(ProcessingResult::default()) + } + } + } + + pub(crate) async fn prewarm_all_trees_during_wait( + &self, + epoch_info: &ForesterEpochInfo, + deadline_slot: u64, + ) { + let current_slot = self.ctx.slot_tracker.estimated_current_slot(); + let slots_until_active = deadline_slot.saturating_sub(current_slot); + + let trees = self.trees.lock().await; + let total_v2_state = trees + .iter() + .filter(|t| matches!(t.tree_type, TreeType::StateV2)) + .count(); + let v2_state_trees: Vec<_> = trees + .iter() + .filter(|t| { + matches!(t.tree_type, TreeType::StateV2) + && !should_skip_tree(&self.ctx.config, &t.tree_type) + }) + .cloned() + .collect(); + let skipped_count = total_v2_state - v2_state_trees.len(); + drop(trees); + + if v2_state_trees.is_empty() { + if skipped_count > 0 { + info!( + event = "prewarm_skipped_all_trees_filtered", + run_id = %self.ctx.run_id, + skipped_trees = skipped_count, + "No trees to pre-warm; all StateV2 trees skipped by config" + ); + } + return; + } + + if slots_until_active < 15 { + info!( + event = "prewarm_skipped_not_enough_time", + run_id = %self.ctx.run_id, + slots_until_active, + min_required_slots = 15, + "Skipping pre-warming; not enough slots until active phase" + ); + return; + } + + let prewarm_futures: Vec<_> = v2_state_trees + .iter() + .map(|tree_accounts| { + let tree_pubkey = tree_accounts.merkle_tree; + let epoch_info = epoch_info.clone(); + let tree_accounts = *tree_accounts; + let self_ref = self.clone(); + + async move { + let cache = self_ref + .processor_pool + .get_or_create_proof_cache(tree_pubkey); + + let cache_len = cache.len().await; + if cache_len > 0 && !cache.is_warming().await { + let mut rpc = match self_ref.ctx.rpc_pool.get_connection().await { + Ok(r) => r, + Err(e) => { + warn!( + event = "prewarm_cache_validation_rpc_failed", + run_id = %self_ref.ctx.run_id, + tree = %tree_pubkey, + error = ?e, + "Failed to get RPC for cache validation" + ); + return; + } + }; + if let Ok(current_root) = + self_ref.fetch_current_root(&mut *rpc, &tree_accounts).await + { + info!( + event = "prewarm_skipped_cache_already_warm", + run_id = %self_ref.ctx.run_id, + tree = %tree_pubkey, + cached_proofs = cache_len, + root_prefix = ?¤t_root[..4], + "Tree already has cached proofs from previous epoch; skipping pre-warm" + ); + return; + } + } + + let processor = match self_ref + .processor_pool + .get_or_create_state_processor( + &self_ref.ctx, + &epoch_info.epoch, + &tree_accounts, + self_ref.ops_cache.clone(), + ) + .await + { + Ok(p) => p, + Err(e) => { + warn!( + event = "prewarm_processor_create_failed", + run_id = %self_ref.ctx.run_id, + tree = %tree_pubkey, + error = ?e, + "Failed to create processor for pre-warming tree" + ); + return; + } + }; + + const PREWARM_MAX_BATCHES: usize = 4; + let mut p = processor.lock().await; + match p + .prewarm_from_indexer( + cache.clone(), + light_compressed_account::QueueType::OutputStateV2, + PREWARM_MAX_BATCHES, + ) + .await + { + Ok(result) => { + if result.items_processed > 0 { + info!( + event = "prewarm_tree_completed", + run_id = %self_ref.ctx.run_id, + tree = %tree_pubkey, + items = result.items_processed, + "Pre-warmed items for tree during wait" + ); + self_ref + .epoch_tracker + .add_processing_metrics(epoch_info.epoch.epoch, result.metrics) + .await; + } + } + Err(e) => { + debug!( + "Pre-warming from indexer failed for tree {}: {:?}", + tree_pubkey, e + ); + cache.clear().await; + } + } + } + }) + .collect(); + + let timeout_slots = slots_until_active.saturating_sub(5); + let timeout_duration = + (slot_duration() * timeout_slots as u32).min(Duration::from_secs(30)); + + info!( + event = "prewarm_started", + run_id = %self.ctx.run_id, + trees = v2_state_trees.len(), + skipped_trees = skipped_count, + timeout_ms = timeout_duration.as_millis() as u64, + "Starting pre-warming" + ); + + match tokio::time::timeout(timeout_duration, futures::future::join_all(prewarm_futures)) + .await + { + Ok(_) => { + info!( + event = "prewarm_completed", + run_id = %self.ctx.run_id, + trees = v2_state_trees.len(), + "Completed pre-warming for all trees" + ); + } + Err(_) => { + info!( + event = "prewarm_timed_out", + run_id = %self.ctx.run_id, + timeout_ms = timeout_duration.as_millis() as u64, + "Pre-warming timed out" + ); + } + } + } + + pub(crate) async fn try_send_cached_proofs( + &self, + epoch_info: &Epoch, + tree_accounts: &TreeAccounts, + consecutive_eligibility_end: u64, + ) -> crate::Result> { + let tree_pubkey = tree_accounts.merkle_tree; + + let current_slot = self.ctx.slot_tracker.estimated_current_slot(); + if current_slot >= consecutive_eligibility_end { + debug!( + event = "cached_proofs_skipped_outside_eligibility", + run_id = %self.ctx.run_id, + tree = %tree_pubkey, + current_slot, + eligibility_end_slot = consecutive_eligibility_end, + "Skipping cached proof send because eligibility window has ended" + ); + return Ok(None); + } + + let Some(confirmation_deadline) = scheduled_confirmation_deadline( + consecutive_eligibility_end.saturating_sub(current_slot), + ) else { + debug!( + event = "cached_proofs_skipped_confirmation_budget_exhausted", + run_id = %self.ctx.run_id, + tree = %tree_pubkey, + current_slot, + eligibility_end_slot = consecutive_eligibility_end, + "Skipping cached proofs because not enough eligible slots remain for confirmation" + ); + return Ok(None); + }; + + let cache = match self.processor_pool.get_proof_cache(&tree_pubkey) { + Some(c) => c, + None => return Ok(None), + }; + + if cache.is_warming().await { + debug!( + event = "cached_proofs_skipped_cache_warming", + run_id = %self.ctx.run_id, + tree = %tree_pubkey, + "Skipping cached proofs because cache is still warming" + ); + return Ok(None); + } + + let mut rpc = self.ctx.rpc_pool.get_connection().await?; + let current_root = match self.fetch_current_root(&mut *rpc, tree_accounts).await { + Ok(root) => root, + Err(e) => { + warn!( + event = "cached_proofs_root_fetch_failed", + run_id = %self.ctx.run_id, + tree = %tree_pubkey, + error = ?e, + "Failed to fetch current root for tree" + ); + return Ok(None); + } + }; + + let cached_proofs = match cache.take_if_valid(¤t_root).await { + Some(proofs) => proofs, + None => { + debug!( + event = "cached_proofs_not_available", + run_id = %self.ctx.run_id, + tree = %tree_pubkey, + root_prefix = ?¤t_root[..4], + "No valid cached proofs for tree" + ); + return Ok(None); + } + }; + + if cached_proofs.is_empty() { + return Ok(Some(0)); + } + + info!( + event = "cached_proofs_send_started", + run_id = %self.ctx.run_id, + tree = %tree_pubkey, + proofs = cached_proofs.len(), + root_prefix = ?¤t_root[..4], + "Sending cached proofs for tree" + ); + + let items_sent = self + .send_cached_proofs_as_transactions( + epoch_info, + tree_accounts, + cached_proofs, + confirmation_deadline, + ) + .await?; + + Ok(Some(items_sent)) + } + + async fn fetch_current_root( + &self, + rpc: &mut impl Rpc, + tree_accounts: &TreeAccounts, + ) -> crate::Result<[u8; 32]> { + use light_batched_merkle_tree::merkle_tree::BatchedMerkleTreeAccount; + + let mut account = rpc + .get_account(tree_accounts.merkle_tree) + .await? + .ok_or_else(|| anyhow!("Tree account not found: {}", tree_accounts.merkle_tree))?; + + let tree = match tree_accounts.tree_type { + TreeType::StateV2 => BatchedMerkleTreeAccount::state_from_bytes( + &mut account.data, + &tree_accounts.merkle_tree.into(), + )?, + TreeType::AddressV2 => BatchedMerkleTreeAccount::address_from_bytes( + &mut account.data, + &tree_accounts.merkle_tree.into(), + )?, + _ => return Err(anyhow!("Unsupported tree type for root fetch")), + }; + + let root = tree.root_history.last().copied().unwrap_or([0u8; 32]); + Ok(root) + } + + async fn send_cached_proofs_as_transactions( + &self, + epoch_info: &Epoch, + tree_accounts: &TreeAccounts, + cached_proofs: Vec, + confirmation_deadline: Instant, + ) -> crate::Result { + let mut total_items = 0; + let authority = self.ctx.config.payer_keypair.pubkey(); + let derivation = self.ctx.config.derivation_pubkey; + + const PROOFS_PER_TX: usize = 4; + for chunk in cached_proofs.chunks(PROOFS_PER_TX) { + let mut instructions = Vec::new(); + let mut chunk_items = 0; + + for proof in chunk { + match &proof.instruction { + BatchInstruction::Append(data) => { + for d in data { + let serialized = d + .try_to_vec() + .with_context(|| "Failed to serialize batch append payload")?; + instructions.push(create_batch_append_instruction( + authority, + derivation, + tree_accounts.merkle_tree, + tree_accounts.queue, + epoch_info.epoch, + serialized, + )); + } + } + BatchInstruction::Nullify(data) => { + for d in data { + let serialized = d + .try_to_vec() + .with_context(|| "Failed to serialize batch nullify payload")?; + instructions.push(create_batch_nullify_instruction( + authority, + derivation, + tree_accounts.merkle_tree, + epoch_info.epoch, + serialized, + )); + } + } + BatchInstruction::AddressAppend(data) => { + for d in data { + let serialized = d.try_to_vec().with_context(|| { + "Failed to serialize batch address append payload" + })?; + instructions.push(create_batch_update_address_tree_instruction( + authority, + derivation, + tree_accounts.merkle_tree, + epoch_info.epoch, + serialized, + )); + } + } + } + chunk_items += proof.items; + } + + if !instructions.is_empty() { + let mut rpc = self.ctx.rpc_pool.get_connection().await?; + let priority_fee = self + .ctx + .resolve_tree_priority_fee(&*rpc, epoch_info.epoch, tree_accounts) + .await?; + let instruction_count = instructions.len(); + let payer = self.ctx.config.payer_keypair.pubkey(); + let signers = [&self.ctx.config.payer_keypair]; + match send_smart_transaction( + &mut *rpc, + SendSmartTransactionConfig { + instructions, + payer: &payer, + signers: &signers, + address_lookup_tables: &self.ctx.address_lookup_tables, + compute_budget: ComputeBudgetConfig { + compute_unit_price: priority_fee, + compute_unit_limit: Some(self.ctx.config.transaction_config.cu_limit), + }, + confirmation: Some(self.ctx.confirmation_config()), + confirmation_deadline: Some(confirmation_deadline), + }, + ) + .await + .map_err(light_client::rpc::RpcError::from) + { + Ok(sig) => { + info!( + event = "cached_proofs_tx_sent", + run_id = %self.ctx.run_id, + signature = %sig, + instruction_count, + "Sent cached proofs transaction" + ); + total_items += chunk_items; + } + Err(e) => { + warn!( + event = "cached_proofs_tx_send_failed", + run_id = %self.ctx.run_id, + error = ?e, + "Failed to send cached proofs transaction" + ); + } + } + } + } + + Ok(total_items) + } +} diff --git a/forester/src/forester_status.rs b/forester/src/forester_status.rs index 80c4539075..d8f958134d 100644 --- a/forester/src/forester_status.rs +++ b/forester/src/forester_status.rs @@ -670,20 +670,22 @@ fn parse_tree_status( let fullness = next_index as f64 / capacity as f64 * 100.0; let (queue_len, queue_cap) = queue_account - .map(|acc| { - unsafe { parse_hash_set_from_bytes::(&acc.data) } - .ok() - .map(|hs| { + .map( + |acc| match unsafe { parse_hash_set_from_bytes::(&acc.data) } { + Ok(hs) => { let len = hs .iter() .filter(|(_, cell)| cell.sequence_number.is_none()) .count() as u64; let cap = hs.get_capacity() as u64; - (len, cap) - }) - .unwrap_or((0, 0)) - }) - .map(|(l, c)| (Some(l), Some(c))) + (Some(len), Some(cap)) + } + Err(error) => { + warn!(?error, "Failed to parse StateV1 queue hash set"); + (None, None) + } + }, + ) .unwrap_or((None, None)); ( @@ -725,20 +727,22 @@ fn parse_tree_status( let fullness = next_index as f64 / capacity as f64 * 100.0; let (queue_len, queue_cap) = queue_account - .map(|acc| { - unsafe { parse_hash_set_from_bytes::(&acc.data) } - .ok() - .map(|hs| { + .map( + |acc| match unsafe { parse_hash_set_from_bytes::(&acc.data) } { + Ok(hs) => { let len = hs .iter() .filter(|(_, cell)| cell.sequence_number.is_none()) .count() as u64; let cap = hs.get_capacity() as u64; - (len, cap) - }) - .unwrap_or((0, 0)) - }) - .map(|(l, c)| (Some(l), Some(c))) + (Some(len), Some(cap)) + } + Err(error) => { + warn!(?error, "Failed to parse AddressV1 queue hash set"); + (None, None) + } + }, + ) .unwrap_or((None, None)); ( diff --git a/forester/src/metrics.rs b/forester/src/metrics.rs index a7be12dd83..ece7d6a577 100644 --- a/forester/src/metrics.rs +++ b/forester/src/metrics.rs @@ -440,21 +440,19 @@ pub async fn metrics_handler() -> Result { if let Err(e) = encoder.encode(®ISTRY.gather(), &mut buffer) { error!("could not encode custom metrics: {}", e); }; - let mut res = String::from_utf8(buffer.clone()).unwrap_or_else(|e| { + let mut res = String::from_utf8(buffer).unwrap_or_else(|e| { error!("custom metrics could not be from_utf8'd: {}", e); String::new() }); - buffer.clear(); let mut buffer = Vec::new(); if let Err(e) = encoder.encode(&prometheus::gather(), &mut buffer) { error!("could not encode prometheus metrics: {}", e); }; - let res_prometheus = String::from_utf8(buffer.clone()).unwrap_or_else(|e| { + let res_prometheus = String::from_utf8(buffer).unwrap_or_else(|e| { error!("prometheus metrics could not be from_utf8'd: {}", e); String::new() }); - buffer.clear(); res.push_str(&res_prometheus); Ok(res) diff --git a/forester/src/priority_fee.rs b/forester/src/priority_fee.rs index c788712e1c..0b9ee4049c 100644 --- a/forester/src/priority_fee.rs +++ b/forester/src/priority_fee.rs @@ -208,7 +208,7 @@ pub async fn request_priority_fee_estimate( .map_err(|error| PriorityFeeEstimateError::ClientBuild(error.clone()))?; let response = http_client - .post(url.clone()) + .post(url.as_str()) .header("Content-Type", "application/json") .json(&rpc_request) .send() diff --git a/forester/src/processor/v2/helpers.rs b/forester/src/processor/v2/helpers.rs deleted file mode 100644 index 66b574dc4e..0000000000 --- a/forester/src/processor/v2/helpers.rs +++ /dev/null @@ -1,797 +0,0 @@ -use std::{ - collections::{HashMap, HashSet}, - sync::{Arc, Condvar, Mutex, MutexGuard}, -}; - -use anyhow::anyhow; -use light_batched_merkle_tree::merkle_tree::BatchedMerkleTreeAccount; -use light_client::{ - indexer::{AddressQueueData, Indexer, QueueElementsV2Options, StateQueueData}, - rpc::Rpc, -}; -use light_hasher::hash_chain::create_hash_chain_from_slice; - -use crate::processor::v2::{common::clamp_to_u16, BatchContext}; - -pub(crate) fn lock_recover<'a, T>(mutex: &'a Mutex, name: &'static str) -> MutexGuard<'a, T> { - match mutex.lock() { - Ok(guard) => guard, - Err(poisoned) => { - tracing::warn!("Poisoned mutex (recovering): {}", name); - poisoned.into_inner() - } - } -} - -#[derive(Debug, Clone)] -pub struct AddressBatchSnapshot { - pub addresses: Vec<[u8; 32]>, - pub low_element_values: Vec<[u8; 32]>, - pub low_element_next_values: Vec<[u8; 32]>, - pub low_element_indices: Vec, - pub low_element_next_indices: Vec, - pub low_element_proofs: Vec<[[u8; 32]; HEIGHT]>, - pub leaves_hashchain: [u8; 32], -} - -pub async fn fetch_zkp_batch_size(context: &BatchContext) -> crate::Result { - let rpc = context.rpc_pool.get_connection().await?; - let mut account = rpc - .get_account(context.merkle_tree) - .await? - .ok_or_else(|| anyhow!("Merkle tree account not found"))?; - - let tree = BatchedMerkleTreeAccount::state_from_bytes( - account.data.as_mut_slice(), - &context.merkle_tree.into(), - )?; - - let batch_index = tree.queue_batches.pending_batch_index; - let batch = tree - .queue_batches - .batches - .get(batch_index as usize) - .ok_or_else(|| anyhow!("Batch not found"))?; - - Ok(batch.zkp_batch_size) -} - -pub async fn fetch_onchain_state_root( - context: &BatchContext, -) -> crate::Result<[u8; 32]> { - let rpc = context.rpc_pool.get_connection().await?; - let mut account = rpc - .get_account(context.merkle_tree) - .await? - .ok_or_else(|| anyhow!("Merkle tree account not found"))?; - - let tree = BatchedMerkleTreeAccount::state_from_bytes( - account.data.as_mut_slice(), - &context.merkle_tree.into(), - )?; - - // Get the current root (last entry in root_history) - let root = tree - .root_history - .last() - .copied() - .ok_or_else(|| anyhow!("Root history is empty"))?; - - Ok(root) -} - -pub async fn fetch_address_zkp_batch_size(context: &BatchContext) -> crate::Result { - let rpc = context.rpc_pool.get_connection().await?; - let mut account = rpc - .get_account(context.merkle_tree) - .await? - .ok_or_else(|| anyhow!("Merkle tree account not found"))?; - - let tree = BatchedMerkleTreeAccount::address_from_bytes( - account.data.as_mut_slice(), - &context.merkle_tree.into(), - ) - .map_err(|e| anyhow!("Failed to deserialize address tree: {}", e))?; - - let batch_index = tree.queue_batches.pending_batch_index; - let batch = tree - .queue_batches - .batches - .get(batch_index as usize) - .ok_or_else(|| anyhow!("Batch not found"))?; - - Ok(batch.zkp_batch_size) -} - -pub async fn fetch_onchain_address_root( - context: &BatchContext, -) -> crate::Result<[u8; 32]> { - let rpc = context.rpc_pool.get_connection().await?; - let mut account = rpc - .get_account(context.merkle_tree) - .await? - .ok_or_else(|| anyhow!("Merkle tree account not found"))?; - - let tree = BatchedMerkleTreeAccount::address_from_bytes( - account.data.as_mut_slice(), - &context.merkle_tree.into(), - ) - .map_err(|e| anyhow!("Failed to deserialize address tree: {}", e))?; - - let root = tree - .root_history - .last() - .copied() - .ok_or_else(|| anyhow!("Root history is empty"))?; - - Ok(root) -} - -const INDEXER_FETCH_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60); -const ADDRESS_INDEXER_FETCH_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120); -const PAGE_SIZE_BATCHES: u64 = 5; -const ADDRESS_PAGE_SIZE_BATCHES: u64 = 5; - -pub async fn fetch_paginated_batches( - context: &BatchContext, - total_elements: u64, - zkp_batch_size: u64, -) -> crate::Result> { - if zkp_batch_size == 0 { - return Err(anyhow::anyhow!("zkp_batch_size cannot be zero")); - } - if total_elements == 0 { - return Ok(None); - } - - let page_size_elements = PAGE_SIZE_BATCHES * zkp_batch_size; - if total_elements <= page_size_elements { - tracing::debug!( - "fetch_paginated_batches: single page fetch with start_index=None, total_elements={}, page_size={}", - total_elements, page_size_elements - ); - return fetch_batches(context, None, None, total_elements, zkp_batch_size).await; - } - - let num_pages = total_elements.div_ceil(page_size_elements) as usize; - tracing::debug!( - "Parallel fetch: {} elements ({} batches) in {} pages of {} batches each", - total_elements, - total_elements / zkp_batch_size, - num_pages, - PAGE_SIZE_BATCHES - ); - - // Fetch first page with start_index=None to discover the actual first_queue_index - // (queue may have been pruned, so indices don't start at 0) - let first_page = fetch_batches(context, None, None, page_size_elements, zkp_batch_size).await?; - - let Some(first_page_data) = first_page else { - return Ok(None); - }; - - // Get the actual starting indices from the first page response - // IMPORTANT: Only use first_queue_index if the queue actually has elements. - // When queue is empty, photon returns default first_queue_index=0, which would - // cause subsequent pages to request start_index=2500 even though the actual - // queue might start at 149500 (if elements arrive between requests). - let output_first_index: Option = first_page_data - .output_queue - .as_ref() - .filter(|oq| !oq.leaf_indices.is_empty()) - .map(|oq| oq.first_queue_index); - let input_first_index: Option = first_page_data - .input_queue - .as_ref() - .filter(|iq| !iq.leaf_indices.is_empty()) - .map(|iq| iq.first_queue_index); - - tracing::debug!( - "First page fetched: output_first_index={:?}, input_first_index={:?}", - output_first_index, - input_first_index - ); - - // If only one page needed, return the first page result - if num_pages == 1 { - return Ok(Some(first_page_data)); - } - - // Fetch remaining pages in parallel with offsets relative to first_queue_index - // Only request queues for which we have valid first_queue_index from the first page - let mut fetch_futures = Vec::with_capacity(num_pages - 1); - let mut offset = page_size_elements; - - for _page_idx in 1..num_pages { - let page_size = (total_elements - offset).min(page_size_elements); - // Only use Some(index) for queues we actually got data for in the first page - // If first page had no data for a queue, we don't know its first_queue_index - let output_start = output_first_index.map(|idx| idx + offset); - let input_start = input_first_index.map(|idx| idx + offset); - - let ctx = context.clone(); - - fetch_futures.push(async move { - fetch_batches(&ctx, output_start, input_start, page_size, zkp_batch_size).await - }); - - offset += page_size; - } - - let results = futures::future::join_all(fetch_futures).await; - - // Initialize with first page data - let initial_root = first_page_data.initial_root; - let root_seq = first_page_data.root_seq; - let mut nodes_map: HashMap = HashMap::new(); - for (&idx, &hash) in first_page_data - .nodes - .iter() - .zip(first_page_data.node_hashes.iter()) - { - nodes_map.insert(idx, hash); - } - let mut output_queue = first_page_data.output_queue; - let mut input_queue = first_page_data.input_queue; - - // Merge remaining pages - for (page_idx, result) in results.into_iter().enumerate() { - let page = match result? { - Some(data) => data, - None => continue, - }; - - if page.initial_root != initial_root { - tracing::warn!( - "Page {} has different root ({:?} vs {:?}), stopping merge", - page_idx + 1, - &page.initial_root[..4], - &initial_root[..4] - ); - break; - } - - for (&idx, &hash) in page.nodes.iter().zip(page.node_hashes.iter()) { - nodes_map.entry(idx).or_insert(hash); - } - - if let Some(page_oq) = page.output_queue { - if let Some(ref mut oq) = output_queue { - oq.leaf_indices.extend(page_oq.leaf_indices); - oq.account_hashes.extend(page_oq.account_hashes); - oq.old_leaves.extend(page_oq.old_leaves); - oq.leaves_hash_chains.extend(page_oq.leaves_hash_chains); - } else { - output_queue = Some(page_oq); - } - } - - if let Some(page_iq) = page.input_queue { - if let Some(ref mut iq) = input_queue { - iq.leaf_indices.extend(page_iq.leaf_indices); - iq.account_hashes.extend(page_iq.account_hashes); - iq.current_leaves.extend(page_iq.current_leaves); - iq.tx_hashes.extend(page_iq.tx_hashes); - iq.nullifiers.extend(page_iq.nullifiers); - iq.leaves_hash_chains.extend(page_iq.leaves_hash_chains); - } else { - input_queue = Some(page_iq); - } - } - } - - let mut nodes_vec: Vec<_> = nodes_map.into_iter().collect(); - nodes_vec.sort_by_key(|(idx, _)| *idx); - let (nodes, node_hashes): (Vec<_>, Vec<_>) = nodes_vec.into_iter().unzip(); - - tracing::debug!( - "Parallel fetch complete: {} nodes, output={}, input={}", - nodes.len(), - output_queue - .as_ref() - .map(|oq| oq.leaf_indices.len()) - .unwrap_or(0), - input_queue - .as_ref() - .map(|iq| iq.leaf_indices.len()) - .unwrap_or(0) - ); - - Ok(Some(StateQueueData { - nodes, - node_hashes, - initial_root, - root_seq, - output_queue, - input_queue, - })) -} - -pub async fn fetch_batches( - context: &BatchContext, - output_start_index: Option, - input_start_index: Option, - fetch_len: u64, - zkp_batch_size: u64, -) -> crate::Result> { - tracing::debug!( - "fetch_batches: tree={}, output_start={:?}, input_start={:?}, fetch_len={}, zkp_batch_size={}", - context.merkle_tree, output_start_index, input_start_index, fetch_len, zkp_batch_size - ); - - let fetch_len_u16 = clamp_to_u16(fetch_len, "fetch_len"); - let zkp_batch_size_u16 = clamp_to_u16(zkp_batch_size, "zkp_batch_size"); - - let mut rpc = context.rpc_pool.get_connection().await?; - let indexer = rpc.indexer_mut()?; - let options = QueueElementsV2Options::default() - .with_output_queue(output_start_index, Some(fetch_len_u16)) - .with_output_queue_batch_size(Some(zkp_batch_size_u16)) - .with_input_queue(input_start_index, Some(fetch_len_u16)) - .with_input_queue_batch_size(Some(zkp_batch_size_u16)); - - let fetch_future = indexer.get_queue_elements(context.merkle_tree.to_bytes(), options, None); - - let res = match tokio::time::timeout(INDEXER_FETCH_TIMEOUT, fetch_future).await { - Ok(result) => result?, - Err(_) => { - tracing::warn!( - "fetch_batches timed out after {:?} for tree {}", - INDEXER_FETCH_TIMEOUT, - context.merkle_tree - ); - return Err(anyhow::anyhow!( - "Indexer fetch timed out after {:?} for state tree {}", - INDEXER_FETCH_TIMEOUT, - context.merkle_tree - )); - } - }; - - Ok(res.value.state_queue) -} - -pub async fn fetch_address_batches( - context: &BatchContext, - output_start_index: Option, - fetch_len: u64, - zkp_batch_size: u64, -) -> crate::Result> { - let fetch_len_u16 = clamp_to_u16(fetch_len, "fetch_len"); - let zkp_batch_size_u16 = clamp_to_u16(zkp_batch_size, "zkp_batch_size"); - - let mut rpc = context.rpc_pool.get_connection().await?; - let indexer = rpc.indexer_mut()?; - - let options = QueueElementsV2Options::default() - .with_address_queue(output_start_index, Some(fetch_len_u16)) - .with_address_queue_batch_size(Some(zkp_batch_size_u16)); - - tracing::debug!( - "fetch_address_batches: tree={}, start={:?}, len={}, zkp_batch_size={}", - context.merkle_tree, - output_start_index, - fetch_len_u16, - zkp_batch_size_u16 - ); - - let fetch_future = indexer.get_queue_elements(context.merkle_tree.to_bytes(), options, None); - - let res = match tokio::time::timeout(ADDRESS_INDEXER_FETCH_TIMEOUT, fetch_future).await { - Ok(result) => result?, - Err(_) => { - tracing::warn!( - "fetch_address_batches timed out after {:?} for tree {}", - ADDRESS_INDEXER_FETCH_TIMEOUT, - context.merkle_tree - ); - return Err(anyhow::anyhow!( - "Indexer fetch timed out after {:?} for address tree {}", - ADDRESS_INDEXER_FETCH_TIMEOUT, - context.merkle_tree - )); - } - }; - - if let Some(ref aq) = res.value.address_queue { - tracing::debug!( - "fetch_address_batches response: address_queue present = true, addresses={}, subtrees={}, leaves_hash_chains={}, start_index={}", - aq.addresses.len(), - aq.subtrees.len(), - aq.leaves_hash_chains.len(), - aq.start_index - ); - } else { - tracing::debug!("fetch_address_batches response: address_queue present = false"); - } - - Ok(res.value.address_queue) -} - -/// Streams address queue data by fetching pages in the background. -/// -/// The first page is fetched synchronously, then subsequent pages are fetched -/// in a background task. Consumers can access data as it becomes available -/// without waiting for the entire fetch to complete. -#[derive(Debug)] -pub struct StreamingAddressQueue { - /// The accumulated address queue data from all fetched pages. - pub data: Arc>, - - /// Number of elements currently available for processing. - /// Paired with `data_ready` condvar for signaling new data. - available_elements: Arc>, - - /// Signaled when new elements become available. - /// Paired with `available_elements` mutex. - data_ready: Arc, - - /// Whether the background fetch has completed (all pages fetched or error). - /// Paired with `fetch_complete_condvar` for signaling completion. - fetch_complete: Arc>, - - /// Signaled when background fetch completes. - /// Paired with `fetch_complete` mutex. - fetch_complete_condvar: Arc, - - /// Number of elements per ZKP batch, used for batch boundary calculations. - zkp_batch_size: usize, -} - -impl StreamingAddressQueue { - /// Waits until at least `batch_end` elements are available or fetch completes. - /// - /// Uses a polling loop to avoid race conditions between the available_elements - /// and fetch_complete mutexes. Returns the number of available elements. - pub fn wait_for_batch(&self, batch_end: usize) -> usize { - const POLL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120); - let start = std::time::Instant::now(); - - loop { - let available = *lock_recover( - &self.available_elements, - "streaming_address_queue.available_elements", - ); - if available >= batch_end { - return available; - } - - let complete = *lock_recover( - &self.fetch_complete, - "streaming_address_queue.fetch_complete", - ); - if complete { - return available; - } - - if start.elapsed() > POLL_TIMEOUT { - tracing::warn!( - "wait_for_batch timed out after {:?} waiting for {} elements (available: {})", - POLL_TIMEOUT, - batch_end, - available - ); - return available; - } - - // Use condvar wait with timeout instead of thread::sleep to avoid - // blocking the thread and to wake up promptly when data arrives. - let guard = lock_recover( - &self.available_elements, - "streaming_address_queue.available_elements", - ); - let _ = self - .data_ready - .wait_timeout(guard, std::time::Duration::from_millis(50)); - } - } - - pub fn get_batch_snapshot( - &self, - start: usize, - end: usize, - hashchain_idx: usize, - ) -> crate::Result>> { - let available = self.wait_for_batch(end); - if available < end || start >= end { - return Ok(None); - } - let actual_end = end; - let data = lock_recover(&self.data, "streaming_address_queue.data"); - - let min_len = [ - data.addresses.len(), - data.low_element_values.len(), - data.low_element_next_values.len(), - data.low_element_indices.len(), - data.low_element_next_indices.len(), - ] - .into_iter() - .min() - .unwrap_or(0); - if min_len < actual_end { - return Err(anyhow!( - "incomplete batch data: min field length {} < required end {}", - min_len, - actual_end - )); - } - - let addresses = data.addresses[start..actual_end].to_vec(); - if addresses.is_empty() { - return Err(anyhow!("Empty batch at start={}", start)); - } - - let leaves_hashchain = match data.leaves_hash_chains.get(hashchain_idx).copied() { - Some(hashchain) => hashchain, - None => { - tracing::debug!( - "Missing leaves_hash_chain for batch {} (available: {}), deriving from addresses", - hashchain_idx, - data.leaves_hash_chains.len() - ); - create_hash_chain_from_slice(&addresses).map_err(|error| { - anyhow!( - "Failed to derive leaves_hash_chain for batch {} from {} addresses: {}", - hashchain_idx, - addresses.len(), - error - ) - })? - } - }; - - Ok(Some(AddressBatchSnapshot { - low_element_values: data.low_element_values[start..actual_end].to_vec(), - low_element_next_values: data.low_element_next_values[start..actual_end].to_vec(), - low_element_indices: data.low_element_indices[start..actual_end].to_vec(), - low_element_next_indices: data.low_element_next_indices[start..actual_end].to_vec(), - low_element_proofs: data - .reconstruct_proofs::(start..actual_end) - .map_err(|error| { - anyhow!("incomplete batch data: failed to reconstruct proofs: {error}") - })?, - addresses, - leaves_hashchain, - })) - } - - pub fn into_data(self) -> AddressQueueData { - let mut complete = lock_recover( - &self.fetch_complete, - "streaming_address_queue.fetch_complete", - ); - while !*complete { - complete = match self.fetch_complete_condvar.wait_while(complete, |c| !*c) { - Ok(guard) => guard, - Err(poisoned) => { - tracing::warn!("Poisoned mutex while waiting (recovering): streaming_address_queue.fetch_complete"); - poisoned.into_inner() - } - }; - } - drop(complete); - match Arc::try_unwrap(self.data) { - Ok(mutex) => mutex.into_inner().unwrap_or_else(|poisoned| { - tracing::warn!("Poisoned mutex during into_data (recovering)"); - poisoned.into_inner() - }), - Err(arc) => lock_recover(arc.as_ref(), "streaming_address_queue.data_clone").clone(), - } - } - - pub fn initial_root(&self) -> [u8; 32] { - lock_recover(&self.data, "streaming_address_queue.data").initial_root - } - - pub fn start_index(&self) -> u64 { - lock_recover(&self.data, "streaming_address_queue.data").start_index - } - - pub fn tree_next_insertion_index(&self) -> u64 { - lock_recover(&self.data, "streaming_address_queue.data").tree_next_insertion_index - } - - pub fn subtrees(&self) -> Vec<[u8; 32]> { - lock_recover(&self.data, "streaming_address_queue.data") - .subtrees - .clone() - } - - pub fn root_seq(&self) -> u64 { - lock_recover(&self.data, "streaming_address_queue.data").root_seq - } - - pub fn available_batches(&self) -> usize { - debug_assert!(self.zkp_batch_size != 0, "zkp_batch_size must be non-zero"); - if self.zkp_batch_size == 0 { - tracing::error!("zkp_batch_size is zero, returning 0 batches to avoid panic"); - return 0; - } - let available = *lock_recover( - &self.available_elements, - "streaming_address_queue.available_elements", - ); - available / self.zkp_batch_size - } - - pub fn is_complete(&self) -> bool { - *lock_recover( - &self.fetch_complete, - "streaming_address_queue.fetch_complete", - ) - } -} - -pub async fn fetch_streaming_address_batches( - context: &BatchContext, - total_elements: u64, - zkp_batch_size: u64, -) -> crate::Result> { - if total_elements == 0 { - return Ok(None); - } - - let page_size_elements = ADDRESS_PAGE_SIZE_BATCHES * zkp_batch_size; - let num_pages = total_elements.div_ceil(page_size_elements) as usize; - - tracing::debug!( - "address fetch: {} elements ({} batches) in {} pages of {} batches each", - total_elements, - total_elements / zkp_batch_size, - num_pages, - ADDRESS_PAGE_SIZE_BATCHES - ); - - let first_page_size = page_size_elements.min(total_elements); - let first_page = - match fetch_address_batches(context, None, first_page_size, zkp_batch_size).await? { - Some(data) if !data.addresses.is_empty() => data, - _ => return Ok(None), - }; - - let initial_elements = first_page.addresses.len(); - let first_page_requested = first_page_size as usize; - - let queue_exhausted = initial_elements < first_page_requested; - - tracing::info!( - "First page fetched: {} addresses ({} batches ready), root={:?}[..4], queue_exhausted={}", - initial_elements, - initial_elements / zkp_batch_size as usize, - &first_page.initial_root[..4], - queue_exhausted - ); - - let streaming = StreamingAddressQueue { - data: Arc::new(Mutex::new(first_page)), - available_elements: Arc::new(Mutex::new(initial_elements)), - data_ready: Arc::new(Condvar::new()), - fetch_complete: Arc::new(Mutex::new(num_pages == 1 || queue_exhausted)), - fetch_complete_condvar: Arc::new(Condvar::new()), - zkp_batch_size: zkp_batch_size as usize, - }; - - if num_pages == 1 || queue_exhausted { - return Ok(Some(streaming)); - } - - let data = Arc::clone(&streaming.data); - let available = Arc::clone(&streaming.available_elements); - let ready = Arc::clone(&streaming.data_ready); - let complete = Arc::clone(&streaming.fetch_complete); - let complete_condvar = Arc::clone(&streaming.fetch_complete_condvar); - let ctx = context.clone(); - let initial_root = streaming.initial_root(); - - // Get the start_index from the first page to calculate offsets for subsequent pages - let first_page_start_index = streaming.start_index(); - - tokio::spawn(async move { - let mut offset = first_page_size; - - for page_idx in 1..num_pages { - let page_size = (total_elements - offset).min(page_size_elements); - // Use absolute index: first page's start_index + relative offset - let absolute_start = Some(first_page_start_index + offset); - - tracing::debug!( - "Fetching address page {}/{}: absolute_start={:?}, size={}", - page_idx + 1, - num_pages, - absolute_start, - page_size - ); - - match fetch_address_batches(&ctx, absolute_start, page_size, zkp_batch_size).await { - Ok(Some(page)) => { - if page.initial_root != initial_root { - tracing::warn!( - "Address page {} has different root ({:?} vs {:?}), stopping fetch", - page_idx + 1, - &page.initial_root[..4], - &initial_root[..4] - ); - break; - } - - let page_elements = page.addresses.len(); - let page_requested = page_size as usize; - - { - let mut data_guard = - lock_recover(data.as_ref(), "streaming_address_queue.data"); - data_guard.addresses.extend(page.addresses); - data_guard - .low_element_values - .extend(page.low_element_values); - data_guard - .low_element_next_values - .extend(page.low_element_next_values); - data_guard - .low_element_indices - .extend(page.low_element_indices); - data_guard - .low_element_next_indices - .extend(page.low_element_next_indices); - data_guard - .leaves_hash_chains - .extend(page.leaves_hash_chains); - let mut seen: HashSet = data_guard.nodes.iter().copied().collect(); - for (&idx, &hash) in page.nodes.iter().zip(page.node_hashes.iter()) { - if seen.insert(idx) { - data_guard.nodes.push(idx); - data_guard.node_hashes.push(hash); - } - } - } - - { - let mut avail = lock_recover( - available.as_ref(), - "streaming_address_queue.available_elements", - ); - *avail += page_elements; - tracing::debug!( - "Page {} fetched: {} elements, total available: {} ({} batches)", - page_idx + 1, - page_elements, - *avail, - *avail / zkp_batch_size as usize - ); - } - ready.notify_all(); - - if page_elements < page_requested { - tracing::debug!( - "Page {} returned fewer elements than requested ({} < {}), queue exhausted", - page_idx + 1, page_elements, page_requested - ); - break; - } - } - Ok(None) => { - tracing::debug!("Page {} returned empty, stopping fetch", page_idx + 1); - break; - } - Err(e) => { - tracing::warn!("Error fetching page {}: {}", page_idx + 1, e); - break; - } - } - - offset += page_size; - } - - { - let mut complete_guard = - lock_recover(complete.as_ref(), "streaming_address_queue.fetch_complete"); - *complete_guard = true; - } - ready.notify_all(); - complete_condvar.notify_all(); - tracing::debug!("Background address fetch complete"); - }); - - Ok(Some(streaming)) -} diff --git a/forester/src/processor/v2/indexer_fetch.rs b/forester/src/processor/v2/indexer_fetch.rs new file mode 100644 index 0000000000..ff7ffbbb62 --- /dev/null +++ b/forester/src/processor/v2/indexer_fetch.rs @@ -0,0 +1,292 @@ +//! Indexer batch fetching: paginated state queue and address queue data retrieval. + +use std::collections::HashMap; + +use light_client::{ + indexer::{AddressQueueData, Indexer, QueueElementsV2Options, StateQueueData}, + rpc::Rpc, +}; + +use super::BatchContext; +use crate::processor::v2::common::clamp_to_u16; + +const INDEXER_FETCH_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60); +const ADDRESS_INDEXER_FETCH_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120); +const PAGE_SIZE_BATCHES: u64 = 5; +pub(super) const ADDRESS_PAGE_SIZE_BATCHES: u64 = 5; + +pub async fn fetch_paginated_batches( + context: &BatchContext, + total_elements: u64, + zkp_batch_size: u64, +) -> crate::Result> { + if zkp_batch_size == 0 { + return Err(anyhow::anyhow!("zkp_batch_size cannot be zero")); + } + if total_elements == 0 { + return Ok(None); + } + + let page_size_elements = PAGE_SIZE_BATCHES * zkp_batch_size; + if total_elements <= page_size_elements { + tracing::debug!( + "fetch_paginated_batches: single page fetch with start_index=None, total_elements={}, page_size={}", + total_elements, page_size_elements + ); + return fetch_batches(context, None, None, total_elements, zkp_batch_size).await; + } + + let num_pages = total_elements.div_ceil(page_size_elements) as usize; + tracing::debug!( + "Parallel fetch: {} elements ({} batches) in {} pages of {} batches each", + total_elements, + total_elements / zkp_batch_size, + num_pages, + PAGE_SIZE_BATCHES + ); + + // Fetch first page with start_index=None to discover the actual first_queue_index + // (queue may have been pruned, so indices don't start at 0) + let first_page = fetch_batches(context, None, None, page_size_elements, zkp_batch_size).await?; + + let Some(first_page_data) = first_page else { + return Ok(None); + }; + + // Get the actual starting indices from the first page response + // IMPORTANT: Only use first_queue_index if the queue actually has elements. + // When queue is empty, photon returns default first_queue_index=0, which would + // cause subsequent pages to request start_index=2500 even though the actual + // queue might start at 149500 (if elements arrive between requests). + let output_first_index: Option = first_page_data + .output_queue + .as_ref() + .filter(|oq| !oq.leaf_indices.is_empty()) + .map(|oq| oq.first_queue_index); + let input_first_index: Option = first_page_data + .input_queue + .as_ref() + .filter(|iq| !iq.leaf_indices.is_empty()) + .map(|iq| iq.first_queue_index); + + tracing::debug!( + "First page fetched: output_first_index={:?}, input_first_index={:?}", + output_first_index, + input_first_index + ); + + // If only one page needed, return the first page result + if num_pages == 1 { + return Ok(Some(first_page_data)); + } + + // Fetch remaining pages in parallel with offsets relative to first_queue_index + // Only request queues for which we have valid first_queue_index from the first page + let mut fetch_futures = Vec::with_capacity(num_pages - 1); + let mut offset = page_size_elements; + + for _page_idx in 1..num_pages { + let page_size = (total_elements - offset).min(page_size_elements); + // Only use Some(index) for queues we actually got data for in the first page + // If first page had no data for a queue, we don't know its first_queue_index + let output_start = output_first_index.map(|idx| idx + offset); + let input_start = input_first_index.map(|idx| idx + offset); + + let ctx = context.clone(); + + fetch_futures.push(async move { + fetch_batches(&ctx, output_start, input_start, page_size, zkp_batch_size).await + }); + + offset += page_size; + } + + let results = futures::future::join_all(fetch_futures).await; + + // Initialize with first page data + let initial_root = first_page_data.initial_root; + let root_seq = first_page_data.root_seq; + let mut nodes_map: HashMap = HashMap::new(); + for (&idx, &hash) in first_page_data + .nodes + .iter() + .zip(first_page_data.node_hashes.iter()) + { + nodes_map.insert(idx, hash); + } + let mut output_queue = first_page_data.output_queue; + let mut input_queue = first_page_data.input_queue; + + // Merge remaining pages + for (page_idx, result) in results.into_iter().enumerate() { + let page = match result? { + Some(data) => data, + None => continue, + }; + + if page.initial_root != initial_root { + tracing::warn!( + "Page {} has different root ({:?} vs {:?}), stopping merge", + page_idx + 1, + &page.initial_root[..4], + &initial_root[..4] + ); + break; + } + + for (&idx, &hash) in page.nodes.iter().zip(page.node_hashes.iter()) { + nodes_map.entry(idx).or_insert(hash); + } + + if let Some(page_oq) = page.output_queue { + if let Some(ref mut oq) = output_queue { + oq.leaf_indices.extend(page_oq.leaf_indices); + oq.account_hashes.extend(page_oq.account_hashes); + oq.old_leaves.extend(page_oq.old_leaves); + oq.leaves_hash_chains.extend(page_oq.leaves_hash_chains); + } else { + output_queue = Some(page_oq); + } + } + + if let Some(page_iq) = page.input_queue { + if let Some(ref mut iq) = input_queue { + iq.leaf_indices.extend(page_iq.leaf_indices); + iq.account_hashes.extend(page_iq.account_hashes); + iq.current_leaves.extend(page_iq.current_leaves); + iq.tx_hashes.extend(page_iq.tx_hashes); + iq.nullifiers.extend(page_iq.nullifiers); + iq.leaves_hash_chains.extend(page_iq.leaves_hash_chains); + } else { + input_queue = Some(page_iq); + } + } + } + + let mut nodes_vec: Vec<_> = nodes_map.into_iter().collect(); + nodes_vec.sort_by_key(|(idx, _)| *idx); + let (nodes, node_hashes): (Vec<_>, Vec<_>) = nodes_vec.into_iter().unzip(); + + tracing::debug!( + "Parallel fetch complete: {} nodes, output={}, input={}", + nodes.len(), + output_queue + .as_ref() + .map(|oq| oq.leaf_indices.len()) + .unwrap_or(0), + input_queue + .as_ref() + .map(|iq| iq.leaf_indices.len()) + .unwrap_or(0) + ); + + Ok(Some(StateQueueData { + nodes, + node_hashes, + initial_root, + root_seq, + output_queue, + input_queue, + })) +} + +pub async fn fetch_batches( + context: &BatchContext, + output_start_index: Option, + input_start_index: Option, + fetch_len: u64, + zkp_batch_size: u64, +) -> crate::Result> { + tracing::debug!( + "fetch_batches: tree={}, output_start={:?}, input_start={:?}, fetch_len={}, zkp_batch_size={}", + context.merkle_tree, output_start_index, input_start_index, fetch_len, zkp_batch_size + ); + + let fetch_len_u16 = clamp_to_u16(fetch_len, "fetch_len"); + let zkp_batch_size_u16 = clamp_to_u16(zkp_batch_size, "zkp_batch_size"); + + let mut rpc = context.rpc_pool.get_connection().await?; + let indexer = rpc.indexer_mut()?; + let options = QueueElementsV2Options::default() + .with_output_queue(output_start_index, Some(fetch_len_u16)) + .with_output_queue_batch_size(Some(zkp_batch_size_u16)) + .with_input_queue(input_start_index, Some(fetch_len_u16)) + .with_input_queue_batch_size(Some(zkp_batch_size_u16)); + + let fetch_future = indexer.get_queue_elements(context.merkle_tree.to_bytes(), options, None); + + let res = match tokio::time::timeout(INDEXER_FETCH_TIMEOUT, fetch_future).await { + Ok(result) => result?, + Err(_) => { + tracing::warn!( + "fetch_batches timed out after {:?} for tree {}", + INDEXER_FETCH_TIMEOUT, + context.merkle_tree + ); + return Err(anyhow::anyhow!( + "Indexer fetch timed out after {:?} for state tree {}", + INDEXER_FETCH_TIMEOUT, + context.merkle_tree + )); + } + }; + + Ok(res.value.state_queue) +} + +pub async fn fetch_address_batches( + context: &BatchContext, + output_start_index: Option, + fetch_len: u64, + zkp_batch_size: u64, +) -> crate::Result> { + let fetch_len_u16 = clamp_to_u16(fetch_len, "fetch_len"); + let zkp_batch_size_u16 = clamp_to_u16(zkp_batch_size, "zkp_batch_size"); + + let mut rpc = context.rpc_pool.get_connection().await?; + let indexer = rpc.indexer_mut()?; + + let options = QueueElementsV2Options::default() + .with_address_queue(output_start_index, Some(fetch_len_u16)) + .with_address_queue_batch_size(Some(zkp_batch_size_u16)); + + tracing::debug!( + "fetch_address_batches: tree={}, start={:?}, len={}, zkp_batch_size={}", + context.merkle_tree, + output_start_index, + fetch_len_u16, + zkp_batch_size_u16 + ); + + let fetch_future = indexer.get_queue_elements(context.merkle_tree.to_bytes(), options, None); + + let res = match tokio::time::timeout(ADDRESS_INDEXER_FETCH_TIMEOUT, fetch_future).await { + Ok(result) => result?, + Err(_) => { + tracing::warn!( + "fetch_address_batches timed out after {:?} for tree {}", + ADDRESS_INDEXER_FETCH_TIMEOUT, + context.merkle_tree + ); + return Err(anyhow::anyhow!( + "Indexer fetch timed out after {:?} for address tree {}", + ADDRESS_INDEXER_FETCH_TIMEOUT, + context.merkle_tree + )); + } + }; + + if let Some(ref aq) = res.value.address_queue { + tracing::debug!( + "fetch_address_batches response: address_queue present = true, addresses={}, subtrees={}, leaves_hash_chains={}, start_index={}", + aq.addresses.len(), + aq.subtrees.len(), + aq.leaves_hash_chains.len(), + aq.start_index + ); + } else { + tracing::debug!("fetch_address_batches response: address_queue present = false"); + } + + Ok(res.value.address_queue) +} diff --git a/forester/src/processor/v2/mod.rs b/forester/src/processor/v2/mod.rs index ef1fdf10b1..929779937d 100644 --- a/forester/src/processor/v2/mod.rs +++ b/forester/src/processor/v2/mod.rs @@ -1,12 +1,14 @@ mod batch_job_builder; pub mod common; pub mod errors; -mod helpers; +mod indexer_fetch; mod processor; pub mod proof_cache; mod proof_worker; mod root_guard; pub mod strategy; +pub(crate) mod streaming_queue; +mod tree_data; mod tx_sender; pub use common::{BatchContext, ProverConfig}; diff --git a/forester/src/processor/v2/processor.rs b/forester/src/processor/v2/processor.rs index 3de6dea860..3484a8318d 100644 --- a/forester/src/processor/v2/processor.rs +++ b/forester/src/processor/v2/processor.rs @@ -118,6 +118,14 @@ where self.proof_cache = Some(cache); } + fn ensure_worker_pool(&mut self) -> crate::Result<()> { + if self.worker_pool.is_none() { + let job_tx = spawn_proof_workers(&self.context.prover_config)?; + self.worker_pool = Some(WorkerPool { job_tx }); + } + Ok(()) + } + pub async fn process(&mut self) -> std::result::Result { let queue_size = self.zkp_batch_size * self.context.max_batches_per_tree as u64; self.process_queue_update(queue_size).await @@ -131,10 +139,7 @@ where return Ok(ProcessingResult::default()); } - if self.worker_pool.is_none() { - let job_tx = spawn_proof_workers(&self.context.prover_config); - self.worker_pool = Some(WorkerPool { job_tx }); - } + self.ensure_worker_pool()?; if let Some(cached) = self.cached_state.take() { let actual_available = self @@ -531,10 +536,7 @@ where let max_batches = ((queue_size / self.zkp_batch_size) as usize).min(self.context.max_batches_per_tree); - if self.worker_pool.is_none() { - let job_tx = spawn_proof_workers(&self.context.prover_config); - self.worker_pool = Some(WorkerPool { job_tx }); - } + self.ensure_worker_pool()?; let queue_data = match self .strategy @@ -560,10 +562,7 @@ where let max_batches = max_batches.min(self.context.max_batches_per_tree); - if self.worker_pool.is_none() { - let job_tx = spawn_proof_workers(&self.context.prover_config); - self.worker_pool = Some(WorkerPool { job_tx }); - } + self.ensure_worker_pool()?; let queue_data = match self .strategy diff --git a/forester/src/processor/v2/proof_worker.rs b/forester/src/processor/v2/proof_worker.rs index b7afeacf0b..ba9f133aee 100644 --- a/forester/src/processor/v2/proof_worker.rs +++ b/forester/src/processor/v2/proof_worker.rs @@ -164,11 +164,13 @@ impl ProofClients { } } -pub fn spawn_proof_workers(config: &ProverConfig) -> async_channel::Sender { +pub fn spawn_proof_workers( + config: &ProverConfig, +) -> crate::Result> { let (job_tx, job_rx) = async_channel::bounded::(256); let clients = Arc::new(ProofClients::new(config)); tokio::spawn(async move { run_proof_pipeline(job_rx, clients).await }); - job_tx + Ok(job_tx) } async fn run_proof_pipeline( diff --git a/forester/src/processor/v2/strategy/address.rs b/forester/src/processor/v2/strategy/address.rs index 51236c389b..f863e6ba15 100644 --- a/forester/src/processor/v2/strategy/address.rs +++ b/forester/src/processor/v2/strategy/address.rs @@ -15,13 +15,13 @@ use tracing::{debug, info, instrument}; use crate::processor::v2::{ batch_job_builder::BatchJobBuilder, errors::V2Error, - helpers::{ - fetch_address_zkp_batch_size, fetch_onchain_address_root, fetch_streaming_address_batches, - AddressBatchSnapshot, StreamingAddressQueue, - }, proof_worker::ProofInput, root_guard::{reconcile_alignment, AlignmentDecision}, strategy::{CircuitType, QueueData, TreeStrategy}, + streaming_queue::{ + fetch_streaming_address_batches, AddressBatchSnapshot, StreamingAddressQueue, + }, + tree_data::{fetch_address_zkp_batch_size, fetch_onchain_address_root}, BatchContext, }; diff --git a/forester/src/processor/v2/strategy/state.rs b/forester/src/processor/v2/strategy/state.rs index d517c51757..e67042e69e 100644 --- a/forester/src/processor/v2/strategy/state.rs +++ b/forester/src/processor/v2/strategy/state.rs @@ -11,10 +11,11 @@ use tracing::{debug, instrument}; use crate::processor::v2::{ batch_job_builder::BatchJobBuilder, common::{batch_range, get_leaves_hashchain}, - helpers::{fetch_onchain_state_root, fetch_paginated_batches, fetch_zkp_batch_size}, + indexer_fetch::fetch_paginated_batches, proof_worker::ProofInput, root_guard::{reconcile_alignment, AlignmentDecision}, strategy::{CircuitType, QueueData, TreeStrategy}, + tree_data::{fetch_onchain_state_root, fetch_zkp_batch_size}, BatchContext, }; diff --git a/forester/src/processor/v2/streaming_queue.rs b/forester/src/processor/v2/streaming_queue.rs new file mode 100644 index 0000000000..65defa64f7 --- /dev/null +++ b/forester/src/processor/v2/streaming_queue.rs @@ -0,0 +1,429 @@ +//! StreamingAddressQueue: background-fetched address queue with batch-level access. + +use std::{ + collections::HashSet, + sync::{Arc, Condvar, Mutex, MutexGuard}, +}; + +use anyhow::anyhow; +use light_client::{indexer::AddressQueueData, rpc::Rpc}; +use light_hasher::hash_chain::create_hash_chain_from_slice; + +use super::{ + indexer_fetch::{fetch_address_batches, ADDRESS_PAGE_SIZE_BATCHES}, + BatchContext, +}; +use crate::logging::should_emit_rate_limited_warning; + +fn lock_recover<'a, T>(mutex: &'a Mutex, name: &'static str) -> MutexGuard<'a, T> { + match mutex.lock() { + Ok(guard) => guard, + Err(poisoned) => { + tracing::warn!("Poisoned mutex (recovering): {}", name); + poisoned.into_inner() + } + } +} + +#[derive(Debug, Clone)] +pub struct AddressBatchSnapshot { + pub addresses: Vec<[u8; 32]>, + pub low_element_values: Vec<[u8; 32]>, + pub low_element_next_values: Vec<[u8; 32]>, + pub low_element_indices: Vec, + pub low_element_next_indices: Vec, + pub low_element_proofs: Vec<[[u8; 32]; HEIGHT]>, + pub leaves_hashchain: [u8; 32], +} + +/// Streams address queue data by fetching pages in the background. +/// +/// The first page is fetched synchronously, then subsequent pages are fetched +/// in a background task. Consumers can access data as it becomes available +/// without waiting for the entire fetch to complete. +#[derive(Debug)] +pub struct StreamingAddressQueue { + /// The accumulated address queue data from all fetched pages. + pub data: Arc>, + + /// Number of elements currently available for processing. + available_elements: Arc>, + + /// Signaled when new elements become available. + data_ready: Arc, + + /// Whether the background fetch has completed (all pages fetched or error). + fetch_complete: Arc>, + + /// Signaled when background fetch completes. + fetch_complete_condvar: Arc, + + /// Number of elements per ZKP batch, used for batch boundary calculations. + zkp_batch_size: usize, +} + +impl StreamingAddressQueue { + /// Waits until at least `batch_end` elements are available or fetch completes. + pub fn wait_for_batch(&self, batch_end: usize) -> usize { + const POLL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120); + let start = std::time::Instant::now(); + + loop { + let available = *lock_recover( + &self.available_elements, + "streaming_address_queue.available_elements", + ); + if available >= batch_end { + return available; + } + + let complete = *lock_recover( + &self.fetch_complete, + "streaming_address_queue.fetch_complete", + ); + if complete { + return available; + } + + if start.elapsed() > POLL_TIMEOUT { + tracing::warn!( + "wait_for_batch timed out after {:?} waiting for {} elements (available: {})", + POLL_TIMEOUT, + batch_end, + available + ); + return available; + } + + let guard = lock_recover( + &self.available_elements, + "streaming_address_queue.available_elements", + ); + let _ = self + .data_ready + .wait_timeout(guard, std::time::Duration::from_millis(50)); + } + } + + pub fn get_batch_snapshot( + &self, + start: usize, + end: usize, + hashchain_idx: usize, + ) -> crate::Result>> { + let available = self.wait_for_batch(end); + if available < end || start >= end { + return Ok(None); + } + let data = lock_recover(&self.data, "streaming_address_queue.data"); + + let range = start..end; + let ( + Some(addresses), + Some(low_element_values), + Some(low_element_next_values), + Some(low_element_indices), + Some(low_element_next_indices), + ) = ( + data.addresses.get(range.clone()).map(<[_]>::to_vec), + data.low_element_values + .get(range.clone()) + .map(<[_]>::to_vec), + data.low_element_next_values + .get(range.clone()) + .map(<[_]>::to_vec), + data.low_element_indices + .get(range.clone()) + .map(<[_]>::to_vec), + data.low_element_next_indices + .get(range.clone()) + .map(<[_]>::to_vec), + ) + else { + return Ok(None); + }; + + let low_element_proofs = match data.reconstruct_proofs::(range) { + Ok(proofs) => proofs, + Err(error) => { + if should_emit_rate_limited_warning( + "address_queue_proofs_not_ready", + std::time::Duration::from_secs(60), + ) { + tracing::warn!( + ?error, + start, + end, + "address proof reconstruction not ready, retrying" + ); + } + return Ok(None); + } + }; + + let leaves_hashchain = match data.leaves_hash_chains.get(hashchain_idx).copied() { + Some(hashchain) => hashchain, + None => { + tracing::debug!( + "Missing leaves_hash_chain for batch {} (available: {}), deriving from addresses", + hashchain_idx, + data.leaves_hash_chains.len() + ); + create_hash_chain_from_slice(&addresses).map_err(|error| { + anyhow!( + "Failed to derive leaves_hash_chain for batch {} from {} addresses: {}", + hashchain_idx, + addresses.len(), + error + ) + })? + } + }; + + Ok(Some(AddressBatchSnapshot { + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + low_element_proofs, + addresses, + leaves_hashchain, + })) + } + + pub fn into_data(self) -> AddressQueueData { + let mut complete = lock_recover( + &self.fetch_complete, + "streaming_address_queue.fetch_complete", + ); + while !*complete { + complete = match self.fetch_complete_condvar.wait_while(complete, |c| !*c) { + Ok(guard) => guard, + Err(poisoned) => { + tracing::warn!("Poisoned mutex while waiting (recovering): streaming_address_queue.fetch_complete"); + poisoned.into_inner() + } + }; + } + drop(complete); + match Arc::try_unwrap(self.data) { + Ok(mutex) => mutex.into_inner().unwrap_or_else(|poisoned| { + tracing::warn!("Poisoned mutex during into_data (recovering)"); + poisoned.into_inner() + }), + Err(arc) => lock_recover(arc.as_ref(), "streaming_address_queue.data_clone").clone(), + } + } + + pub fn initial_root(&self) -> [u8; 32] { + lock_recover(&self.data, "streaming_address_queue.data").initial_root + } + + pub fn start_index(&self) -> u64 { + lock_recover(&self.data, "streaming_address_queue.data").start_index + } + + pub fn tree_next_insertion_index(&self) -> u64 { + lock_recover(&self.data, "streaming_address_queue.data").tree_next_insertion_index + } + + pub fn subtrees(&self) -> Vec<[u8; 32]> { + lock_recover(&self.data, "streaming_address_queue.data") + .subtrees + .clone() + } + + pub fn root_seq(&self) -> u64 { + lock_recover(&self.data, "streaming_address_queue.data").root_seq + } + + pub fn available_batches(&self) -> usize { + debug_assert!(self.zkp_batch_size != 0, "zkp_batch_size must be non-zero"); + if self.zkp_batch_size == 0 { + tracing::error!("zkp_batch_size is zero, returning 0 batches to avoid panic"); + return 0; + } + let available = *lock_recover( + &self.available_elements, + "streaming_address_queue.available_elements", + ); + available / self.zkp_batch_size + } + + pub fn is_complete(&self) -> bool { + *lock_recover( + &self.fetch_complete, + "streaming_address_queue.fetch_complete", + ) + } +} + +pub async fn fetch_streaming_address_batches( + context: &BatchContext, + total_elements: u64, + zkp_batch_size: u64, +) -> crate::Result> { + if total_elements == 0 { + return Ok(None); + } + + let page_size_elements = ADDRESS_PAGE_SIZE_BATCHES * zkp_batch_size; + let num_pages = total_elements.div_ceil(page_size_elements) as usize; + + tracing::debug!( + "address fetch: {} elements ({} batches) in {} pages of {} batches each", + total_elements, + total_elements / zkp_batch_size, + num_pages, + ADDRESS_PAGE_SIZE_BATCHES + ); + + let first_page_size = page_size_elements.min(total_elements); + let first_page = + match fetch_address_batches(context, None, first_page_size, zkp_batch_size).await? { + Some(data) if !data.addresses.is_empty() => data, + _ => return Ok(None), + }; + + let initial_elements = first_page.addresses.len(); + let first_page_requested = first_page_size as usize; + + let queue_exhausted = initial_elements < first_page_requested; + + tracing::info!( + "First page fetched: {} addresses ({} batches ready), root={:?}[..4], queue_exhausted={}", + initial_elements, + initial_elements / zkp_batch_size as usize, + &first_page.initial_root[..4], + queue_exhausted + ); + + let streaming = StreamingAddressQueue { + data: Arc::new(Mutex::new(first_page)), + available_elements: Arc::new(Mutex::new(initial_elements)), + data_ready: Arc::new(Condvar::new()), + fetch_complete: Arc::new(Mutex::new(num_pages == 1 || queue_exhausted)), + fetch_complete_condvar: Arc::new(Condvar::new()), + zkp_batch_size: zkp_batch_size as usize, + }; + + if num_pages == 1 || queue_exhausted { + return Ok(Some(streaming)); + } + + let data = Arc::clone(&streaming.data); + let available = Arc::clone(&streaming.available_elements); + let ready = Arc::clone(&streaming.data_ready); + let complete = Arc::clone(&streaming.fetch_complete); + let complete_condvar = Arc::clone(&streaming.fetch_complete_condvar); + let ctx = context.clone(); + let initial_root = streaming.initial_root(); + let first_page_start_index = streaming.start_index(); + + tokio::spawn(async move { + let mut offset = first_page_size; + + for page_idx in 1..num_pages { + let page_size = (total_elements - offset).min(page_size_elements); + let absolute_start = Some(first_page_start_index + offset); + + tracing::debug!( + "Fetching address page {}/{}: absolute_start={:?}, size={}", + page_idx + 1, + num_pages, + absolute_start, + page_size + ); + + match fetch_address_batches(&ctx, absolute_start, page_size, zkp_batch_size).await { + Ok(Some(page)) => { + if page.initial_root != initial_root { + tracing::warn!( + "Address page {} has different root ({:?} vs {:?}), stopping fetch", + page_idx + 1, + &page.initial_root[..4], + &initial_root[..4] + ); + break; + } + + let page_elements = page.addresses.len(); + let page_requested = page_size as usize; + + { + let mut data_guard = + lock_recover(data.as_ref(), "streaming_address_queue.data"); + data_guard.addresses.extend(page.addresses); + data_guard + .low_element_values + .extend(page.low_element_values); + data_guard + .low_element_next_values + .extend(page.low_element_next_values); + data_guard + .low_element_indices + .extend(page.low_element_indices); + data_guard + .low_element_next_indices + .extend(page.low_element_next_indices); + data_guard + .leaves_hash_chains + .extend(page.leaves_hash_chains); + let mut seen: HashSet = data_guard.nodes.iter().copied().collect(); + for (&idx, &hash) in page.nodes.iter().zip(page.node_hashes.iter()) { + if seen.insert(idx) { + data_guard.nodes.push(idx); + data_guard.node_hashes.push(hash); + } + } + } + + { + let mut avail = lock_recover( + available.as_ref(), + "streaming_address_queue.available_elements", + ); + *avail += page_elements; + tracing::debug!( + "Page {} fetched: {} elements, total available: {} ({} batches)", + page_idx + 1, + page_elements, + *avail, + *avail / zkp_batch_size as usize + ); + } + ready.notify_all(); + + if page_elements < page_requested { + tracing::debug!( + "Page {} returned fewer elements than requested ({} < {}), queue exhausted", + page_idx + 1, page_elements, page_requested + ); + break; + } + } + Ok(None) => { + tracing::debug!("Page {} returned empty, stopping fetch", page_idx + 1); + break; + } + Err(e) => { + tracing::warn!("Error fetching page {}: {}", page_idx + 1, e); + break; + } + } + + offset += page_size; + } + + { + let mut complete_guard = + lock_recover(complete.as_ref(), "streaming_address_queue.fetch_complete"); + *complete_guard = true; + } + ready.notify_all(); + complete_condvar.notify_all(); + tracing::debug!("Background address fetch complete"); + }); + + Ok(Some(streaming)) +} diff --git a/forester/src/processor/v2/tree_data.rs b/forester/src/processor/v2/tree_data.rs new file mode 100644 index 0000000000..44cce84e39 --- /dev/null +++ b/forester/src/processor/v2/tree_data.rs @@ -0,0 +1,99 @@ +//! On-chain tree data reads: batch sizes and roots from Merkle tree accounts. + +use anyhow::anyhow; +use light_batched_merkle_tree::merkle_tree::BatchedMerkleTreeAccount; +use light_client::rpc::Rpc; + +use super::BatchContext; + +pub async fn fetch_zkp_batch_size(context: &BatchContext) -> crate::Result { + let rpc = context.rpc_pool.get_connection().await?; + let mut account = rpc + .get_account(context.merkle_tree) + .await? + .ok_or_else(|| anyhow!("Merkle tree account not found"))?; + + let tree = BatchedMerkleTreeAccount::state_from_bytes( + account.data.as_mut_slice(), + &context.merkle_tree.into(), + )?; + + let batch_index = tree.queue_batches.pending_batch_index; + let batch = tree + .queue_batches + .batches + .get(batch_index as usize) + .ok_or_else(|| anyhow!("Batch not found"))?; + + Ok(batch.zkp_batch_size) +} + +pub async fn fetch_onchain_state_root( + context: &BatchContext, +) -> crate::Result<[u8; 32]> { + let rpc = context.rpc_pool.get_connection().await?; + let mut account = rpc + .get_account(context.merkle_tree) + .await? + .ok_or_else(|| anyhow!("Merkle tree account not found"))?; + + let tree = BatchedMerkleTreeAccount::state_from_bytes( + account.data.as_mut_slice(), + &context.merkle_tree.into(), + )?; + + let root = tree + .root_history + .last() + .copied() + .ok_or_else(|| anyhow!("Root history is empty"))?; + + Ok(root) +} + +pub async fn fetch_address_zkp_batch_size(context: &BatchContext) -> crate::Result { + let rpc = context.rpc_pool.get_connection().await?; + let mut account = rpc + .get_account(context.merkle_tree) + .await? + .ok_or_else(|| anyhow!("Merkle tree account not found"))?; + + let tree = BatchedMerkleTreeAccount::address_from_bytes( + account.data.as_mut_slice(), + &context.merkle_tree.into(), + ) + .map_err(|e| anyhow!("Failed to deserialize address tree: {}", e))?; + + let batch_index = tree.queue_batches.pending_batch_index; + let batch = tree + .queue_batches + .batches + .get(batch_index as usize) + .ok_or_else(|| anyhow!("Batch not found"))?; + + Ok(batch.zkp_batch_size) +} + +pub async fn fetch_onchain_address_root( + context: &BatchContext, +) -> crate::Result<[u8; 32]> { + let rpc = context.rpc_pool.get_connection().await?; + let mut account = rpc + .get_account(context.merkle_tree) + .await? + .ok_or_else(|| anyhow!("Merkle tree account not found"))?; + + let tree = BatchedMerkleTreeAccount::address_from_bytes( + account.data.as_mut_slice(), + &context.merkle_tree.into(), + ) + .map_err(|e| anyhow!("Failed to deserialize address tree: {}", e))?; + + let root = tree + .root_history + .last() + .copied() + .ok_or_else(|| anyhow!("Root history is empty"))?; + + Ok(root) +} diff --git a/forester/tests/e2e_test.rs b/forester/tests/e2e_test.rs index c11500f29a..6a60e31ca4 100644 --- a/forester/tests/e2e_test.rs +++ b/forester/tests/e2e_test.rs @@ -269,16 +269,9 @@ async fn e2e_test() { slot_update_interval_seconds: 10, tree_discovery_interval_seconds: 5, enable_metrics: false, - skip_v1_state_trees: false, - skip_v2_state_trees: false, - skip_v1_address_trees: false, - skip_v2_address_trees: false, - tree_ids: vec![], sleep_after_processing_ms: 50, sleep_when_idle_ms: 100, - queue_polling_mode: Default::default(), - group_authority: None, - helius_rpc: false, + ..Default::default() }, rpc_pool_config: RpcPoolConfig { max_size: 50, diff --git a/forester/tests/legacy/batched_state_async_indexer_test.rs b/forester/tests/legacy/batched_state_async_indexer_test.rs index fe599a39a8..cffd8c92a2 100644 --- a/forester/tests/legacy/batched_state_async_indexer_test.rs +++ b/forester/tests/legacy/batched_state_async_indexer_test.rs @@ -23,9 +23,7 @@ use light_compressed_account::{ use light_compressed_token::process_transfer::{ transfer_sdk::create_transfer_instruction, TokenTransferOutputData, }; -use light_token::compat::TokenDataWithMerkleContext; use light_program_test::accounts::test_accounts::TestAccounts; -use light_prover_client::prover::spawn_prover; use light_registry::{ protocol_config::state::{ProtocolConfig, ProtocolConfigPda}, utils::get_protocol_config_pda_address, @@ -34,6 +32,7 @@ use light_test_utils::{ conversions::sdk_to_program_token_data, spl::create_mint_helper_with_keypair, system_program::create_invoke_instruction, }; +use light_token::compat::TokenDataWithMerkleContext; use rand::{prelude::SliceRandom, rngs::StdRng, Rng, SeedableRng}; use serial_test::serial; use solana_program::{native_token::LAMPORTS_PER_SOL, pubkey::Pubkey}; @@ -87,7 +86,6 @@ async fn test_state_indexer_async_batched() { validator_args: vec![], })) .await; - spawn_prover().await; let env = TestAccounts::get_local_test_validator_accounts(); let mut config = forester_config(); @@ -306,10 +304,7 @@ async fn wait_for_slot(rpc: &mut LightClient, target_slot: u64) { return; } Err(e) => { - println!( - "warp_to_slot unavailable ({}), falling back to polling", - e - ); + println!("warp_to_slot unavailable ({}), falling back to polling", e); } } while rpc.get_slot().await.unwrap() < target_slot { diff --git a/forester/tests/test_compressible_ctoken.rs b/forester/tests/test_compressible_ctoken.rs index 569278286f..a3d1c04624 100644 --- a/forester/tests/test_compressible_ctoken.rs +++ b/forester/tests/test_compressible_ctoken.rs @@ -442,7 +442,7 @@ async fn test_compressible_ctoken_compression() { .expect("Failed to register forester"); let rpc_from_pool = ctx.rpc_pool.get_connection().await.unwrap(); let current_slot = rpc_from_pool.get_slot().await.unwrap(); - let ready_accounts = tracker.get_ready_to_compress(current_slot); + let ready_accounts = tracker.get_ready_states(current_slot); assert_eq!(ready_accounts.len(), 1, "Should have 1 account ready"); assert_eq!(ready_accounts[0].pubkey, token_account_pubkey_2); @@ -453,7 +453,7 @@ async fn test_compressible_ctoken_compression() { let compressor = CTokenCompressor::new( ctx.rpc_pool.clone(), tracker.clone(), - ctx.forester_keypair, + Arc::new(ctx.forester_keypair), forester::smart_transaction::TransactionPolicy::default(), ); let compressor_handle = tokio::spawn(async move { diff --git a/forester/tests/test_compressible_mint.rs b/forester/tests/test_compressible_mint.rs index 248db07251..1cffd22be9 100644 --- a/forester/tests/test_compressible_mint.rs +++ b/forester/tests/test_compressible_mint.rs @@ -389,7 +389,7 @@ async fn test_compressible_mint_compression() { rpc.warp_to_slot(future_slot).await.expect("warp_to_slot"); let current_slot = rpc.get_slot().await.unwrap(); - let ready_accounts = tracker.get_ready_to_compress(current_slot); + let ready_accounts = tracker.get_ready_states(current_slot); println!("Ready to compress: {} mints", ready_accounts.len()); assert!( @@ -401,7 +401,7 @@ async fn test_compressible_mint_compression() { let compressor = MintCompressor::new( rpc_pool.clone(), tracker.clone(), - payer.insecure_clone(), + Arc::new(payer.insecure_clone()), forester::smart_transaction::TransactionPolicy::default(), ); @@ -593,7 +593,7 @@ async fn test_compressible_mint_subscription() { // Get ready-to-compress accounts let current_slot = rpc.get_slot().await.unwrap(); - let ready_accounts = tracker.get_ready_to_compress(current_slot); + let ready_accounts = tracker.get_ready_states(current_slot); println!( "Ready to compress: {} mints (current_slot: {})", ready_accounts.len(), @@ -611,7 +611,7 @@ async fn test_compressible_mint_subscription() { let compressor = MintCompressor::new( rpc_pool.clone(), tracker.clone(), - payer.insecure_clone(), + Arc::new(payer.insecure_clone()), forester::smart_transaction::TransactionPolicy::default(), ); @@ -656,7 +656,7 @@ async fn test_compressible_mint_subscription() { println!("Tracker updated: now has {} mint(s)", tracker.len()); // Verify the remaining mint is the second one - let remaining_accounts = tracker.get_ready_to_compress(current_slot); + let remaining_accounts = tracker.get_ready_states(current_slot); assert_eq!(remaining_accounts.len(), 1); assert_eq!( remaining_accounts[0].pubkey, mint_pda_2, diff --git a/forester/tests/test_compressible_pda.rs b/forester/tests/test_compressible_pda.rs index 97783d537c..78fe1d5e5b 100644 --- a/forester/tests/test_compressible_pda.rs +++ b/forester/tests/test_compressible_pda.rs @@ -604,8 +604,7 @@ async fn test_compressible_pda_compression() { // compressible_slot threshold for testing. We can't warp slots in the test validator, // so this tricks get_ready_to_compress_for_program into returning accounts as if // enough time has passed (ready_accounts will include accounts where compressible_slot < current_slot + 1000). - let ready_accounts = - tracker.get_ready_to_compress_for_program(&program_id, current_slot + 1000); + let ready_accounts = tracker.get_ready_states_for_program(&program_id, current_slot + 1000); println!("Ready to compress: {} accounts", ready_accounts.len()); if !ready_accounts.is_empty() { @@ -613,7 +612,7 @@ async fn test_compressible_pda_compression() { let compressor = PdaCompressor::new( ctx.rpc_pool.clone(), tracker.clone(), - ctx.forester_keypair.insecure_clone(), + Arc::new(ctx.forester_keypair.insecure_clone()), forester::smart_transaction::TransactionPolicy::default(), ); @@ -919,8 +918,7 @@ async fn test_compressible_pda_subscription() { let current_slot = rpc_from_pool.get_slot().await.unwrap(); // These should be ready since they're rent-free PDAs - let ready_accounts = - tracker.get_ready_to_compress_for_program(&program_id, current_slot + 1000); + let ready_accounts = tracker.get_ready_states_for_program(&program_id, current_slot + 1000); println!( "Ready to compress: {} PDAs (current_slot: {})", ready_accounts.len(), @@ -936,7 +934,7 @@ async fn test_compressible_pda_subscription() { let compressor = PdaCompressor::new( ctx.rpc_pool.clone(), tracker.clone(), - ctx.forester_keypair.insecure_clone(), + Arc::new(ctx.forester_keypair.insecure_clone()), forester::smart_transaction::TransactionPolicy::default(), ); @@ -977,7 +975,7 @@ async fn test_compressible_pda_subscription() { println!("Tracker updated: now has {} PDA(s)", tracker.len()); // Verify the remaining PDA is the second one - let remaining = tracker.get_ready_to_compress_for_program(&program_id, current_slot + 1000); + let remaining = tracker.get_ready_states_for_program(&program_id, current_slot + 1000); assert_eq!(remaining.len(), 1); assert_eq!( remaining[0].pubkey, record_pda_2, diff --git a/forester/tests/test_utils.rs b/forester/tests/test_utils.rs index 4ae9352482..a36e8585cb 100644 --- a/forester/tests/test_utils.rs +++ b/forester/tests/test_utils.rs @@ -105,19 +105,12 @@ pub fn forester_config() -> ForesterConfig { indexer_config: Default::default(), transaction_config: Default::default(), general_config: GeneralConfig { - slot_update_interval_seconds: 10, tree_discovery_interval_seconds: 5, enable_metrics: false, - skip_v1_state_trees: false, - skip_v2_state_trees: false, - skip_v1_address_trees: false, - skip_v2_address_trees: false, - tree_ids: vec![], sleep_after_processing_ms: 50, sleep_when_idle_ms: 100, queue_polling_mode: QueuePollingMode::OnChain, - group_authority: None, - helius_rpc: false, + ..Default::default() }, rpc_pool_config: RpcPoolConfig { max_size: 50, diff --git a/js/token-interface/src/instructions/mint-to-compressed.ts b/js/token-interface/src/instructions/mint-to-compressed.ts index d31a91b900..410505ba7a 100644 --- a/js/token-interface/src/instructions/mint-to-compressed.ts +++ b/js/token-interface/src/instructions/mint-to-compressed.ts @@ -1,7 +1,4 @@ -import { - SystemProgram, - TransactionInstruction, -} from '@solana/web3.js'; +import { SystemProgram, TransactionInstruction } from '@solana/web3.js'; import { Buffer } from 'buffer'; import { LIGHT_TOKEN_PROGRAM_ID, @@ -125,9 +122,17 @@ export function createMintToCompressedInstruction({ isSigner: false, isWritable: false, }, - { pubkey: SystemProgram.programId, isSigner: false, isWritable: false }, + { + pubkey: SystemProgram.programId, + isSigner: false, + isWritable: false, + }, { pubkey: outputQueue, isSigner: false, isWritable: true }, - { pubkey: merkleContext.treeInfo.tree, isSigner: false, isWritable: true }, + { + pubkey: merkleContext.treeInfo.tree, + isSigner: false, + isWritable: true, + }, { pubkey: merkleContext.treeInfo.queue, isSigner: false, diff --git a/js/token-interface/tests/e2e/ata-read.test.ts b/js/token-interface/tests/e2e/ata-read.test.ts index 8159e81738..2c709f9f3e 100644 --- a/js/token-interface/tests/e2e/ata-read.test.ts +++ b/js/token-interface/tests/e2e/ata-read.test.ts @@ -1,6 +1,10 @@ import { describe, expect, it } from 'vitest'; import { newAccountWithLamports } from '@lightprotocol/stateless.js'; -import { createAtaInstructions, getAta, getAssociatedTokenAddress } from '../../src'; +import { + createAtaInstructions, + getAta, + getAssociatedTokenAddress, +} from '../../src'; import { createMintFixture, sendInstructions } from './helpers'; describe('ata creation and reads', () => { diff --git a/js/token-interface/tests/e2e/mint-to-compressed.test.ts b/js/token-interface/tests/e2e/mint-to-compressed.test.ts index 0bf4e08abe..59e040c79e 100644 --- a/js/token-interface/tests/e2e/mint-to-compressed.test.ts +++ b/js/token-interface/tests/e2e/mint-to-compressed.test.ts @@ -52,7 +52,12 @@ describe('mint-to-compressed instruction', () => { [COMPRESSED_MINT_SEED, mintSigner.publicKey.toBuffer()], LIGHT_TOKEN_PROGRAM_ID, ); - const mintInfo = await getMint(rpc, mint, undefined, LIGHT_TOKEN_PROGRAM_ID); + const mintInfo = await getMint( + rpc, + mint, + undefined, + LIGHT_TOKEN_PROGRAM_ID, + ); if (!mintInfo.merkleContext || !mintInfo.mintContext) { throw new Error('Light mint context missing.'); } @@ -103,7 +108,12 @@ describe('mint-to-compressed instruction', () => { recipientB.publicKey, { mint }, ); - const mintAfter = await getMint(rpc, mint, undefined, LIGHT_TOKEN_PROGRAM_ID); + const mintAfter = await getMint( + rpc, + mint, + undefined, + LIGHT_TOKEN_PROGRAM_ID, + ); const amountA = aAccounts.items.reduce( (sum, account) => sum + BigInt(account.parsed.amount.toString()), diff --git a/js/token-interface/tests/unit/instruction-builders.test.ts b/js/token-interface/tests/unit/instruction-builders.test.ts index b70bf20790..ba36c7e433 100644 --- a/js/token-interface/tests/unit/instruction-builders.test.ts +++ b/js/token-interface/tests/unit/instruction-builders.test.ts @@ -670,7 +670,9 @@ describe('instruction builders', () => { mintSigner: Keypair.generate().publicKey.toBytes(), bump: 255, }, - recipients: [{ recipient: Keypair.generate().publicKey, amount: 42n }], + recipients: [ + { recipient: Keypair.generate().publicKey, amount: 42n }, + ], }); expect(instruction.programId.equals(LIGHT_TOKEN_PROGRAM_ID)).toBe(true); diff --git a/justfile b/justfile index dbbb007cdf..580d8a7241 100644 --- a/justfile +++ b/justfile @@ -44,7 +44,7 @@ lint-readmes: set -e echo "Checking READMEs are up-to-date..." if ! command -v cargo-rdme &> /dev/null; then - cargo install cargo-rdme + cargo install cargo-rdme --locked fi for toml in $(find program-libs sdk-libs -name '.cargo-rdme.toml' -type f); do crate_dir=$(dirname "$toml") diff --git a/program-tests/utils/src/actions/legacy/instructions/transfer2.rs b/program-tests/utils/src/actions/legacy/instructions/transfer2.rs index 1ff92eeda9..b006824302 100644 --- a/program-tests/utils/src/actions/legacy/instructions/transfer2.rs +++ b/program-tests/utils/src/actions/legacy/instructions/transfer2.rs @@ -169,13 +169,23 @@ pub async fn create_generic_transfer2_instruction( payer: Pubkey, should_filter_zero_outputs: bool, ) -> Result { - // // Get a single shared output queue for ALL compress/compress-and-close operations - // // This prevents reordering issues caused by the sort_by_key at the end - // let shared_output_queue = rpc - // .get_random_state_tree_info() - // .unwrap() - // .get_output_pubkey() - // .unwrap(); + // Transfer2 supports a single output queue per instruction. Legacy helpers accept + // per-action queues, but normalize them down to one shared queue for the IX. + let mut explicit_output_queue = None; + for action in &actions { + let candidate = match action { + Transfer2InstructionType::Compress(input) => Some(input.output_queue), + Transfer2InstructionType::CompressAndClose(input) => Some(input.output_queue), + Transfer2InstructionType::Decompress(_) + | Transfer2InstructionType::Transfer(_) + | Transfer2InstructionType::Approve(_) => None, + }; + if let Some(candidate) = candidate { + if explicit_output_queue.is_none() { + explicit_output_queue = Some(candidate); + } + } + } let mut hashes = Vec::new(); actions.iter().for_each(|account| match account { @@ -210,24 +220,16 @@ pub async fn create_generic_transfer2_instruction( .value; let mut packed_tree_accounts = PackedAccounts::default(); - // tree infos must be packed before packing the token input accounts - let packed_tree_infos = rpc_proof_result.pack_tree_infos(&mut packed_tree_accounts); + // Pack only input state tree infos. Grouped transfer2 proofs can span multiple output trees. + let packed_tree_infos = rpc_proof_result.pack_state_tree_infos(&mut packed_tree_accounts); - // We use a single shared output queue for all compress/compress-and-close operations to avoid ordering failures. - let shared_output_queue = if packed_tree_infos.address_trees.is_empty() { - let shared_output_queue = rpc - .get_random_state_tree_info() + let shared_output_queue = explicit_output_queue.unwrap_or_else(|| { + rpc.get_random_state_tree_info() .unwrap() .get_output_pubkey() - .unwrap(); - packed_tree_accounts.insert_or_get(shared_output_queue) - } else { - packed_tree_infos - .state_trees - .as_ref() .unwrap() - .output_tree_index - }; + }); + let shared_output_queue = packed_tree_accounts.insert_or_get(shared_output_queue); let mut inputs_offset = 0; let mut in_lamports = Vec::new(); @@ -242,14 +244,7 @@ pub async fn create_generic_transfer2_instruction( if let Some(ref input_token_account) = input.compressed_token_account { let token_data = input_token_account .iter() - .zip( - packed_tree_infos - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[inputs_offset..] - .iter(), - ) + .zip(packed_tree_infos[inputs_offset..].iter()) .map(|(account, rpc_account)| { if input.to != account.token.owner { return Err(TokenSdkError::InvalidCompressInputOwner); @@ -391,14 +386,7 @@ pub async fn create_generic_transfer2_instruction( let token_data = input .compressed_token_account .iter() - .zip( - packed_tree_infos - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[inputs_offset..] - .iter(), - ) + .zip(packed_tree_infos[inputs_offset..].iter()) .map(|(account, rpc_account)| { pack_input_token_account( account, @@ -460,14 +448,7 @@ pub async fn create_generic_transfer2_instruction( let token_data = input .compressed_token_account .iter() - .zip( - packed_tree_infos - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[inputs_offset..] - .iter(), - ) + .zip(packed_tree_infos[inputs_offset..].iter()) .map(|(account, rpc_account)| { pack_input_token_account( account, @@ -542,14 +523,7 @@ pub async fn create_generic_transfer2_instruction( let token_data = input .compressed_token_account .iter() - .zip( - packed_tree_infos - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[inputs_offset..] - .iter(), - ) + .zip(packed_tree_infos[inputs_offset..].iter()) .map(|(account, rpc_account)| { pack_input_token_account( account, diff --git a/program-tests/utils/src/e2e_test_env.rs b/program-tests/utils/src/e2e_test_env.rs index edb94fdf48..482eb5eaaa 100644 --- a/program-tests/utils/src/e2e_test_env.rs +++ b/program-tests/utils/src/e2e_test_env.rs @@ -835,7 +835,7 @@ where .map_err(|error| RpcError::CustomError(error.to_string())) .unwrap(); let (proof_a, proof_b, proof_c) = - proof_from_json_struct(proof_json); + proof_from_json_struct(proof_json).unwrap(); let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c); let instruction_data = InstructionDataBatchNullifyInputs { diff --git a/prover/client/src/helpers.rs b/prover/client/src/helpers.rs index 9a20b8958e..f1d77b3bfa 100644 --- a/prover/client/src/helpers.rs +++ b/prover/client/src/helpers.rs @@ -2,7 +2,7 @@ use std::process::Command; use light_hasher::{Hasher, Poseidon}; use light_sparse_merkle_tree::changelog::ChangelogEntry; -use num_bigint::{BigInt, BigUint}; +use num_bigint::{BigInt, BigUint, Sign}; use num_traits::{Num, ToPrimitive}; use serde::Serialize; @@ -35,10 +35,17 @@ pub fn convert_endianness_128(bytes: &[u8]) -> Vec { .collect::>() } -pub fn bigint_to_u8_32(n: &BigInt) -> Result<[u8; 32], Box> { - let (_, bytes_be) = n.to_bytes_be(); +pub fn bigint_to_u8_32(n: &BigInt) -> Result<[u8; 32], ProverClientError> { + let (sign, bytes_be) = n.to_bytes_be(); + if sign == Sign::Minus { + return Err(ProverClientError::InvalidProofData( + "negative integers are not valid field elements".to_string(), + )); + } if bytes_be.len() > 32 { - Err("Number too large to fit in [u8; 32]")?; + return Err(ProverClientError::InvalidProofData( + "number too large to fit in [u8; 32]".to_string(), + )); } let mut array = [0; 32]; let bytes = &bytes_be[..bytes_be.len()]; diff --git a/prover/client/src/proof.rs b/prover/client/src/proof.rs index cc66f3aed6..da9b913830 100644 --- a/prover/client/src/proof.rs +++ b/prover/client/src/proof.rs @@ -12,6 +12,9 @@ use solana_bn254::compression::prelude::{ alt_bn128_g2_decompress_be, convert_endianness, }; +pub type CompressedProofBytes = ([u8; 32], [u8; 64], [u8; 32]); +pub type UncompressedProofBytes = ([u8; 64], [u8; 128], [u8; 64]); + #[derive(Debug, Clone, Copy)] pub struct ProofCompressed { pub a: [u8; 32], @@ -66,16 +69,27 @@ pub fn deserialize_gnark_proof_json(json_data: &str) -> serde_json::Result [u8; 32] { - let trimmed_str = hex_str.trim_start_matches("0x"); - let big_int = num_bigint::BigInt::from_str_radix(trimmed_str, 16).unwrap(); - let big_int_bytes = big_int.to_bytes_be().1; - if big_int_bytes.len() < 32 { +pub fn deserialize_hex_string_to_be_bytes(hex_str: &str) -> Result<[u8; 32], ProverClientError> { + let trimmed_str = hex_str + .strip_prefix("0x") + .or_else(|| hex_str.strip_prefix("0X")) + .unwrap_or(hex_str); + let big_uint = num_bigint::BigUint::from_str_radix(trimmed_str, 16) + .map_err(|error| ProverClientError::InvalidHexString(format!("{hex_str}: {error}")))?; + let big_uint_bytes = big_uint.to_bytes_be(); + if big_uint_bytes.len() > 32 { + return Err(ProverClientError::InvalidHexString(format!( + "{hex_str}: exceeds 32 bytes" + ))); + } + if big_uint_bytes.len() < 32 { let mut result = [0u8; 32]; - result[32 - big_int_bytes.len()..].copy_from_slice(&big_int_bytes); - result + result[32 - big_uint_bytes.len()..].copy_from_slice(&big_uint_bytes); + Ok(result) } else { - big_int_bytes.try_into().unwrap() + big_uint_bytes.try_into().map_err(|_| { + ProverClientError::InvalidHexString(format!("{hex_str}: invalid 32-byte encoding")) + }) } } @@ -90,40 +104,72 @@ pub fn compress_proof( (proof_a, proof_b, proof_c) } -pub fn proof_from_json_struct(json: GnarkProofJson) -> ([u8; 64], [u8; 128], [u8; 64]) { - let proof_a_x = deserialize_hex_string_to_be_bytes(&json.ar[0]); - let proof_a_y = deserialize_hex_string_to_be_bytes(&json.ar[1]); - let proof_a: [u8; 64] = [proof_a_x, proof_a_y].concat().try_into().unwrap(); - let proof_a = negate_g1(&proof_a); - let proof_b_x_0 = deserialize_hex_string_to_be_bytes(&json.bs[0][0]); - let proof_b_x_1 = deserialize_hex_string_to_be_bytes(&json.bs[0][1]); - let proof_b_y_0 = deserialize_hex_string_to_be_bytes(&json.bs[1][0]); - let proof_b_y_1 = deserialize_hex_string_to_be_bytes(&json.bs[1][1]); +pub fn proof_from_json_struct( + json: GnarkProofJson, +) -> Result { + let proof_a_x = deserialize_hex_string_to_be_bytes(json.ar.first().ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof A x coordinate".to_string()) + })?)?; + let proof_a_y = deserialize_hex_string_to_be_bytes(json.ar.get(1).ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof A y coordinate".to_string()) + })?)?; + let proof_a: [u8; 64] = [proof_a_x, proof_a_y] + .concat() + .try_into() + .map_err(|_| ProverClientError::InvalidProofData("invalid proof A length".to_string()))?; + let proof_a = negate_g1(&proof_a)?; + let proof_b_x_0 = deserialize_hex_string_to_be_bytes( + json.bs.first().and_then(|row| row.first()).ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof B x0 coordinate".to_string()) + })?, + )?; + let proof_b_x_1 = deserialize_hex_string_to_be_bytes( + json.bs.first().and_then(|row| row.get(1)).ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof B x1 coordinate".to_string()) + })?, + )?; + let proof_b_y_0 = deserialize_hex_string_to_be_bytes( + json.bs.get(1).and_then(|row| row.first()).ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof B y0 coordinate".to_string()) + })?, + )?; + let proof_b_y_1 = + deserialize_hex_string_to_be_bytes(json.bs.get(1).and_then(|row| row.get(1)).ok_or_else( + || ProverClientError::InvalidProofData("missing proof B y1 coordinate".to_string()), + )?)?; let proof_b: [u8; 128] = [proof_b_x_0, proof_b_x_1, proof_b_y_0, proof_b_y_1] .concat() .try_into() - .unwrap(); + .map_err(|_| ProverClientError::InvalidProofData("invalid proof B length".to_string()))?; - let proof_c_x = deserialize_hex_string_to_be_bytes(&json.krs[0]); - let proof_c_y = deserialize_hex_string_to_be_bytes(&json.krs[1]); - let proof_c: [u8; 64] = [proof_c_x, proof_c_y].concat().try_into().unwrap(); - (proof_a, proof_b, proof_c) + let proof_c_x = deserialize_hex_string_to_be_bytes(json.krs.first().ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof C x coordinate".to_string()) + })?)?; + let proof_c_y = deserialize_hex_string_to_be_bytes(json.krs.get(1).ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof C y coordinate".to_string()) + })?)?; + let proof_c: [u8; 64] = [proof_c_x, proof_c_y] + .concat() + .try_into() + .map_err(|_| ProverClientError::InvalidProofData("invalid proof C length".to_string()))?; + Ok((proof_a, proof_b, proof_c)) } -pub fn negate_g1(g1_be: &[u8; 64]) -> [u8; 64] { +pub fn negate_g1(g1_be: &[u8; 64]) -> Result<[u8; 64], ProverClientError> { let g1_le = convert_endianness::<32, 64>(g1_be); - let g1: G1 = G1::deserialize_with_mode(g1_le.as_slice(), Compress::No, Validate::No).unwrap(); + let g1: G1 = G1::deserialize_with_mode(g1_le.as_slice(), Compress::No, Validate::Yes) + .map_err(|error| ProverClientError::InvalidProofData(error.to_string()))?; let g1_neg = g1.neg(); let mut g1_neg_be = [0u8; 64]; g1_neg .x .serialize_with_mode(&mut g1_neg_be[..32], Compress::No) - .unwrap(); + .map_err(|error| ProverClientError::InvalidProofData(error.to_string()))?; g1_neg .y .serialize_with_mode(&mut g1_neg_be[32..], Compress::No) - .unwrap(); + .map_err(|error| ProverClientError::InvalidProofData(error.to_string()))?; let g1_neg_be: [u8; 64] = convert_endianness::<32, 64>(&g1_neg_be); - g1_neg_be + Ok(g1_neg_be) } diff --git a/prover/client/src/proof_client.rs b/prover/client/src/proof_client.rs index 859ea4917f..e6c5008c39 100644 --- a/prover/client/src/proof_client.rs +++ b/prover/client/src/proof_client.rs @@ -654,7 +654,7 @@ impl ProofClient { ProverClientError::ProverServerError(format!("Failed to deserialize proof JSON: {}", e)) })?; - let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); + let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json)?; let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c); Ok(ProofResult { diff --git a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs index fdc50621cc..bf414b160d 100644 --- a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs @@ -93,7 +93,6 @@ pub struct BatchAddressAppendInputs { pub hashchain_hash: BigUint, pub low_element_values: Vec, pub low_element_indices: Vec, - pub low_element_next_indices: Vec, pub low_element_next_values: Vec, pub low_element_proofs: Vec>, pub new_element_values: Vec, @@ -105,84 +104,6 @@ pub struct BatchAddressAppendInputs { pub tree_height: usize, } -impl BatchAddressAppendInputs { - #[allow(clippy::too_many_arguments)] - pub fn new( - batch_size: usize, - leaves_hashchain: [u8; 32], - low_element_values: &[[u8; 32]], - low_element_indices: &[u64], - low_element_next_indices: &[u64], - low_element_next_values: &[[u8; 32]], - low_element_proofs: Vec>, - new_element_values: &[[u8; 32]], - new_element_proofs: Vec>, - new_root: [u8; 32], - old_root: [u8; 32], - start_index: usize, - ) -> Result { - let hash_chain_inputs = [ - old_root, - new_root, - leaves_hashchain, - bigint_to_be_bytes_array::<32>(&start_index.into())?, - ]; - let public_input_hash = create_hash_chain_from_array(hash_chain_inputs)?; - - let low_element_proofs_bigint: Vec> = low_element_proofs - .into_iter() - .map(|proof| { - proof - .into_iter() - .map(|p| BigUint::from_bytes_be(&p)) - .collect() - }) - .collect(); - - let new_element_proofs_bigint: Vec> = new_element_proofs - .into_iter() - .map(|proof| { - proof - .into_iter() - .map(|p| BigUint::from_bytes_be(&p)) - .collect() - }) - .collect(); - - Ok(Self { - batch_size, - hashchain_hash: BigUint::from_bytes_be(&leaves_hashchain), - low_element_values: low_element_values - .iter() - .map(|v| BigUint::from_bytes_be(v)) - .collect(), - low_element_indices: low_element_indices - .iter() - .map(|&i| BigUint::from(i)) - .collect(), - low_element_next_indices: low_element_next_indices - .iter() - .map(|&i| BigUint::from(i)) - .collect(), - low_element_next_values: low_element_next_values - .iter() - .map(|v| BigUint::from_bytes_be(v)) - .collect(), - low_element_proofs: low_element_proofs_bigint, - new_element_values: new_element_values - .iter() - .map(|v| BigUint::from_bytes_be(v)) - .collect(), - new_element_proofs: new_element_proofs_bigint, - new_root: BigUint::from_bytes_be(&new_root), - old_root: BigUint::from_bytes_be(&old_root), - public_input_hash: BigUint::from_bytes_be(&public_input_hash), - start_index, - tree_height: HEIGHT, - }) - } -} - #[allow(clippy::too_many_arguments)] pub fn get_batch_address_append_circuit_inputs( next_index: usize, @@ -199,32 +120,31 @@ pub fn get_batch_address_append_circuit_inputs( changelog: &mut Vec>, indexed_changelog: &mut Vec>, ) -> Result { - if zkp_batch_size > new_element_values.len() - || zkp_batch_size > low_element_values.len() - || zkp_batch_size > low_element_indices.len() - || zkp_batch_size > low_element_next_indices.len() - || zkp_batch_size > low_element_next_values.len() - || zkp_batch_size > low_element_proofs.len() - { - return Err(ProverClientError::GenericError(format!( - "zkp_batch_size {} exceeds input slice lengths \ - (new_element_values={}, low_element_values={}, low_element_indices={}, \ - low_element_next_indices={}, low_element_next_values={}, low_element_proofs={})", - zkp_batch_size, - new_element_values.len(), - low_element_values.len(), - low_element_indices.len(), - low_element_next_indices.len(), - low_element_next_values.len(), - low_element_proofs.len(), - ))); + for (name, len) in [ + ("new_element_values", new_element_values.len()), + ("low_element_values", low_element_values.len()), + ("low_element_next_values", low_element_next_values.len()), + ("low_element_indices", low_element_indices.len()), + ("low_element_next_indices", low_element_next_indices.len()), + ("low_element_proofs", low_element_proofs.len()), + ] { + if len < zkp_batch_size { + return Err(ProverClientError::GenericError(format!( + "truncated batch from indexer: {} len {} < required batch size {}", + name, len, zkp_batch_size + ))); + } } + let new_element_values = &new_element_values[..zkp_batch_size]; + let mut staged_changelog = changelog.clone(); + let mut staged_indexed_changelog = indexed_changelog.clone(); + let mut staged_sparse_merkle_tree = sparse_merkle_tree.clone(); + let initial_changelog_len = staged_changelog.len(); let mut new_root = [0u8; 32]; let mut low_element_circuit_merkle_proofs = Vec::with_capacity(zkp_batch_size); let mut new_element_circuit_merkle_proofs = Vec::with_capacity(zkp_batch_size); let mut patched_low_element_next_values = Vec::with_capacity(zkp_batch_size); - let mut patched_low_element_next_indices = Vec::with_capacity(zkp_batch_size); let mut patched_low_element_values = Vec::with_capacity(zkp_batch_size); let mut patched_low_element_indices = Vec::with_capacity(zkp_batch_size); @@ -256,9 +176,9 @@ pub fn get_batch_address_append_circuit_inputs( next_index ); - let mut patcher = ChangelogProofPatcher::new::(changelog); + let mut patcher = ChangelogProofPatcher::new::(&staged_changelog); - let is_first_batch = indexed_changelog.is_empty(); + let is_first_batch = staged_indexed_changelog.is_empty(); let mut expected_root_for_low = current_root; for i in 0..zkp_batch_size { @@ -294,7 +214,7 @@ pub fn get_batch_address_append_circuit_inputs( patch_indexed_changelogs( 0, &mut changelog_index, - indexed_changelog, + &mut staged_indexed_changelog, &mut low_element, &mut new_element, &mut low_element_next_value, @@ -308,7 +228,6 @@ pub fn get_batch_address_append_circuit_inputs( })?; patched_low_element_next_values .push(bigint_to_be_bytes_array::<32>(&low_element_next_value)?); - patched_low_element_next_indices.push(low_element.next_index()); patched_low_element_indices.push(low_element.index); patched_low_element_values.push(bigint_to_be_bytes_array::<32>(&low_element.value)?); @@ -386,7 +305,7 @@ pub fn get_batch_address_append_circuit_inputs( new_low_element.index, )?; - patcher.push_changelog_entry::(changelog, changelog_entry); + patcher.push_changelog_entry::(&mut staged_changelog, changelog_entry); low_element_circuit_merkle_proofs.push( merkle_proof .iter() @@ -399,10 +318,10 @@ pub fn get_batch_address_append_circuit_inputs( let low_element_changelog_entry = IndexedChangelogEntry { element: new_low_element_raw, proof: low_element_changelog_proof, - changelog_index: indexed_changelog.len(), //change_log_index, + changelog_index: staged_indexed_changelog.len(), }; - indexed_changelog.push(low_element_changelog_entry); + staged_indexed_changelog.push(low_element_changelog_entry); { let new_element_next_value = low_element_next_value; @@ -412,10 +331,10 @@ pub fn get_batch_address_append_circuit_inputs( ProverClientError::GenericError(format!("Failed to hash new element: {}", e)) })?; - let sparse_root_before = sparse_merkle_tree.root(); - let sparse_next_idx_before = sparse_merkle_tree.get_next_index(); + let sparse_root_before = staged_sparse_merkle_tree.root(); + let sparse_next_idx_before = staged_sparse_merkle_tree.get_next_index(); - let mut merkle_proof_array = sparse_merkle_tree.append(new_element_leaf_hash); + let mut merkle_proof_array = staged_sparse_merkle_tree.append(new_element_leaf_hash); let current_index = next_index + i; @@ -427,7 +346,7 @@ pub fn get_batch_address_append_circuit_inputs( current_index, )?; - if i == 0 && changelog.len() == 1 { + if i == 0 && staged_changelog.len() == initial_changelog_len + 1 { if sparse_next_idx_before != current_index { return Err(ProverClientError::GenericError(format!( "sparse index mismatch: sparse tree next_index={} but expected current_index={}", @@ -486,7 +405,7 @@ pub fn get_batch_address_append_circuit_inputs( new_root = updated_root; - patcher.push_changelog_entry::(changelog, changelog_entry); + patcher.push_changelog_entry::(&mut staged_changelog, changelog_entry); new_element_circuit_merkle_proofs.push( merkle_proof_array .iter() @@ -504,9 +423,9 @@ pub fn get_batch_address_append_circuit_inputs( let new_element_changelog_entry = IndexedChangelogEntry { element: new_element_raw, proof: merkle_proof_array, - changelog_index: indexed_changelog.len(), + changelog_index: staged_indexed_changelog.len(), }; - indexed_changelog.push(new_element_changelog_entry); + staged_indexed_changelog.push(new_element_changelog_entry); } } @@ -542,18 +461,18 @@ pub fn get_batch_address_append_circuit_inputs( patcher.hits, patcher.misses, patcher.overwrites, - changelog.len(), - indexed_changelog.len() + staged_changelog.len(), + staged_indexed_changelog.len() ); - if patcher.hits == 0 && !changelog.is_empty() { + if patcher.hits == 0 && !staged_changelog.is_empty() { tracing::warn!( "Address proof patcher had 0 cache hits despite non-empty changelog (changelog_len={}, indexed_changelog_len={})", - changelog.len(), - indexed_changelog.len() + staged_changelog.len(), + staged_indexed_changelog.len() ); } - Ok(BatchAddressAppendInputs { + let inputs = BatchAddressAppendInputs { batch_size: patched_low_element_values.len(), hashchain_hash: BigUint::from_bytes_be(&leaves_hashchain), low_element_values: patched_low_element_values @@ -564,16 +483,12 @@ pub fn get_batch_address_append_circuit_inputs( .iter() .map(|&i| BigUint::from(i)) .collect(), - low_element_next_indices: patched_low_element_next_indices - .iter() - .map(|&i| BigUint::from(i)) - .collect(), low_element_next_values: patched_low_element_next_values .iter() .map(|v| BigUint::from_bytes_be(v)) .collect(), low_element_proofs: low_element_circuit_merkle_proofs, - new_element_values: new_element_values[0..] + new_element_values: new_element_values .iter() .map(|v| BigUint::from_bytes_be(v)) .collect(), @@ -583,5 +498,11 @@ pub fn get_batch_address_append_circuit_inputs( public_input_hash: BigUint::from_bytes_be(&public_input_hash), start_index: next_index, tree_height: HEIGHT, - }) + }; + + *changelog = staged_changelog; + *indexed_changelog = staged_indexed_changelog; + *sparse_merkle_tree = staged_sparse_merkle_tree; + + Ok(inputs) } diff --git a/prover/client/src/proof_types/batch_append/proof_inputs.rs b/prover/client/src/proof_types/batch_append/proof_inputs.rs index 41a6dcfcd6..186086a6a5 100644 --- a/prover/client/src/proof_types/batch_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_append/proof_inputs.rs @@ -128,9 +128,20 @@ pub fn get_batch_append_inputs( batch_size: u32, previous_changelogs: &[ChangelogEntry], ) -> Result<(BatchAppendsCircuitInputs, Vec>), ProverClientError> { + let batch_size_usize = batch_size as usize; + if old_leaves.len() != batch_size_usize + || leaves.len() != batch_size_usize + || merkle_proofs.len() != batch_size_usize + { + return Err(ProverClientError::InvalidProofData(format!( + "batch append input length mismatch: old_leaves={}, leaves={}, merkle_proofs={}, expected batch_size={}", + old_leaves.len(), leaves.len(), merkle_proofs.len(), batch_size + ))); + } + let mut new_root = [0u8; 32]; let mut changelog: Vec> = Vec::new(); - let mut circuit_merkle_proofs = Vec::with_capacity(batch_size as usize); + let mut circuit_merkle_proofs = Vec::with_capacity(batch_size_usize); for (i, (old_leaf, (new_leaf, mut merkle_proof))) in old_leaves .iter() diff --git a/prover/client/src/proof_types/batch_update/proof_inputs.rs b/prover/client/src/proof_types/batch_update/proof_inputs.rs index 2ada02b92b..f5467184aa 100644 --- a/prover/client/src/proof_types/batch_update/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_update/proof_inputs.rs @@ -31,8 +31,12 @@ pub struct BatchUpdateCircuitInputs { } impl BatchUpdateCircuitInputs { - pub fn public_inputs_arr(&self) -> [u8; 32] { - bigint_to_u8_32(&self.public_input_hash).unwrap() + pub fn public_inputs_arr(&self) -> Result<[u8; 32], ProverClientError> { + bigint_to_u8_32(&self.public_input_hash).map_err(|error| { + ProverClientError::GenericError(format!( + "failed to serialize batch update public input: {error}" + )) + }) } pub fn new( @@ -112,9 +116,17 @@ impl BatchUpdateCircuitInputs { pub struct BatchUpdateInputs<'a>(pub &'a [BatchUpdateCircuitInputs]); impl BatchUpdateInputs<'_> { - pub fn public_inputs(&self) -> Vec<[u8; 32]> { - // Concatenate all public inputs into a single flat vector - vec![self.0[0].public_inputs_arr()] + pub fn public_inputs(&self) -> Result, ProverClientError> { + if self.0.is_empty() { + return Err(ProverClientError::GenericError( + "batch update inputs cannot be empty".to_string(), + )); + } + + self.0 + .iter() + .map(BatchUpdateCircuitInputs::public_inputs_arr) + .collect() } } diff --git a/prover/client/src/proof_types/combined/v2/json.rs b/prover/client/src/proof_types/combined/v2/json.rs index 322a5ee8ec..71de6b2ed3 100644 --- a/prover/client/src/proof_types/combined/v2/json.rs +++ b/prover/client/src/proof_types/combined/v2/json.rs @@ -2,6 +2,7 @@ use serde::Serialize; use crate::{ constants::{DEFAULT_BATCH_ADDRESS_TREE_HEIGHT, DEFAULT_BATCH_STATE_TREE_HEIGHT}, + errors::ProverClientError, helpers::{big_int_to_string, create_json_from_struct}, proof_types::{ circuit_type::CircuitType, @@ -29,21 +30,22 @@ pub struct CombinedJsonStruct { } impl CombinedJsonStruct { - pub fn from_combined_inputs(inputs: &CombinedProofInputs) -> Self { + pub fn from_combined_inputs(inputs: &CombinedProofInputs) -> Result { let inclusion_parameters = BatchInclusionJsonStruct::from_inclusion_proof_inputs(&inputs.inclusion_parameters); - let non_inclusion_parameters = BatchNonInclusionJsonStruct::from_non_inclusion_proof_inputs( - &inputs.non_inclusion_parameters, - ); + let non_inclusion_parameters = + BatchNonInclusionJsonStruct::from_non_inclusion_proof_inputs( + &inputs.non_inclusion_parameters, + )?; - Self { + Ok(Self { circuit_type: CircuitType::Combined.to_string(), state_tree_height: DEFAULT_BATCH_STATE_TREE_HEIGHT, address_tree_height: DEFAULT_BATCH_ADDRESS_TREE_HEIGHT, public_input_hash: big_int_to_string(&inputs.public_input_hash), inclusion: inclusion_parameters.inputs, non_inclusion: non_inclusion_parameters.inputs, - } + }) } #[allow(clippy::inherent_to_string)] diff --git a/prover/client/src/proof_types/non_inclusion/v2/json.rs b/prover/client/src/proof_types/non_inclusion/v2/json.rs index f6174e724d..9a556843af 100644 --- a/prover/client/src/proof_types/non_inclusion/v2/json.rs +++ b/prover/client/src/proof_types/non_inclusion/v2/json.rs @@ -2,6 +2,7 @@ use num_traits::ToPrimitive; use serde::Serialize; use crate::{ + errors::ProverClientError, helpers::{big_int_to_string, create_json_from_struct}, proof_types::{circuit_type::CircuitType, non_inclusion::v2::NonInclusionProofInputs}, }; @@ -24,7 +25,7 @@ pub struct NonInclusionJsonStruct { pub value: String, #[serde(rename(serialize = "pathIndex"))] - pub path_index: u32, + pub path_index: u64, #[serde(rename(serialize = "pathElements"))] pub path_elements: Vec, @@ -42,13 +43,23 @@ impl BatchNonInclusionJsonStruct { create_json_from_struct(&self) } - pub fn from_non_inclusion_proof_inputs(inputs: &NonInclusionProofInputs) -> Self { + pub fn from_non_inclusion_proof_inputs( + inputs: &NonInclusionProofInputs, + ) -> Result { let mut proof_inputs: Vec = Vec::new(); for input in inputs.inputs.iter() { let prof_input = NonInclusionJsonStruct { root: big_int_to_string(&input.root), value: big_int_to_string(&input.value), - path_index: input.index_hashed_indexed_element_leaf.to_u32().unwrap(), + path_index: input + .index_hashed_indexed_element_leaf + .to_u64() + .ok_or_else(|| { + ProverClientError::IntegerConversion(format!( + "failed to convert path index {} to u64", + input.index_hashed_indexed_element_leaf + )) + })?, path_elements: input .merkle_proof_hashed_indexed_element_leaf .iter() @@ -60,11 +71,11 @@ impl BatchNonInclusionJsonStruct { proof_inputs.push(prof_input); } - Self { + Ok(Self { circuit_type: CircuitType::NonInclusion.to_string(), address_tree_height: 40, public_input_hash: big_int_to_string(&inputs.public_input_hash), inputs: proof_inputs, - } + }) } } diff --git a/prover/client/tests/batch_address_append.rs b/prover/client/tests/batch_address_append.rs index 7b8ceaa5f9..8e354db75a 100644 --- a/prover/client/tests/batch_address_append.rs +++ b/prover/client/tests/batch_address_append.rs @@ -212,7 +212,6 @@ pub fn get_test_batch_address_append_inputs( let mut low_element_values = Vec::new(); let mut low_element_indices = Vec::new(); - let mut low_element_next_indices = Vec::new(); let mut low_element_next_values = Vec::new(); let mut low_element_proofs = Vec::new(); let mut new_element_values = Vec::new(); @@ -230,7 +229,6 @@ pub fn get_test_batch_address_append_inputs( &non_inclusion_proof.leaf_lower_range_value, )); low_element_indices.push(non_inclusion_proof.leaf_index.into()); - low_element_next_indices.push(non_inclusion_proof.next_index.into()); low_element_next_values.push(BigUint::from_bytes_be( &non_inclusion_proof.leaf_higher_range_value, )); @@ -275,7 +273,6 @@ pub fn get_test_batch_address_append_inputs( hashchain_hash: BigUint::from_bytes_be(&leaves_hashchain), low_element_values, low_element_indices, - low_element_next_indices, low_element_next_values, low_element_proofs, new_element_values, diff --git a/prover/client/tests/init_merkle_tree.rs b/prover/client/tests/init_merkle_tree.rs index 3bb5584cd3..1cba92ad6b 100644 --- a/prover/client/tests/init_merkle_tree.rs +++ b/prover/client/tests/init_merkle_tree.rs @@ -221,7 +221,7 @@ pub fn non_inclusion_new_with_public_inputs_v2( .collect(), path_index: merkle_inputs .index_hashed_indexed_element_leaf - .to_u32() + .to_u64() .unwrap(), leaf_lower_range_value: big_int_to_string(&merkle_inputs.leaf_lower_range_value), leaf_higher_range_value: big_int_to_string(&merkle_inputs.leaf_higher_range_value), diff --git a/prover/server/prover/common/types.go b/prover/server/prover/common/types.go index 65668a764e..2d077e77e7 100644 --- a/prover/server/prover/common/types.go +++ b/prover/server/prover/common/types.go @@ -14,7 +14,7 @@ const ( // JSON input structures (these are not in circuit_utils.go) type InclusionProofInputsJSON struct { Root string `json:"root"` - PathIndex uint32 `json:"pathIndex"` + PathIndex uint64 `json:"pathIndex"` PathElements []string `json:"pathElements"` Leaf string `json:"leaf"` } diff --git a/prover/server/prover/v1/inclusion_proving_system.go b/prover/server/prover/v1/inclusion_proving_system.go index ce5d30b253..a9f075f6fe 100644 --- a/prover/server/prover/v1/inclusion_proving_system.go +++ b/prover/server/prover/v1/inclusion_proving_system.go @@ -14,7 +14,7 @@ import ( type InclusionInputs struct { Root big.Int - PathIndex uint32 + PathIndex uint64 PathElements []big.Int Leaf big.Int } diff --git a/prover/server/prover/v1/marshal_non_inclusion.go b/prover/server/prover/v1/marshal_non_inclusion.go index 583a2348f5..336f928524 100644 --- a/prover/server/prover/v1/marshal_non_inclusion.go +++ b/prover/server/prover/v1/marshal_non_inclusion.go @@ -9,7 +9,7 @@ import ( type NonInclusionProofInputsJSON struct { Root string `json:"root"` Value string `json:"value"` - PathIndex uint32 `json:"pathIndex"` + PathIndex uint64 `json:"pathIndex"` PathElements []string `json:"pathElements"` LeafLowerRangeValue string `json:"leafLowerRangeValue"` LeafHigherRangeValue string `json:"leafHigherRangeValue"` diff --git a/prover/server/prover/v1/non_inclusion_proving_system.go b/prover/server/prover/v1/non_inclusion_proving_system.go index b31d800e93..b79fe9c8c2 100644 --- a/prover/server/prover/v1/non_inclusion_proving_system.go +++ b/prover/server/prover/v1/non_inclusion_proving_system.go @@ -15,7 +15,7 @@ import ( type NonInclusionInputs struct { Root big.Int Value big.Int - PathIndex uint32 + PathIndex uint64 PathElements []big.Int LeafLowerRangeValue big.Int diff --git a/prover/server/prover/v1/non_inclusion_test.go b/prover/server/prover/v1/non_inclusion_test.go index 48f77f0669..8a50708b92 100644 --- a/prover/server/prover/v1/non_inclusion_test.go +++ b/prover/server/prover/v1/non_inclusion_test.go @@ -226,7 +226,7 @@ func TestNonInclusionCircuit(t *testing.T) { LeafLowerRangeValue: *leafLowerRangeValue, LeafHigherRangeValue: *leafHigherRangeValue, NextIndex: uint32(0), - PathIndex: uint32(pathIndex), + PathIndex: uint64(pathIndex), PathElements: pathElements, } diff --git a/prover/server/prover/v1/test_data_helpers.go b/prover/server/prover/v1/test_data_helpers.go index f888422b02..f93d58ec00 100644 --- a/prover/server/prover/v1/test_data_helpers.go +++ b/prover/server/prover/v1/test_data_helpers.go @@ -29,7 +29,7 @@ func BuildTestTree(depth int, numberOfCompressedAccounts int, random bool) Inclu for i := 0; i < numberOfCompressedAccounts; i++ { inputs[i].Leaf = *leaf - inputs[i].PathIndex = uint32(pathIndex) + inputs[i].PathIndex = uint64(pathIndex) inputs[i].PathElements = tree.Update(pathIndex, *leaf) inputs[i].Root = tree.Root.Value() } @@ -96,7 +96,7 @@ func BuildTestNonInclusionTree(depth int, numberOfCompressedAccounts int, random inputs[i].LeafLowerRangeValue = *leafLower inputs[i].LeafHigherRangeValue = *leafUpper inputs[i].NextIndex = uint32(0) // Set NextIndex explicitly - inputs[i].PathIndex = uint32(pathIndex) + inputs[i].PathIndex = uint64(pathIndex) inputs[i].PathElements = pathElements } diff --git a/prover/server/prover/v2/inclusion_proving_system.go b/prover/server/prover/v2/inclusion_proving_system.go index c0c9627d16..1b53448b66 100644 --- a/prover/server/prover/v2/inclusion_proving_system.go +++ b/prover/server/prover/v2/inclusion_proving_system.go @@ -16,7 +16,7 @@ import ( type InclusionInputs struct { Root big.Int - PathIndex uint32 + PathIndex uint64 PathElements []big.Int Leaf big.Int PublicInputHash big.Int diff --git a/prover/server/prover/v2/marshal_inclusion.go b/prover/server/prover/v2/marshal_inclusion.go index 560eed5b5f..cc1b8465c0 100644 --- a/prover/server/prover/v2/marshal_inclusion.go +++ b/prover/server/prover/v2/marshal_inclusion.go @@ -9,7 +9,7 @@ import ( type InclusionProofInputsJSON struct { Root string `json:"root"` - PathIndex uint32 `json:"pathIndex"` + PathIndex uint64 `json:"pathIndex"` PathElements []string `json:"pathElements"` Leaf string `json:"leaf"` } diff --git a/prover/server/prover/v2/marshal_non_inclusion.go b/prover/server/prover/v2/marshal_non_inclusion.go index 9a4e1e2977..a2ae2b46b9 100644 --- a/prover/server/prover/v2/marshal_non_inclusion.go +++ b/prover/server/prover/v2/marshal_non_inclusion.go @@ -10,7 +10,7 @@ import ( type NonInclusionProofInputsJSON struct { Root string `json:"root"` Value string `json:"value"` - PathIndex uint32 `json:"pathIndex"` + PathIndex uint64 `json:"pathIndex"` PathElements []string `json:"pathElements"` LeafLowerRangeValue string `json:"leafLowerRangeValue"` LeafHigherRangeValue string `json:"leafHigherRangeValue"` diff --git a/prover/server/prover/v2/non_inclusion_proving_system.go b/prover/server/prover/v2/non_inclusion_proving_system.go index 952ea922bd..380c8d56f8 100644 --- a/prover/server/prover/v2/non_inclusion_proving_system.go +++ b/prover/server/prover/v2/non_inclusion_proving_system.go @@ -17,7 +17,7 @@ import ( type NonInclusionInputs struct { Root big.Int Value big.Int - PathIndex uint32 + PathIndex uint64 PathElements []big.Int LeafLowerRangeValue big.Int diff --git a/prover/server/prover/v2/test_data_helpers.go b/prover/server/prover/v2/test_data_helpers.go index d6c670f082..09caeee544 100644 --- a/prover/server/prover/v2/test_data_helpers.go +++ b/prover/server/prover/v2/test_data_helpers.go @@ -31,7 +31,7 @@ func BuildTestTree(depth int, numberOfCompressedAccounts int, random bool) Inclu for i := 0; i < numberOfCompressedAccounts; i++ { inputs[i].Leaf = *leaf - inputs[i].PathIndex = uint32(pathIndex) + inputs[i].PathIndex = uint64(pathIndex) inputs[i].PathElements = tree.Update(pathIndex, *leaf) inputs[i].Root = tree.Root.Value() leaves[i] = leaf @@ -92,7 +92,7 @@ func BuildTestNonInclusionTree(depth int, numberOfCompressedAccounts int, random } inputs[i].Value = *value - inputs[i].PathIndex = uint32(pathIndex) + inputs[i].PathIndex = uint64(pathIndex) inputs[i].PathElements = tree.Update(pathIndex, *leaf) inputs[i].Root = tree.Root.Value() inputs[i].LeafLowerRangeValue = *leafLower diff --git a/scripts/format.sh b/scripts/format.sh index 5a4449c473..001e4c2370 100755 --- a/scripts/format.sh +++ b/scripts/format.sh @@ -28,7 +28,7 @@ CARGO_BUILD_JOBS="$CLIPPY_JOBS" cargo clippy \ # Regenerate READMEs with cargo-rdme echo "Regenerating READMEs..." if ! command -v cargo-rdme &> /dev/null; then - cargo install cargo-rdme + cargo install cargo-rdme --locked fi for toml in $(find program-libs sdk-libs -name '.cargo-rdme.toml' -type f); do crate_dir=$(dirname "$toml") diff --git a/scripts/lint.sh b/scripts/lint.sh index b440380898..bdf6d0fbb5 100755 --- a/scripts/lint.sh +++ b/scripts/lint.sh @@ -20,7 +20,7 @@ cargo clippy --workspace --all-features --all-targets -- -D warnings # Check that READMEs are up-to-date with cargo-rdme echo "Checking READMEs are up-to-date..." if ! command -v cargo-rdme &> /dev/null; then - cargo install cargo-rdme + cargo install cargo-rdme --locked fi for toml in $(find program-libs sdk-libs -name '.cargo-rdme.toml' -type f); do crate_dir=$(dirname "$toml") diff --git a/sdk-libs/client/src/indexer/types/proof.rs b/sdk-libs/client/src/indexer/types/proof.rs index 0b45e00986..75291e8cff 100644 --- a/sdk-libs/client/src/indexer/types/proof.rs +++ b/sdk-libs/client/src/indexer/types/proof.rs @@ -189,41 +189,57 @@ pub struct PackedTreeInfos { } impl ValidityProofWithContext { - pub fn pack_tree_infos(&self, packed_accounts: &mut PackedAccounts) -> PackedTreeInfos { - let mut packed_tree_infos = Vec::new(); - let mut address_trees = Vec::new(); - let mut output_tree_index = None; - for account in self.accounts.iter() { - // Pack TreeInfo - let merkle_tree_pubkey_index = packed_accounts.insert_or_get(account.tree_info.tree); - let queue_pubkey_index = packed_accounts.insert_or_get(account.tree_info.queue); - let tree_info_packed = PackedStateTreeInfo { - root_index: account.root_index.root_index, - merkle_tree_pubkey_index, - queue_pubkey_index, + pub fn pack_state_tree_infos( + &self, + packed_accounts: &mut PackedAccounts, + ) -> Vec { + self.accounts + .iter() + .map(|account| PackedStateTreeInfo { + root_index: account.root_index.root_index().unwrap_or_default(), + merkle_tree_pubkey_index: packed_accounts.insert_or_get(account.tree_info.tree), + queue_pubkey_index: packed_accounts.insert_or_get(account.tree_info.queue), leaf_index: account.leaf_index as u32, prove_by_index: account.root_index.proof_by_index(), - }; - packed_tree_infos.push(tree_info_packed); + }) + .collect() + } + pub fn pack_tree_infos( + &self, + packed_accounts: &mut PackedAccounts, + ) -> Result { + let packed_tree_infos = self.pack_state_tree_infos(packed_accounts); + let mut address_trees = Vec::new(); + let mut output_tree_index = None; + for account in self.accounts.iter() { // If a next Merkle tree exists the Merkle tree is full -> use the next Merkle tree for new state. // Else use the current Merkle tree for new state. if let Some(next) = account.tree_info.next_tree_info { // SAFETY: account will always have a state Merkle tree context. // pack_output_tree_index only panics on an address Merkle tree context. - let index = next.pack_output_tree_index(packed_accounts).unwrap(); - if output_tree_index.is_none() { - output_tree_index = Some(index); + let index = next.pack_output_tree_index(packed_accounts)?; + match output_tree_index { + Some(existing) if existing != index => { + return Err(IndexerError::InvalidParameters(format!( + "mixed output tree indices in state proof: {existing} != {index}" + ))); + } + Some(_) => {} + None => output_tree_index = Some(index), } } else { // SAFETY: account will always have a state Merkle tree context. // pack_output_tree_index only panics on an address Merkle tree context. - let index = account - .tree_info - .pack_output_tree_index(packed_accounts) - .unwrap(); - if output_tree_index.is_none() { - output_tree_index = Some(index); + let index = account.tree_info.pack_output_tree_index(packed_accounts)?; + match output_tree_index { + Some(existing) if existing != index => { + return Err(IndexerError::InvalidParameters(format!( + "mixed output tree indices in state proof: {existing} != {index}" + ))); + } + Some(_) => {} + None => output_tree_index = Some(index), } } } @@ -244,13 +260,17 @@ impl ValidityProofWithContext { } else { Some(PackedStateTreeInfos { packed_tree_infos, - output_tree_index: output_tree_index.unwrap(), + output_tree_index: output_tree_index.ok_or_else(|| { + IndexerError::InvalidParameters( + "missing output tree index for non-empty state proof".to_string(), + ) + })?, }) }; - PackedTreeInfos { + Ok(PackedTreeInfos { state_trees: packed_tree_infos, address_trees, - } + }) } pub fn from_api_model( @@ -365,3 +385,133 @@ impl ValidityProofWithContext { }) } } + +#[cfg(test)] +mod tests { + use light_compressed_account::TreeType; + + use super::*; + use crate::indexer::NextTreeInfo; + + /// Helper to build an `AccountProofInputs` with a given tree and queue pubkey. + fn account_with_tree(tree: Pubkey, queue: Pubkey) -> AccountProofInputs { + AccountProofInputs { + hash: [0u8; 32], + root: [0u8; 32], + root_index: RootIndex::new_some(0), + leaf_index: 0, + tree_info: TreeInfo { + tree, + queue, + cpi_context: None, + next_tree_info: None, + tree_type: TreeType::StateV1, + }, + } + } + + #[test] + fn test_pack_tree_infos_mixed_output_tree_indices_error() { + let tree_a = Pubkey::new_unique(); + let tree_b = Pubkey::new_unique(); + let queue_a = Pubkey::new_unique(); + let queue_b = Pubkey::new_unique(); + + let proof_ctx = ValidityProofWithContext { + proof: ValidityProof::default(), + accounts: vec![ + account_with_tree(tree_a, queue_a), + account_with_tree(tree_b, queue_b), + ], + addresses: vec![], + }; + + let mut packed = PackedAccounts::default(); + let result = proof_ctx.pack_tree_infos(&mut packed); + + assert!(result.is_err(), "expected error for mixed output trees"); + let err = result.unwrap_err(); + match &err { + IndexerError::InvalidParameters(msg) => { + assert!( + msg.contains("mixed output tree indices"), + "unexpected error message: {msg}" + ); + } + other => panic!("expected InvalidParameters, got: {other:?}"), + } + } + + #[test] + fn test_pack_tree_infos_mixed_next_tree_indices_error() { + // Both accounts use the same input tree but have different *next* trees. + let input_tree = Pubkey::new_unique(); + let input_queue = Pubkey::new_unique(); + let next_tree_a = Pubkey::new_unique(); + let next_queue_a = Pubkey::new_unique(); + let next_tree_b = Pubkey::new_unique(); + let next_queue_b = Pubkey::new_unique(); + + let mut acc1 = account_with_tree(input_tree, input_queue); + acc1.tree_info.next_tree_info = Some(NextTreeInfo { + tree: next_tree_a, + queue: next_queue_a, + cpi_context: None, + tree_type: TreeType::StateV1, + }); + + let mut acc2 = account_with_tree(input_tree, input_queue); + acc2.tree_info.next_tree_info = Some(NextTreeInfo { + tree: next_tree_b, + queue: next_queue_b, + cpi_context: None, + tree_type: TreeType::StateV1, + }); + + let proof_ctx = ValidityProofWithContext { + proof: ValidityProof::default(), + accounts: vec![acc1, acc2], + addresses: vec![], + }; + + let mut packed = PackedAccounts::default(); + let result = proof_ctx.pack_tree_infos(&mut packed); + + assert!( + result.is_err(), + "expected error for mixed next-tree output indices" + ); + let err = result.unwrap_err(); + match &err { + IndexerError::InvalidParameters(msg) => { + assert!( + msg.contains("mixed output tree indices"), + "unexpected error message: {msg}" + ); + } + other => panic!("expected InvalidParameters, got: {other:?}"), + } + } + + #[test] + fn test_pack_tree_infos_same_output_tree_ok() { + let tree = Pubkey::new_unique(); + let queue = Pubkey::new_unique(); + + let proof_ctx = ValidityProofWithContext { + proof: ValidityProof::default(), + accounts: vec![ + account_with_tree(tree, queue), + account_with_tree(tree, queue), + ], + addresses: vec![], + }; + + let mut packed = PackedAccounts::default(); + let result = proof_ctx.pack_tree_infos(&mut packed); + assert!( + result.is_ok(), + "same output trees should succeed: {result:?}" + ); + } +} diff --git a/sdk-libs/client/src/indexer/types/queue.rs b/sdk-libs/client/src/indexer/types/queue.rs index 45f1094b6f..41ceeadfc4 100644 --- a/sdk-libs/client/src/indexer/types/queue.rs +++ b/sdk-libs/client/src/indexer/types/queue.rs @@ -115,6 +115,7 @@ impl AddressQueueData { pub fn reconstruct_all_proofs( &self, ) -> Result, IndexerError> { + self.validate_proof_height::()?; self.reconstruct_proofs::(0..self.addresses.len()) } @@ -135,6 +136,7 @@ impl AddressQueueData { address_idx: usize, node_lookup: &HashMap, ) -> Result<[[u8; 32]; HEIGHT], IndexerError> { + self.validate_proof_height::()?; let leaf_index = *self.low_element_indices.get(address_idx).ok_or_else(|| { IndexerError::MissingResult { context: "reconstruct_proof".to_string(), diff --git a/sdk-libs/client/src/interface/initialize_config.rs b/sdk-libs/client/src/interface/initialize_config.rs index 7b5919cdb1..9fbeacfe89 100644 --- a/sdk-libs/client/src/interface/initialize_config.rs +++ b/sdk-libs/client/src/interface/initialize_config.rs @@ -7,6 +7,8 @@ use borsh::{BorshDeserialize as AnchorDeserialize, BorshSerialize as AnchorSeria use solana_instruction::{AccountMeta, Instruction}; use solana_pubkey::Pubkey; +use crate::interface::instructions::INITIALIZE_COMPRESSION_CONFIG_DISCRIMINATOR; + /// Default address tree v2 pubkey. pub const ADDRESS_TREE_V2: Pubkey = solana_pubkey::pubkey!("amt2kaJA14v3urZbZvnc5v2np8jqvc4Z8zDep5wbtzx"); @@ -115,16 +117,14 @@ impl InitializeRentFreeConfig { address_space: self.address_space, }; - // Anchor discriminator for "initialize_compression_config" - // SHA256("global:initialize_compression_config")[..8] - const DISCRIMINATOR: [u8; 8] = [133, 228, 12, 169, 56, 76, 222, 61]; - let serialized_data = instruction_data .try_to_vec() .expect("Failed to serialize instruction data"); - let mut data = Vec::with_capacity(DISCRIMINATOR.len() + serialized_data.len()); - data.extend_from_slice(&DISCRIMINATOR); + let mut data = Vec::with_capacity( + INITIALIZE_COMPRESSION_CONFIG_DISCRIMINATOR.len() + serialized_data.len(), + ); + data.extend_from_slice(&INITIALIZE_COMPRESSION_CONFIG_DISCRIMINATOR); data.extend_from_slice(&serialized_data); let instruction = Instruction { diff --git a/sdk-libs/client/src/interface/instructions.rs b/sdk-libs/client/src/interface/instructions.rs index e80e7b72c1..bb4056ceae 100644 --- a/sdk-libs/client/src/interface/instructions.rs +++ b/sdk-libs/client/src/interface/instructions.rs @@ -234,12 +234,7 @@ where let output_queue = get_output_queue(&cold_accounts[0].0.tree_info); let output_state_tree_index = remaining_accounts.insert_or_get(output_queue); - let packed_tree_infos = proof.pack_tree_infos(&mut remaining_accounts); - let tree_infos = &packed_tree_infos - .state_trees - .as_ref() - .ok_or("missing state_trees in packed_tree_infos")? - .packed_tree_infos; + let tree_infos = proof.pack_state_tree_infos(&mut remaining_accounts); let mut accounts = program_account_metas.to_vec(); let mut typed_accounts = Vec::with_capacity(cold_accounts.len()); @@ -313,14 +308,9 @@ pub fn build_compress_accounts_idempotent( let output_queue = get_output_queue(&proof.accounts[0].tree_info); let output_state_tree_index = remaining_accounts.insert_or_get(output_queue); - let packed_tree_infos = proof.pack_tree_infos(&mut remaining_accounts); - let tree_infos = packed_tree_infos - .state_trees - .as_ref() - .ok_or("missing state_trees in packed_tree_infos")?; + let tree_infos = proof.pack_state_tree_infos(&mut remaining_accounts); let cold_metas: Vec<_> = tree_infos - .packed_tree_infos .iter() .map(|tree_info| CompressedAccountMetaNoLamportsNoAddress { tree_info: *tree_info, diff --git a/sdk-libs/client/src/interface/load_accounts.rs b/sdk-libs/client/src/interface/load_accounts.rs index 0d564734a2..01d91364c1 100644 --- a/sdk-libs/client/src/interface/load_accounts.rs +++ b/sdk-libs/client/src/interface/load_accounts.rs @@ -220,7 +220,7 @@ fn group_pda_specs<'a, V>( specs: &[&'a PdaSpec], max_per_group: usize, ) -> Vec>> { - assert!(max_per_group > 0, "max_per_group must be non-zero"); + debug_assert!(max_per_group > 0, "max_per_group must be non-zero"); if specs.is_empty() { return Vec::new(); } @@ -424,11 +424,7 @@ fn build_transfer2( fee_payer: Pubkey, ) -> Result { let mut packed = PackedAccounts::default(); - let packed_trees = proof.pack_tree_infos(&mut packed); - let tree_infos = packed_trees - .state_trees - .as_ref() - .ok_or_else(|| LoadAccountsError::BuildInstruction("no state trees".into()))?; + let tree_infos = proof.pack_state_tree_infos(&mut packed); let mut token_accounts = Vec::with_capacity(contexts.len()); let mut tlv_data: Vec> = Vec::with_capacity(contexts.len()); @@ -436,12 +432,12 @@ fn build_transfer2( for (i, ctx) in contexts.iter().enumerate() { let token = &ctx.compressed.token; - let tree = tree_infos.packed_tree_infos.get(i).ok_or( - LoadAccountsError::TreeInfoIndexOutOfBounds { + let tree = tree_infos + .get(i) + .ok_or(LoadAccountsError::TreeInfoIndexOutOfBounds { index: i, - len: tree_infos.packed_tree_infos.len(), - }, - )?; + len: tree_infos.len(), + })?; let owner_idx = packed.insert_or_get_config(ctx.wallet_owner, true, false); let ata_idx = packed.insert_or_get(derive_token_ata(&ctx.wallet_owner, &ctx.mint)); diff --git a/sdk-libs/client/src/interface/pack.rs b/sdk-libs/client/src/interface/pack.rs index 804a48751d..97505adabe 100644 --- a/sdk-libs/client/src/interface/pack.rs +++ b/sdk-libs/client/src/interface/pack.rs @@ -12,6 +12,9 @@ use crate::indexer::{TreeInfo, ValidityProofWithContext}; pub enum PackError { #[error("Failed to add system accounts: {0}")] SystemAccounts(#[from] light_sdk::error::LightSdkError), + + #[error("Failed to pack tree infos: {0}")] + Indexer(#[from] crate::indexer::IndexerError), } /// Packed state tree infos from validity proof. @@ -87,7 +90,7 @@ fn pack_proof_internal( // For mint creation: pack address tree first (index 1), then state tree. let (client_packed_tree_infos, state_tree_index) = if include_state_tree { // Pack tree infos first to ensure address tree is at index 1 - let tree_infos = proof.pack_tree_infos(&mut packed); + let tree_infos = proof.pack_tree_infos(&mut packed)?; // Then add state tree (will be after address tree) let state_tree = output_tree @@ -99,7 +102,7 @@ fn pack_proof_internal( (tree_infos, Some(state_idx)) } else { - let tree_infos = proof.pack_tree_infos(&mut packed); + let tree_infos = proof.pack_tree_infos(&mut packed)?; (tree_infos, None) }; let (remaining_accounts, system_offset, _) = packed.to_account_metas(); diff --git a/sdk-libs/client/src/local_test_validator.rs b/sdk-libs/client/src/local_test_validator.rs index 36ed7c04b3..bc6f0817da 100644 --- a/sdk-libs/client/src/local_test_validator.rs +++ b/sdk-libs/client/src/local_test_validator.rs @@ -1,6 +1,7 @@ -use std::process::{Command, Stdio}; +use std::process::Stdio; use light_prover_client::helpers::get_project_root; +use tokio::process::Command; /// Configuration for an upgradeable program to deploy to the validator. #[derive(Debug, Clone)] @@ -57,71 +58,78 @@ impl Default for LightValidatorConfig { pub async fn spawn_validator(config: LightValidatorConfig) { if let Some(project_root) = get_project_root() { - let path = "cli/test_bin/run test-validator"; - let mut path = format!("{}/{}", project_root.trim(), path); + let project_root = project_root.trim_end_matches(['\n', '\r']); + let executable = format!("{}/cli/test_bin/run", project_root); + let mut args = vec!["test-validator".to_string()]; if !config.enable_indexer { - path.push_str(" --skip-indexer"); + args.push("--skip-indexer".to_string()); } if let Some(limit_ledger_size) = config.limit_ledger_size { - path.push_str(&format!(" --limit-ledger-size {}", limit_ledger_size)); + args.push("--limit-ledger-size".to_string()); + args.push(limit_ledger_size.to_string()); } for sbf_program in config.sbf_programs.iter() { - path.push_str(&format!( - " --sbf-program {} {}", - sbf_program.0, sbf_program.1 - )); + args.push("--sbf-program".to_string()); + args.push(sbf_program.0.clone()); + args.push(sbf_program.1.clone()); } for upgradeable_program in config.upgradeable_programs.iter() { - path.push_str(&format!( - " --upgradeable-program {} {} {}", - upgradeable_program.program_id, - upgradeable_program.program_path, - upgradeable_program.upgrade_authority - )); + args.push("--upgradeable-program".to_string()); + args.push(upgradeable_program.program_id.clone()); + args.push(upgradeable_program.program_path.clone()); + args.push(upgradeable_program.upgrade_authority.clone()); } if !config.enable_prover { - path.push_str(" --skip-prover"); + args.push("--skip-prover".to_string()); } if config.use_surfpool { - path.push_str(" --use-surfpool"); + args.push("--use-surfpool".to_string()); } for arg in config.validator_args.iter() { - path.push_str(&format!(" {}", arg)); + args.push(arg.clone()); } - println!("Starting validator with command: {}", path); + println!( + "Starting validator with command: {} {}", + executable, + args.join(" ") + ); if config.use_surfpool { // The CLI starts surfpool, prover, and photon, then exits once all // services are ready. Wait for it to finish so we know everything // is up before the test proceeds. - let mut child = Command::new("sh") - .arg("-c") - .arg(path) + let mut child = Command::new(&executable) + .args(&args) .stdin(Stdio::null()) .stdout(Stdio::inherit()) .stderr(Stdio::inherit()) .spawn() .expect("Failed to start server process"); - let status = child.wait().expect("Failed to wait for CLI process"); + let status = child.wait().await.expect("Failed to wait for CLI process"); assert!(status.success(), "CLI exited with error: {}", status); } else { - let child = Command::new("sh") - .arg("-c") - .arg(path) + let mut child = Command::new(&executable) + .args(&args) .stdin(Stdio::null()) .stdout(Stdio::null()) .stderr(Stdio::null()) .spawn() .expect("Failed to start server process"); - std::mem::drop(child); tokio::time::sleep(tokio::time::Duration::from_secs(config.wait_time)).await; + if let Some(status) = child.try_wait().expect("Failed to poll validator process") { + assert!( + status.success(), + "Validator exited early with error: {}", + status + ); + } } } } diff --git a/sdk-libs/client/src/utils.rs b/sdk-libs/client/src/utils.rs index b8f2e05ecb..0055f8dbea 100644 --- a/sdk-libs/client/src/utils.rs +++ b/sdk-libs/client/src/utils.rs @@ -15,8 +15,11 @@ pub fn find_light_bin() -> Option { if !output.status.success() { return None; } - // Convert the output into a string (removing any trailing newline) - let light_path = String::from_utf8_lossy(&output.stdout).trim().to_string(); + let light_path = std::str::from_utf8(&output.stdout) + .ok()? + .trim_end_matches("\r\n") + .trim_end_matches('\n') + .to_string(); // Get the parent directory of the 'light' binary let mut light_bin_path = PathBuf::from(light_path); light_bin_path.pop(); // Remove the 'light' binary itself @@ -30,16 +33,16 @@ pub fn find_light_bin() -> Option { #[cfg(feature = "devenv")] { println!("Use only in light protocol monorepo. Using 'git rev-parse --show-toplevel' to find the location of 'light' binary"); - let light_protocol_toplevel = String::from_utf8_lossy( - &std::process::Command::new("git") - .arg("rev-parse") - .arg("--show-toplevel") - .output() - .expect("Failed to get top-level directory") - .stdout, - ) - .trim() - .to_string(); + let output = std::process::Command::new("git") + .arg("rev-parse") + .arg("--show-toplevel") + .output() + .expect("Failed to get top-level directory"); + let light_protocol_toplevel = std::str::from_utf8(&output.stdout) + .ok()? + .trim_end_matches("\r\n") + .trim_end_matches('\n') + .to_string(); let light_path = PathBuf::from(format!("{}/target/deploy/", light_protocol_toplevel)); Some(light_path) } diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index b2d86635b1..176aa94373 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -95,6 +95,20 @@ use crate::accounts::{ }; use crate::indexer::TestIndexerExtensions; +fn build_compressed_proof(body: &str) -> Result { + let proof_json = deserialize_gnark_proof_json(body) + .map_err(|error| IndexerError::CustomError(error.to_string()))?; + let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json) + .map_err(|error| IndexerError::CustomError(error.to_string()))?; + let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c); + + Ok(CompressedProof { + a: proof_a, + b: proof_b, + c: proof_c, + }) +} + #[derive(Debug)] pub struct TestIndexer { pub state_merkle_trees: Vec, @@ -1706,11 +1720,13 @@ impl TestIndexer { DEFAULT_BATCH_ROOT_HISTORY_LEN, )); - (FeeConfig::test_batched().state_merkle_tree_rollover as i64,merkle_tree, Some(params.output_queue_batch_size as usize)) + (FeeConfig::test_batched().state_merkle_tree_rollover as i64, merkle_tree, Some(params.output_queue_batch_size as usize)) } #[cfg(not(feature = "devenv"))] - panic!("Batched state merkle trees require the 'devenv' feature to be enabled") + { + panic!("Batched state merkle trees require the 'devenv' feature to be enabled") + } } _ => panic!( "add_state_merkle_tree: tree_type not supported, {}. tree_type: 1 concurrent, 2 batched", @@ -2367,7 +2383,8 @@ impl TestIndexer { Some( BatchNonInclusionJsonStruct::from_non_inclusion_proof_inputs( &non_inclusion_proof_inputs, - ), + ) + .map_err(|error| IndexerError::CustomError(error.to_string()))?, ), None, ) @@ -2595,20 +2612,10 @@ impl TestIndexer { })?; if status.is_success() { - let proof_json = deserialize_gnark_proof_json(&body) - .map_err(|error| IndexerError::CustomError(error.to_string()))?; - let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); - let (proof_a, proof_b, proof_c) = - compress_proof(&proof_a, &proof_b, &proof_c); return Ok(ValidityProofWithContext { accounts: account_proof_inputs, addresses: address_proof_inputs, - proof: CompressedProof { - a: proof_a, - b: proof_b, - c: proof_c, - } - .into(), + proof: build_compressed_proof(&body)?.into(), }); } diff --git a/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/read_only.rs b/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/read_only.rs index 154f4e2045..e36b1ef30c 100644 --- a/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/read_only.rs +++ b/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/read_only.rs @@ -127,7 +127,7 @@ async fn create_compressed_account( ) .await? .value; - let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts).unwrap(); let output_tree_index = rpc .get_random_state_tree_info() @@ -178,6 +178,7 @@ async fn read_sha256_light_system_cpi( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .unwrap() .state_trees .unwrap(); @@ -231,6 +232,7 @@ async fn read_sha256_lowlevel( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .unwrap() .state_trees .unwrap(); @@ -289,7 +291,7 @@ async fn create_compressed_account_poseidon( ) .await? .value; - let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts).unwrap(); let output_tree_index = rpc .get_random_state_tree_info() @@ -340,6 +342,7 @@ async fn read_poseidon_light_system_cpi( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .unwrap() .state_trees .unwrap(); @@ -393,6 +396,7 @@ async fn read_poseidon_lowlevel( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .unwrap() .state_trees .unwrap(); diff --git a/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/test.rs b/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/test.rs index e19d0742de..a5a4655db2 100644 --- a/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/test.rs +++ b/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/test.rs @@ -171,7 +171,7 @@ async fn create_compressed_account( ) .await? .value; - let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts).unwrap(); let output_tree_index = rpc .get_random_state_tree_info() @@ -223,6 +223,7 @@ async fn update_compressed_account( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .unwrap() .state_trees .unwrap(); @@ -277,6 +278,7 @@ async fn close_compressed_account( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .unwrap() .state_trees .unwrap(); @@ -340,6 +342,7 @@ async fn reinit_closed_account( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .unwrap() .state_trees .unwrap(); @@ -388,6 +391,7 @@ async fn close_compressed_account_permanent( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .unwrap() .state_trees .unwrap(); diff --git a/sdk-tests/sdk-native-test/tests/test.rs b/sdk-tests/sdk-native-test/tests/test.rs index 30d792487f..ac519f0e57 100644 --- a/sdk-tests/sdk-native-test/tests/test.rs +++ b/sdk-tests/sdk-native-test/tests/test.rs @@ -103,7 +103,10 @@ pub async fn create_pda( .value; let output_merkle_tree_index = accounts.insert_or_get(*merkle_tree_pubkey); - let packed_address_tree_info = rpc_result.pack_tree_infos(&mut accounts).address_trees[0]; + let packed_address_tree_info = rpc_result + .pack_tree_infos(&mut accounts) + .unwrap() + .address_trees[0]; let (accounts, system_accounts_offset, tree_accounts_offset) = accounts.to_account_metas(); let instruction_data = CreatePdaInstructionData { @@ -147,6 +150,7 @@ pub async fn update_pda( let packed_accounts = rpc_result .pack_tree_infos(&mut accounts) + .unwrap() .state_trees .unwrap(); diff --git a/sdk-tests/sdk-pinocchio-v1-test/tests/test.rs b/sdk-tests/sdk-pinocchio-v1-test/tests/test.rs index 0ae7f5c029..17d2c07b1e 100644 --- a/sdk-tests/sdk-pinocchio-v1-test/tests/test.rs +++ b/sdk-tests/sdk-pinocchio-v1-test/tests/test.rs @@ -101,7 +101,10 @@ pub async fn create_pda( .value; let output_merkle_tree_index = accounts.insert_or_get(*merkle_tree_pubkey); - let packed_address_tree_info = rpc_result.pack_tree_infos(&mut accounts).address_trees[0]; + let packed_address_tree_info = rpc_result + .pack_tree_infos(&mut accounts) + .unwrap() + .address_trees[0]; let (accounts, system_accounts_offset, tree_accounts_offset) = accounts.to_account_metas(); let instruction_data = CreatePdaInstructionData { proof: rpc_result.proof, @@ -145,6 +148,7 @@ pub async fn update_pda( let packed_accounts = rpc_result .pack_tree_infos(&mut accounts) + .unwrap() .state_trees .unwrap(); diff --git a/sdk-tests/sdk-pinocchio-v2-test/tests/test.rs b/sdk-tests/sdk-pinocchio-v2-test/tests/test.rs index 59a0562c63..510c98b2b5 100644 --- a/sdk-tests/sdk-pinocchio-v2-test/tests/test.rs +++ b/sdk-tests/sdk-pinocchio-v2-test/tests/test.rs @@ -111,7 +111,8 @@ pub async fn create_pda( .value; let output_merkle_tree_index = accounts.insert_or_get(*merkle_tree_pubkey); - let packed_address_tree_info = rpc_result.pack_tree_infos(&mut accounts).address_trees[0]; + let packed_tree_infos = rpc_result.pack_tree_infos(&mut accounts)?; + let packed_address_tree_info = packed_tree_infos.address_trees[0]; let (accounts, system_accounts_offset, tree_accounts_offset) = accounts.to_account_metas(); let instruction_data = CreatePdaInstructionData { proof: rpc_result.proof, @@ -154,7 +155,7 @@ pub async fn update_pda( .value; let packed_accounts = rpc_result - .pack_tree_infos(&mut accounts) + .pack_tree_infos(&mut accounts)? .state_trees .unwrap(); diff --git a/sdk-tests/sdk-token-test/tests/ctoken_pda.rs b/sdk-tests/sdk-token-test/tests/ctoken_pda.rs index 8e2b595285..e17308bd62 100644 --- a/sdk-tests/sdk-token-test/tests/ctoken_pda.rs +++ b/sdk-tests/sdk-token-test/tests/ctoken_pda.rs @@ -156,7 +156,7 @@ pub async fn create_mint( let config = SystemAccountMetaConfig::new_with_cpi_context(ID, tree_info.cpi_context.unwrap()); packed_accounts.add_system_accounts_v2(config).unwrap(); // packed_accounts.insert_or_get(tree_info.get_output_pubkey()?); - rpc_result.pack_tree_infos(&mut packed_accounts); + rpc_result.pack_tree_infos(&mut packed_accounts).unwrap(); // Create PDA parameters let pda_amount = 100u64; diff --git a/sdk-tests/sdk-token-test/tests/decompress_full_cpi.rs b/sdk-tests/sdk-token-test/tests/decompress_full_cpi.rs index 5f096af560..bd511f2b2b 100644 --- a/sdk-tests/sdk-token-test/tests/decompress_full_cpi.rs +++ b/sdk-tests/sdk-token-test/tests/decompress_full_cpi.rs @@ -213,7 +213,7 @@ async fn test_decompress_full_cpi() { .unwrap() .value; - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_infos = rpc_result.pack_state_tree_infos(&mut remaining_accounts); let config = DecompressFullAccounts::new(None); remaining_accounts .add_custom_system_accounts(config) @@ -236,12 +236,7 @@ async fn test_decompress_full_cpi() { let indices: Vec<_> = token_data .iter() .zip( - packed_tree_info - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos - .iter(), + packed_tree_infos.iter(), ) .zip(ctx.destination_accounts.iter()) .zip(versions.iter()) @@ -370,7 +365,7 @@ async fn test_decompress_full_cpi_with_context() { .value; // Add tree accounts first, then custom system accounts (no CPI context since params is None) - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_infos = rpc_result.pack_state_tree_infos(&mut remaining_accounts); let config = DecompressFullAccounts::new(None); remaining_accounts .add_custom_system_accounts(config) @@ -393,12 +388,7 @@ async fn test_decompress_full_cpi_with_context() { let indices: Vec<_> = token_data .iter() .zip( - packed_tree_info - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos - .iter(), + packed_tree_infos.iter(), ) .zip(ctx.destination_accounts.iter()) .zip(versions.iter()) diff --git a/sdk-tests/sdk-token-test/tests/pda_ctoken.rs b/sdk-tests/sdk-token-test/tests/pda_ctoken.rs index 91e0f2db9e..130c1caa0d 100644 --- a/sdk-tests/sdk-token-test/tests/pda_ctoken.rs +++ b/sdk-tests/sdk-token-test/tests/pda_ctoken.rs @@ -214,7 +214,7 @@ pub async fn create_mint( let mut packed_accounts = PackedAccounts::default(); let config = SystemAccountMetaConfig::new_with_cpi_context(ID, tree_info.cpi_context.unwrap()); packed_accounts.add_system_accounts_v2(config).unwrap(); - rpc_result.pack_tree_infos(&mut packed_accounts); + rpc_result.pack_tree_infos(&mut packed_accounts).unwrap(); // Create PDA parameters let pda_amount = 100u64; diff --git a/sdk-tests/sdk-token-test/tests/test.rs b/sdk-tests/sdk-token-test/tests/test.rs index 3c6941881d..9bebbcf897 100644 --- a/sdk-tests/sdk-token-test/tests/test.rs +++ b/sdk-tests/sdk-token-test/tests/test.rs @@ -367,7 +367,7 @@ async fn transfer_compressed_tokens( .await? .value; - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts).unwrap(); let output_tree_index = packed_tree_info .state_trees .as_ref() @@ -433,7 +433,7 @@ async fn decompress_compressed_tokens( .await? .value; - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts).unwrap(); let output_tree_index = packed_tree_info .state_trees .as_ref() diff --git a/sdk-tests/sdk-token-test/tests/test_4_invocations.rs b/sdk-tests/sdk-token-test/tests/test_4_invocations.rs index 9e70170056..3c0bce8241 100644 --- a/sdk-tests/sdk-token-test/tests/test_4_invocations.rs +++ b/sdk-tests/sdk-token-test/tests/test_4_invocations.rs @@ -22,6 +22,21 @@ use solana_sdk::{ signature::{Keypair, Signature, Signer}, }; +fn pack_selected_output_tree_index( + tree_info: light_client::indexer::TreeInfo, + remaining_accounts: &mut PackedAccounts, +) -> Result> { + tree_info + .next_tree_info + .map(|next| next.pack_output_tree_index(remaining_accounts)) + .unwrap_or_else(|| tree_info.pack_output_tree_index(remaining_accounts)) + .map_err(|error| { + Box::new(RpcError::CustomError(format!( + "Failed to pack output tree index: {error}" + ))) + }) +} + #[ignore = "fix cpi context usage"] #[tokio::test] async fn test_4_invocations() { @@ -389,7 +404,7 @@ async fn create_compressed_escrow_pda( .await? .value; - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts).unwrap(); let new_address_params = packed_tree_info.address_trees[0] .into_new_address_params_assigned_packed(address_seed, Some(0)); @@ -495,29 +510,18 @@ async fn test_four_invokes_instruction( ) .await? .value; - // We need to pack the tree after the cpi context. - remaining_accounts.insert_or_get(rpc_result.accounts[0].tree_info.tree); - - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); - let output_tree_index = packed_tree_info - .state_trees - .as_ref() - .unwrap() - .output_tree_index; + let output_tree_index = pack_selected_output_tree_index( + mint2_token_account.account.tree_info, + &mut remaining_accounts, + ) + .map_err(|error| *error)?; + let packed_tree_infos = rpc_result.pack_state_tree_infos(&mut remaining_accounts); // Create token metas from compressed accounts - each uses its respective tree info index // Index 0: escrow PDA, Index 1: mint2 token account, Index 2: mint3 token account - let mint2_tree_info = packed_tree_info - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[1]; + let mint2_tree_info = packed_tree_infos[1]; - let mint3_tree_info = packed_tree_info - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[2]; + let mint3_tree_info = packed_tree_infos[2]; // Create FourInvokesParams let four_invokes_params = sdk_token_test::FourInvokesParams { @@ -557,11 +561,7 @@ async fn test_four_invokes_instruction( }; // Create PdaParams - escrow PDA uses tree info index 0 - let escrow_tree_info = packed_tree_info - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[0]; + let escrow_tree_info = packed_tree_infos[0]; let pda_params = sdk_token_test::PdaParams { account_meta: light_sdk::instruction::account_meta::CompressedAccountMeta { diff --git a/sdk-tests/sdk-token-test/tests/test_4_transfer2.rs b/sdk-tests/sdk-token-test/tests/test_4_transfer2.rs index d7ef38a08c..1e17b61ce1 100644 --- a/sdk-tests/sdk-token-test/tests/test_4_transfer2.rs +++ b/sdk-tests/sdk-token-test/tests/test_4_transfer2.rs @@ -27,6 +27,20 @@ use solana_sdk::{ signature::{Keypair, Signer}, }; +#[allow(clippy::result_large_err)] +fn pack_selected_output_tree_index( + tree_info: light_client::indexer::TreeInfo, + remaining_accounts: &mut PackedAccounts, +) -> Result { + tree_info + .next_tree_info + .map(|next| next.pack_output_tree_index(remaining_accounts)) + .unwrap_or_else(|| tree_info.pack_output_tree_index(remaining_accounts)) + .map_err(|error| { + RpcError::CustomError(format!("Failed to pack output tree index: {error}")) + }) +} + #[tokio::test] async fn test_4_transfer2() { // Initialize the test environment @@ -339,7 +353,7 @@ async fn create_compressed_escrow_pda( .await? .value; - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts).unwrap(); let new_address_params = packed_tree_info.address_trees[0] .into_new_address_params_assigned_packed(address_seed, Some(0)); @@ -435,29 +449,17 @@ async fn test_four_transfer2_instruction( ) .await? .value; - // We need to pack the tree after the cpi context. - remaining_accounts.insert_or_get(rpc_result.accounts[0].tree_info.tree); - - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); - let output_tree_index = packed_tree_info - .state_trees - .as_ref() - .unwrap() - .output_tree_index; + let output_tree_index = pack_selected_output_tree_index( + mint2_token_account.account.tree_info, + &mut remaining_accounts, + )?; + let packed_tree_infos = rpc_result.pack_state_tree_infos(&mut remaining_accounts); // Create token metas from compressed accounts - each uses its respective tree info index // Index 0: escrow PDA, Index 1: mint2 token account, Index 2: mint3 token account - let mint2_tree_info = packed_tree_info - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[1]; + let mint2_tree_info = packed_tree_infos[1]; - let mint3_tree_info = packed_tree_info - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[2]; + let mint3_tree_info = packed_tree_infos[2]; // Create FourTransfer2Params let four_transfer2_params = sdk_token_test::process_four_transfer2::FourTransfer2Params { @@ -491,11 +493,7 @@ async fn test_four_transfer2_instruction( }; // Create PdaParams - escrow PDA uses tree info index 0 - let escrow_tree_info = packed_tree_info - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[0]; + let escrow_tree_info = packed_tree_infos[0]; let pda_params = sdk_token_test::PdaParams { account_meta: light_sdk::instruction::account_meta::CompressedAccountMeta { diff --git a/sdk-tests/sdk-token-test/tests/test_deposit.rs b/sdk-tests/sdk-token-test/tests/test_deposit.rs index 9ebcbd8549..7353b08d92 100644 --- a/sdk-tests/sdk-token-test/tests/test_deposit.rs +++ b/sdk-tests/sdk-token-test/tests/test_deposit.rs @@ -23,6 +23,39 @@ use solana_sdk::{ signature::{Keypair, Signature, Signer}, }; +#[allow(clippy::result_large_err)] +fn pack_selected_output_tree_context( + tree_info: light_client::indexer::TreeInfo, + remaining_accounts: &mut PackedAccounts, +) -> Result<(u8, u8, u8), RpcError> { + let (tree, queue, output_state_tree_index) = if let Some(next) = tree_info.next_tree_info { + ( + next.tree, + next.queue, + next.pack_output_tree_index(remaining_accounts) + .map_err(|error| { + RpcError::CustomError(format!("Failed to pack output tree index: {error}")) + })?, + ) + } else { + ( + tree_info.tree, + tree_info.queue, + tree_info + .pack_output_tree_index(remaining_accounts) + .map_err(|error| { + RpcError::CustomError(format!("Failed to pack output tree index: {error}")) + })?, + ) + }; + + Ok(( + remaining_accounts.insert_or_get(tree), + remaining_accounts.insert_or_get(queue), + output_state_tree_index, + )) +} + #[ignore = "fix cpi context usage"] #[tokio::test] async fn test_deposit_compressed_account() { @@ -206,7 +239,7 @@ async fn create_deposit_compressed_account( ) .await? .value; - let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts).unwrap(); println!("packed_accounts {:?}", packed_accounts.state_trees); // Create token meta from compressed account @@ -302,9 +335,14 @@ async fn update_deposit_compressed_account( "rpc_result.accounts[0].tree_info.queue {:?}", rpc_result.accounts[0].tree_info.queue.to_bytes() ); - // We need to pack the tree after the cpi context. - let index = remaining_accounts.insert_or_get(rpc_result.accounts[0].tree_info.tree); - println!("index {}", index); + let (output_tree_index, output_tree_queue_index, output_state_tree_index) = + pack_selected_output_tree_context( + rpc_result.accounts[0].tree_info, + &mut remaining_accounts, + )?; + println!("output_tree_index {}", output_tree_index); + println!("output_tree_queue_index {}", output_tree_queue_index); + println!("output_state_tree_index {}", output_state_tree_index); // Get mint from the compressed token account let mint = deposit_ctoken_account.token.mint; println!( @@ -318,15 +356,11 @@ async fn update_deposit_compressed_account( // Get validity proof for the compressed token account and new address println!("rpc_result {:?}", rpc_result); - let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts); - println!("packed_accounts {:?}", packed_accounts.state_trees); + let packed_tree_infos = rpc_result.pack_state_tree_infos(&mut remaining_accounts); + println!("packed_tree_infos {:?}", packed_tree_infos); // TODO: investigate why packed_tree_infos seem to be out of order // Create token meta from compressed account - let tree_info = packed_accounts - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[1]; + let tree_info = packed_tree_infos[1]; let depositing_token_metas = vec![TokenAccountMeta { amount: deposit_ctoken_account.token.amount, delegate_index: None, @@ -335,11 +369,7 @@ async fn update_deposit_compressed_account( tlv: None, }]; println!("depositing_token_metas {:?}", depositing_token_metas); - let tree_info = packed_accounts - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[2]; + let tree_info = packed_tree_infos[2]; let escrowed_token_meta = TokenAccountMeta { amount: escrow_ctoken_account.token.amount, delegate_index: None, @@ -354,19 +384,11 @@ async fn update_deposit_compressed_account( let system_accounts_start_offset = system_accounts_start_offset as u8; println!("remaining_accounts {:?}", remaining_accounts); - let tree_info = packed_accounts - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[0]; + let tree_info = packed_tree_infos[0]; let account_meta = CompressedAccountMeta { tree_info, address: escrow_pda.address.unwrap(), - output_state_tree_index: packed_accounts - .state_trees - .as_ref() - .unwrap() - .output_tree_index, + output_state_tree_index, }; let instruction = Instruction { @@ -381,14 +403,8 @@ async fn update_deposit_compressed_account( .concat(), data: sdk_token_test::instruction::UpdateDeposit { proof: rpc_result.proof, - output_tree_index: packed_accounts - .state_trees - .as_ref() - .unwrap() - .packed_tree_infos[0] - .merkle_tree_pubkey_index, - output_tree_queue_index: packed_accounts.state_trees.unwrap().packed_tree_infos[0] - .queue_pubkey_index, + output_tree_index, + output_tree_queue_index, system_accounts_start_offset, token_params: sdk_token_test::TokenParams { deposit_amount: amount, diff --git a/sdk-tests/sdk-v1-native-test/tests/test.rs b/sdk-tests/sdk-v1-native-test/tests/test.rs index a93beab599..2e10e61e14 100644 --- a/sdk-tests/sdk-v1-native-test/tests/test.rs +++ b/sdk-tests/sdk-v1-native-test/tests/test.rs @@ -94,7 +94,8 @@ pub async fn create_pda( .value; let output_merkle_tree_index = accounts.insert_or_get(*merkle_tree_pubkey); - let packed_address_tree_info = rpc_result.pack_tree_infos(&mut accounts).address_trees[0]; + let packed_tree_infos = rpc_result.pack_tree_infos(&mut accounts)?; + let packed_address_tree_info = packed_tree_infos.address_trees[0]; let (accounts, system_accounts_offset, tree_accounts_offset) = accounts.to_account_metas(); let instruction_data = CreatePdaInstructionData { @@ -137,7 +138,7 @@ pub async fn update_pda( .value; let packed_accounts = rpc_result - .pack_tree_infos(&mut accounts) + .pack_tree_infos(&mut accounts)? .state_trees .unwrap();