From f68a3af915b2602e40bac5c84c317b3422a81dba Mon Sep 17 00:00:00 2001 From: 876pol Date: Thu, 23 Apr 2026 21:53:40 +0000 Subject: [PATCH] feat: rv64 cuda sha2 --- .../cuda/include/block_hasher/variant.cuh | 9 ++++----- .../sha2/circuit/cuda/include/main/columns.cuh | 6 +++--- extensions/sha2/circuit/cuda/src/sha2_main.cu | 16 ++++++++-------- extensions/sha2/circuit/src/extension/cuda.rs | 8 ++++---- extensions/sha2/circuit/src/extension/mod.rs | 2 +- 5 files changed, 20 insertions(+), 21 deletions(-) diff --git a/extensions/sha2/circuit/cuda/include/block_hasher/variant.cuh b/extensions/sha2/circuit/cuda/include/block_hasher/variant.cuh index ec36a9dd1c..b95482214b 100644 --- a/extensions/sha2/circuit/cuda/include/block_hasher/variant.cuh +++ b/extensions/sha2/circuit/cuda/include/block_hasher/variant.cuh @@ -7,9 +7,8 @@ namespace sha2 { // Common VM constants across SHA-2 variants. inline constexpr size_t SHA2_REGISTER_READS = 3; -inline constexpr size_t SHA2_READ_SIZE = 4; -inline constexpr size_t SHA2_WRITE_SIZE = 4; -inline constexpr size_t SHA2_MAIN_READ_SIZE = 4; +inline constexpr size_t SHA2_READ_SIZE = 8; +inline constexpr size_t SHA2_WRITE_SIZE = 8; template < typename WordT, @@ -43,8 +42,8 @@ struct Sha2VariantBase { static constexpr size_t NUM_READ_ROWS = BLOCK_U8S / SHA2_READ_SIZE; static constexpr size_t STATE_BYTES = HASH_WORDS * WORD_U8S; - static constexpr size_t BLOCK_READS = BLOCK_U8S / SHA2_MAIN_READ_SIZE; - static constexpr size_t STATE_READS = STATE_BYTES / SHA2_MAIN_READ_SIZE; + static constexpr size_t BLOCK_READS = BLOCK_U8S / SHA2_READ_SIZE; + static constexpr size_t STATE_READS = STATE_BYTES / SHA2_READ_SIZE; static constexpr size_t STATE_WRITES = STATE_BYTES / SHA2_WRITE_SIZE; static constexpr size_t TIMESTAMP_DELTA = BLOCK_READS + STATE_READS + STATE_WRITES + SHA2_REGISTER_READS; diff --git a/extensions/sha2/circuit/cuda/include/main/columns.cuh b/extensions/sha2/circuit/cuda/include/main/columns.cuh index 88e054d4e3..892ffc1f96 100644 --- a/extensions/sha2/circuit/cuda/include/main/columns.cuh +++ b/extensions/sha2/circuit/cuda/include/main/columns.cuh @@ -24,9 +24,9 @@ template struct Sha2MainInstructionCols { T dst_reg_ptr; T state_reg_ptr; T input_reg_ptr; - T dst_ptr_limbs[RV32_REGISTER_NUM_LIMBS]; - T state_ptr_limbs[RV32_REGISTER_NUM_LIMBS]; - T input_ptr_limbs[RV32_REGISTER_NUM_LIMBS]; + T dst_ptr_limbs[RV64_WORD_NUM_LIMBS]; + T state_ptr_limbs[RV64_WORD_NUM_LIMBS]; + T input_ptr_limbs[RV64_WORD_NUM_LIMBS]; }; template struct Sha2MainMemoryCols { diff --git a/extensions/sha2/circuit/cuda/src/sha2_main.cu b/extensions/sha2/circuit/cuda/src/sha2_main.cu index c2a61c09d5..6158ed5db7 100644 --- a/extensions/sha2/circuit/cuda/src/sha2_main.cu +++ b/extensions/sha2/circuit/cuda/src/sha2_main.cu @@ -53,9 +53,9 @@ static __device__ __forceinline__ void sha2_main_row_body( SHA2_MAIN_WRITE_INSTR(V, row, state_reg_ptr, header->state_reg_ptr); SHA2_MAIN_WRITE_INSTR(V, row, input_reg_ptr, header->input_reg_ptr); - uint8_t dst_ptr_bytes[RV32_REGISTER_NUM_LIMBS]; - uint8_t state_ptr_bytes[RV32_REGISTER_NUM_LIMBS]; - uint8_t input_ptr_bytes[RV32_REGISTER_NUM_LIMBS]; + uint8_t dst_ptr_bytes[RV64_WORD_NUM_LIMBS]; + uint8_t state_ptr_bytes[RV64_WORD_NUM_LIMBS]; + uint8_t input_ptr_bytes[RV64_WORD_NUM_LIMBS]; memcpy(dst_ptr_bytes, &header->dst_ptr, sizeof(uint32_t)); memcpy(state_ptr_bytes, &header->state_ptr, sizeof(uint32_t)); memcpy(input_ptr_bytes, &header->input_ptr, sizeof(uint32_t)); @@ -66,12 +66,12 @@ static __device__ __forceinline__ void sha2_main_row_body( // Range checks on top limbs uint8_t needs_range_check[4] = { - dst_ptr_bytes[RV32_REGISTER_NUM_LIMBS - 1], - state_ptr_bytes[RV32_REGISTER_NUM_LIMBS - 1], - input_ptr_bytes[RV32_REGISTER_NUM_LIMBS - 1], - input_ptr_bytes[RV32_REGISTER_NUM_LIMBS - 1], + dst_ptr_bytes[RV64_WORD_NUM_LIMBS - 1], + state_ptr_bytes[RV64_WORD_NUM_LIMBS - 1], + input_ptr_bytes[RV64_WORD_NUM_LIMBS - 1], + input_ptr_bytes[RV64_WORD_NUM_LIMBS - 1], }; - uint32_t shift = 1u << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - ptr_max_bits); + uint32_t shift = 1u << (RV64_WORD_NUM_LIMBS * RV64_CELL_BITS - ptr_max_bits); for (int i = 0; i < 4; i += 2) { bitwise_lookup.add_range( static_cast(needs_range_check[i]) * shift, diff --git a/extensions/sha2/circuit/src/extension/cuda.rs b/extensions/sha2/circuit/src/extension/cuda.rs index 2f3551be95..4d2e092a7d 100644 --- a/extensions/sha2/circuit/src/extension/cuda.rs +++ b/extensions/sha2/circuit/src/extension/cuda.rs @@ -77,18 +77,18 @@ impl VmProverExtension for S } } -pub struct Sha2Rv32GpuBuilder; +pub struct Sha2Rv64GpuBuilder; type E = GpuBabyBearPoseidon2Engine; -impl VmBuilder for Sha2Rv32GpuBuilder { - type VmConfig = Sha2Rv32Config; +impl VmBuilder for Sha2Rv64GpuBuilder { + type VmConfig = Sha2Rv64Config; type SystemChipInventory = SystemChipInventoryGPU; type RecordArena = DenseRecordArena; fn create_chip_complex( &self, - config: &Sha2Rv32Config, + config: &Sha2Rv64Config, circuit: AirInventory<::SC>, device_ctx: &openvm_stark_backend::EngineDeviceCtx, ) -> Result< diff --git a/extensions/sha2/circuit/src/extension/mod.rs b/extensions/sha2/circuit/src/extension/mod.rs index 97525666f8..8448ef2946 100644 --- a/extensions/sha2/circuit/src/extension/mod.rs +++ b/extensions/sha2/circuit/src/extension/mod.rs @@ -35,7 +35,7 @@ cfg_if::cfg_if! { mod cuda; pub use self::cuda::*; pub use self::cuda::Sha2GpuProverExt as Sha2ProverExt; - pub use self::cuda::Sha2Rv32GpuBuilder as Sha2Rv64Builder; + pub use self::cuda::Sha2Rv64GpuBuilder as Sha2Rv64Builder; } else { pub use self::Sha2CpuProverExt as Sha2ProverExt; pub use self::Sha2Rv64CpuBuilder as Sha2Rv64Builder;