diff --git a/Cargo.lock b/Cargo.lock index a83e37c45..04bcabaf7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2237,7 +2237,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fmle_no_padding#1fc9f700b54dfb63415e3d4115d778fc10ad9131" dependencies = [ "once_cell", "p3", @@ -3243,7 +3243,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fmle_no_padding#1fc9f700b54dfb63415e3d4115d778fc10ad9131" dependencies = [ "bincode 1.3.3", "clap", @@ -3267,7 +3267,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fmle_no_padding#1fc9f700b54dfb63415e3d4115d778fc10ad9131" dependencies = [ "either", "ff_ext", @@ -4558,7 +4558,7 @@ dependencies = [ [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fmle_no_padding#1fc9f700b54dfb63415e3d4115d778fc10ad9131" dependencies = [ "p3-air", "p3-baby-bear", @@ -5126,7 +5126,7 @@ dependencies = [ [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fmle_no_padding#1fc9f700b54dfb63415e3d4115d778fc10ad9131" dependencies = [ "ff_ext", "p3", @@ -6083,7 +6083,7 @@ dependencies = [ [[package]] name = "sp1-curves" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fmle_no_padding#1fc9f700b54dfb63415e3d4115d778fc10ad9131" dependencies = [ "cfg-if", "dashu", @@ -6208,7 +6208,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fmle_no_padding#1fc9f700b54dfb63415e3d4115d778fc10ad9131" dependencies = [ "either", "ff_ext", @@ -6226,7 +6226,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fmle_no_padding#1fc9f700b54dfb63415e3d4115d778fc10ad9131" dependencies = [ "itertools 0.13.0", "p3", @@ -6633,7 +6633,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fmle_no_padding#1fc9f700b54dfb63415e3d4115d778fc10ad9131" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -6927,7 +6927,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fmle_no_padding#1fc9f700b54dfb63415e3d4115d778fc10ad9131" dependencies = [ "bincode 1.3.3", "clap", @@ -7214,7 +7214,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fmle_no_padding#1fc9f700b54dfb63415e3d4115d778fc10ad9131" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/Cargo.toml b/Cargo.toml index 8cc5823a5..59a7e8653 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,16 +27,16 @@ version = "0.1.0" ceno_crypto_primitives = { git = "https://github.com/scroll-tech/ceno-patch.git", package = "ceno_crypto_primitives", branch = "main" } ceno_syscall = { git = "https://github.com/scroll-tech/ceno-patch.git", package = "ceno_syscall", branch = "main" } -ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", tag = "v1.0.0-alpha.24" } -mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", tag = "v1.0.0-alpha.24" } -multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", tag = "v1.0.0-alpha.24" } -p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", tag = "v1.0.0-alpha.24" } -poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", tag = "v1.0.0-alpha.24" } -sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", tag = "v1.0.0-alpha.24" } -sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", tag = "v1.0.0-alpha.24" } -transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", tag = "v1.0.0-alpha.24" } -whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", tag = "v1.0.0-alpha.24" } -witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", tag = "v1.0.0-alpha.24" } +ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", branch = "feat/mle_no_padding" } +mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", branch = "feat/mle_no_padding" } +multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", branch = "feat/mle_no_padding" } +p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", branch = "feat/mle_no_padding" } +poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", branch = "feat/mle_no_padding" } +sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", branch = "feat/mle_no_padding" } +sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", branch = "feat/mle_no_padding" } +transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", branch = "feat/mle_no_padding" } +whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", branch = "feat/mle_no_padding" } +witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", branch = "feat/mle_no_padding" } anyhow = { version = "1.0", default-features = false } bincode = "1" @@ -129,7 +129,7 @@ lto = "thin" #[patch."https://github.com/scroll-tech/ceno-gpu-mock.git"] #ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features = ["bb31"] } - +# #[patch."https://github.com/scroll-tech/gkr-backend"] #ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } #mpcs = { path = "../gkr-backend/crates/mpcs", package = "mpcs" } diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 95b15b581..c489567a7 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -5,8 +5,7 @@ use ceno_zkvm::print_allocated_bytes; use ceno_zkvm::{ e2e::{ Checkpoint, FieldType, MultiProver, PcsKind, Preset, public_io_words_to_digest_words, - run_e2e_full_trace_verify, run_e2e_single_shard_debug_verify, run_e2e_with_checkpoint, - setup_platform, setup_platform_debug, + run_e2e_with_checkpoint, setup_platform, setup_platform_debug, }, scheme::{ ZKVMProof, constants::MAX_NUM_VARIABLES, create_backend, create_prover, hal::ProverDevice, @@ -352,17 +351,10 @@ fn run_inner< fs::write(&vk_file, vk_bytes).unwrap(); if checkpoint > Checkpoint::PrepVerify { + // `run_e2e_with_checkpoint` already performs the real verification for the + // complete flow. Re-running it here without the emulation exit code causes + // a false "Unfinished execution" error to be logged. let verifier = ZKVMVerifier::new(vk); - if target_shard_id.is_some() { - run_e2e_single_shard_debug_verify( - &verifier, - zkvm_proofs.first().cloned().expect("missing shard proof"), - None, - max_steps, - ); - } else { - run_e2e_full_trace_verify(&verifier, zkvm_proofs.clone(), None, max_steps); - } soundness_test(zkvm_proofs.first().cloned().unwrap(), &verifier); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/keccak.rs b/ceno_zkvm/src/instructions/gpu/chips/keccak.rs index 565e0dffa..4dc1bf289 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/keccak.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/keccak.rs @@ -348,8 +348,7 @@ fn replay_keccak_witness_only_from_packed( ) -> Result, ZKVMError> { use crate::precompiles::KECCAK_ROUNDS_CEIL_LOG2; - let num_padded_instances = num_instances.next_power_of_two().max(2); - let num_padded_rows = num_padded_instances * 32; + let num_rows = num_instances * 32; let rotation = KECCAK_ROUNDS_CEIL_LOG2; let col_map = info_span!("col_map").in_scope(|| extract_keccak_column_map(config, num_witin)); @@ -358,7 +357,7 @@ fn replay_keccak_witness_only_from_packed( .witgen_keccak( &col_map, packed_instances, - num_padded_rows, + num_rows, shard_offset, fetch_base_pc, fetch_num_slots, @@ -372,9 +371,10 @@ fn replay_keccak_witness_only_from_packed( let raw_witin = if crate::instructions::gpu::config::is_debug_compare_enabled() || !should_materialize_witness_on_gpu() { - info_span!("transpose_d2h", rows = num_padded_rows, cols = num_witin).in_scope(|| { + let produced_rows = gpu_result.witness.num_rows; + info_span!("transpose_d2h", rows = produced_rows, cols = num_witin).in_scope(|| { let mut rmm_buffer = hal - .alloc_elems_on_device(num_padded_rows * num_witin, false, None) + .alloc_elems_on_device(produced_rows * num_witin, false, None) .map_err(|e| { ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) })?; @@ -382,7 +382,7 @@ fn replay_keccak_witness_only_from_packed( &hal.inner, &mut rmm_buffer, &gpu_result.witness.device_buffer, - num_padded_rows, + produced_rows, num_witin, ) .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; @@ -445,8 +445,7 @@ fn gpu_assign_keccak_inner( use crate::precompiles::KECCAK_ROUNDS_CEIL_LOG2; let num_instances = step_indices.len(); - let num_padded_instances = num_instances.next_power_of_two().max(2); - let num_padded_rows = num_padded_instances * 32; // 2^5 = 32 rows per instance + let num_rows = num_instances * 32; // 2^5 = 32 rows per instance let rotation = KECCAK_ROUNDS_CEIL_LOG2; // = 5 let materialize_initial_witness = crate::instructions::gpu::config::is_debug_compare_enabled() || should_materialize_witness_on_initial_assign(); @@ -479,7 +478,7 @@ fn gpu_assign_keccak_inner( .witgen_keccak( &col_map, &packed_instances, - num_padded_rows, + num_rows, shard_ctx.current_shard_offset_cycle(), fetch_base_pc, fetch_num_slots, @@ -565,9 +564,10 @@ fn gpu_assign_keccak_inner( } else if crate::instructions::gpu::config::is_debug_compare_enabled() || !should_materialize_witness_on_gpu() { - info_span!("transpose_d2h", rows = num_padded_rows, cols = num_witin).in_scope(|| { + let produced_rows = gpu_result.witness.num_rows; + info_span!("transpose_d2h", rows = produced_rows, cols = num_witin).in_scope(|| { let mut rmm_buffer = hal - .alloc_elems_on_device(num_padded_rows * num_witin, false, None) + .alloc_elems_on_device(produced_rows * num_witin, false, None) .map_err(|e| { ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) })?; @@ -575,7 +575,7 @@ fn gpu_assign_keccak_inner( &hal.inner, &mut rmm_buffer, &gpu_result.witness.device_buffer, - num_padded_rows, + produced_rows, num_witin, ) .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; diff --git a/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs b/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs index 21f1f89a0..0813449e1 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs @@ -439,11 +439,11 @@ pub(crate) fn try_gpu_assign_shard_ram( { let struct_data = tracing::info_span!( "gpu_shard_ram_structural_transpose_d2h", - num_rows_padded, + rows = gpu_structural.num_rows, num_structural_witin, ) .in_scope(|| -> Result<_, ZKVMError> { - let wit_num_rows = num_rows_padded; + let wit_num_rows = gpu_structural.num_rows; let struct_num_cols = num_structural_witin; let mut struct_rmm_buf = hal .witgen @@ -684,11 +684,11 @@ pub(crate) fn try_gpu_assign_shard_ram_from_device( { let struct_data = tracing::info_span!( "gpu_shard_ram_structural_transpose_d2h_from_device", - num_rows_padded, + rows = gpu_structural.num_rows, num_structural_witin, ) .in_scope(|| -> Result<_, ZKVMError> { - let wit_num_rows = num_rows_padded; + let wit_num_rows = gpu_structural.num_rows; let struct_num_cols = num_structural_witin; let mut struct_rmm_buf = hal .witgen diff --git a/ceno_zkvm/src/instructions/gpu/dispatch.rs b/ceno_zkvm/src/instructions/gpu/dispatch.rs index be51108c4..f7cf58969 100644 --- a/ceno_zkvm/src/instructions/gpu/dispatch.rs +++ b/ceno_zkvm/src/instructions/gpu/dispatch.rs @@ -481,7 +481,7 @@ fn gpu_assign_instances_inner>( total_instances, num_witin, I::padding_strategy(), - ) + )? }; if materialize_initial_witness { raw_witin.padding_by_strategy(); @@ -1484,7 +1484,7 @@ fn replay_gpu_witness_from_resident_raw>( total_instances, replay.num_witin, I::padding_strategy(), - ); + )?; // Keep replayed witness immutable after attaching the col-major device backing. // Mutating/padding a RowMajorMatrix clears device metadata, but replay consumers diff --git a/ceno_zkvm/src/instructions/gpu/utils/d2h.rs b/ceno_zkvm/src/instructions/gpu/utils/d2h.rs index fc558046d..5647cef12 100644 --- a/ceno_zkvm/src/instructions/gpu/utils/d2h.rs +++ b/ceno_zkvm/src/instructions/gpu/utils/d2h.rs @@ -303,9 +303,9 @@ pub(crate) fn gpu_witness_to_rmm( num_rows: usize, num_cols: usize, padding: InstancePaddingStrategy, -) -> RowMajorMatrix { +) -> Result, ZKVMError> { let mut rmm = RowMajorMatrix::::new(num_rows, num_cols, padding); // Keep the original col-major witness buffer as the source of truth for GPU commit. rmm.set_device_backing(gpu_result.device_buffer, DeviceMatrixLayout::ColMajor); - rmm + Ok(rmm) } diff --git a/ceno_zkvm/src/scheme/gpu/memory.rs b/ceno_zkvm/src/scheme/gpu/memory.rs index 6421ad3c9..f61bce28f 100644 --- a/ceno_zkvm/src/scheme/gpu/memory.rs +++ b/ceno_zkvm/src/scheme/gpu/memory.rs @@ -11,11 +11,16 @@ use ceno_gpu::{ estimate_build_tower_memory, estimate_prove_tower_memory, estimate_sumcheck_memory, }; use ff_ext::ExtensionField; -use gkr_iop::gpu::{ - BB31Base, GpuBackend, - gpu_prover::{ - BB31Ext, CacheLevel, CudaHalBB31, MemTracker, get_gpu_cache_level, get_mem_tracking_mode, +use gkr_iop::{ + evaluation::EvalExpression, + gpu::{ + BB31Base, GpuBackend, + gpu_prover::{ + BB31Ext, CacheLevel, CudaHalBB31, MemTracker, get_gpu_cache_level, + get_mem_tracking_mode, + }, }, + hal::MultilinearPolynomial, }; use mpcs::PolynomialCommitmentScheme; @@ -40,15 +45,92 @@ pub fn init_gpu_mem_tracker<'a>( const ESTIMATION_TOLERANCE_BYTES: usize = 2 * 1024 * 1024; // max under-estimation error: 2 MB const ESTIMATION_SAFETY_MARGIN_BYTES: usize = 10 * 1024 * 1024; // reserved headroom / allowed over-estimate margin: 10 MB +const SCHEDULER_ESTIMATION_WARNING_MARGIN_BYTES: usize = 512 * 1024 * 1024; /// Validate that the estimated GPU memory matches actual usage within tolerance. /// - Under-estimate (actual > estimated): diff must be <= `ESTIMATION_TOLERANCE_BYTES` /// - Over-estimate (estimated > actual): diff must be <= `ESTIMATION_SAFETY_MARGIN_BYTES` pub fn check_gpu_mem_estimation(mem_tracker: Option, estimated_bytes: usize) { + check_gpu_mem_estimation_with_context(mem_tracker, estimated_bytes, None); +} + +pub fn check_gpu_mem_estimation_with_context( + mem_tracker: Option, + estimated_bytes: usize, + context: Option<&str>, +) { + check_gpu_mem_estimation_with_margins( + mem_tracker, + estimated_bytes, + context, + ESTIMATION_TOLERANCE_BYTES, + ESTIMATION_SAFETY_MARGIN_BYTES, + ); +} + +pub fn check_gpu_scheduler_mem_estimation_with_context( + mem_tracker: Option, + estimated_bytes: usize, + context: Option<&str>, +) { + // Scheduler estimates are admission-control estimates, not exact stage-local allocation + // estimates. They intentionally include safety margins and conservative lifetime overlap, so + // large over-estimates should be surfaced as warnings rather than failing the proof. Under- + // estimates remain hard failures because they can admit unsafe concurrent work. + if let Some(mem_tracker) = mem_tracker { + const ONE_MB: usize = 1024 * 1024; + let label = mem_tracker.name(); + let label = context + .filter(|context| !context.is_empty()) + .map(|context| format!("{label}[{context}]")) + .unwrap_or_else(|| label.to_string()); + let mem_stats = mem_tracker.finish(); + let actual_bytes = mem_stats.mem_occupancy as usize; + let diff = estimated_bytes as isize - actual_bytes as isize; + let to_mb = |b: usize| b as f64 / ONE_MB as f64; + let diff_mb = diff as f64 / ONE_MB as f64; + tracing::info!( + "[memcheck] {label}: scheduler_estimated={:.2}MB, actual={:.2}MB, diff={:.2}MB", + to_mb(estimated_bytes), + to_mb(actual_bytes), + diff_mb + ); + if diff < 0 { + assert!( + (-diff) as usize <= ESTIMATION_TOLERANCE_BYTES, + "[memcheck] {label}: scheduler under-estimate! estimated={:.2}MB, actual={:.2}MB, diff={:.2}MB, tolerance={:.2}MB", + to_mb(estimated_bytes), + to_mb(actual_bytes), + diff_mb, + to_mb(ESTIMATION_TOLERANCE_BYTES), + ); + } else if diff as usize > SCHEDULER_ESTIMATION_WARNING_MARGIN_BYTES { + tracing::warn!( + "[memcheck] {label}: scheduler over-estimate warning: estimated={:.2}MB, actual={:.2}MB, diff={:.2}MB, warning_margin={:.2}MB", + to_mb(estimated_bytes), + to_mb(actual_bytes), + diff_mb, + to_mb(SCHEDULER_ESTIMATION_WARNING_MARGIN_BYTES), + ); + } + } +} + +fn check_gpu_mem_estimation_with_margins( + mem_tracker: Option, + estimated_bytes: usize, + context: Option<&str>, + under_tolerance_bytes: usize, + over_tolerance_bytes: usize, +) { // `mem_tracker will` be Some only in sequential mode with mem tracking enabled, so if it's None, do nothing if let Some(mem_tracker) = mem_tracker { const ONE_MB: usize = 1024 * 1024; let label = mem_tracker.name(); + let label = context + .filter(|context| !context.is_empty()) + .map(|context| format!("{label}[{context}]")) + .unwrap_or_else(|| label.to_string()); let mem_stats = mem_tracker.finish(); let actual_bytes = mem_stats.mem_occupancy as usize; let diff = estimated_bytes as isize - actual_bytes as isize; @@ -63,22 +145,22 @@ pub fn check_gpu_mem_estimation(mem_tracker: Option, estimated_bytes if diff < 0 { // Under-estimate: actual exceeds estimated assert!( - (-diff) as usize <= ESTIMATION_TOLERANCE_BYTES, + (-diff) as usize <= under_tolerance_bytes, "[memcheck] {label}: under-estimate! estimated={:.2}MB, actual={:.2}MB, diff={:.2}MB, tolerance={:.2}MB", to_mb(estimated_bytes), to_mb(actual_bytes), diff_mb, - to_mb(ESTIMATION_TOLERANCE_BYTES), + to_mb(under_tolerance_bytes), ); } else { // Over-estimate: estimated exceeds actual assert!( - diff as usize <= ESTIMATION_SAFETY_MARGIN_BYTES, + diff as usize <= over_tolerance_bytes, "[memcheck] {label}: over-estimate! estimated={:.2}MB, actual={:.2}MB, diff={:.2}MB, margin={:.2}MB", to_mb(estimated_bytes), to_mb(actual_bytes), diff_mb, - to_mb(ESTIMATION_SAFETY_MARGIN_BYTES), + to_mb(over_tolerance_bytes), ); } } @@ -110,12 +192,14 @@ pub fn estimate_chip_proof_memory(replay_plan: &GpuReplayPlan) -> usize { + match replay_plan.kind { + GpuWitgenKind::Keccak => replay_plan + .keccak_instances + .as_ref() + .map(|instances| instances.len() * 32) + .unwrap_or(replay_plan.trace_height), + GpuWitgenKind::ShardRam => replay_plan.trace_height, + _ => replay_plan.step_indices.len(), + } +} + +fn replay_plan_actual_structural_rows(replay_plan: &GpuReplayPlan) -> usize { + match replay_plan.kind { + GpuWitgenKind::ShardRam => replay_plan.shard_ram_num_records, + _ => replay_plan.trace_height, + } +} + pub fn estimate_replay_materialization_bytes( num_witin: usize, _num_structural_witin: usize, @@ -273,7 +383,7 @@ pub fn estimate_replay_materialization_bytes_for_plan( _num_vars: usize, ) -> usize { let elem_size = std::mem::size_of::(); - let witness_bytes = replay_plan.trace_height * replay_plan.num_witin * elem_size; + let witness_bytes = replay_plan_actual_rows(replay_plan) * replay_plan.num_witin * elem_size; let replay_temp_bytes = match replay_plan.kind { GpuWitgenKind::Keccak => replay_plan .keccak_instances @@ -299,6 +409,7 @@ pub fn estimate_replay_materialization_bytes_for_plan( pub(crate) fn estimate_trace_bytes>( composed_cs: &ComposedConstrainSystem, input: &ProofInput<'_, GpuBackend>, + replay_plan: Option<&GpuReplayPlan>, witness_replayable: bool, structural_cached_on_device: bool, ) -> TraceEstimate { @@ -308,14 +419,37 @@ pub(crate) fn estimate_trace_bytes() + }) + .unwrap_or_else(|| { + estimate_structural_mle_bytes( + cs.num_structural_witin as usize, + num_var_with_rotation, + ) + }) } else { estimate_structural_mle_bytes(cs.num_structural_witin as usize, num_var_with_rotation) }; - let (witness_mle_bytes, trace_temporary_bytes) = estimate_trace_extraction_bytes( - cs.num_witin as usize, - num_var_with_rotation, - witness_replayable, - ); + let (witness_mle_bytes, trace_temporary_bytes) = + if should_materialize_witness_on_gpu() && witness_replayable { + let base_elem_size = std::mem::size_of::(); + let actual_rows = replay_plan + .map(replay_plan_actual_rows) + .unwrap_or(1usize << num_var_with_rotation); + (cs.num_witin as usize * actual_rows * base_elem_size, 0) + } else { + estimate_trace_extraction_bytes( + cs.num_witin as usize, + num_var_with_rotation, + input.num_instances() << composed_cs.rotation_vars().unwrap_or(0), + witness_replayable, + ) + }; TraceEstimate { trace_resident_bytes: witness_mle_bytes + structural_mle_bytes, @@ -325,11 +459,72 @@ pub(crate) fn estimate_trace_bytes( composed_cs: &ComposedConstrainSystem, - num_var_with_rotation: usize, + output_rows: usize, ) -> usize { let elem_size = std::mem::size_of::(); - let record_len = 1usize << num_var_with_rotation; - tower_output_count(composed_cs) * record_len * elem_size + main_witness_materialized_output_count(composed_cs) * output_rows * elem_size +} + +fn main_witness_materialized_output_count( + composed_cs: &ComposedConstrainSystem, +) -> usize { + let Some(gkr_circuit) = composed_cs.gkr_circuit.as_ref() else { + return 0; + }; + let final_layer_output_count = tower_output_count(composed_cs); + + gkr_circuit + .layers + .iter() + .enumerate() + .map(|(layer_index, layer)| { + let final_layer = layer_index == 0; + let out_evals = layer + .out_sel_and_eval_exprs + .iter() + .flat_map(|(_, out_eval)| out_eval.iter()); + + if final_layer { + out_evals + .take(final_layer_output_count) + .filter(|out_eval| main_witness_materializes_output(out_eval)) + .count() + } else { + out_evals + .filter(|out_eval| main_witness_materializes_output(out_eval)) + .count() + } + }) + .sum() +} + +fn main_witness_materializes_output(out_eval: &EvalExpression) -> bool { + matches!( + out_eval, + EvalExpression::Single(_) | EvalExpression::Linear(_, _, _) + ) +} + +pub fn main_witness_output_rows>( + composed_cs: &ComposedConstrainSystem, + input: &ProofInput<'_, GpuBackend>, +) -> usize { + if composed_cs + .gkr_circuit + .as_ref() + .and_then(|circuit| circuit.layers.last()) + .is_some_and(|input_layer| input_layer.in_eval_expr.is_empty()) + { + if let Some(structural_mle) = input.structural_witness.first() { + return structural_mle.evaluations_len(); + } + } + + input + .witness + .first() + .map(|mle| mle.evaluations_len()) + .unwrap_or_else(|| input.num_instances() << composed_cs.rotation_vars().unwrap_or(0)) } pub(crate) fn estimate_main_constraints_bytes< @@ -403,7 +598,7 @@ pub(crate) fn estimate_main_constraints_bytes< fn estimate_tower_stage_components>( composed_cs: &ComposedConstrainSystem, input: &ProofInput<'_, GpuBackend>, -) -> (usize, usize, usize) { +) -> (usize, usize, usize, usize) { let cs = &composed_cs.zkvm_v1_css; let num_prod_towers = composed_cs.num_reads() + composed_cs.num_writes(); let num_logup_towers = if composed_cs.is_with_lk_table() { @@ -418,31 +613,46 @@ fn estimate_tower_stage_components(); let has_logup_numerator = composed_cs.is_with_lk_table(); + let occupied_rows = input.num_instances() << composed_cs.rotation_vars().unwrap_or(0); let build_est = estimate_build_tower_memory( num_prod_towers, num_logup_towers, num_vars, num_vars, + occupied_rows, elem_size, has_logup_numerator, ); + let shard_ram_tower_batch_overhead = composed_cs + .gkr_circuit + .as_ref() + .and_then(|circuit| circuit.layers.first()) + .is_some_and(|layer| layer.name == "ShardRamCircuit_main") + .then_some(10 * 1024 * 1024) + .unwrap_or(0); + let build_bytes = build_est.total_bytes + shard_ram_tower_batch_overhead; let prove_est = estimate_prove_tower_memory( num_prod_towers, num_logup_towers, num_vars, num_vars, + occupied_rows, NUM_FANIN, elem_size, + has_logup_numerator, ); let tower_input_live_bytes = prove_est.prod_tower_buffer_bytes + prove_est.logup_tower_buffer_bytes; + let borrowed_input_bytes = + prove_est.prod_borrowed_input_bytes + prove_est.logup_borrowed_input_bytes; let prove_local_bytes = prove_est.total_bytes.saturating_sub(tower_input_live_bytes); ( - build_est.total_bytes, + build_bytes, prove_local_bytes, tower_input_live_bytes, + borrowed_input_bytes, ) } @@ -452,7 +662,8 @@ pub(crate) fn estimate_tower_stage_bytes, input: &ProofInput<'_, GpuBackend>, ) -> (usize, usize) { - let (build_bytes, prove_local_bytes, _) = estimate_tower_stage_components(composed_cs, input); + let (build_bytes, prove_local_bytes, _, _) = + estimate_tower_stage_components(composed_cs, input); (build_bytes, prove_local_bytes) } @@ -460,7 +671,7 @@ pub(crate) fn estimate_tower_bytes, input: &ProofInput<'_, GpuBackend>, ) -> usize { - let (build_bytes, prove_local_bytes, tower_input_live_bytes) = + let (build_bytes, prove_local_bytes, tower_input_live_bytes, _) = estimate_tower_stage_components(composed_cs, input); build_bytes.max(tower_input_live_bytes + prove_local_bytes) } @@ -475,11 +686,13 @@ pub(crate) fn estimate_tower_bytes (usize, usize) { let base_elem_size = std::mem::size_of::(); let mle_len = 1usize << num_vars; - let poly_bytes = num_witin * mle_len * base_elem_size; + let compact_poly_bytes = num_witin * occupied_rows * base_elem_size; + let logical_poly_bytes = num_witin * mle_len * base_elem_size; if should_materialize_witness_on_gpu() { if should_retain_witness_device_backing_after_commit() { @@ -495,19 +708,19 @@ pub(crate) fn estimate_trace_extraction_bytes( // duration of the chip proof. There is no separate extraction temp // buffer, but the replayed witness itself must be accounted for as // resident task memory. - return (poly_bytes, 0); + return (compact_poly_bytes, 0); } // GPU witgen alone does not imply replayability. Non-replayable traces // still go through basefold::get_trace in cache-none mode, which // allocates the extracted witness plus a temporary 2x transpose buffer. - return (poly_bytes, 2 * poly_bytes); + return (compact_poly_bytes, 2 * logical_poly_bytes); } if matches!(get_gpu_cache_level(), CacheLevel::None) { // Default cache level is None // get_trace allocates poly copies (resident) + temp_buffer (2x, freed after) - (poly_bytes, 2 * poly_bytes) + (compact_poly_bytes, 2 * logical_poly_bytes) } else { (0, 0) } diff --git a/ceno_zkvm/src/scheme/gpu/mod.rs b/ceno_zkvm/src/scheme/gpu/mod.rs index 142f894fe..986b04439 100644 --- a/ceno_zkvm/src/scheme/gpu/mod.rs +++ b/ceno_zkvm/src/scheme/gpu/mod.rs @@ -42,7 +42,7 @@ use multilinear_extensions::{ util::ceil_log2, virtual_poly::{build_eq_x_r_vec, eq_eval}, }; -use p3::matrix::Matrix; +use p3::{field::FieldAlgebra, matrix::Matrix}; use rayon::iter::{ IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, @@ -68,11 +68,13 @@ use witness::DeviceMatrixLayout; use gkr_iop::gpu::gpu_prover::*; mod memory; + mod util; pub(crate) use memory::{ - check_gpu_mem_estimation, estimate_chip_proof_memory, estimate_main_witness_bytes, - estimate_replay_materialization_bytes_for_plan, estimate_tower_bytes, - estimate_tower_stage_bytes, init_gpu_mem_tracker, + check_gpu_mem_estimation, check_gpu_mem_estimation_with_context, + check_gpu_scheduler_mem_estimation_with_context, estimate_chip_proof_memory, + estimate_main_witness_bytes, estimate_replay_materialization_bytes_for_plan, + estimate_tower_bytes, estimate_tower_stage_bytes, init_gpu_mem_tracker, }; use memory::{ estimate_ecc_quark_bytes_from_num_vars, estimate_main_constraints_bytes, @@ -101,6 +103,30 @@ struct PcsResidentStats { total_rmms: usize, } +fn rmm_device_backing_bytes(rmm: &witness::RowMajorMatrix) -> usize +where + T: FieldAlgebra + Default + Sync + Clone + Send + Copy + 'static, +{ + rmm.device_backing_ref::>() + .map(|device_buffer| device_buffer.len() * std::mem::size_of::()) + .unwrap_or(0) +} + +fn rmm_col_major_device_rows(rmm: &witness::RowMajorMatrix) -> Option +where + T: FieldAlgebra + Default + Sync + Clone + Send + Copy + 'static, +{ + if rmm.device_backing_layout() != Some(DeviceMatrixLayout::ColMajor) { + return None; + } + let cols = rmm.width(); + if cols == 0 { + return Some(0); + } + let device_buffer = rmm.device_backing_ref::>()?; + Some(device_buffer.len() / cols) +} + fn pcs_resident_stats( pcs_data_basefold: &BasefoldCommitmentWithWitnessGpu< BB31Base, @@ -141,7 +167,7 @@ fn pcs_resident_stats( ( rmms.iter() .filter(|rmm| rmm.has_device_backing()) - .map(|rmm| rmm.height() * rmm.width() * std::mem::size_of::()) + .map(rmm_device_backing_bytes) .sum::(), rmms.iter().filter(|rmm| rmm.has_device_backing()).count(), rmms.len(), @@ -331,24 +357,12 @@ pub fn prove_tower_relation_impl> = Vec::new(); - let mut _ones_buffer: Vec> = Vec::new(); - let mut _view_last_layers: Vec>>> = Vec::new(); let (prod_gpu, logup_gpu) = info_span!("[ceno] build_tower_witness_gpu").in_scope(|| { - build_tower_witness_gpu( - composed_cs, - input, - records, - challenges, - cuda_hal, - &mut _big_buffers, - &mut _ones_buffer, - &mut _view_last_layers, - ) - .map_err(|e| format!("build_tower_witness_gpu failed: {}", e)) - .unwrap() + build_tower_witness_gpu(composed_cs, input, records, challenges, cuda_hal) + .map_err(|e| format!("build_tower_witness_gpu failed: {}", e)) + .unwrap() }); exit_span!(span); @@ -380,7 +394,7 @@ pub fn prove_tower_relation_impl let wit = LayerWitness( chain!(&input.witness, &input.fixed, &input.structural_witness) .cloned() - .collect_vec(), + .map(|mle| unsafe { std::mem::transmute(mle) }) + .collect(), ); let (proof, points) = gkr_iop::gkr::layer::gpu::prove_rotation_gpu::( @@ -694,7 +709,8 @@ pub fn prove_main_constraints_impl< layers: vec![LayerWitness( chain!(&input.witness, &input.fixed, &input.structural_witness,) .cloned() - .collect_vec(), + .map(|mle| unsafe { std::mem::transmute(mle) }) + .collect(), )], }, &out_evals, @@ -1272,9 +1288,13 @@ impl> TraceCommitter>> = poly_group @@ -1367,7 +1392,8 @@ where let device_buffer = witness_rmm .device_backing_ref::>() .unwrap_or_else(|| panic!("col-major replay witness device backing type mismatch")); - let rows = witness_rmm.height(); + let rows = rmm_col_major_device_rows(&witness_rmm) + .unwrap_or_else(|| panic!("col-major replay witness device backing row count mismatch")); let cols = witness_rmm.width(); let poly_len_bytes = rows * std::mem::size_of::(); @@ -1380,14 +1406,11 @@ where (0..cols) .map(|col_idx| { let src_byte_offset = col_idx * poly_len_bytes; - // Keep an owned handle to the parent GPU allocation instead of a - // borrowed CudaView. The resulting MLE outlives this helper. let view_buf = device_buffer.owned_subrange(src_byte_offset..src_byte_offset + poly_len_bytes); - let view_poly = GpuPolynomial::new(view_buf, rows.trailing_zeros() as usize); - let view_poly_static: GpuPolynomial<'static> = - unsafe { std::mem::transmute(view_poly) }; - let mle_static = MultilinearExtensionGpu::from_ceno_gpu_base(view_poly_static); + let view_poly = GpuPolynomial::new(view_buf, witness_rmm.num_vars()); + let poly_static: GpuPolynomial<'static> = unsafe { std::mem::transmute(view_poly) }; + let mle_static = MultilinearExtensionGpu::from_ceno_gpu_base(poly_static); Arc::new(unsafe { std::mem::transmute::< MultilinearExtensionGpu<'static, E>, @@ -1421,7 +1444,7 @@ pub fn clear_replayable_trace_device_backing( let before_device_bytes = rmms .iter() .filter(|rmm| rmm.has_device_backing()) - .map(|rmm| rmm.height() * rmm.width() * std::mem::size_of::()) + .map(rmm_device_backing_bytes) .sum::(); for (trace_idx, _) in replayable_traces { @@ -1432,7 +1455,7 @@ pub fn clear_replayable_trace_device_backing( let after_device_bytes = rmms .iter() .filter(|rmm| rmm.has_device_backing()) - .map(|rmm| rmm.height() * rmm.width() * std::mem::size_of::()) + .map(rmm_device_backing_bytes) .sum::(); tracing::info!( "[gpu] cleared replayable PCS RMM device backing: replayable_traces={}, rmms_device_before={:.2}MB ({}) -> after={:.2}MB ({})", @@ -1508,7 +1531,8 @@ where let device_buffer = structural_rmm .device_backing_ref::>() .unwrap_or_else(|| panic!("col-major structural device backing type mismatch")); - let rows = structural_rmm.height(); + let rows = rmm_col_major_device_rows(structural_rmm) + .unwrap_or_else(|| panic!("col-major structural device backing row count mismatch")); let cols = structural_rmm.width(); let poly_len_bytes = rows * std::mem::size_of::(); let total_bytes = cols * poly_len_bytes; @@ -1517,8 +1541,7 @@ where total_bytes, "structural col-major buffer size mismatch" ); - let num_vars_in_poly = rows.trailing_zeros() as usize; - assert_eq!(rows, 1usize << num_vars_in_poly); + let num_vars_in_poly = structural_rmm.num_vars(); (0..cols) .map(|col_idx| { @@ -1559,25 +1582,22 @@ where } #[allow(clippy::too_many_arguments)] -pub(crate) fn build_tower_witness_gpu<'buf, E: ExtensionField>( +pub(crate) fn build_tower_witness_gpu( composed_cs: &ComposedConstrainSystem, input: &ProofInput<'_, GpuBackend>>, records: &[ArcMultilinearExtensionGpu<'_, E>], challenges: &[E; 2], cuda_hal: &CudaHalBB31, - big_buffers: &'buf mut Vec>, - ones_buffer: &mut Vec>, - view_last_layers: &mut Vec>>>, ) -> Result< ( - Vec>, - Vec>, + Vec>, + Vec>, ), String, > { let stream = gkr_iop::gpu::get_thread_stream(); use crate::scheme::constants::{NUM_FANIN, NUM_FANIN_LOGUP}; - use ceno_gpu::{CudaHal as _, bb31::GpuPolynomialExt}; + use ceno_gpu::bb31::GpuPolynomialExt; use p3::field::FieldAlgebra; let ComposedConstrainSystem { @@ -1585,7 +1605,7 @@ pub(crate) fn build_tower_witness_gpu<'buf, E: ExtensionField>( } = composed_cs; let _num_instances_with_rotation = input.num_instances() << composed_cs.rotation_vars().unwrap_or(0); - let _chip_record_alpha = challenges[0]; + let chip_record_alpha: BB31Ext = unsafe { std::mem::transmute_copy(&challenges[0]) }; // SAFETY: The `records` slice is borrowed for the duration of this function call. // The lifetime is erased to 'static only to satisfy GPU API signatures that require @@ -1616,46 +1636,56 @@ pub(crate) fn build_tower_witness_gpu<'buf, E: ExtensionField>( &records[offset..][..cs.lk_expressions.len()] }; - assert_eq!(big_buffers.len(), 0, "expect no big buffers"); - - // prod: last layes & buffer - let mut is_prod_buffer_exists = false; + // prod: split last layer once, then build compact tower layers. let prod_last_layers = r_set_wit .iter() .chain(w_set_wit.iter()) - .map(|wit| wit.as_view_chunks(NUM_FANIN)) - .collect::>(); + .map(|wit| match wit.inner() { + gkr_iop::gpu::GpuFieldType::Ext(poly) => cuda_hal + .tower + .masked_mle_view_chunks(&*cuda_hal, poly, NUM_FANIN, BB31Ext::ONE, stream.as_ref()) + .map_err(|e| format!("Failed to split compact prod tower input: {e}")), + _ => return Err("tower witness expects extension-field record MLEs".to_string()), + }) + .collect::, String>>()?; if !prod_last_layers.is_empty() { let first_layer = &prod_last_layers[0]; assert_eq!(first_layer.len(), 2, "prod last_layer must have 2 MLEs"); - let num_vars = first_layer[0].num_vars(); - let num_towers = prod_last_layers.len(); - view_last_layers.push(prod_last_layers); - - // Allocate one big buffer for all product towers and add it to big_buffers - let tower_size = 1 << (num_vars + 1); // 2 * mle_len elements per tower - let total_buffer_size = num_towers * tower_size; - tracing::debug!( - "prod tower request buffer size: {:.2} MB", - (total_buffer_size * std::mem::size_of::()) as f64 / (1024.0 * 1024.0) - ); - let big_buffer = cuda_hal - .alloc_ext_elems_on_device(total_buffer_size, false, stream.as_ref()) - .map_err(|e| format!("Failed to allocate prod GPU buffer: {:?}", e))?; - big_buffers.push(big_buffer); - is_prod_buffer_exists = true; } - // logup: last layes - let mut is_logup_buffer_exists = false; + // logup: split last layer once, then build compact tower layers. let lk_numerator_last_layer = lk_n_wit .iter() - .map(|wit| wit.as_view_chunks(NUM_FANIN_LOGUP)) - .collect::>(); + .map(|wit| match wit.inner() { + gkr_iop::gpu::GpuFieldType::Ext(poly) => cuda_hal + .tower + .masked_mle_view_chunks( + &*cuda_hal, + poly, + NUM_FANIN_LOGUP, + chip_record_alpha, + stream.as_ref(), + ) + .map_err(|e| format!("Failed to split compact logup numerator: {e}")), + _ => Err("tower witness expects extension-field logup numerator MLEs".to_string()), + }) + .collect::, String>>()?; let lk_denominator_last_layer = lk_d_wit .iter() - .map(|wit| wit.as_view_chunks(NUM_FANIN_LOGUP)) - .collect::>(); + .map(|wit| match wit.inner() { + gkr_iop::gpu::GpuFieldType::Ext(poly) => cuda_hal + .tower + .masked_mle_view_chunks( + &*cuda_hal, + poly, + NUM_FANIN_LOGUP, + chip_record_alpha, + stream.as_ref(), + ) + .map_err(|e| format!("Failed to split compact logup denominator: {e}")), + _ => Err("tower witness expects extension-field logup denominator MLEs".to_string()), + }) + .collect::, String>>()?; let logup_last_layers = if !lk_numerator_last_layer.is_empty() { // Case when we have both numerator and denominator // Combine [p1, p2] from numerator and [q1, q2] from denominator @@ -1665,100 +1695,58 @@ pub(crate) fn build_tower_witness_gpu<'buf, E: ExtensionField>( .map(|(lk_n_chunks, lk_d_chunks)| { let mut last_layer = lk_n_chunks; last_layer.extend(lk_d_chunks); - last_layer + Ok(last_layer) }) - .collect::>() + .collect::, String>>()? } else if lk_denominator_last_layer.is_empty() { vec![] } else { - // Case when numerator is empty - create shared ones_buffer and use views - // This saves memory by having all p1, p2 polynomials reference the same buffer + // Case when numerator is empty: share one scalar compact polynomial. + // Its tail default is also ONE, so all logical numerator entries read as ONE + // without materializing per-chunk denominator-sized buffers. let nv = lk_denominator_last_layer[0][0].num_vars(); + let ones_poly = GpuPolynomialExt::new_with_scalar_len( + &cuda_hal.inner, + nv, + 1, + BB31Ext::ONE, + stream.as_ref(), + ) + .map_err(|e| format!("Failed to create compact shared ones numerator: {e:?}"))?; + let ones_poly: GpuPolynomialExt<'static> = unsafe { std::mem::transmute(ones_poly) }; + let one_len_bytes = ones_poly.buf.len() * std::mem::size_of::(); - // Create one shared ones_buffer as Owned (can be 'static) - let ones_poly = - GpuPolynomialExt::new_with_scalar(&cuda_hal.inner, nv, BB31Ext::ONE, stream.as_ref()) - .map_err(|e| format!("Failed to create shared ones_buffer: {:?}", e)) - .unwrap(); - // SAFETY: Owned buffer can be safely treated as 'static - let ones_poly_static: GpuPolynomialExt<'static> = unsafe { std::mem::transmute(ones_poly) }; - ones_buffer.push(ones_poly_static); - - // Get reference from storage to ensure proper lifetime - let ones_poly_ref = ones_buffer.last().unwrap(); - let mle_len_bytes = ones_poly_ref.evaluations().len() * std::mem::size_of::(); - - // Create views referencing the shared ones_buffer for each tower's p1, p2 lk_denominator_last_layer .into_iter() .map(|lk_d_chunks| { - // Create views of ones_buffer for p1 and p2 - let p1_view = ones_poly_ref.evaluations().as_slice_range(0..mle_len_bytes); - let p2_view = ones_poly_ref.evaluations().as_slice_range(0..mle_len_bytes); - let p1_gpu = GpuPolynomialExt::new(BufferImpl::new_from_view(p1_view), nv); - let p2_gpu = GpuPolynomialExt::new(BufferImpl::new_from_view(p2_view), nv); - // SAFETY: views from 'static buffer can be 'static - let p1_gpu: GpuPolynomialExt<'static> = unsafe { std::mem::transmute(p1_gpu) }; - let p2_gpu: GpuPolynomialExt<'static> = unsafe { std::mem::transmute(p2_gpu) }; - // Use [p1, p2, q1, q2] format for the last layer + let p1_gpu = GpuPolynomialExt::new_with_tail_default( + ones_poly.buf.owned_subrange(0..one_len_bytes), + nv, + BB31Ext::ONE, + ); + let p2_gpu = GpuPolynomialExt::new_with_tail_default( + ones_poly.buf.owned_subrange(0..one_len_bytes), + nv, + BB31Ext::ONE, + ); let mut last_layer = vec![p1_gpu, p2_gpu]; last_layer.extend(lk_d_chunks); - last_layer + Ok(last_layer) }) - .collect::>() + .collect::, String>>()? }; if !logup_last_layers.is_empty() { let first_layer = &logup_last_layers[0]; assert_eq!(first_layer.len(), 4, "logup last_layer must have 4 MLEs"); - let num_vars = first_layer[0].num_vars(); - let num_towers = logup_last_layers.len(); - view_last_layers.push(logup_last_layers); - - // Allocate one big buffer for all towers and add it to big_buffers - let tower_size = 1 << (num_vars + 2); // 4 * mle_len elements per tower - let total_buffer_size = num_towers * tower_size; - tracing::debug!( - "logup tower request buffer size: {:.2} MB", - (total_buffer_size * std::mem::size_of::()) as f64 / (1024.0 * 1024.0) - ); - let big_buffer = cuda_hal - .alloc_ext_elems_on_device(total_buffer_size, false, stream.as_ref()) - .unwrap(); - big_buffers.push(big_buffer); - is_logup_buffer_exists = true; } - let (_, pushed_big_buffers) = big_buffers.split_at_mut(0); - let (prod_big_buffer, logup_big_buffer) = match ( - is_prod_buffer_exists, - is_logup_buffer_exists, - pushed_big_buffers, - ) { - (false, false, []) => (None, None), - (true, false, [prod]) => (Some(prod), None), - (false, true, [logup]) => (None, Some(logup)), - (true, true, [prod, logup]) => (Some(prod), Some(logup)), - (prod_flag, logup_flag, slice) => { - panic!( - "unexpected state: prod={}, logup={}, newly_pushed_len={}", - prod_flag, - logup_flag, - slice.len() - ); - } - }; - // Build product GpuProverSpecs let mut prod_gpu_specs = Vec::new(); - if is_prod_buffer_exists { - let prod_last_layers = &view_last_layers[0]; + if !prod_last_layers.is_empty() { let first_layer = &prod_last_layers[0]; assert_eq!(first_layer.len(), 2, "prod last_layer must have 2 MLEs"); let num_vars = first_layer[0].num_vars(); let num_towers = prod_last_layers.len(); - let Some(prod_big_buffer) = prod_big_buffer else { - panic!("prod big buffer not found"); - }; let span_prod = entered_span!( "build_prod_tower", @@ -1768,31 +1756,31 @@ pub(crate) fn build_tower_witness_gpu<'buf, E: ExtensionField>( let last_layers_refs: Vec<&[GpuPolynomialExt<'_>]> = prod_last_layers.iter().map(|v| v.as_slice()).collect(); let gpu_specs = { - cuda_hal.tower.build_prod_tower_from_gpu_polys_batch( + cuda_hal.tower.build_prod_tower_dense_from_gpu_polys_batch( cuda_hal, - prod_big_buffer, &last_layers_refs, num_vars, num_towers, stream.as_ref(), ) } - .map_err(|e| format!("build_prod_tower_from_gpu_polys_batch failed: {:?}", e))?; + .map_err(|e| { + format!( + "build_prod_tower_dense_from_gpu_polys_batch failed: {:?}", + e + ) + })?; prod_gpu_specs.extend(gpu_specs); exit_span!(span_prod); } // Build logup GpuProverSpecs let mut logup_gpu_specs = Vec::new(); - if is_logup_buffer_exists { - let logup_last_layers = view_last_layers.last().unwrap(); + if !logup_last_layers.is_empty() { let first_layer = &logup_last_layers[0]; assert_eq!(first_layer.len(), 4, "logup last_layer must have 4 MLEs"); let num_vars = first_layer[0].num_vars(); let num_towers = logup_last_layers.len(); - let Some(logup_big_buffer) = logup_big_buffer else { - panic!("logup big buffer not found"); - }; let span_logup = entered_span!( "build_logup_tower", @@ -1803,16 +1791,19 @@ pub(crate) fn build_tower_witness_gpu<'buf, E: ExtensionField>( logup_last_layers.iter().map(|v| v.as_slice()).collect(); let gpu_specs = cuda_hal .tower - .build_logup_tower_from_gpu_polys_batch( + .build_logup_tower_dense_from_gpu_polys_batch( cuda_hal, - logup_big_buffer, &last_layers_refs, num_vars, num_towers, stream.as_ref(), ) - .map_err(|e| format!("build_logup_tower_from_gpu_polys_batch failed: {:?}", e))?; - + .map_err(|e| { + format!( + "build_logup_tower_dense_from_gpu_polys_batch failed: {:?}", + e + ) + })?; logup_gpu_specs.extend(gpu_specs); exit_span!(span_logup); } @@ -1878,7 +1869,15 @@ impl> TowerProver(composed_cs, input); - check_gpu_mem_estimation(gpu_mem_tracker, estimated_bytes); + check_gpu_mem_estimation_with_context( + gpu_mem_tracker, + estimated_bytes, + composed_cs + .gkr_circuit + .as_ref() + .and_then(|circuit| circuit.layers.first()) + .map(|layer| layer.name.as_str()), + ); res } @@ -1927,7 +1926,15 @@ impl> MainSumcheckProver(composed_cs, input); - check_gpu_mem_estimation(gpu_mem_tracker, estimated_bytes); + check_gpu_mem_estimation_with_context( + gpu_mem_tracker, + estimated_bytes, + composed_cs + .gkr_circuit + .as_ref() + .and_then(|circuit| circuit.layers.first()) + .map(|layer| layer.name.as_str()), + ); res } @@ -1964,7 +1971,15 @@ impl> EccQuarkProver> OpeningProver> task.circuit_name, estimated_replay_bytes as f64 / (1024.0 * 1024.0), ); - let witness_rmm = replay_plan.replay_witness().expect("GPU raw replay failed"); - check_gpu_mem_estimation(gpu_mem_tracker, estimated_replay_bytes); - task.input.witness = info_span!("[ceno] replay_gpu_witness_from_raw") - .in_scope(|| extract_witness_mles_for_trace_rmm::(witness_rmm)); + task.input.witness = if let Some(trace_idx) = task.witness_trace_idx { + check_gpu_mem_estimation_with_context( + gpu_mem_tracker, + 0, + Some(task.circuit_name.as_str()), + ); + info_span!("[ceno] extract_witness_mles").in_scope(|| { + extract_witness_mles_for_trace::( + pcs_data, + trace_idx, + task.num_witin, + num_vars, + ) + }) + } else { + let witness_rmm = replay_plan.replay_witness().expect("GPU raw replay failed"); + check_gpu_mem_estimation_with_context( + gpu_mem_tracker, + estimated_replay_bytes, + Some(task.circuit_name.as_str()), + ); + info_span!("[ceno] replay_gpu_witness_from_raw") + .in_scope(|| extract_witness_mles_for_trace_rmm::(witness_rmm)) + }; if let Some(rmm) = task.structural_rmm.as_ref() { task.input.structural_witness = info_span!("[ceno] transport_structural_witness") .in_scope(|| { diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 516dee741..33627ae68 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -603,6 +603,12 @@ impl< task.circuit_idx as u64, )); + let task_name = task.circuit_name.clone(); + let estimated_memory_bytes = task.estimated_memory_bytes as usize; + let cuda_hal = gkr_iop::gpu::get_cuda_hal().expect("Failed to get CUDA HAL"); + let chip_mem_tracker = + crate::scheme::gpu::init_gpu_mem_tracker(&cuda_hal, "create_chip_proof"); + let gpu_input: ProofInput<'static, gkr_iop::gpu::GpuBackend> = unsafe { std::mem::transmute(task.input) }; @@ -619,6 +625,11 @@ impl< task.num_witin, task.structural_rmm, )?; + crate::scheme::gpu::check_gpu_scheduler_mem_estimation_with_context( + chip_mem_tracker, + estimated_memory_bytes, + Some(task_name.as_str()), + ); Ok(ChipTaskResult { task_id: task.task_id, @@ -1117,12 +1128,11 @@ where scheme::{ constants::NUM_FANIN, gpu::{ - build_tower_witness_gpu, check_gpu_mem_estimation, - estimate_replay_materialization_bytes_for_plan, estimate_tower_stage_bytes, - extract_out_evals_from_gpu_towers, extract_witness_mles_for_trace, - log_gpu_device_state, log_gpu_pool_usage, prove_ec_sum_quark_impl, - prove_main_constraints_impl, prove_rotation_impl, prove_tower_relation_impl, - transport_structural_witness_to_gpu, + build_tower_witness_gpu, estimate_replay_materialization_bytes_for_plan, + estimate_tower_stage_bytes, extract_out_evals_from_gpu_towers, + extract_witness_mles_for_trace, log_gpu_device_state, log_gpu_pool_usage, + prove_ec_sum_quark_impl, prove_main_constraints_impl, prove_rotation_impl, + prove_tower_relation_impl, transport_structural_witness_to_gpu, }, }, }; @@ -1164,7 +1174,11 @@ where log_gpu_device_state(&format!("{name}:before_replay")); log_gpu_pool_usage(&format!("{name}:before_replay")); let witness_rmm = replay_plan.replay_witness()?; - check_gpu_mem_estimation(gpu_mem_tracker, estimated_replay_bytes); + crate::scheme::gpu::check_gpu_mem_estimation_with_context( + gpu_mem_tracker, + estimated_replay_bytes, + Some(name), + ); input.witness = info_span!("[ceno] replay_gpu_witness_from_raw") .in_scope(|| crate::scheme::gpu::extract_witness_mles_for_trace_rmm::(witness_rmm)); if let Some(structural_rmm_cached) = structural_rmm.as_ref() { @@ -1254,32 +1268,22 @@ where let span = entered_span!("prove_tower_relation", profiling_2 = true); let r_set_len = cs.zkvm_v1_css.r_expressions.len() + cs.zkvm_v1_css.r_table_expressions.len(); - let (tower_build_estimated_bytes, tower_prove_estimated_bytes) = + let (tower_build_estimated_bytes, tower_prove_prebuild_estimated_bytes) = estimate_tower_stage_bytes::(cs, &input); tracing::info!( "[gpu tower][{}] estimated: build_tower={:.2}MB, prove_tower={:.2}MB", name, tower_build_estimated_bytes as f64 / (1024.0 * 1024.0), - tower_prove_estimated_bytes as f64 / (1024.0 * 1024.0), + tower_prove_prebuild_estimated_bytes as f64 / (1024.0 * 1024.0), ); let tower_build_mem_tracker = crate::scheme::gpu::init_gpu_mem_tracker(&cuda_hal, "build_tower_witness_gpu"); - let mut big_buffers = Vec::new(); - let mut ones_buffer = Vec::new(); - let mut view_last_layers = Vec::new(); log_gpu_device_state(&format!("{name}:before_build_tower_witness")); log_gpu_pool_usage(&format!("{name}:before_build_tower_witness")); let (prod_gpu, logup_gpu, lk_out_evals, w_out_evals, r_out_evals) = info_span!("[ceno] build_tower_witness_gpu").in_scope(|| { let (prod_gpu, logup_gpu) = build_tower_witness_gpu( - cs, - &input, - &records, - challenges, - &cuda_hal, - &mut big_buffers, - &mut ones_buffer, - &mut view_last_layers, + cs, &input, &records, challenges, &cuda_hal, ) .map_err(|e| { ZKVMError::InvalidWitness(format!("build_tower_witness_gpu failed: {e}").into()) @@ -1288,7 +1292,11 @@ where extract_out_evals_from_gpu_towers(&prod_gpu, &logup_gpu, r_set_len); Ok::<_, ZKVMError>((prod_gpu, logup_gpu, lk_out_evals, w_out_evals, r_out_evals)) })?; - check_gpu_mem_estimation(tower_build_mem_tracker, tower_build_estimated_bytes); + crate::scheme::gpu::check_gpu_mem_estimation_with_context( + tower_build_mem_tracker, + tower_build_estimated_bytes, + Some(name), + ); log_gpu_device_state(&format!("{name}:after_build_tower_witness")); log_gpu_pool_usage(&format!("{name}:after_build_tower_witness")); @@ -1308,6 +1316,27 @@ where prod_specs: prod_gpu, logup_specs: logup_gpu, }; + let tower_prove_estimate = cuda_hal + .tower + .estimate_memory_requirements(&tower_input, NUM_FANIN); + let tower_input_live_bytes = tower_prove_estimate.prod_tower_buffer_bytes + + tower_prove_estimate.logup_tower_buffer_bytes; + let runtime_layout_prove_bytes = tower_prove_estimate + .total_bytes + .saturating_sub(tower_input_live_bytes); + let release_adjusted_prebuild_bytes = + tower_prove_prebuild_estimated_bytes / NUM_FANIN + 4 * 1024 * 1024; + let tower_prove_estimated_bytes = + runtime_layout_prove_bytes.max(release_adjusted_prebuild_bytes); + tracing::info!( + "[gpu tower][{}] refined prove_tower estimate: prebuild={:.2}MB, runtime_layout={:.2}MB, release_adjusted={:.2}MB, local={:.2}MB, tower_live={:.2}MB", + name, + tower_prove_prebuild_estimated_bytes as f64 / (1024.0 * 1024.0), + runtime_layout_prove_bytes as f64 / (1024.0 * 1024.0), + release_adjusted_prebuild_bytes as f64 / (1024.0 * 1024.0), + tower_prove_estimated_bytes as f64 / (1024.0 * 1024.0), + tower_input_live_bytes as f64 / (1024.0 * 1024.0), + ); let tower_prove_mem_tracker = crate::scheme::gpu::init_gpu_mem_tracker(&cuda_hal, "prove_tower_relation_gpu"); log_gpu_device_state(&format!("{name}:before_prove_tower")); @@ -1318,7 +1347,7 @@ where .tower .create_proof( &cuda_hal, - &tower_input, + tower_input, NUM_FANIN, basic_tr, gkr_iop::gpu::get_thread_stream().as_ref(), @@ -1329,12 +1358,12 @@ where log_gpu_pool_usage(&format!("{name}:after_prove_tower")); let rt_tower: Point = unsafe { std::mem::transmute(rt_tower_gl) }; let tower_proof: TowerProofs = unsafe { std::mem::transmute(tower_proof_gpu) }; - check_gpu_mem_estimation(tower_prove_mem_tracker, tower_prove_estimated_bytes); + crate::scheme::gpu::check_gpu_mem_estimation_with_context( + tower_prove_mem_tracker, + tower_prove_estimated_bytes, + Some(name), + ); drop(records); - drop(tower_input); - drop(big_buffers); - drop(ones_buffer); - drop(view_last_layers); log_gpu_device_state(&format!("{name}:after_drop_tower")); exit_span!(span); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index ead260f7d..4921d7f8c 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -680,7 +680,10 @@ pub fn build_main_witness< .iter() .chain(&input.structural_witness) .chain(&input.fixed) - .all(|v| { v.evaluations_len() == 1 << num_var_with_rotation }) + .all(|v| { + v.num_vars() == num_var_with_rotation + && v.evaluations_len() <= (1 << num_var_with_rotation) + }) ); // GPU memory estimation @@ -704,9 +707,28 @@ pub fn build_main_witness< // GPU memory check: validate estimation against actual usage #[cfg(feature = "gpu")] { + let input_layer_has_only_structural_inputs = composed_cs + .gkr_circuit + .as_ref() + .and_then(|circuit| circuit.layers.last()) + .is_some_and(|input_layer| input_layer.in_eval_expr.is_empty()); + let output_rows = if input_layer_has_only_structural_inputs { + input + .structural_witness + .first() + .map(|mle| mle.evaluations_len()) + } else { + None + } + .or_else(|| input.witness.first().map(|mle| mle.evaluations_len())) + .unwrap_or_else(|| input.num_instances() << composed_cs.rotation_vars().unwrap_or(0)); let estimated_bytes = - crate::scheme::gpu::estimate_main_witness_bytes(composed_cs, num_var_with_rotation); - crate::scheme::gpu::check_gpu_mem_estimation(gpu_mem_tracker, estimated_bytes); + crate::scheme::gpu::estimate_main_witness_bytes(composed_cs, output_rows); + crate::scheme::gpu::check_gpu_mem_estimation_with_context( + gpu_mem_tracker, + estimated_bytes, + gkr_circuit.layers.first().map(|layer| layer.name.as_str()), + ); } gkr_circuit_out.0.0 diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 028c2d551..e45cae99d 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -340,6 +340,13 @@ impl> vm_proof: ZKVMProof, mut transcript: impl ForkableTranscript, ) -> Result, ZKVMError> { + tracing::info!( + "verifying shard proof: expected_shard_id={}, proof_shard_id={}, chip_groups={}", + shard_id, + vm_proof.public_values.shard_id, + vm_proof.chip_proofs.len() + ); + // main invariant between opcode circuits and table circuits let mut prod_r = E::ONE; let mut prod_w = E::ONE; @@ -548,6 +555,14 @@ impl> .into(), )); }; + if q1 == E::ZERO || q2 == E::ZERO { + return Err(ZKVMError::InvalidProof( + format!( + "{shard_id}th shard {circuit_name} has zero logup denominator in lk_out_evals: {evals:?}" + ) + .into(), + )); + } Ok(p1 * q1.inverse() + p2 * q2.inverse()) }) .sum::>()?; diff --git a/gkr_iop/src/gkr/layer/gpu/utils.rs b/gkr_iop/src/gkr/layer/gpu/utils.rs index e67153cc2..3e185da47 100644 --- a/gkr_iop/src/gkr/layer/gpu/utils.rs +++ b/gkr_iop/src/gkr/layer/gpu/utils.rs @@ -251,8 +251,9 @@ pub fn build_rotation_mles_gpu panic!("should be base field"), _ => panic!("unimplemented input mle"), }; + let logical_len = 1usize << input_mle.mle.num_vars(); let mut output_buf = cuda_hal - .alloc_elems_on_device(input_buf.len(), false, stream.as_ref()) + .alloc_elems_on_device(logical_len, false, stream.as_ref()) .unwrap(); // Safety: GPU buffers are actually 'static lifetime. We only read from input_buf @@ -294,8 +295,8 @@ pub fn build_rotation_selector_gpu MultilinearExtensionGpu<'static, E> { let stream = crate::gpu::get_thread_stream(); - let total_len = wit[0].evaluations_len(); // Take first mle just to retrieve total length - assert!(total_len.is_power_of_two()); + let num_vars = wit[0].num_vars(); + let total_len = 1usize << num_vars; let mut output_buf = cuda_hal .alloc_ext_elems_on_device(total_len, false, stream.as_ref()) .unwrap(); @@ -322,10 +323,8 @@ pub fn build_rotation_selector_gpu, diff --git a/gkr_iop/src/gpu/mod.rs b/gkr_iop/src/gpu/mod.rs index 54ae3744a..62d19ca85 100644 --- a/gkr_iop/src/gpu/mod.rs +++ b/gkr_iop/src/gpu/mod.rs @@ -222,7 +222,7 @@ impl<'a, E: ExtensionField> MultilinearExtensionGpu<'a, E> { let cpu_evaluations = poly.to_cpu_vec(stream.as_ref()); let cpu_evaluations_base: Vec = unsafe { std::mem::transmute(cpu_evaluations) }; - MultilinearExtension::from_evaluations_vec( + MultilinearExtension::from_evaluations_vec_compact( self.mle.num_vars(), cpu_evaluations_base, ) @@ -230,7 +230,7 @@ impl<'a, E: ExtensionField> MultilinearExtensionGpu<'a, E> { GpuFieldType::Ext(poly) => { let cpu_evaluations = poly.to_cpu_vec(stream.as_ref()); let cpu_evaluations_ext: Vec = unsafe { std::mem::transmute(cpu_evaluations) }; - MultilinearExtension::from_evaluations_ext_vec( + MultilinearExtension::from_evaluations_ext_vec_compact( self.mle.num_vars(), cpu_evaluations_ext, ) @@ -506,13 +506,23 @@ impl> let all_witins_gpu_gl64: Vec<&MultilinearExtensionGpu> = unsafe { std::mem::transmute(all_witins_gpu) }; let all_witins_gpu_type_gl64 = all_witins_gpu_gl64.iter().map(|mle| &mle.mle).collect_vec(); + // Match the CPU witness inference path: layer outputs are materialized over + // the occupied prefix of the layer witness domain, not the maximum length of + // any referenced structural/fixed MLE. + let output_len = all_witins_gpu_gl64 + .first() + .map(|mle| mle.evaluations_len()) + .unwrap_or(0); + let output_lengths = + std::iter::repeat_n(output_len, mle_indices_per_term.len()).collect_vec(); // buffer for output witness from gpu let cuda_hal = get_cuda_hal().unwrap(); - let mut next_witness_buf = (0..num_non_zero_expr) - .map(|_| { + let mut next_witness_buf = output_lengths + .iter() + .map(|&output_len| { cuda_hal - .alloc_ext_elems_on_device(1 << num_vars, false, stream.as_ref()) + .alloc_ext_elems_on_device(output_len, false, stream.as_ref()) .map_err(|e| format!("Failed to allocate prod GPU buffer: {:?}", e)) }) .collect::, _>>() diff --git a/gkr_iop/src/utils.rs b/gkr_iop/src/utils.rs index e1c8d7453..f133ae126 100644 --- a/gkr_iop/src/utils.rs +++ b/gkr_iop/src/utils.rs @@ -5,7 +5,6 @@ use itertools::Itertools; use multilinear_extensions::{ Fixed, WitIn, WitnessId, mle::{ArcMultilinearExtension, MultilinearExtension}, - util::ceil_log2, virtual_poly::{build_eq_x_r_vec, eq_eval}, }; use p3::field::FieldAlgebra; @@ -49,7 +48,7 @@ pub fn rotation_next_base_mle<'a, E: ExtensionField>( rotate_chunk[to] = original_chunk[from]; } }); - MultilinearExtension::from_evaluation_vec_smart(mle.num_vars(), rotated_mle_evals) + MultilinearExtension::from_evaluation_vec_smart_compact(mle.num_vars(), rotated_mle_evals) } pub fn rotation_selector<'a, E: ExtensionField>( @@ -59,7 +58,6 @@ pub fn rotation_selector<'a, E: ExtensionField>( cyclic_group_log2_size: usize, total_len: usize, ) -> MultilinearExtension<'a, E> { - assert!(total_len.is_power_of_two()); let cyclic_group_size = 1 << cyclic_group_log2_size; assert!(cyclic_subgroup_size <= cyclic_group_size); let rotation_index = bh.into_iter().take(cyclic_subgroup_size).collect_vec(); @@ -74,7 +72,10 @@ pub fn rotation_selector<'a, E: ExtensionField>( rotate_chunk[to] = eq_chunk[to]; } }); - MultilinearExtension::from_evaluation_vec_smart(ceil_log2(total_len), rotated_mle_evals) + MultilinearExtension::from_evaluation_vec_smart_compact( + eq.len().ilog2() as usize, + rotated_mle_evals, + ) } /// sel(rx)