Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions extensions/sha2/circuit/cuda/include/block_hasher/variant.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ namespace sha2 {

// Common VM constants across SHA-2 variants.
inline constexpr size_t SHA2_REGISTER_READS = 3;
inline constexpr size_t SHA2_READ_SIZE = 4;
inline constexpr size_t SHA2_WRITE_SIZE = 4;
inline constexpr size_t SHA2_MAIN_READ_SIZE = 4;
inline constexpr size_t SHA2_READ_SIZE = 8;
inline constexpr size_t SHA2_WRITE_SIZE = 8;
Comment on lines +10 to +11
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm we should consolidate these constants at some point so there's a single source of truth. feels like this should just be default_block_size


template <
typename WordT,
Expand Down Expand Up @@ -43,8 +42,8 @@ struct Sha2VariantBase {

static constexpr size_t NUM_READ_ROWS = BLOCK_U8S / SHA2_READ_SIZE;
static constexpr size_t STATE_BYTES = HASH_WORDS * WORD_U8S;
static constexpr size_t BLOCK_READS = BLOCK_U8S / SHA2_MAIN_READ_SIZE;
static constexpr size_t STATE_READS = STATE_BYTES / SHA2_MAIN_READ_SIZE;
static constexpr size_t BLOCK_READS = BLOCK_U8S / SHA2_READ_SIZE;
static constexpr size_t STATE_READS = STATE_BYTES / SHA2_READ_SIZE;
static constexpr size_t STATE_WRITES = STATE_BYTES / SHA2_WRITE_SIZE;
static constexpr size_t TIMESTAMP_DELTA =
BLOCK_READS + STATE_READS + STATE_WRITES + SHA2_REGISTER_READS;
Expand Down
6 changes: 3 additions & 3 deletions extensions/sha2/circuit/cuda/include/main/columns.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ template <typename T> struct Sha2MainInstructionCols {
T dst_reg_ptr;
T state_reg_ptr;
T input_reg_ptr;
T dst_ptr_limbs[RV32_REGISTER_NUM_LIMBS];
T state_ptr_limbs[RV32_REGISTER_NUM_LIMBS];
T input_ptr_limbs[RV32_REGISTER_NUM_LIMBS];
T dst_ptr_limbs[RV64_WORD_NUM_LIMBS];
T state_ptr_limbs[RV64_WORD_NUM_LIMBS];
T input_ptr_limbs[RV64_WORD_NUM_LIMBS];
};

template <typename V, typename T> struct Sha2MainMemoryCols {
Expand Down
16 changes: 8 additions & 8 deletions extensions/sha2/circuit/cuda/src/sha2_main.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ static __device__ __forceinline__ void sha2_main_row_body(
SHA2_MAIN_WRITE_INSTR(V, row, state_reg_ptr, header->state_reg_ptr);
SHA2_MAIN_WRITE_INSTR(V, row, input_reg_ptr, header->input_reg_ptr);

uint8_t dst_ptr_bytes[RV32_REGISTER_NUM_LIMBS];
uint8_t state_ptr_bytes[RV32_REGISTER_NUM_LIMBS];
uint8_t input_ptr_bytes[RV32_REGISTER_NUM_LIMBS];
uint8_t dst_ptr_bytes[RV64_WORD_NUM_LIMBS];
uint8_t state_ptr_bytes[RV64_WORD_NUM_LIMBS];
uint8_t input_ptr_bytes[RV64_WORD_NUM_LIMBS];
memcpy(dst_ptr_bytes, &header->dst_ptr, sizeof(uint32_t));
memcpy(state_ptr_bytes, &header->state_ptr, sizeof(uint32_t));
memcpy(input_ptr_bytes, &header->input_ptr, sizeof(uint32_t));
Expand All @@ -66,12 +66,12 @@ static __device__ __forceinline__ void sha2_main_row_body(

// Range checks on top limbs
uint8_t needs_range_check[4] = {
dst_ptr_bytes[RV32_REGISTER_NUM_LIMBS - 1],
state_ptr_bytes[RV32_REGISTER_NUM_LIMBS - 1],
input_ptr_bytes[RV32_REGISTER_NUM_LIMBS - 1],
input_ptr_bytes[RV32_REGISTER_NUM_LIMBS - 1],
dst_ptr_bytes[RV64_WORD_NUM_LIMBS - 1],
state_ptr_bytes[RV64_WORD_NUM_LIMBS - 1],
input_ptr_bytes[RV64_WORD_NUM_LIMBS - 1],
input_ptr_bytes[RV64_WORD_NUM_LIMBS - 1],
};
uint32_t shift = 1u << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - ptr_max_bits);
uint32_t shift = 1u << (RV64_WORD_NUM_LIMBS * RV64_CELL_BITS - ptr_max_bits);
for (int i = 0; i < 4; i += 2) {
bitwise_lookup.add_range(
static_cast<uint32_t>(needs_range_check[i]) * shift,
Expand Down
8 changes: 4 additions & 4 deletions extensions/sha2/circuit/src/extension/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,18 @@ impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Sha2> for S
}
}

pub struct Sha2Rv32GpuBuilder;
pub struct Sha2Rv64GpuBuilder;

type E = GpuBabyBearPoseidon2Engine;

impl VmBuilder<E> for Sha2Rv32GpuBuilder {
type VmConfig = Sha2Rv32Config;
impl VmBuilder<E> for Sha2Rv64GpuBuilder {
type VmConfig = Sha2Rv64Config;
type SystemChipInventory = SystemChipInventoryGPU;
type RecordArena = DenseRecordArena;

fn create_chip_complex(
&self,
config: &Sha2Rv32Config,
config: &Sha2Rv64Config,
circuit: AirInventory<<E as StarkEngine>::SC>,
device_ctx: &openvm_stark_backend::EngineDeviceCtx<E>,
) -> Result<
Expand Down
2 changes: 1 addition & 1 deletion extensions/sha2/circuit/src/extension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ cfg_if::cfg_if! {
mod cuda;
pub use self::cuda::*;
pub use self::cuda::Sha2GpuProverExt as Sha2ProverExt;
pub use self::cuda::Sha2Rv32GpuBuilder as Sha2Rv64Builder;
pub use self::cuda::Sha2Rv64GpuBuilder as Sha2Rv64Builder;
} else {
pub use self::Sha2CpuProverExt as Sha2ProverExt;
pub use self::Sha2Rv64CpuBuilder as Sha2Rv64Builder;
Expand Down
Loading