Skip to content

Unify Gemm#411

Draft
eugenebokhan wants to merge 75 commits into
mainfrom
unify-matmul
Draft

Unify Gemm#411
eugenebokhan wants to merge 75 commits into
mainfrom
unify-matmul

Conversation

@eugenebokhan
Copy link
Copy Markdown
Contributor

No description provided.

leftover references from the main merge — unified gemm and matmul kernel
signatures, plus the size() deref tweak in dense_buffer.rs.
datatype now derives serialize/deserialize directly with per-variant
#[serde(rename)] ("bfloat16", "float16", "int8", ...), so the wire format
matches what configdatatype produced. configdatatype and config/common.rs are
deleted; every caller now uses crate::DataType.

quantizationmode variants renamed UINT4/INT8/UINT8 -> U4/I8/U8 to match the
datatype convention, with #[serde(rename)] preserving the json wire format.
Includes the matching kernel-side updates in quant_embedding.metal and the
regenerated quantization.h header.
the metal codegen emits the full path
crate::backends::common::gpu_types::quantization::QuantizationMode for the
SPECIALIZE arg, and the build script's cross-backend signature equality check
requires both backends to agree on the type string.
gpu_types:
- promote QuantizedFormat (MLX | AWQ) into gpu_types/, replacing the kernel-
  layer QuantizedMatmulType { Mlx, ZeroPoint }; loader and matmul callers
  speak the same type now
- promote GemmTilingConfig into gpu_types/unified_gemm/; collapse the three
  separate Threadgroup/Simdgroup/Fragment tile structs into the one
  11-u32-field aggregate that already lived kernel-side
- drop the BitsPerWeight enum and its companion bits_per_weight.h header;
  bit width is derived on demand from QuantizationMode via DataType::size_in_bits

kernel layer:
- flatten WeightsStorageFormat to FullPrecision | Quantized { format, mode,
  group_size }; collapse the quantized_storage/ subdirectory into a single
  weights_storage_format.rs at gemm/
- gemm.metal kernel signature: rename a/b/d -> activations/weights/result;
  drop the separate full-precision-only b buffer (weights is always present
  and reinterpreted in body); add scales/biases/zero_points OPTIONAL slots
  gated on use_mlx_quant / use_zero_points bool SPECIALIZEs; collapse the
  three tile constant-buffer args into one GemmTilingConfig
- introduce GemmWeightsBuffers enum (FullPrecision | Mlx | Awq) so the
  Rust-side encode() takes one typed bundle instead of four loose buffers
- inline GemmTile::validate into UnifiedGemmSpecialization::validate
# Conflicts:
#	Cargo.lock
#	crates/backend-uzu/build/cpu/compiler.rs
#	crates/backend-uzu/build/metal/bindgen.rs
#	crates/backend-uzu/build/metal/compiler.rs
#	crates/backend-uzu/build/metal/mod.rs
#	crates/backend-uzu/build/metal/wrapper.rs
#	crates/backend-uzu/src/backends/common/gpu_types/mod.rs
#	crates/backend-uzu/src/backends/common/kernel/quant_matmul.rs
#	crates/backend-uzu/src/backends/metal/metal_extensions/function_constant_values_extensions_set_value.rs
#	crates/backend-uzu/src/encodable_block/embedding.rs
#	crates/backend-uzu/src/encodable_block/linear/quantized.rs
#	crates/backend-uzu/tests/unit/backends/common/kernel/quant_matmul_test.rs
# Conflicts:
#	crates/backend-uzu/src/backends/common/kernel/quant_matmul.rs
#	crates/backend-uzu/src/encodable_block/embedding.rs
#	crates/backend-uzu/src/encodable_block/linear/quantized.rs
#	crates/backend-uzu/tests/unit/backends/common/kernel/quant_matmul_test.rs
#	crates/backend-uzu/tests/unit/kernel/quant_matmul/qmm_transposed_test.rs
#	crates/backend-uzu/tests/unit/kernel/quant_matmul/qmv_fast_test.rs
#	crates/backend-uzu/tests/unit/kernel/quant_matmul/qmv_test.rs
# Conflicts:
#	crates/backend-uzu/src/backends/metal/kernel/generated/quantization_method.h
#	crates/backend-uzu/tests/unit/kernel/quant_matmul/qmm_transposed_test.rs
#	crates/backend-uzu/tests/unit/kernel/quant_matmul/qmv_fast_test.rs
# Conflicts:
#	crates/backend-uzu/src/backends/metal/kernel/matmul/gemm.metal
#	crates/backend-uzu/src/backends/metal/kernel/matmul/gemm.rs
#	crates/backend-uzu/src/backends/metal/kernel/matmul/mod.rs
#	crates/backend-uzu/tests/performance/matmul/bench.rs
#	crates/backend-uzu/tests/unit/kernel/matmul/gemm_mpp_test.rs
#	crates/backend-uzu/tests/unit/kernel/matmul/gemm_test.rs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant