From 11b947ab81d8f9f054762bb0db4b231d844bb6cc Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Fri, 1 May 2026 22:04:15 -0700 Subject: [PATCH 01/12] Upgrade XLA to 20a3e2cdd937 --- crates/ryft-xla-sys/WORKSPACE | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/ryft-xla-sys/WORKSPACE b/crates/ryft-xla-sys/WORKSPACE index 7fcf4794..4635571b 100644 --- a/crates/ryft-xla-sys/WORKSPACE +++ b/crates/ryft-xla-sys/WORKSPACE @@ -6,9 +6,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") # XLA # ---------------------------------------------------- -XLA_COMMIT = "7b1be14958aac5c83f1b9f7bcdfc51fdbd29acba" +XLA_COMMIT = "20a3e2cdd937424f351533165b3ac8e0589e5957" -XLA_SHA256 = "9ed81c034535e6398e8463882d6b517e303e0abdc2d8147dd66fd245ace33b54" +XLA_SHA256 = "54f6bb1a23ee2fd753901d211e87526f84e56add0556d802f5319ace74750cef" JAX_COMMIT = "a33ed614c58ee8a10d0b7536c50c2609c38500c1" From 33f488dcc8ec355ff308a42ccf650a2b45af07a2 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Fri, 1 May 2026 22:17:41 -0700 Subject: [PATCH 02/12] . --- crates/ryft-pjrt/src/devices.rs | 23 ++++++++++ crates/ryft-pjrt/src/lib.rs | 1 + crates/ryft-pjrt/src/versions.rs | 4 +- crates/ryft-xla-sys/WORKSPACE | 23 +++++++--- crates/ryft-xla-sys/src/bindings.rs | 68 +++++++++++++++++++++++++++++ crates/ryft-xla-sys/src/protos.rs | 51 ++++++++++++++++++---- 6 files changed, 153 insertions(+), 17 deletions(-) diff --git a/crates/ryft-pjrt/src/devices.rs b/crates/ryft-pjrt/src/devices.rs index 858712c2..e9f981cd 100644 --- a/crates/ryft-pjrt/src/devices.rs +++ b/crates/ryft-pjrt/src/devices.rs @@ -254,6 +254,13 @@ impl Device<'_> { ) } + /// Clears the memory/allocator statistics for this [`Device`]. Note that not all PJRT [`Plugin`]s support this + /// functionality, and this function may return [`Error::Unimplemented`] for plugins where it is not supported. + pub fn clear_memory_statistics(&self) -> Result<(), Error> { + use ffi::PJRT_Device_ClearMemoryStats_Args; + invoke_pjrt_api_error_fn!(self.api(), PJRT_Device_ClearMemoryStats, { device = self.to_c_api() }) + } + /// _Poisons_ the earliest execution on this [`Device`] with the provided launch ID if it is not finished /// yet (i.e., sets the resulting [`Buffer`](crate::Buffer) to an error buffer; refer to the documentation of /// [`Client::error_buffer`] for more information on buffer _poisoning_). Returns `true` if the execution was @@ -983,6 +990,22 @@ pub(crate) mod ffi { pub type PJRT_Device_MemoryStats = unsafe extern "C" fn(args: *mut PJRT_Device_MemoryStats_Args) -> *mut PJRT_Error; + #[repr(C)] + pub struct PJRT_Device_ClearMemoryStats_Args { + pub struct_size: usize, + pub extension_start: *mut PJRT_Extension_Base, + pub device: *mut PJRT_Device, + } + + impl PJRT_Device_ClearMemoryStats_Args { + pub fn new(device: *mut PJRT_Device) -> Self { + Self { struct_size: size_of::(), extension_start: std::ptr::null_mut(), device } + } + } + + pub type PJRT_Device_ClearMemoryStats = + unsafe extern "C" fn(args: *mut PJRT_Device_ClearMemoryStats_Args) -> *mut PJRT_Error; + #[repr(C)] pub struct PJRT_Device_PoisonExecution_Args { pub struct_size: usize, diff --git a/crates/ryft-pjrt/src/lib.rs b/crates/ryft-pjrt/src/lib.rs index 3cfd8f27..c91dd8da 100644 --- a/crates/ryft-pjrt/src/lib.rs +++ b/crates/ryft-pjrt/src/lib.rs @@ -361,6 +361,7 @@ pub(crate) mod ffi { pub PJRT_Error_ForEachPayload: Option, pub PJRT_TopologyDescription_Fingerprint: Option, pub PJRT_Executable_ParameterMemoryKinds: Option, + pub PJRT_Device_ClearMemoryStats: Option, } } diff --git a/crates/ryft-pjrt/src/versions.rs b/crates/ryft-pjrt/src/versions.rs index 4342b75e..0b3f8763 100644 --- a/crates/ryft-pjrt/src/versions.rs +++ b/crates/ryft-pjrt/src/versions.rs @@ -31,7 +31,7 @@ pub(crate) mod ffi { use crate::ffi::PJRT_Extension_Base; pub const PJRT_API_MAJOR: u32 = 0; - pub const PJRT_API_MINOR: u32 = 104; + pub const PJRT_API_MINOR: u32 = 107; #[repr(C)] pub struct PJRT_Api_Version { @@ -61,6 +61,6 @@ mod tests { #[test] fn test_version_display() { - assert_eq!(format!("{VERSION}"), "0.104"); + assert_eq!(format!("{VERSION}"), "0.107"); } } diff --git a/crates/ryft-xla-sys/WORKSPACE b/crates/ryft-xla-sys/WORKSPACE index 4635571b..a2c76f2a 100644 --- a/crates/ryft-xla-sys/WORKSPACE +++ b/crates/ryft-xla-sys/WORKSPACE @@ -38,10 +38,10 @@ http_archive( http_archive( name = "rules_ml_toolchain", - sha256 = "f2c924e85a22ba2eaa0c08657e5f5467fedbc3d0506f9cc0c69dd97ed9fbaf28", - strip_prefix = "rules_ml_toolchain-99c43dfe995a0e81c767d5b6d686191992672fe6", + sha256 = "0b42f693a60c6050d87db1e0a0eaeb84ab3f54191fce094d86334faedc807da0", + strip_prefix = "rules_ml_toolchain-398d613aea7a4c294da49b79a6d6f3f8732bd84c", urls = [ - "https://github.com/google-ml-infra/rules_ml_toolchain/archive/99c43dfe995a0e81c767d5b6d686191992672fe6.tar.gz", + "https://github.com/google-ml-infra/rules_ml_toolchain/archive/398d613aea7a4c294da49b79a6d6f3f8732bd84c.tar.gz", ], ) @@ -74,19 +74,20 @@ load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules") python_init_rules() -load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories") +load("@rules_ml_toolchain//py:python_init_repositories.bzl", "python_init_repositories") python_init_repositories( requirements = { "3.11": "@xla//:requirements_lock_3_11.txt", + "3.12": "@xla//:requirements_lock_3_12.txt", }, ) -load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains") +load("@rules_ml_toolchain//py:python_register_toolchain.bzl", "python_register_toolchain") -python_init_toolchains() +python_register_toolchain() -load("@xla//third_party/py:python_init_pip.bzl", "python_init_pip") +load("@rules_ml_toolchain//py:python_init_pip.bzl", "python_init_pip") python_init_pip() @@ -179,6 +180,14 @@ nvshmem_redist_init_repository( nvshmem_redistributions = NVSHMEM_REDISTRIBUTIONS, ) +load("@xla//build_tools/pjrt_wheels:nightly.bzl", "nightly_timestamp_repo") + +nightly_timestamp_repo(name = "nightly_timestamp") + +load("@xla//build_tools/pjrt_wheels:release_candidate.bzl", "rc_number_repo") + +rc_number_repo(name = "rc_number") + load("@jax//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") flatbuffers() diff --git a/crates/ryft-xla-sys/src/bindings.rs b/crates/ryft-xla-sys/src/bindings.rs index c1d60f6f..b6215228 100644 --- a/crates/ryft-xla-sys/src/bindings.rs +++ b/crates/ryft-xla-sys/src/bindings.rs @@ -3203,6 +3203,74 @@ unsafe extern "C" { unsafe extern "C" { pub fn stablehloResultAccuracyAttrGetMode(attr: MlirAttribute) -> MlirAttribute; } +unsafe extern "C" { + pub fn stablehloSubAxisInfoAttrGet(ctx: MlirContext, preSize: i64, size: i64) -> MlirAttribute; +} +unsafe extern "C" { + pub fn stablehloAttributeIsASubAxisInfoAttr(attr: MlirAttribute) -> bool; +} +unsafe extern "C" { + pub fn stablehloSubAxisInfoAttrGetPreSize(attr: MlirAttribute) -> i64; +} +unsafe extern "C" { + pub fn stablehloSubAxisInfoAttrGetSize(attr: MlirAttribute) -> i64; +} +unsafe extern "C" { + pub fn stablehloAxisRefAttrGet( + ctx: MlirContext, + name: MlirStringRef, + subAxisInfo: MlirAttribute, + ) -> MlirAttribute; +} +unsafe extern "C" { + pub fn stablehloAttributeIsAnAxisRefAttr(attr: MlirAttribute) -> bool; +} +unsafe extern "C" { + pub fn stablehloAxisRefAttrGetName(attr: MlirAttribute) -> MlirStringRef; +} +unsafe extern "C" { + pub fn stablehloAxisRefAttrGetSubAxisInfo(attr: MlirAttribute) -> MlirAttribute; +} +unsafe extern "C" { + pub fn stablehloReplicaGroupMeshAxesAttrGet( + ctx: MlirContext, + mesh: MlirAttribute, + axes: MlirAttribute, + ) -> MlirAttribute; +} +unsafe extern "C" { + pub fn stablehloAttributeIsAReplicaGroupMeshAxesAttr(attr: MlirAttribute) -> bool; +} +unsafe extern "C" { + pub fn stablehloReplicaGroupMeshAxesAttrGetMesh(attr: MlirAttribute) -> MlirAttribute; +} +unsafe extern "C" { + pub fn stablehloReplicaGroupMeshAxesAttrGetAxes(attr: MlirAttribute) -> MlirAttribute; +} +unsafe extern "C" { + pub fn stablehloMeshAxisAttrGet(ctx: MlirContext, name: MlirStringRef, size: i64) -> MlirAttribute; +} +unsafe extern "C" { + pub fn stablehloAttributeIsAMeshAxisAttr(attr: MlirAttribute) -> bool; +} +unsafe extern "C" { + pub fn stablehloMeshAxisAttrGetName(attr: MlirAttribute) -> MlirStringRef; +} +unsafe extern "C" { + pub fn stablehloMeshAxisAttrGetSize(attr: MlirAttribute) -> i64; +} +unsafe extern "C" { + pub fn stablehloMeshAttrGet(ctx: MlirContext, axes: MlirAttribute, deviceIds: MlirAttribute) -> MlirAttribute; +} +unsafe extern "C" { + pub fn stablehloAttributeIsAMeshAttr(attr: MlirAttribute) -> bool; +} +unsafe extern "C" { + pub fn stablehloMeshAttrGetAxes(attr: MlirAttribute) -> MlirAttribute; +} +unsafe extern "C" { + pub fn stablehloMeshAttrGetDeviceIds(attr: MlirAttribute) -> MlirAttribute; +} unsafe extern "C" { pub fn mlirGetDialectHandle__stablehlo__() -> MlirDialectHandle; } diff --git a/crates/ryft-xla-sys/src/protos.rs b/crates/ryft-xla-sys/src/protos.rs index 88a165c8..e2612d58 100644 --- a/crates/ryft-xla-sys/src/protos.rs +++ b/crates/ryft-xla-sys/src/protos.rs @@ -934,6 +934,25 @@ pub enum CollectiveOperationType { AllCollectives = 8, } +/// Memory mode for XLA GPU collective operations. +/// +/// This type corresponds to `DebugOptions.CollectivesMode` in [XLA](https://github.com/openxla/xla). +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Enumeration)] +#[repr(i32)] +pub enum CollectivesMode { + /// Invalid or unrecognized collective memory mode. + Invalid = 0, + + /// Collective operations use per-device private memory. + PrivateMemory = 1, + + /// Collective operations use symmetric device memory across participating ranks. + SymmetricMemory = 2, + + /// Collective operations use peer memory access directly. + PeerMemory = 3, +} + /// GPU command types for command buffer recording and execution. Command buffers allow batching multiple GPU operations /// together for more efficient execution. This enum identifies the type of command being recorded. /// @@ -1428,6 +1447,10 @@ pub struct DebugOptions { #[prost(int64, optional, tag = "237")] pub xla_gpu_collective_permute_decomposer_threshold: Option, + /// Memory mode for collective-permute buffers. + #[prost(enumeration = "CollectivesMode", optional, tag = "473")] + pub xla_gpu_collective_permute_mode: Option, + /// If `true`, collective cliques will not be locked for each XLA GPU execution, using permanent cliques instead. /// This disables deadlock prevention. #[prost(bool, optional, tag = "354")] @@ -1746,14 +1769,6 @@ pub struct DebugOptions { #[prost(bool, optional, tag = "388")] pub xla_gpu_experimental_enable_nvshmem: Option, - /// If `true`, OneHot patterns will be rewritten into Gather operations during GPU lowering. - #[prost(bool, optional, tag = "458")] - pub xla_gpu_experimental_enable_onehot_rewriter: Option, - - /// If `true`, GEMMs that underutilize the GPU will be split along the K dimension. - #[prost(bool, optional, tag = "386")] - pub xla_gpu_experimental_enable_split_k_rewrite: Option, - /// If `true`, fusion for subchannel dequantization sequences will be enabled. #[prost(bool, optional, tag = "368")] pub xla_gpu_experimental_enable_subchannel_dequantisation_fusion: Option, @@ -1766,10 +1781,22 @@ pub struct DebugOptions { #[prost(bool, optional, tag = "421")] pub xla_gpu_experimental_enable_triton_warp_specialization: Option, + /// Forces a specific split-K value. Zero means the heuristic is used. + #[prost(int32, optional, tag = "472")] + pub xla_gpu_experimental_force_split_k: Option, + + /// If `true`, the GEMM fusion v2 pass will build Triton fusions. + #[prost(bool, optional, tag = "475")] + pub xla_gpu_experimental_gemm_fusion_v2: Option, + /// Maximum unroll factor to allow on Blackwell architectures. #[prost(int32, optional, tag = "459")] pub xla_gpu_experimental_max_unroll_factor: Option, + /// If `true`, GEMM and convolution autotuning will run after fusion passes. + #[prost(bool, optional, tag = "477")] + pub xla_gpu_experimental_move_gemm_conv_autotuner: Option, + /// If `true`, sub-byte dot operands will be laid out along the contracting (K) dimension. #[prost(bool, optional, tag = "362")] pub xla_gpu_experimental_pack_dot_operands_along_k_dimension: Option, @@ -2159,6 +2186,10 @@ pub struct DebugOptions { #[prost(bool, optional, tag = "131")] pub xla_dump_include_timestamp: Option, + /// If `true`, HLO modules will be dumped in a subfolder named after the module. + #[prost(bool, optional, tag = "502")] + pub xla_dump_hlo_to_subfolder: Option, + /// Maximum number of HLO modules to dump per directory. A negative value means unbounded. #[prost(int32, optional, tag = "132")] pub xla_dump_max_hlo_modules: Option, @@ -2311,6 +2342,10 @@ pub struct DebugOptions { #[prost(bool, optional, tag = "456")] pub xla_gpu_experimental_enable_tiling_propagation: Option, + /// Cost model options for experimental GEMM fusion tiling decisions. + #[prost(map = "string, string", tag = "474")] + pub xla_gpu_experimental_cost_model_gemm_tiling_options: HashMap, + /// Extra backend-specific options as key-value pairs. #[prost(map = "string, string", tag = "500")] pub xla_backend_extra_options: HashMap, From d0138ddb309e3e57cd317cb14d9c02f9a01343a2 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Fri, 1 May 2026 22:23:23 -0700 Subject: [PATCH 03/12] . --- crates/ryft-xla-sys/src/bindings.rs | 6 +- .../src/c++/mlir/dialects/triton.cc | 59 +++++++++++++++++-- .../src/c++/mlir/dialects/triton.h | 10 +++- .../src/mlir/dialects/triton/tt.rs | 11 +++- 4 files changed, 74 insertions(+), 12 deletions(-) diff --git a/crates/ryft-xla-sys/src/bindings.rs b/crates/ryft-xla-sys/src/bindings.rs index b6215228..b21005a6 100644 --- a/crates/ryft-xla-sys/src/bindings.rs +++ b/crates/ryft-xla-sys/src/bindings.rs @@ -3216,11 +3216,7 @@ unsafe extern "C" { pub fn stablehloSubAxisInfoAttrGetSize(attr: MlirAttribute) -> i64; } unsafe extern "C" { - pub fn stablehloAxisRefAttrGet( - ctx: MlirContext, - name: MlirStringRef, - subAxisInfo: MlirAttribute, - ) -> MlirAttribute; + pub fn stablehloAxisRefAttrGet(ctx: MlirContext, name: MlirStringRef, subAxisInfo: MlirAttribute) -> MlirAttribute; } unsafe extern "C" { pub fn stablehloAttributeIsAnAxisRefAttr(attr: MlirAttribute) -> bool; diff --git a/crates/ryft-xla-sys/src/c++/mlir/dialects/triton.cc b/crates/ryft-xla-sys/src/c++/mlir/dialects/triton.cc index 7ea78a71..55248869 100644 --- a/crates/ryft-xla-sys/src/c++/mlir/dialects/triton.cc +++ b/crates/ryft-xla-sys/src/c++/mlir/dialects/triton.cc @@ -83,15 +83,64 @@ bool mlirTypeIsATritonTtTensorDescType(MlirType type) { return type.ptr != nullptr && llvm::isa(unwrap(type)); } -MlirType mlirTritonTtTensorDescTypeGet(MlirType blockType) { - if (blockType.ptr == nullptr) { +MlirType mlirTritonTtTensorDescTypeGet( + const int64_t *shape, + intptr_t shapeSize, + MlirType elementType, + MlirAttribute sharedLayout) { + if (shape == nullptr || shapeSize < 0 || elementType.ptr == nullptr) { return {nullptr}; } - auto rankedTensorType = llvm::dyn_cast(unwrap(blockType)); - if (!rankedTensorType) { + llvm::ArrayRef shapeRef(shape, static_cast(shapeSize)); + mlir::Attribute sharedLayoutAttribute; + if (sharedLayout.ptr != nullptr) { + sharedLayoutAttribute = unwrap(sharedLayout); + } + return wrap(mlir::triton::TensorDescType::get(shapeRef, unwrap(elementType), sharedLayoutAttribute)); +} + +intptr_t mlirTritonTtTensorDescTypeGetNumDims(MlirType type) { + if (type.ptr == nullptr) { + return 0; + } + auto tensorDescType = llvm::dyn_cast(unwrap(type)); + if (!tensorDescType) { + return 0; + } + return static_cast(tensorDescType.getShape().size()); +} + +int64_t mlirTritonTtTensorDescTypeGetDimSize(MlirType type, intptr_t dimension) { + if (type.ptr == nullptr) { + return 0; + } + auto tensorDescType = llvm::dyn_cast(unwrap(type)); + if (!tensorDescType || dimension < 0 || dimension >= static_cast(tensorDescType.getShape().size())) { + return 0; + } + return tensorDescType.getShape()[static_cast(dimension)]; +} + +MlirType mlirTritonTtTensorDescTypeGetElementType(MlirType type) { + if (type.ptr == nullptr) { + return {nullptr}; + } + auto tensorDescType = llvm::dyn_cast(unwrap(type)); + if (!tensorDescType) { + return {nullptr}; + } + return wrap(tensorDescType.getElementType()); +} + +MlirAttribute mlirTritonTtTensorDescTypeGetSharedLayout(MlirType type) { + if (type.ptr == nullptr) { + return {nullptr}; + } + auto tensorDescType = llvm::dyn_cast(unwrap(type)); + if (!tensorDescType) { return {nullptr}; } - return wrap(mlir::triton::TensorDescType::get(rankedTensorType.getContext(), rankedTensorType)); + return wrap(tensorDescType.getSharedLayout()); } MlirType mlirTritonTtTensorDescTypeGetBlockType(MlirType type) { diff --git a/crates/ryft-xla-sys/src/c++/mlir/dialects/triton.h b/crates/ryft-xla-sys/src/c++/mlir/dialects/triton.h index e261cad0..443d55dd 100644 --- a/crates/ryft-xla-sys/src/c++/mlir/dialects/triton.h +++ b/crates/ryft-xla-sys/src/c++/mlir/dialects/triton.h @@ -33,7 +33,15 @@ MlirType mlirTritonTtPointerTypeGetPointeeType(MlirType type); int32_t mlirTritonTtPointerTypeGetAddressSpace(MlirType type); bool mlirTypeIsATritonTtTensorDescType(MlirType type); -MlirType mlirTritonTtTensorDescTypeGet(MlirType blockType); +MlirType mlirTritonTtTensorDescTypeGet( + const int64_t *shape, + intptr_t shapeSize, + MlirType elementType, + MlirAttribute sharedLayout); +intptr_t mlirTritonTtTensorDescTypeGetNumDims(MlirType type); +int64_t mlirTritonTtTensorDescTypeGetDimSize(MlirType type, intptr_t dimension); +MlirType mlirTritonTtTensorDescTypeGetElementType(MlirType type); +MlirAttribute mlirTritonTtTensorDescTypeGetSharedLayout(MlirType type); MlirType mlirTritonTtTensorDescTypeGetBlockType(MlirType type); bool mlirAttributeIsATritonTtEnumAttr(MlirAttribute attribute, enum MlirTritonTtEnumAttribute kind); diff --git a/crates/ryft-xla-sys/src/mlir/dialects/triton/tt.rs b/crates/ryft-xla-sys/src/mlir/dialects/triton/tt.rs index 3fe935ff..1e318d80 100644 --- a/crates/ryft-xla-sys/src/mlir/dialects/triton/tt.rs +++ b/crates/ryft-xla-sys/src/mlir/dialects/triton/tt.rs @@ -28,7 +28,16 @@ unsafe extern "C" { pub fn mlirTritonTtPointerTypeGetAddressSpace(r#type: MlirType) -> i32; pub fn mlirTypeIsATritonTtTensorDescType(r#type: MlirType) -> bool; - pub fn mlirTritonTtTensorDescTypeGet(block_type: MlirType) -> MlirType; + pub fn mlirTritonTtTensorDescTypeGet( + shape: *const i64, + shape_size: isize, + element_type: MlirType, + shared_layout: MlirAttribute, + ) -> MlirType; + pub fn mlirTritonTtTensorDescTypeGetNumDims(r#type: MlirType) -> isize; + pub fn mlirTritonTtTensorDescTypeGetDimSize(r#type: MlirType, dimension: isize) -> i64; + pub fn mlirTritonTtTensorDescTypeGetElementType(r#type: MlirType) -> MlirType; + pub fn mlirTritonTtTensorDescTypeGetSharedLayout(r#type: MlirType) -> MlirAttribute; pub fn mlirTritonTtTensorDescTypeGetBlockType(r#type: MlirType) -> MlirType; pub fn mlirAttributeIsATritonTtEnumAttr(attribute: MlirAttribute, kind: MlirTritonTtEnumAttribute) -> bool; From 2bb44e021b13dea3f10ab95c5483eef4652fc525 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Fri, 1 May 2026 22:27:30 -0700 Subject: [PATCH 04/12] . --- .../src/dialects/stable_hlo/attributes.rs | 581 +++++++++++++++++- .../src/dialects/triton/tt/operations.rs | 2 +- .../ryft-mlir/src/dialects/triton/tt/types.rs | 114 ++-- 3 files changed, 647 insertions(+), 50 deletions(-) diff --git a/crates/ryft-mlir/src/dialects/stable_hlo/attributes.rs b/crates/ryft-mlir/src/dialects/stable_hlo/attributes.rs index b6e17b8d..f22e058a 100644 --- a/crates/ryft-mlir/src/dialects/stable_hlo/attributes.rs +++ b/crates/ryft-mlir/src/dialects/stable_hlo/attributes.rs @@ -1,9 +1,19 @@ use ryft_xla_sys::bindings::{ - MlirAttribute, mlirShapedTypeGetDynamicSize, stablehloAttributeIsTypeExtensions, stablehloTypeExtensionsGet, - stablehloTypeExtensionsGetBoundsElem, stablehloTypeExtensionsGetBoundsSize, + MlirAttribute, mlirShapedTypeGetDynamicSize, stablehloAttributeIsAMeshAttr, stablehloAttributeIsAMeshAxisAttr, + stablehloAttributeIsAReplicaGroupMeshAxesAttr, stablehloAttributeIsASubAxisInfoAttr, + stablehloAttributeIsAnAxisRefAttr, stablehloAttributeIsTypeExtensions, stablehloAxisRefAttrGet, + stablehloAxisRefAttrGetName, stablehloAxisRefAttrGetSubAxisInfo, stablehloMeshAttrGet, stablehloMeshAttrGetAxes, + stablehloMeshAttrGetDeviceIds, stablehloMeshAxisAttrGet, stablehloMeshAxisAttrGetName, + stablehloMeshAxisAttrGetSize, stablehloReplicaGroupMeshAxesAttrGet, stablehloReplicaGroupMeshAxesAttrGetAxes, + stablehloReplicaGroupMeshAxesAttrGetMesh, stablehloSubAxisInfoAttrGet, stablehloSubAxisInfoAttrGetPreSize, + stablehloSubAxisInfoAttrGetSize, stablehloTypeExtensionsGet, stablehloTypeExtensionsGetBoundsElem, + stablehloTypeExtensionsGetBoundsSize, }; -use crate::{Attribute, Context, DialectHandle, mlir_subtype_trait_impls}; +use crate::{ + ArrayAttributeRef, Attribute, AttributeRef, Context, DenseIntegerElementsAttributeRef, DialectHandle, StringRef, + mlir_subtype_trait_impls, +}; /// StableHLO [`Attribute`] that is used to extend the built-in MLIR [`TensorTypeRef`](crate::TensorTypeRef) with /// StableHLO tensor-specific properties. These properties are not modeled in the built-in MLIR type. This is included @@ -91,6 +101,309 @@ impl<'t> Context<'t> { } } +/// StableHLO [`Attribute`] that identifies a contiguous sub-axis derived from a full mesh axis. +#[derive(Copy, Clone)] +pub struct SubAxisInfoAttributeRef<'c, 't> { + /// Handle that represents this [`Attribute`] in the MLIR C API. + handle: MlirAttribute, + + /// [`Context`] that owns this [`Attribute`]. + context: &'c Context<'t>, +} + +impl<'c, 't> SubAxisInfoAttributeRef<'c, 't> { + /// Returns the product of the sizes of the sub-axes that appear before this sub-axis. + pub fn pre_size(&self) -> i64 { + unsafe { stablehloSubAxisInfoAttrGetPreSize(self.handle) } + } + + /// Returns the size of this sub-axis. + pub fn size(&self) -> i64 { + unsafe { stablehloSubAxisInfoAttrGetSize(self.handle) } + } +} + +impl<'c, 't> Attribute<'c, 't> for SubAxisInfoAttributeRef<'c, 't> { + unsafe fn from_c_api(handle: MlirAttribute, context: &'c Context<'t>) -> Option { + if !handle.ptr.is_null() && unsafe { stablehloAttributeIsASubAxisInfoAttr(handle) } { + Some(Self { handle, context }) + } else { + None + } + } + + unsafe fn to_c_api(&self) -> MlirAttribute { + self.handle + } + + fn context(&self) -> &'c Context<'t> { + self.context + } +} + +mlir_subtype_trait_impls!(SubAxisInfoAttributeRef<'c, 't> as Attribute, mlir_type = Attribute); + +/// StableHLO [`Attribute`] that references either a full mesh axis or a split sub-axis. +#[derive(Copy, Clone)] +pub struct AxisRefAttributeRef<'c, 't> { + /// Handle that represents this [`Attribute`] in the MLIR C API. + handle: MlirAttribute, + + /// [`Context`] that owns this [`Attribute`]. + context: &'c Context<'t>, +} + +impl<'c, 't> AxisRefAttributeRef<'c, 't> { + /// Returns the referenced axis name. + pub fn name(&self) -> StringRef<'c> { + unsafe { StringRef::from_c_api(stablehloAxisRefAttrGetName(self.handle)) } + } + + /// Returns split metadata when this references a sub-axis. + pub fn sub_axis_info(&self) -> Option> { + unsafe { SubAxisInfoAttributeRef::from_c_api(stablehloAxisRefAttrGetSubAxisInfo(self.handle), self.context) } + } +} + +impl<'c, 't> Attribute<'c, 't> for AxisRefAttributeRef<'c, 't> { + unsafe fn from_c_api(handle: MlirAttribute, context: &'c Context<'t>) -> Option { + if !handle.ptr.is_null() && unsafe { stablehloAttributeIsAnAxisRefAttr(handle) } { + Some(Self { handle, context }) + } else { + None + } + } + + unsafe fn to_c_api(&self) -> MlirAttribute { + self.handle + } + + fn context(&self) -> &'c Context<'t> { + self.context + } +} + +mlir_subtype_trait_impls!(AxisRefAttributeRef<'c, 't> as Attribute, mlir_type = Attribute); + +/// StableHLO [`Attribute`] that represents replica groups using a mesh and referenced mesh axes. +#[derive(Copy, Clone)] +pub struct ReplicaGroupMeshAxesAttributeRef<'c, 't> { + /// Handle that represents this [`Attribute`] in the MLIR C API. + handle: MlirAttribute, + + /// [`Context`] that owns this [`Attribute`]. + context: &'c Context<'t>, +} + +impl<'c, 't> ReplicaGroupMeshAxesAttributeRef<'c, 't> { + /// Returns the mesh attribute, which may be a symbol reference or an inline mesh. + pub fn mesh(&self) -> AttributeRef<'c, 't> { + unsafe { + AttributeRef::from_c_api(stablehloReplicaGroupMeshAxesAttrGetMesh(self.handle), self.context) + .expect("invalid StableHLO replica-group mesh") + } + } + + /// Returns the array of axes used to form replica groups. + pub fn axes(&self) -> ArrayAttributeRef<'c, 't> { + unsafe { + ArrayAttributeRef::from_c_api(stablehloReplicaGroupMeshAxesAttrGetAxes(self.handle), self.context) + .expect("invalid StableHLO replica-group axes") + } + } +} + +impl<'c, 't> Attribute<'c, 't> for ReplicaGroupMeshAxesAttributeRef<'c, 't> { + unsafe fn from_c_api(handle: MlirAttribute, context: &'c Context<'t>) -> Option { + if !handle.ptr.is_null() && unsafe { stablehloAttributeIsAReplicaGroupMeshAxesAttr(handle) } { + Some(Self { handle, context }) + } else { + None + } + } + + unsafe fn to_c_api(&self) -> MlirAttribute { + self.handle + } + + fn context(&self) -> &'c Context<'t> { + self.context + } +} + +mlir_subtype_trait_impls!(ReplicaGroupMeshAxesAttributeRef<'c, 't> as Attribute, mlir_type = Attribute); + +/// StableHLO [`Attribute`] that defines a single named mesh axis and its size. +#[derive(Copy, Clone)] +pub struct MeshAxisAttributeRef<'c, 't> { + /// Handle that represents this [`Attribute`] in the MLIR C API. + handle: MlirAttribute, + + /// [`Context`] that owns this [`Attribute`]. + context: &'c Context<'t>, +} + +impl<'c, 't> MeshAxisAttributeRef<'c, 't> { + /// Returns the mesh axis name. + pub fn name(&self) -> StringRef<'c> { + unsafe { StringRef::from_c_api(stablehloMeshAxisAttrGetName(self.handle)) } + } + + /// Returns the mesh axis size. + pub fn size(&self) -> i64 { + unsafe { stablehloMeshAxisAttrGetSize(self.handle) } + } +} + +impl<'c, 't> Attribute<'c, 't> for MeshAxisAttributeRef<'c, 't> { + unsafe fn from_c_api(handle: MlirAttribute, context: &'c Context<'t>) -> Option { + if !handle.ptr.is_null() && unsafe { stablehloAttributeIsAMeshAxisAttr(handle) } { + Some(Self { handle, context }) + } else { + None + } + } + + unsafe fn to_c_api(&self) -> MlirAttribute { + self.handle + } + + fn context(&self) -> &'c Context<'t> { + self.context + } +} + +mlir_subtype_trait_impls!(MeshAxisAttributeRef<'c, 't> as Attribute, mlir_type = Attribute); + +/// StableHLO [`Attribute`] that defines an inline device mesh. +#[derive(Copy, Clone)] +pub struct MeshAttributeRef<'c, 't> { + /// Handle that represents this [`Attribute`] in the MLIR C API. + handle: MlirAttribute, + + /// [`Context`] that owns this [`Attribute`]. + context: &'c Context<'t>, +} + +impl<'c, 't> MeshAttributeRef<'c, 't> { + /// Returns the array of mesh axis attributes. + pub fn axes(&self) -> ArrayAttributeRef<'c, 't> { + unsafe { + ArrayAttributeRef::from_c_api(stablehloMeshAttrGetAxes(self.handle), self.context) + .expect("invalid StableHLO mesh axes") + } + } + + /// Returns the optional dense device-id tensor for this mesh. + pub fn device_ids(&self) -> Option> { + unsafe { + DenseIntegerElementsAttributeRef::from_c_api(stablehloMeshAttrGetDeviceIds(self.handle), self.context) + } + } +} + +impl<'c, 't> Attribute<'c, 't> for MeshAttributeRef<'c, 't> { + unsafe fn from_c_api(handle: MlirAttribute, context: &'c Context<'t>) -> Option { + if !handle.ptr.is_null() && unsafe { stablehloAttributeIsAMeshAttr(handle) } { + Some(Self { handle, context }) + } else { + None + } + } + + unsafe fn to_c_api(&self) -> MlirAttribute { + self.handle + } + + fn context(&self) -> &'c Context<'t> { + self.context + } +} + +mlir_subtype_trait_impls!(MeshAttributeRef<'c, 't> as Attribute, mlir_type = Attribute); + +impl<'t> Context<'t> { + /// Creates a new StableHLO [`SubAxisInfoAttributeRef`] owned by this [`Context`]. + pub fn stable_hlo_sub_axis_info<'c>(&'c self, pre_size: i64, size: i64) -> SubAxisInfoAttributeRef<'c, 't> { + self.load_dialect(DialectHandle::stable_hlo()); + unsafe { + SubAxisInfoAttributeRef::from_c_api( + stablehloSubAxisInfoAttrGet(*self.handle.borrow(), pre_size, size), + self, + ) + .unwrap() + } + } + + /// Creates a new StableHLO [`AxisRefAttributeRef`] owned by this [`Context`]. + pub fn stable_hlo_axis_ref<'c, N: AsRef>( + &'c self, + name: N, + sub_axis_info: Option>, + ) -> AxisRefAttributeRef<'c, 't> { + self.load_dialect(DialectHandle::stable_hlo()); + unsafe { + AxisRefAttributeRef::from_c_api( + stablehloAxisRefAttrGet( + *self.handle.borrow(), + StringRef::from(name.as_ref()).to_c_api(), + sub_axis_info.map(|value| value.to_c_api()).unwrap_or(self.null_attribute().to_c_api()), + ), + self, + ) + .unwrap() + } + } + + /// Creates a new StableHLO [`ReplicaGroupMeshAxesAttributeRef`] owned by this [`Context`]. + pub fn stable_hlo_replica_group_mesh_axes<'c, M: Attribute<'c, 't>>( + &'c self, + mesh: M, + axes: ArrayAttributeRef<'c, 't>, + ) -> ReplicaGroupMeshAxesAttributeRef<'c, 't> { + self.load_dialect(DialectHandle::stable_hlo()); + unsafe { + ReplicaGroupMeshAxesAttributeRef::from_c_api( + stablehloReplicaGroupMeshAxesAttrGet(*self.handle.borrow(), mesh.to_c_api(), axes.to_c_api()), + self, + ) + .unwrap() + } + } + + /// Creates a new StableHLO [`MeshAxisAttributeRef`] owned by this [`Context`]. + pub fn stable_hlo_mesh_axis<'c, N: AsRef>(&'c self, name: N, size: i64) -> MeshAxisAttributeRef<'c, 't> { + self.load_dialect(DialectHandle::stable_hlo()); + unsafe { + MeshAxisAttributeRef::from_c_api( + stablehloMeshAxisAttrGet(*self.handle.borrow(), StringRef::from(name.as_ref()).to_c_api(), size), + self, + ) + .unwrap() + } + } + + /// Creates a new StableHLO [`MeshAttributeRef`] owned by this [`Context`]. + pub fn stable_hlo_mesh<'c>( + &'c self, + axes: ArrayAttributeRef<'c, 't>, + device_ids: Option>, + ) -> MeshAttributeRef<'c, 't> { + self.load_dialect(DialectHandle::stable_hlo()); + unsafe { + MeshAttributeRef::from_c_api( + stablehloMeshAttrGet( + *self.handle.borrow(), + axes.to_c_api(), + device_ids.map(|value| value.to_c_api()).unwrap_or(self.null_attribute().to_c_api()), + ), + self, + ) + .unwrap() + } + } +} + #[cfg(test)] mod tests { use crate::attributes::tests::{test_attribute_casting, test_attribute_display_and_debug}; @@ -138,4 +451,266 @@ mod tests { let attribute = context.stable_hlo_tensor_type_extensions(&[Some(10), None, Some(20), None]); test_attribute_casting(attribute); } + + #[test] + fn test_sub_axis_info_attribute() { + let context = Context::new(); + let attribute = context.stable_hlo_sub_axis_info(2, 4); + assert_eq!(&context, attribute.context()); + assert_eq!(attribute.pre_size(), 2); + assert_eq!(attribute.size(), 4); + } + + #[test] + fn test_sub_axis_info_attribute_equality() { + let context = Context::new(); + + // Same attributes from the same context must be equal because they are "uniqued". + let attribute_1 = context.stable_hlo_sub_axis_info(2, 4); + let attribute_2 = context.stable_hlo_sub_axis_info(2, 4); + assert_eq!(attribute_1, attribute_2); + + // Different attributes from the same context must not be equal. + let attribute_2 = context.stable_hlo_sub_axis_info(1, 4); + assert_ne!(attribute_1, attribute_2); + + // Same attributes from different contexts must not be equal. + let context = Context::new(); + let attribute_2 = context.stable_hlo_sub_axis_info(2, 4); + assert_ne!(attribute_1, attribute_2); + } + + #[test] + fn test_sub_axis_info_attribute_display_and_debug() { + let context = Context::new(); + let attribute = context.stable_hlo_sub_axis_info(2, 4); + test_attribute_display_and_debug(attribute, "#stablehlo"); + } + + #[test] + fn test_sub_axis_info_attribute_casting() { + let context = Context::new(); + let attribute = context.stable_hlo_sub_axis_info(2, 4); + test_attribute_casting(attribute); + } + + #[test] + fn test_axis_ref_attribute() { + let context = Context::new(); + let sub_axis_info = context.stable_hlo_sub_axis_info(2, 4); + let attribute = context.stable_hlo_axis_ref("x", Some(sub_axis_info)); + assert_eq!(&context, attribute.context()); + assert_eq!(attribute.name().as_str().unwrap(), "x"); + assert_eq!(attribute.sub_axis_info(), Some(sub_axis_info)); + + let attribute = context.stable_hlo_axis_ref("y", None); + assert_eq!(attribute.name().as_str().unwrap(), "y"); + assert_eq!(attribute.sub_axis_info(), None); + } + + #[test] + fn test_axis_ref_attribute_equality() { + let context = Context::new(); + let sub_axis_info = context.stable_hlo_sub_axis_info(2, 4); + + // Same attributes from the same context must be equal because they are "uniqued". + let attribute_1 = context.stable_hlo_axis_ref("x", Some(sub_axis_info)); + let attribute_2 = context.stable_hlo_axis_ref("x", Some(sub_axis_info)); + assert_eq!(attribute_1, attribute_2); + + // Different attributes from the same context must not be equal. + let attribute_2 = context.stable_hlo_axis_ref("y", Some(sub_axis_info)); + assert_ne!(attribute_1, attribute_2); + + // Same attributes from different contexts must not be equal. + let context = Context::new(); + let sub_axis_info = context.stable_hlo_sub_axis_info(2, 4); + let attribute_2 = context.stable_hlo_axis_ref("x", Some(sub_axis_info)); + assert_ne!(attribute_1, attribute_2); + } + + #[test] + fn test_axis_ref_attribute_display_and_debug() { + let context = Context::new(); + let sub_axis_info = context.stable_hlo_sub_axis_info(2, 4); + let attribute = context.stable_hlo_axis_ref("x", Some(sub_axis_info)); + test_attribute_display_and_debug(attribute, "#stablehlo.axis_ref"); + } + + #[test] + fn test_axis_ref_attribute_casting() { + let context = Context::new(); + let sub_axis_info = context.stable_hlo_sub_axis_info(2, 4); + let attribute = context.stable_hlo_axis_ref("x", Some(sub_axis_info)); + test_attribute_casting(attribute); + } + + #[test] + fn test_replica_group_mesh_axes_attribute() { + let context = Context::new(); + let mesh = context.flat_symbol_ref_attribute("mesh"); + let axis_x = context.stable_hlo_axis_ref("x", Some(context.stable_hlo_sub_axis_info(2, 4))); + let axis_y = context.stable_hlo_axis_ref("y", None); + let axes = context.array_attribute(&[axis_x, axis_y]); + let attribute = context.stable_hlo_replica_group_mesh_axes(mesh, axes); + assert_eq!(&context, attribute.context()); + assert_eq!(attribute.mesh(), mesh.as_ref()); + assert_eq!(attribute.axes(), axes); + } + + #[test] + fn test_replica_group_mesh_axes_attribute_equality() { + let context = Context::new(); + let mesh = context.flat_symbol_ref_attribute("mesh"); + let axis_x = context.stable_hlo_axis_ref("x", None); + let axis_y = context.stable_hlo_axis_ref("y", None); + let axes = context.array_attribute(&[axis_x, axis_y]); + + // Same attributes from the same context must be equal because they are "uniqued". + let attribute_1 = context.stable_hlo_replica_group_mesh_axes(mesh, axes); + let attribute_2 = context.stable_hlo_replica_group_mesh_axes(mesh, axes); + assert_eq!(attribute_1, attribute_2); + + // Different attributes from the same context must not be equal. + let axes = context.array_attribute(&[axis_x]); + let attribute_2 = context.stable_hlo_replica_group_mesh_axes(mesh, axes); + assert_ne!(attribute_1, attribute_2); + + // Same attributes from different contexts must not be equal. + let context = Context::new(); + let mesh = context.flat_symbol_ref_attribute("mesh"); + let axis_x = context.stable_hlo_axis_ref("x", None); + let axis_y = context.stable_hlo_axis_ref("y", None); + let axes = context.array_attribute(&[axis_x, axis_y]); + let attribute_2 = context.stable_hlo_replica_group_mesh_axes(mesh, axes); + assert_ne!(attribute_1, attribute_2); + } + + #[test] + fn test_replica_group_mesh_axes_attribute_display_and_debug() { + let context = Context::new(); + let mesh = context.flat_symbol_ref_attribute("mesh"); + let axis_x = context.stable_hlo_axis_ref("x", Some(context.stable_hlo_sub_axis_info(2, 4))); + let axis_y = context.stable_hlo_axis_ref("y", None); + let axes = context.array_attribute(&[axis_x, axis_y]); + let attribute = context.stable_hlo_replica_group_mesh_axes(mesh, axes); + test_attribute_display_and_debug( + attribute, + "#stablehlo.replica_group_mesh_axes, #stablehlo.axis_ref]>", + ); + } + + #[test] + fn test_replica_group_mesh_axes_attribute_casting() { + let context = Context::new(); + let mesh = context.flat_symbol_ref_attribute("mesh"); + let axis = context.stable_hlo_axis_ref("x", None); + let axes = context.array_attribute(&[axis]); + let attribute = context.stable_hlo_replica_group_mesh_axes(mesh, axes); + test_attribute_casting(attribute); + } + + #[test] + fn test_mesh_axis_attribute() { + let context = Context::new(); + let attribute = context.stable_hlo_mesh_axis("x", 2); + assert_eq!(&context, attribute.context()); + assert_eq!(attribute.name().as_str().unwrap(), "x"); + assert_eq!(attribute.size(), 2); + } + + #[test] + fn test_mesh_axis_attribute_equality() { + let context = Context::new(); + + // Same attributes from the same context must be equal because they are "uniqued". + let attribute_1 = context.stable_hlo_mesh_axis("x", 2); + let attribute_2 = context.stable_hlo_mesh_axis("x", 2); + assert_eq!(attribute_1, attribute_2); + + // Different attributes from the same context must not be equal. + let attribute_2 = context.stable_hlo_mesh_axis("y", 2); + assert_ne!(attribute_1, attribute_2); + + // Same attributes from different contexts must not be equal. + let context = Context::new(); + let attribute_2 = context.stable_hlo_mesh_axis("x", 2); + assert_ne!(attribute_1, attribute_2); + } + + #[test] + fn test_mesh_axis_attribute_display_and_debug() { + let context = Context::new(); + let attribute = context.stable_hlo_mesh_axis("x", 2); + test_attribute_display_and_debug(attribute, "#stablehlo.mesh_axis"); + } + + #[test] + fn test_mesh_axis_attribute_casting() { + let context = Context::new(); + let attribute = context.stable_hlo_mesh_axis("x", 2); + test_attribute_casting(attribute); + } + + #[test] + fn test_mesh_attribute() { + let context = Context::new(); + let axis_x = context.stable_hlo_mesh_axis("x", 2); + let axis_y = context.stable_hlo_mesh_axis("y", 4); + let axes = context.array_attribute(&[axis_x, axis_y]); + let attribute = context.stable_hlo_mesh(axes, None); + assert_eq!(&context, attribute.context()); + assert_eq!(attribute.axes(), axes); + assert_eq!(attribute.device_ids(), None); + } + + #[test] + fn test_mesh_attribute_equality() { + let context = Context::new(); + let axis_x = context.stable_hlo_mesh_axis("x", 2); + let axis_y = context.stable_hlo_mesh_axis("y", 4); + let axes = context.array_attribute(&[axis_x, axis_y]); + + // Same attributes from the same context must be equal because they are "uniqued". + let attribute_1 = context.stable_hlo_mesh(axes, None); + let attribute_2 = context.stable_hlo_mesh(axes, None); + assert_eq!(attribute_1, attribute_2); + + // Different attributes from the same context must not be equal. + let axes = context.array_attribute(&[axis_x]); + let attribute_2 = context.stable_hlo_mesh(axes, None); + assert_ne!(attribute_1, attribute_2); + + // Same attributes from different contexts must not be equal. + let context = Context::new(); + let axis_x = context.stable_hlo_mesh_axis("x", 2); + let axis_y = context.stable_hlo_mesh_axis("y", 4); + let axes = context.array_attribute(&[axis_x, axis_y]); + let attribute_2 = context.stable_hlo_mesh(axes, None); + assert_ne!(attribute_1, attribute_2); + } + + #[test] + fn test_mesh_attribute_display_and_debug() { + let context = Context::new(); + let axis_x = context.stable_hlo_mesh_axis("x", 2); + let axis_y = context.stable_hlo_mesh_axis("y", 4); + let axes = context.array_attribute(&[axis_x, axis_y]); + let attribute = context.stable_hlo_mesh(axes, None); + test_attribute_display_and_debug( + attribute, + "#stablehlo.mesh, #stablehlo.mesh_axis]>", + ); + } + + #[test] + fn test_mesh_attribute_casting() { + let context = Context::new(); + let axis = context.stable_hlo_mesh_axis("x", 2); + let axes = context.array_attribute(&[axis]); + let attribute = context.stable_hlo_mesh(axes, None); + test_attribute_casting(attribute); + } } diff --git a/crates/ryft-mlir/src/dialects/triton/tt/operations.rs b/crates/ryft-mlir/src/dialects/triton/tt/operations.rs index b437def7..e3dd647e 100644 --- a/crates/ryft-mlir/src/dialects/triton/tt/operations.rs +++ b/crates/ryft-mlir/src/dialects/triton/tt/operations.rs @@ -2884,7 +2884,7 @@ mod tests { let tensor_i32_type = context.tensor_type(i32_type, &[Size::Static(4)], None, location).unwrap(); let tensor_f32_type = context.tensor_type(f32_type, &[Size::Static(4)], None, location).unwrap(); let pointer_type = context.triton_tt_pointer_type(f32_type, 1); - let tensor_desc_type = context.triton_tt_tensor_desc_type(tensor_f32_type); + let tensor_desc_type = context.triton_tt_tensor_desc_type(&[Size::Static(4)], f32_type, None); let function_type = context.function_type(&[i32_type], &[i32_type]); Self { diff --git a/crates/ryft-mlir/src/dialects/triton/tt/types.rs b/crates/ryft-mlir/src/dialects/triton/tt/types.rs index 7e198aa9..38c05fb8 100644 --- a/crates/ryft-mlir/src/dialects/triton/tt/types.rs +++ b/crates/ryft-mlir/src/dialects/triton/tt/types.rs @@ -1,11 +1,14 @@ use ryft_xla_sys::bindings::MlirType; use ryft_xla_sys::mlir::dialects::triton::tt::{ mlirTritonTtPointerTypeGet, mlirTritonTtPointerTypeGetAddressSpace, mlirTritonTtPointerTypeGetPointeeType, - mlirTritonTtTensorDescTypeGet, mlirTritonTtTensorDescTypeGetBlockType, mlirTypeIsATritonTtPointerType, - mlirTypeIsATritonTtTensorDescType, + mlirTritonTtTensorDescTypeGet, mlirTritonTtTensorDescTypeGetBlockType, mlirTritonTtTensorDescTypeGetDimSize, + mlirTritonTtTensorDescTypeGetElementType, mlirTritonTtTensorDescTypeGetNumDims, + mlirTritonTtTensorDescTypeGetSharedLayout, mlirTypeIsATritonTtPointerType, mlirTypeIsATritonTtTensorDescType, }; -use crate::{Context, DialectHandle, TensorTypeRef, Type, TypeRef, mlir_subtype_trait_impls}; +use crate::{ + Attribute, AttributeRef, Context, DialectHandle, Size, TensorTypeRef, Type, TypeRef, mlir_subtype_trait_impls, +}; /// Triton `tt` pointer [`Type`]. Pointer types represent addresses in a Triton address space and may point only to /// scalar element types. @@ -56,9 +59,8 @@ impl<'c, 't> Type<'c, 't> for PointerTypeRef<'c, 't> { mlir_subtype_trait_impls!(PointerTypeRef<'c, 't> as Type, mlir_type = Type); -/// Triton `tt` tensor descriptor [`Type`]. Tensor descriptors represent tiled tensor memory access metadata. -/// -/// The Triton version pinned by this repository models descriptors with a ranked tensor block type. +/// Triton `tt` tensor descriptor [`Type`]. Tensor descriptors represent tiled tensor memory access metadata and are +/// parameterized by a block shape, an element [`Type`], and an optional shared-memory layout attribute. /// /// Refer to the [official Triton dialect documentation](https://triton-lang.org/main/dialects/TritonDialect.html) /// for more information. @@ -72,7 +74,28 @@ pub struct TensorDescTypeRef<'c, 't> { } impl<'c, 't> TensorDescTypeRef<'c, 't> { - /// Returns the ranked tensor block [`Type`] described by this descriptor. + /// Returns the block shape described by this descriptor. + pub fn shape(&self) -> Vec { + let dimension_count = unsafe { mlirTritonTtTensorDescTypeGetNumDims(self.handle) }; + (0..dimension_count) + .map(|dimension| unsafe { Size::from_c_api(mlirTritonTtTensorDescTypeGetDimSize(self.handle, dimension)) }) + .collect() + } + + /// Returns the element [`Type`] described by this descriptor. + pub fn element_type(&self) -> TypeRef<'c, 't> { + unsafe { + TypeRef::from_c_api(mlirTritonTtTensorDescTypeGetElementType(self.handle), self.context) + .expect("invalid `!tt.tensordesc` element type") + } + } + + /// Returns the optional shared-memory layout attribute described by this descriptor. + pub fn shared_layout(&self) -> Option> { + unsafe { AttributeRef::from_c_api(mlirTritonTtTensorDescTypeGetSharedLayout(self.handle), self.context) } + } + + /// Returns the ranked tensor block [`Type`] derived from this descriptor's shape and element type. pub fn block_type(&self) -> TensorTypeRef<'c, 't> { unsafe { TensorTypeRef::from_c_api(mlirTritonTtTensorDescTypeGetBlockType(self.handle), self.context) @@ -116,11 +139,25 @@ impl<'t> Context<'t> { } /// Creates a new Triton `tt` [`TensorDescTypeRef`] owned by this [`Context`]. - pub fn triton_tt_tensor_desc_type<'c>(&'c self, block_type: TensorTypeRef<'c, 't>) -> TensorDescTypeRef<'c, 't> { + pub fn triton_tt_tensor_desc_type<'c, T: Type<'c, 't>>( + &'c self, + shape: &[Size], + element_type: T, + shared_layout: Option>, + ) -> TensorDescTypeRef<'c, 't> { self.load_dialect(DialectHandle::triton_tt()); + let dimensions = shape.iter().map(|dimension| unsafe { dimension.to_c_api() }).collect::>(); unsafe { - TensorDescTypeRef::from_c_api(mlirTritonTtTensorDescTypeGet(block_type.to_c_api()), self) - .expect("invalid arguments to `Context::triton_tt_tensor_desc_type`") + TensorDescTypeRef::from_c_api( + mlirTritonTtTensorDescTypeGet( + dimensions.as_ptr(), + dimensions.len().cast_signed(), + element_type.to_c_api(), + shared_layout.unwrap_or_else(|| self.null_attribute()).to_c_api(), + ), + self, + ) + .expect("invalid arguments to `Context::triton_tt_tensor_desc_type`") } } } @@ -192,76 +229,61 @@ mod tests { fn test_tensor_desc_type() { let context = Context::new(); let location = context.unknown_location(); - let block_type = context - .tensor_type(context.float32_type(), &[Size::Static(16), Size::Static(32)], None, location) - .unwrap(); - let tensor_desc_type = context.triton_tt_tensor_desc_type(block_type); + let shape = [Size::Static(16), Size::Static(32)]; + let block_type = context.tensor_type(context.float32_type(), &shape, None, location).unwrap(); + let tensor_desc_type = context.triton_tt_tensor_desc_type(&shape, context.float32_type(), None); assert_eq!(&context, tensor_desc_type.context()); assert_eq!(tensor_desc_type.dialect().namespace().unwrap(), "tt"); + assert_eq!(tensor_desc_type.shape(), shape.to_vec()); + assert_eq!(tensor_desc_type.element_type(), context.float32_type()); + assert_eq!(tensor_desc_type.shared_layout(), None); assert_eq!(tensor_desc_type.block_type(), block_type); } #[test] fn test_tensor_desc_type_equality() { let context = Context::new(); - let location = context.unknown_location(); - let block_type = context - .tensor_type(context.float32_type(), &[Size::Static(16), Size::Static(32)], None, location) - .unwrap(); + let shape = [Size::Static(16), Size::Static(32)]; // Same types from the same context must be equal because they are "uniqued". - let tensor_desc_type_1 = context.triton_tt_tensor_desc_type(block_type); - let tensor_desc_type_2 = context.triton_tt_tensor_desc_type(block_type); + let tensor_desc_type_1 = context.triton_tt_tensor_desc_type(&shape, context.float32_type(), None); + let tensor_desc_type_2 = context.triton_tt_tensor_desc_type(&shape, context.float32_type(), None); assert_eq!(tensor_desc_type_1, tensor_desc_type_2); // Different types from the same context must not be equal. - let block_type = context - .tensor_type(context.float32_type(), &[Size::Static(8), Size::Static(32)], None, location) - .unwrap(); - let tensor_desc_type_2 = context.triton_tt_tensor_desc_type(block_type); + let shape = [Size::Static(8), Size::Static(32)]; + let tensor_desc_type_2 = context.triton_tt_tensor_desc_type(&shape, context.float32_type(), None); assert_ne!(tensor_desc_type_1, tensor_desc_type_2); // Same types from different contexts must not be equal. let context = Context::new(); - let location = context.unknown_location(); - let block_type = context - .tensor_type(context.float32_type(), &[Size::Static(16), Size::Static(32)], None, location) - .unwrap(); - let tensor_desc_type_2 = context.triton_tt_tensor_desc_type(block_type); + let shape = [Size::Static(16), Size::Static(32)]; + let tensor_desc_type_2 = context.triton_tt_tensor_desc_type(&shape, context.float32_type(), None); assert_ne!(tensor_desc_type_1, tensor_desc_type_2); } #[test] fn test_tensor_desc_type_display_and_debug() { let context = Context::new(); - let location = context.unknown_location(); - let block_type = context - .tensor_type(context.float32_type(), &[Size::Static(16), Size::Static(32)], None, location) - .unwrap(); - let tensor_desc_type = context.triton_tt_tensor_desc_type(block_type); - test_type_display_and_debug(tensor_desc_type, "!tt.tensordesc>"); + let shape = [Size::Static(16), Size::Static(32)]; + let tensor_desc_type = context.triton_tt_tensor_desc_type(&shape, context.float32_type(), None); + test_type_display_and_debug(tensor_desc_type, "!tt.tensordesc<16x32xf32>"); } #[test] fn test_tensor_desc_type_parsing() { let context = Context::new(); context.load_dialect(DialectHandle::triton_tt()); - let location = context.unknown_location(); - let block_type = context - .tensor_type(context.float32_type(), &[Size::Static(16), Size::Static(32)], None, location) - .unwrap(); - let tensor_desc_type = context.triton_tt_tensor_desc_type(block_type); - assert_eq!(context.parse_type("!tt.tensordesc>").unwrap(), tensor_desc_type); + let shape = [Size::Static(16), Size::Static(32)]; + let tensor_desc_type = context.triton_tt_tensor_desc_type(&shape, context.float32_type(), None); + assert_eq!(context.parse_type("!tt.tensordesc<16x32xf32>").unwrap(), tensor_desc_type); } #[test] fn test_tensor_desc_type_casting() { let context = Context::new(); - let location = context.unknown_location(); - let block_type = context - .tensor_type(context.float32_type(), &[Size::Static(16), Size::Static(32)], None, location) - .unwrap(); - let tensor_desc_type = context.triton_tt_tensor_desc_type(block_type); + let shape = [Size::Static(16), Size::Static(32)]; + let tensor_desc_type = context.triton_tt_tensor_desc_type(&shape, context.float32_type(), None); test_type_casting(tensor_desc_type); } } From f61975a2ecee70eaf16a9d34c1aa4cdf70d3a635 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Fri, 1 May 2026 22:31:29 -0700 Subject: [PATCH 05/12] . --- crates/ryft-xla-sys/CHANGELOG.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/ryft-xla-sys/CHANGELOG.md b/crates/ryft-xla-sys/CHANGELOG.md index bb210244..92fe7f92 100644 --- a/crates/ryft-xla-sys/CHANGELOG.md +++ b/crates/ryft-xla-sys/CHANGELOG.md @@ -16,8 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/). ### Changed -- Upgraded the OpenXLA dependency pin to commit `7b1be14958aac5c83f1b9f7bcdfc51fdbd29acba`. -- Synchronized the mirrored protobuf definitions with the upstream PJRT and StreamExecutor schema changes. +- Upgraded the OpenXLA dependency pin to commit `20a3e2cdd937424f351533165b3ac8e0589e5957`. +- Synchronized the mirrored `DebugOptions` protobuf definitions with upstream `xla.proto` changes. +- Synchronized StableHLO C API bindings with upstream mesh and sub-axis attributes. - Pinned macOS Bazel artifacts to a macOS `11.0` deployment target so the published static library remains linkable from Rust consumers that target the workspace baseline. From 481f4f8eb0805df683085ae85bf3570eb33866d0 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Fri, 1 May 2026 22:31:51 -0700 Subject: [PATCH 06/12] . --- crates/ryft-mlir/CHANGELOG.md | 1 + crates/ryft-pjrt/CHANGELOG.md | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/crates/ryft-mlir/CHANGELOG.md b/crates/ryft-mlir/CHANGELOG.md index d6e723bd..aefe5ac8 100644 --- a/crates/ryft-mlir/CHANGELOG.md +++ b/crates/ryft-mlir/CHANGELOG.md @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/). - Added support for the Triton `tt` dialect. - Added support for the Mosaic GPU dialect. - Added support for the Mosaic TPU dialect. +- Added StableHLO mesh and sub-axis attribute wrappers. ### Changed diff --git a/crates/ryft-pjrt/CHANGELOG.md b/crates/ryft-pjrt/CHANGELOG.md index 88313ec2..24670b2a 100644 --- a/crates/ryft-pjrt/CHANGELOG.md +++ b/crates/ryft-pjrt/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/). ### Added - Added support for the new `PJRT_Buffer_Bitcast` C API function. +- Added support for the new `PJRT_Device_ClearMemoryStats` C API function. - Added support for the new `PJRT_Error_ForEachPayload` C API function and for providing payload-aware safe Rust wrappers for error buffers and execution poisoning. - Added support for querying executable parameter memory kinds and topology fingerprints. @@ -18,7 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/). ### Changed -- Updated our PJRT C API bindings for version `0.104`. +- Updated our PJRT C API bindings for version `0.107`. - Expanded executable compiled-memory statistics support to include total allocator bytes, indefinite allocations, and peak unpadded heap bytes. - Changed `TiledLayout::minor_to_major` to `Vec` from `Vec`. From 5cbbddee4f8c35dc5f4223aacc230d706d07297f Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Fri, 1 May 2026 23:14:12 -0700 Subject: [PATCH 07/12] Update PJRT error bridge for XLA 20a3e2 --- crates/ryft-xla-sys/BUILD.bazel | 1 + crates/ryft-xla-sys/src/c++/common.h | 37 ++++++++++------------ crates/ryft-xla-sys/src/c++/distributed.cc | 16 +++++----- 3 files changed, 26 insertions(+), 28 deletions(-) diff --git a/crates/ryft-xla-sys/BUILD.bazel b/crates/ryft-xla-sys/BUILD.bazel index b0aa7ad3..2a19ffde 100644 --- a/crates/ryft-xla-sys/BUILD.bazel +++ b/crates/ryft-xla-sys/BUILD.bazel @@ -95,6 +95,7 @@ XLA_DEPENDENCIES = [ "@stablehlo//:stablehlo_dialect_capi", "@stablehlo//:vhlo_capi", "@xla//xla/ffi/api:c_api", + "@xla//xla/pjrt/c:pjrt_c_api_status_utils", "@xla//xla/mlir_hlo:CAPI", "@xla//xla/pjrt/distributed", "@xla//xla/service/spmd/shardy/integrations/c:xla_sdy_capi", diff --git a/crates/ryft-xla-sys/src/c++/common.h b/crates/ryft-xla-sys/src/c++/common.h index 58359942..ccc3766f 100644 --- a/crates/ryft-xla-sys/src/c++/common.h +++ b/crates/ryft-xla-sys/src/c++/common.h @@ -9,6 +9,7 @@ #include "absl/status/status.h" #include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_status_utils.h" #include "xla/pjrt/distributed/distributed.h" extern "C" { @@ -20,12 +21,9 @@ typedef xla::DistributedRuntimeClient PJRT_Distributed_Runtime_Client; #else typedef struct _DistributedRuntimeService PJRT_Distributed_Runtime_Service; typedef struct _DistributedRuntimeClient PJRT_Distributed_Runtime_Client; +typedef struct PJRT_Error PJRT_Error; #endif -struct PJRT_Error { - absl::Status status; -}; - #ifdef __cplusplus } #endif @@ -35,27 +33,26 @@ struct PJRT_Error { #ifdef __cplusplus -#define PJRT_RETURN_IF_ERROR(expr) \ - do { \ - absl::Status _status = (expr); \ - if (!_status.ok()) { \ - PJRT_Error *_c_status = new PJRT_Error{std::move(_status)}; \ - return _c_status; \ - } \ +#define PJRT_RETURN_IF_ERROR(expr) \ + do { \ + absl::Status _status = (expr); \ + if (!_status.ok()) { \ + PJRT_Error *_c_status = pjrt::StatusToPjRtError(std::move(_status)); \ + return _c_status; \ + } \ } while (false) -#define PJRT_ASSIGN_OR_RETURN(lhs, rexpr) \ +#define PJRT_ASSIGN_OR_RETURN(lhs, rexpr) \ _PJRT_ASSIGN_OR_RETURN_IMPL(_PJRT_CONCAT(_status_or_value, __COUNTER__), lhs, \ - rexpr, \ + rexpr, \ _PJRT_CONCAT(_c_status, __COUNTER__)); -#define _PJRT_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr, c_status) \ - auto statusor = (rexpr); \ - if (!statusor.ok()) { \ - PJRT_Error *c_status = new PJRT_Error(); \ - c_status->status = statusor.status(); \ - return c_status; \ - } \ +#define _PJRT_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr, c_status) \ + auto statusor = (rexpr); \ + if (!statusor.ok()) { \ + PJRT_Error *c_status = pjrt::StatusToPjRtError(statusor.status()); \ + return c_status; \ + } \ lhs = std::move(*statusor) #define _PJRT_CONCAT(x, y) _PJRT_CONCAT_IMPL(x, y) diff --git a/crates/ryft-xla-sys/src/c++/distributed.cc b/crates/ryft-xla-sys/src/c++/distributed.cc index f0f3841d..a453290e 100644 --- a/crates/ryft-xla-sys/src/c++/distributed.cc +++ b/crates/ryft-xla-sys/src/c++/distributed.cc @@ -16,7 +16,7 @@ const PJRT_Error *PJRT_Distributed_Runtime_Service_New(PJRT_Distributed_Runtime_ std::unique_ptr service, GetDistributedRuntimeService(std::string(args->address), options)); args->service = service.release(); - return new PJRT_Error{absl::Status()}; + return nullptr; } void PJRT_Distributed_Runtime_Service_Shutdown(PJRT_Distributed_Runtime_Service_Shutdown_Args *args) { @@ -37,7 +37,7 @@ const PJRT_Error *PJRT_Distributed_Runtime_Client_New(PJRT_Distributed_Runtime_C options.missed_heartbeat_callback = [user_arg = args->missed_heartbeat_callback_user_arg, callback = args->missed_heartbeat_callback](absl::Status status) { - auto error = new PJRT_Error{status}; + PJRT_Error *error = pjrt::StatusToPjRtError(std::move(status)); callback(error, user_arg); }; options.shutdown_on_destruction = args->shutdown_on_destruction; @@ -45,11 +45,11 @@ const PJRT_Error *PJRT_Distributed_Runtime_Client_New(PJRT_Distributed_Runtime_C auto channel = xla::GetDistributedRuntimeClientChannel( std::string(args->address), tsl::GetClientCredentials(false), args->use_compression); args->client = GetDistributedRuntimeClient(channel, options).release(); - return new PJRT_Error{absl::Status()}; + return nullptr; } PJRT_Error *PJRT_Distributed_Runtime_Client_Connect(PJRT_Distributed_Runtime_Client_Connect_Args *args) { - return new PJRT_Error{args->client->Connect()}; + return pjrt::StatusToPjRtError(args->client->Connect()); } PJRT_Error *PJRT_Distributed_Runtime_Client_Blocking_Key_Value_Get( @@ -60,7 +60,7 @@ PJRT_Error *PJRT_Distributed_Runtime_Client_Blocking_Key_Value_Get( char *value = new char[_value.size() + 1]; std::strcpy(value, _value.c_str()); args->value = value; - return new PJRT_Error{absl::Status()}; + return nullptr; } PJRT_Error *PJRT_Distributed_Runtime_Client_Key_Value_Try_Get( @@ -69,16 +69,16 @@ PJRT_Error *PJRT_Distributed_Runtime_Client_Key_Value_Try_Get( char *value = new char[_value.size() + 1]; std::strcpy(value, _value.c_str()); args->value = value; - return new PJRT_Error{absl::Status()}; + return nullptr; } PJRT_Error *PJRT_Distributed_Runtime_Client_Key_Value_Set( PJRT_Distributed_Runtime_Client_Key_Value_Set_Args *args) { - return new PJRT_Error{args->client->KeyValueSet(std::string(args->key), std::string(args->value))}; + return pjrt::StatusToPjRtError(args->client->KeyValueSet(std::string(args->key), std::string(args->value))); } PJRT_Error *PJRT_Distributed_Runtime_Client_Shutdown(PJRT_Distributed_Runtime_Client_Shutdown_Args *args) { - return new PJRT_Error{args->client->Shutdown()}; + return pjrt::StatusToPjRtError(args->client->Shutdown()); } void PJRT_Distributed_Runtime_Client_Destroy(PJRT_Distributed_Runtime_Client_Destroy_Args *args) { From 2c42dc80b8cd38cad3d74dfea008f8338bef5c6b Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Sat, 2 May 2026 08:15:55 -0700 Subject: [PATCH 08/12] . --- .../src/dialects/stable_hlo/attributes.rs | 3 +-- crates/ryft-pjrt/CHANGELOG.md | 2 ++ crates/ryft-pjrt/src/clients.rs | 4 ++-- crates/ryft-pjrt/src/distributed.rs | 1 + crates/ryft-pjrt/src/lib.rs | 4 ++-- crates/ryft-xla-sys/build.rs | 18 +++++++++--------- 6 files changed, 17 insertions(+), 15 deletions(-) diff --git a/crates/ryft-mlir/src/dialects/stable_hlo/attributes.rs b/crates/ryft-mlir/src/dialects/stable_hlo/attributes.rs index f22e058a..3862c46a 100644 --- a/crates/ryft-mlir/src/dialects/stable_hlo/attributes.rs +++ b/crates/ryft-mlir/src/dialects/stable_hlo/attributes.rs @@ -700,8 +700,7 @@ mod tests { let attribute = context.stable_hlo_mesh(axes, None); test_attribute_display_and_debug( attribute, - "#stablehlo.mesh, #stablehlo.mesh_axis]>", + "#stablehlo.mesh, ]>", ); } diff --git a/crates/ryft-pjrt/CHANGELOG.md b/crates/ryft-pjrt/CHANGELOG.md index 24670b2a..2c5a8c2b 100644 --- a/crates/ryft-pjrt/CHANGELOG.md +++ b/crates/ryft-pjrt/CHANGELOG.md @@ -20,6 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/). ### Changed - Updated our PJRT C API bindings for version `0.107`. +- Changed `BufferSpecification` to carry a concrete `Layout`, materializing dense defaults during construction and + parsing before values cross layout-sensitive PJRT C API calls. - Expanded executable compiled-memory statistics support to include total allocator bytes, indefinite allocations, and peak unpadded heap bytes. - Changed `TiledLayout::minor_to_major` to `Vec` from `Vec`. diff --git a/crates/ryft-pjrt/src/clients.rs b/crates/ryft-pjrt/src/clients.rs index 342d9246..6527784b 100644 --- a/crates/ryft-pjrt/src/clients.rs +++ b/crates/ryft-pjrt/src/clients.rs @@ -1325,14 +1325,14 @@ mod tests { fn test_client() { let plugin = test_cpu_plugin(); let client = test_cpu_client(); - assert_eq!(client.attribute("stablehlo_current_version"), Ok(Value::i64_list([1, 16, 0]))); + assert_eq!(client.attribute("stablehlo_current_version"), Ok(Value::i64_list([1, 16, 2]))); assert_eq!(client.attribute("stablehlo_minimum_version"), Ok(Value::i64_list([0, 9, 0]))); assert_eq!(client.attribute("xla_version"), Ok(Value::i64(2))); assert!(matches!( client.attribute("__missing__"), Err(Error::NotFound { message, .. }) if message.contains("__missing__"))); let attributes = client.attributes().unwrap(); - assert_eq!(attributes.get("stablehlo_current_version"), Some(&Value::i64_list([1, 16, 0]))); + assert_eq!(attributes.get("stablehlo_current_version"), Some(&Value::i64_list([1, 16, 2]))); assert_eq!(attributes.get("stablehlo_minimum_version"), Some(&Value::i64_list([0, 9, 0]))); assert_eq!(attributes.get("xla_version"), Some(&Value::i64(2))); assert_eq!(attributes.get("__missing__"), None); diff --git a/crates/ryft-pjrt/src/distributed.rs b/crates/ryft-pjrt/src/distributed.rs index 065e2d92..7ed835ad 100644 --- a/crates/ryft-pjrt/src/distributed.rs +++ b/crates/ryft-pjrt/src/distributed.rs @@ -514,6 +514,7 @@ mod tests { let _service = plugin.distributed_runtime_service(&address, service_options).unwrap(); let client_options = DistributedRuntimeClientOptions::default(); let client = plugin.distributed_runtime_client(&address, client_options).unwrap(); + client.connect().unwrap(); let store = DistributedKeyValueStore::new(client); // Test using valid keys and values. diff --git a/crates/ryft-pjrt/src/lib.rs b/crates/ryft-pjrt/src/lib.rs index c91dd8da..3d5bac57 100644 --- a/crates/ryft-pjrt/src/lib.rs +++ b/crates/ryft-pjrt/src/lib.rs @@ -573,7 +573,7 @@ mod tests { let plugin = test_cpu_plugin(); let api = plugin.api(); - assert_eq!(plugin.attribute("stablehlo_current_version"), Ok(Value::i64_list([1, 16, 0]))); + assert_eq!(plugin.attribute("stablehlo_current_version"), Ok(Value::i64_list([1, 16, 2]))); assert_eq!(plugin.attribute("stablehlo_minimum_version"), Ok(Value::i64_list([0, 9, 0]))); assert_eq!(plugin.attribute("xla_version"), Ok(Value::i64(2))); assert_eq!(plugin.attribute("xla_version"), api.attribute("xla_version")); @@ -581,7 +581,7 @@ mod tests { plugin.attribute("__missing__"), Err(Error::NotFound { message, .. }) if message.contains("__missing__"))); let attributes = plugin.attributes().unwrap(); - assert_eq!(attributes.get("stablehlo_current_version"), Some(&Value::i64_list([1, 16, 0]))); + assert_eq!(attributes.get("stablehlo_current_version"), Some(&Value::i64_list([1, 16, 2]))); assert_eq!(attributes.get("stablehlo_minimum_version"), Some(&Value::i64_list([0, 9, 0]))); assert_eq!(attributes.get("xla_version"), Some(&Value::i64(2))); assert_eq!(attributes.get("__missing__"), None); diff --git a/crates/ryft-xla-sys/build.rs b/crates/ryft-xla-sys/build.rs index 8bb1cdbf..d5bdd13d 100644 --- a/crates/ryft-xla-sys/build.rs +++ b/crates/ryft-xla-sys/build.rs @@ -865,31 +865,31 @@ impl BuildConfiguration { fn precompiled_artifact_checksum(&self, artifact: Artifact) -> Option<&'static str> { match (artifact, self.operating_system, self.architecture, self.device) { (Artifact::RyftXlaSys, OperatingSystem::Linux, Architecture::X86_64, Device::Cpu) => { - Some("6367f647aa629b14e8814093a688203bd6a2d578b488dab21956b6dd6b778f31") + Some("d0c13bf1114d92519198d3de7ac14ae6f86c71c54c6ccdddbc8e3535c0677293") } (Artifact::RyftXlaSys, OperatingSystem::Linux, Architecture::AArch64, Device::Cpu) => { - Some("35271b062a4e6c2c026aaaa3b92be99229ed7828465733083f5a70eaa4a56172") + Some("f54058f11853f898fc18f0facb8f43df20d550ddf94f84cebb76903f770acfc3") } (Artifact::RyftXlaSys, OperatingSystem::MacOS, Architecture::AArch64, Device::Cpu) => { - Some("314fee192c368cbb1b1e24a3a09ded72f604997f94c78b5790f5ae087dc1aa75") + Some("6b22c98945c4278a3128e6aa216d6965b46952592a3fe3c8531206e4f04123d9") } (Artifact::RyftXlaSys, OperatingSystem::Windows, Architecture::X86_64, Device::Cpu) => { - Some("3347a196c29a847c54cdb792a27046aaff7437e16a9c06c8f670e2a07250e8e1") + Some("4db15b1a360a31538e10fe07e3fa6b1fbd01ef7d6d7da811779ba5465cd8cb08") } (Artifact::PjrtPlugin, OperatingSystem::Linux, Architecture::X86_64, Device::Cuda12) => { - Some("378379f03bb9b99b0c7a647f6210ddd2d7dc7f082b2b988eb46f9030c7fa6238") + Some("711f94caf89964d6af8530aa32992f2a175d169f4a38e4ccc5b1cb0d490da704") } (Artifact::PjrtPlugin, OperatingSystem::Linux, Architecture::AArch64, Device::Cuda12) => { - Some("54f4fd763256b7b4f0b557d0822017820fd8f32d97d62d35460383a0a47bdd9e") + Some("2d991f4beaa376736d52fc0a83abd715955b771e98133763b79d00f71734752a") } (Artifact::PjrtPlugin, OperatingSystem::Linux, Architecture::X86_64, Device::Cuda13) => { - Some("3ca63efb91ed03c30e60b41a4a041ca1d4655f45957019757010cca805520587") + Some("5d221a0f0f896461414bfc707fd05a003a95686064632748f10d0756080fb3b6") } (Artifact::PjrtPlugin, OperatingSystem::Linux, Architecture::AArch64, Device::Cuda13) => { - Some("debed7053c456bbd76d9ae49c653ec46530a9fffa05be9bb9ca83c075ca2dc47") + Some("316ed01073b1d933016bbd41af7b8ea2fdcaa869e6dad3866e21db2903f69ffd") } (Artifact::PjrtPlugin, OperatingSystem::Linux, Architecture::X86_64, Device::Rocm7) => { - Some("02ac1b5fae22a901c1a6db7588919df206598eebb748c6a7f4c46f6142a6e456") + Some("cbee961f6e0321c47579f18297febf94cab9011d0db5493a3c30317ca783fc56") } (Artifact::PjrtPlugin, OperatingSystem::Linux, Architecture::X86_64, Device::Tpu) => { Some("5e600d7797ac801d0c903f52ae46c03538bb77817a48579aa581faa8d2a8a734") From 3cad6fa70281f7ad39df0b983914d15de6f3d9a1 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Sat, 2 May 2026 08:38:37 -0700 Subject: [PATCH 09/12] . --- crates/ryft-pjrt/src/buffers.rs | 114 ++++++++++++++++++++---------- crates/ryft-pjrt/src/transfers.rs | 15 ++-- 2 files changed, 83 insertions(+), 46 deletions(-) diff --git a/crates/ryft-pjrt/src/buffers.rs b/crates/ryft-pjrt/src/buffers.rs index d7374650..301a530d 100644 --- a/crates/ryft-pjrt/src/buffers.rs +++ b/crates/ryft-pjrt/src/buffers.rs @@ -716,6 +716,13 @@ impl Layout { } } + /// Creates the default dense [`Layout::Tiled`] for buffers with `rank` dimensions. The resulting [`Layout`] has no + /// tiles and orders dimensions from the most minor physical dimension to the most major physical dimension, which + /// corresponds to dense major-to-minor logical storage. + pub fn dense_major_to_minor(rank: usize) -> Self { + Layout::Tiled(TiledLayout::new((0..rank as u64).rev().collect(), Vec::new())) + } + /// Parses a rendered [`Layout`] (e.g., an XLA layout string) into a [`Layout`]. #[allow(clippy::should_implement_trait)] pub fn from_str>(value: S) -> Result { @@ -1198,7 +1205,7 @@ impl Buffer<'_> { element_type: self.element_type()?, dimensions: self.dimensions()?, #[allow(deprecated)] - layout: Some(self.layout()?), + layout: self.layout()?, }) } @@ -1214,7 +1221,7 @@ impl Buffer<'_> { /// based on the provided [`BufferSpecification`]). pub fn bitcast>(&self, specification: BufferSpecification) -> Result { use ffi::PJRT_Buffer_Bitcast_Args; - let layout = specification.layout.map(|layout| unsafe { layout.to_c_api() }); + let mut layout = unsafe { specification.layout.to_c_api() }; invoke_pjrt_api_error_fn!( self.api(), PJRT_Buffer_Bitcast, @@ -1223,7 +1230,7 @@ impl Buffer<'_> { element_type = specification.element_type.to_c_api(), dims = specification.dimensions.as_ref().as_ptr() as *const i64, num_dims = specification.dimensions.as_ref().len(), - device_layout = layout.map(|layout| &layout as *const _ as *mut _).unwrap_or(std::ptr::null_mut()), + device_layout = &mut layout as *mut _, }, { out_buffer }, ) @@ -1596,9 +1603,8 @@ pub struct BufferSpecification> { /// Dimensions (i.e., shape) of the buffer. pub dimensions: D, - /// Optional memory [`Layout`] of the buffer. If [`None`], then it is assumed to be a dense layout - /// with dimensions in major-to-minor order. - pub layout: Option, + /// Memory [`Layout`] of the buffer. + pub layout: Layout, } impl BufferSpecification> { @@ -1646,7 +1652,8 @@ impl BufferSpecification> { }; let layout = value[(closing_bracket_index + 1)..].trim(); - let layout = if layout.is_empty() { None } else { Some(Layout::from_str(layout)?) }; + let layout = + if layout.is_empty() { Layout::dense_major_to_minor(dimensions.len()) } else { Layout::from_str(layout)? }; Ok(BufferSpecification { element_type, dimensions, layout }) } @@ -1676,12 +1683,27 @@ impl BufferSpecification> { } }) .collect::, _>>()?, - layout: shape.layout.map(|layout| Layout::from_proto(*layout)).transpose()?, + layout: shape + .layout + .map(|layout| Layout::from_proto(*layout)) + .transpose()? + .unwrap_or_else(|| Layout::dense_major_to_minor(shape.dimensions.len())), }) } } impl> BufferSpecification { + /// Creates a new [`BufferSpecification`] using the default dense major-to-minor [`Layout::Tiled`]. + pub fn new(element_type: BufferType, dimensions: D) -> Self { + let layout = Layout::dense_major_to_minor(dimensions.as_ref().len()); + Self { element_type, dimensions, layout } + } + + /// Creates a new [`BufferSpecification`] using the provided [`Layout`]. + pub fn with_layout(element_type: BufferType, dimensions: D, layout: Layout) -> Self { + Self { element_type, dimensions, layout } + } + /// Returns the [`Shape`](crate::protos::Shape) Protobuf that corresponds to this [`BufferSpecification`]. pub fn proto(&self) -> Result { Ok(crate::protos::Shape { @@ -1689,7 +1711,7 @@ impl> BufferSpecification { dimensions: self.dimensions.as_ref().iter().map(|dimension| *dimension as i64).collect::>(), is_dynamic_dimension: vec![false; self.dimensions.as_ref().len()], tuple_shapes: Vec::new(), - layout: self.layout.as_ref().map(Layout::proto).transpose()?.map(Box::new), + layout: Some(Box::new(self.layout.proto()?)), }) } } @@ -1704,9 +1726,7 @@ impl> Display for BufferSpecification { dimensions.try_for_each(|dimension| write!(formatter, ",{dimension}"))?; } write!(formatter, "]")?; - if let Some(layout) = &self.layout { - write!(formatter, "{layout}")?; - } + write!(formatter, "{}", self.layout)?; Ok(()) } } @@ -2021,7 +2041,7 @@ impl<'c> DmaMappedBuffer<'c> { specification.dimensions, None, memory, - specification.layout, + Some(specification.layout), ) .map(|buffer| unsafe { std::mem::transmute::<_, Buffer<'c>>(buffer) }) } @@ -2233,7 +2253,7 @@ impl<'s> Client<'s> { memory: M, ) -> Result, Error> { use ffi::PJRT_Client_CreateUninitializedBuffer_Args; - let layout = specification.layout.map(|layout| unsafe { layout.to_c_api() }); + let mut layout = unsafe { specification.layout.to_c_api() }; invoke_pjrt_api_error_fn!( self.api(), PJRT_Client_CreateUninitializedBuffer, @@ -2242,7 +2262,7 @@ impl<'s> Client<'s> { shape_dims = specification.dimensions.as_ref().as_ptr() as *const i64, shape_num_dims = specification.dimensions.as_ref().len(), shape_element_type = specification.element_type.to_c_api(), - shape_layout = layout.map(|layout| &layout as *const _ as *mut _).unwrap_or(std::ptr::null_mut()), + shape_layout = &mut layout as *mut _, device = std::ptr::null_mut(), memory = memory.default_memory().to_c_api(), }, @@ -2291,7 +2311,7 @@ impl<'s> Client<'s> { stream: Option<*mut std::ffi::c_void>, ) -> Result, Error> { use ffi::PJRT_Client_CreateViewOfDeviceBuffer_Args; - let layout = specification.layout.map(|layout| unsafe { layout.to_c_api() }); + let mut layout = unsafe { specification.layout.to_c_api() }; extern "C" fn callback(_ptr: *mut std::ffi::c_void, arg: *mut std::ffi::c_void) { unsafe { Box::from_raw(arg as *mut F)() }; @@ -2306,7 +2326,7 @@ impl<'s> Client<'s> { dims = specification.dimensions.as_ref().as_ptr() as *const i64, num_dims = specification.dimensions.as_ref().len(), element_type = specification.element_type.to_c_api(), - layout = layout.map(|layout| &layout as *const _ as *mut _).unwrap_or(std::ptr::null_mut()), + layout = &mut layout as *mut _, device = std::ptr::null_mut(), memory = memory.default_memory().to_c_api(), stream = stream.unwrap_or(std::ptr::null_mut()) as isize, @@ -2361,7 +2381,7 @@ impl<'s> Client<'s> { V: AsRef, { use ffi::PJRT_Client_CreateErrorBuffer_Args; - let layout = specification.layout.map(|layout| unsafe { layout.to_c_api() }); + let mut layout = unsafe { specification.layout.to_c_api() }; let error_message = error.message(); let payload = payload .into_iter() @@ -2381,7 +2401,7 @@ impl<'s> Client<'s> { shape_dims = specification.dimensions.as_ref().as_ptr() as *const i64, shape_num_dims = specification.dimensions.as_ref().len(), shape_element_type = specification.element_type.to_c_api(), - shape_layout = layout.map(|layout| &layout as *const _ as *mut _).unwrap_or(std::ptr::null_mut()), + shape_layout = &mut layout as *mut _, memory = memory.default_memory().to_c_api(), payload = payload, payload_size = payload_size, @@ -2434,7 +2454,7 @@ impl<'s> Client<'s> { memory: M, ) -> Result<(Buffer<'_>, AliasBufferFulfillmentToken), Error> { use ffi::PJRT_Client_CreateAliasBuffer_Args; - let layout = specification.layout.map(|layout| unsafe { layout.to_c_api() }); + let mut layout = unsafe { specification.layout.to_c_api() }; invoke_pjrt_api_error_fn!( self.api(), PJRT_Client_CreateAliasBuffer, @@ -2444,7 +2464,7 @@ impl<'s> Client<'s> { shape_dims = specification.dimensions.as_ref().as_ptr() as *const i64, shape_num_dims = specification.dimensions.as_ref().len(), shape_element_type = specification.element_type.to_c_api(), - shape_layout = layout.map(|layout| &layout as *const _ as *mut _).unwrap_or(std::ptr::null_mut()), + shape_layout = &mut layout as *mut _, }, { alias_buffer, fulfill_alias_buffer_cb }, ) @@ -3723,6 +3743,15 @@ mod tests { assert_eq!(format!("{layout}"), "{1,0:T(4,*)}"); assert_eq!(format!("{layout:?}"), "Layout[{1,0:T(4,*)}]"); + // Test creating and round-tripping a dense major-to-minor [`Layout`]. + let layout = Layout::dense_major_to_minor(3); + assert_eq!(layout, Layout::Tiled(TiledLayout::new(vec![2, 1, 0], Vec::new()))); + assert_eq!(unsafe { Layout::from_c_api(&layout.to_c_api() as *const _) }, Ok(layout.clone())); + assert_eq!(Layout::from_str(layout.clone().to_string()), Ok(layout.clone())); + assert_eq!(Layout::from_proto(layout.clone().proto().unwrap()), Ok(layout.clone())); + assert_eq!(format!("{layout}"), "{2,1,0}"); + assert_eq!(format!("{layout:?}"), "Layout[{2,1,0}]"); + // Test round-tripping a [`StridedLayout`] through the C API. let layout = Layout::Strided(StridedLayout::new(vec![16, 4])); assert_eq!(unsafe { Layout::from_c_api(&layout.to_c_api() as *const _) }, Ok(layout.clone())); @@ -3813,12 +3842,12 @@ mod tests { Ok(BufferSpecification { element_type: BufferType::U8, dimensions: [4u64], - layout: Some(Layout::Tiled(TiledLayout { + layout: Layout::Tiled(TiledLayout { minor_to_major, tile_dimensions, tile_dimension_sizes, tile_count: 0, - })), + }), }) if minor_to_major == &[0] && tile_dimensions.is_empty() && tile_dimension_sizes.is_empty(), )); assert_eq!(buffer.on_device_size_in_bytes(), Ok(4)); @@ -3863,9 +3892,7 @@ mod tests { let client = test_cpu_client(); let device = client.addressable_devices().unwrap()[0].clone(); let buffer = client.buffer(&[1u8, 2u8, 3u8, 4u8], BufferType::U8, [4u64], None, device, None).unwrap(); - let buffer = buffer - .bitcast(BufferSpecification { element_type: BufferType::U32, dimensions: [1u64], layout: None }) - .unwrap(); + let buffer = buffer.bitcast(BufferSpecification::new(BufferType::U32, [1u64])).unwrap(); assert_eq!(buffer.element_type(), Ok(BufferType::U32)); assert_eq!(buffer.dimensions(), Ok([1u64].as_slice())); } @@ -4027,23 +4054,38 @@ mod tests { let specification = BufferSpecification { element_type: BufferType::F32, dimensions: vec![2, 3], - layout: Some(Layout::Tiled(TiledLayout::new( + layout: Layout::Tiled(TiledLayout::new( vec![1, 0], vec![ Tile { dimensions: vec![TileDimension::sized(4), TileDimension::sized(2)] }, Tile { dimensions: vec![TileDimension::combined()] }, ], - ))), + )), }; assert_eq!(BufferSpecification::from_str(specification.to_string()), Ok(specification.clone())); assert_eq!(BufferSpecification::from_proto(specification.proto().unwrap()), Ok(specification.clone())); assert_eq!(format!("{specification}"), "f32[2,3]{1,0:T(4,2)(*)}"); assert_eq!(format!("{specification:?}"), "BufferSpecification[f32[2,3]{1,0:T(4,2)(*)}]"); + let specification = BufferSpecification::new(BufferType::F32, vec![2, 3]); + assert_eq!(BufferSpecification::from_str("f32[2,3]"), Ok(specification.clone())); + assert_eq!( + BufferSpecification::from_proto(crate::protos::Shape { + element_type: BufferType::F32.proto() as i32, + dimensions: vec![2, 3], + is_dynamic_dimension: vec![false, false], + tuple_shapes: Vec::new(), + layout: None, + }), + Ok(specification.clone()), + ); + assert_eq!(format!("{specification}"), "f32[2,3]{1,0}"); + assert_eq!(format!("{specification:?}"), "BufferSpecification[f32[2,3]{1,0}]"); + let specification = BufferSpecification { element_type: BufferType::F32, dimensions: vec![2, 3], - layout: Some(Layout::Strided(StridedLayout::new(vec![12, 4]))), + layout: Layout::Strided(StridedLayout::new(vec![12, 4])), }; assert_eq!(BufferSpecification::from_str(specification.to_string()), Ok(specification.clone())); assert!(matches!( @@ -4140,11 +4182,7 @@ mod tests { assert!(!dma_mapped_buffer.is_empty()); let buffer = unsafe { dma_mapped_buffer.into_buffer( - BufferSpecification { - element_type: BufferType::U8, - dimensions: [data.len() as u64], - layout: None, - }, + BufferSpecification::new(BufferType::U8, [data.len() as u64]), device.default_memory().unwrap(), ) } @@ -4201,7 +4239,7 @@ mod tests { fn test_client_uninitialized_buffer() { let client = test_cpu_client(); let device = client.addressable_devices().unwrap()[0].clone(); - let specification = BufferSpecification { element_type: BufferType::U8, dimensions: [4u64], layout: None }; + let specification = BufferSpecification::new(BufferType::U8, [4u64]); let buffer = client.uninitialized_buffer(specification, device.clone()).unwrap(); assert_eq!(buffer.element_type(), Ok(BufferType::U8)); assert_eq!(buffer.dimensions(), Ok([4u64].as_slice())); @@ -4217,7 +4255,7 @@ mod tests { let specification = BufferSpecification { element_type: buffer.element_type().unwrap(), dimensions: buffer.dimensions().unwrap().to_vec(), - layout: None, + layout: Layout::dense_major_to_minor(buffer.rank().unwrap()), }; let borrowed_buffer = unsafe { client @@ -4235,7 +4273,7 @@ mod tests { let client = test_cpu_client(); let device = client.addressable_devices().unwrap()[0].clone(); let error = Error::aborted("test error"); - let specification = BufferSpecification { element_type: BufferType::U8, dimensions: [4u64], layout: None }; + let specification = BufferSpecification::new(BufferType::U8, [4u64]); let buffer = client.error_buffer(error.clone(), specification.clone(), device.clone()).unwrap(); assert!(matches!( @@ -4264,7 +4302,7 @@ mod tests { fn test_client_alias_buffer_and_fulfillment() { let client = test_cpu_client(); let device = client.addressable_devices().unwrap()[0].clone(); - let specification = BufferSpecification { element_type: BufferType::U8, dimensions: [4u64], layout: None }; + let specification = BufferSpecification::new(BufferType::U8, [4u64]); // Create a new alias buffer and fulfill it with some other buffer. let (alias_buffer, token) = client.alias_buffer(specification.clone(), device.clone()).unwrap(); diff --git a/crates/ryft-pjrt/src/transfers.rs b/crates/ryft-pjrt/src/transfers.rs index 01586ecf..4ef05dcd 100644 --- a/crates/ryft-pjrt/src/transfers.rs +++ b/crates/ryft-pjrt/src/transfers.rs @@ -191,9 +191,9 @@ impl<'c> HostToDeviceTransferManager<'c> { /// 2. `data` must point to a contiguous buffer large enough for the literal's dense array storage, including /// layout padding. In XLA terms this corresponds to [`ShapeUtil::ByteSizeOf`]( /// https://github.com/openxla/xla/blob/main/xla/shape_util.h#L177-L183). - /// 3. Element ordering must follow the provided `specification.layout` (or XLA's default dense layout if - /// `specification.layout` is [`None`]). For untiled layouts, linearization follows minor-to-major dimension - /// ordering as documented by [`IndexUtil::MultidimensionalIndexToLinearIndex`]( + /// 3. Element ordering must follow the provided `specification.layout`. For untiled layouts, linearization + /// follows minor-to-major dimension ordering as documented by + /// [`IndexUtil::MultidimensionalIndexToLinearIndex`]( /// https://github.com/openxla/xla/blob/main/xla/index_util.h#L41-L114). For tiled layouts, bytes must follow /// XLA tiled-layout rules (i.e., tile-major ordering, within-tile ordering, and edge padding) as described in /// the [official documentation]( @@ -230,8 +230,7 @@ impl<'c> HostToDeviceTransferManager<'c> { shape_dims = dimensions.as_ptr(), shape_num_dims = dimensions.len(), shape_element_type = specification.element_type.to_c_api(), - shape_layout = specification.layout.map(|layout| &layout.to_c_api() as *const _ as *mut _) - .unwrap_or(std::ptr::null_mut()), + shape_layout = &specification.layout.to_c_api() as *const _ as *mut _, }, { done_with_h2d_transfer }, )?; @@ -364,11 +363,11 @@ impl<'s> Client<'s> { .collect::>(); let layouts = buffer_specifications .iter() - .map(|specification| specification.layout.as_ref().map(|layout| unsafe { layout.to_c_api() })) + .map(|specification| unsafe { specification.layout.to_c_api() }) .collect::>(); let layouts = layouts .iter() - .map(|layout| layout.as_ref().map(|layout| layout as *const _ as *mut _).unwrap_or(std::ptr::null_mut())) + .map(|layout| layout as *const _ as *mut _) .collect::>(); invoke_pjrt_api_error_fn!( self.api(), @@ -1009,7 +1008,7 @@ mod tests { test_for_each_platform!(|_plugin, client, _platform| { let device = client.addressable_devices().unwrap().remove(0); let memory = device.default_memory().unwrap(); - let specification = BufferSpecification { element_type: BufferType::U8, dimensions: [8u64], layout: None }; + let specification = BufferSpecification::new(BufferType::U8, [8u64]); // Test a successful transfer. let manager = client.host_to_device_transfer_manager(vec![specification.clone()], memory).unwrap(); From 12c46199744d25f3473abddbc306976cdb391f92 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Sat, 2 May 2026 08:49:55 -0700 Subject: [PATCH 10/12] . --- crates/ryft-mlir/README.md | 1 - crates/ryft-mlir/src/operations/operation.rs | 33 +++++++++++++++++++- crates/ryft-mlir/src/values.rs | 21 ++++++++++--- 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/crates/ryft-mlir/README.md b/crates/ryft-mlir/README.md index d53c1c8e..474c27c0 100644 --- a/crates/ryft-mlir/README.md +++ b/crates/ryft-mlir/README.md @@ -83,7 +83,6 @@ fn main() { ## Roadmap / TODOs -- [ ] Add support for `mlirOperationReplaceUsesOfWith` and `mlirBlockArgumentSetLocation`. - [ ] Add `Context` constructors like `i32_type`, etc. Maybe also `bool_type` as an alias for `i1_type`? - [ ] Clean up the API we have around elements attributes and use stronger typing, if possible. - [ ] `BooleanAttributeRef::is` panics (and the same for a 1-bit integer attribute in reverse). diff --git a/crates/ryft-mlir/src/operations/operation.rs b/crates/ryft-mlir/src/operations/operation.rs index a561cc40..90cbd904 100644 --- a/crates/ryft-mlir/src/operations/operation.rs +++ b/crates/ryft-mlir/src/operations/operation.rs @@ -15,7 +15,7 @@ use ryft_xla_sys::bindings::{ mlirOperationGetSuccessor, mlirOperationGetTypeID, mlirOperationHasInherentAttributeByName, mlirOperationHashValue, mlirOperationIsBeforeInBlock, mlirOperationMoveAfter, mlirOperationMoveBefore, mlirOperationPrint, mlirOperationPrintWithFlags, mlirOperationPrintWithState, mlirOperationRemoveAttributeByName, - mlirOperationRemoveDiscardableAttributeByName, mlirOperationSetAttributeByName, + mlirOperationRemoveDiscardableAttributeByName, mlirOperationReplaceUsesOfWith, mlirOperationSetAttributeByName, mlirOperationSetDiscardableAttributeByName, mlirOperationSetInherentAttributeByName, mlirOperationSetLocation, mlirOperationSetOperand, mlirOperationSetOperands, mlirOperationSetSuccessor, mlirOperationVerify, mlirOperationWalk, mlirOperationWriteBytecode, mlirOperationWriteBytecodeWithConfig, @@ -522,6 +522,25 @@ pub trait Operation<'o, 'c: 'o, 't: 'c>: Sized { } } + /// Replaces all uses of the `target` [`Value`] inside this [`Operation`] with the provided `replacement`. + /// + /// Note that this function is marked as _unsafe_ because if the provided `replacement` does not _dominate_ this + /// [`Operation`] according to MLIR's dominance rules (i.e., it is not defined before/above it in the current + /// control flow of the program), then calling this function results in undefined behavior. + unsafe fn replace_uses_of_with<'a, 'b, A: Value<'a, 'c, 't>, B: Value<'b, 'c, 't>>( + &mut self, + target: A, + replacement: B, + ) where + 'c: 'a + 'b, + { + // The following context borrow ensures that access to the underlying MLIR data structures is done safely from + // Rust. It is maybe more conservative than would be ideal, but that is due to the limited exposure to MLIR + // internals that we have when working with the MLIR C API. + let _guard = self.context().borrow_mut(); + unsafe { mlirOperationReplaceUsesOfWith(self.to_c_api(), target.to_c_api(), replacement.to_c_api()) } + } + /// Returns the number of results of this [`Operation`]. fn result_count(&self) -> usize { // The following context borrow ensures that access to the underlying MLIR data structures is done safely from @@ -1349,6 +1368,18 @@ mod tests { assert_eq!(op.operand(0).map(|operand| operand.value()), Some(argument_1)); assert_eq!(op.operand(1).map(|operand| operand.value()), Some(argument_2)); assert!(unsafe { !op.replace_operands(&[argument_2]) }); + + // Try replacing all uses of one value inside an operation. + let mut op = OperationBuilder::new("foo", context.unknown_location()) + .add_operand(argument_0) + .add_operand(argument_2) + .add_operand(argument_0) + .build() + .unwrap(); + unsafe { op.replace_uses_of_with(argument_0, argument_2) }; + assert_eq!(op.operand(0).map(|operand| operand.value()), Some(argument_2)); + assert_eq!(op.operand(1).map(|operand| operand.value()), Some(argument_2)); + assert_eq!(op.operand(2).map(|operand| operand.value()), Some(argument_2)); } #[test] diff --git a/crates/ryft-mlir/src/values.rs b/crates/ryft-mlir/src/values.rs index 7760649c..62cb5ebc 100644 --- a/crates/ryft-mlir/src/values.rs +++ b/crates/ryft-mlir/src/values.rs @@ -5,10 +5,10 @@ use std::marker::PhantomData; use ryft_xla_sys::bindings::{ MlirContext, MlirOpOperand, MlirValue, mlirBlockArgumentGetArgNumber, mlirBlockArgumentGetOwner, - mlirOpOperandGetNextUse, mlirOpOperandGetOperandNumber, mlirOpOperandGetOwner, mlirOpOperandGetValue, - mlirOpOperandIsNull, mlirOpResultGetOwner, mlirOpResultGetResultNumber, mlirValueDump, mlirValueGetFirstUse, - mlirValueGetLocation, mlirValueGetType, mlirValueIsABlockArgument, mlirValueIsAOpResult, mlirValuePrintAsOperand, - mlirValueReplaceAllUsesExcept, mlirValueReplaceAllUsesOfWith, mlirValueSetType, + mlirBlockArgumentSetLocation, mlirOpOperandGetNextUse, mlirOpOperandGetOperandNumber, mlirOpOperandGetOwner, + mlirOpOperandGetValue, mlirOpOperandIsNull, mlirOpResultGetOwner, mlirOpResultGetResultNumber, mlirValueDump, + mlirValueGetFirstUse, mlirValueGetLocation, mlirValueGetType, mlirValueIsABlockArgument, mlirValueIsAOpResult, + mlirValuePrintAsOperand, mlirValueReplaceAllUsesExcept, mlirValueReplaceAllUsesOfWith, mlirValueSetType, }; use crate::support::write_to_string_callback; @@ -259,6 +259,15 @@ impl<'b, 'c, 't> BlockArgumentRef<'b, 'c, 't> { let _guard = self.context.borrow(); unsafe { mlirBlockArgumentGetArgNumber(self.handle).cast_unsigned() } } + + /// Sets the [`Location`] of this [`BlockArgumentRef`]. + pub fn set_location>(&mut self, location: L) { + // The following context borrow ensures that access to the underlying MLIR data structures is done safely from + // Rust. It is maybe more conservative than would be ideal, but that is due to the limited exposure to MLIR + // internals that we have when working with the MLIR C API. + let _guard = self.context.borrow_mut(); + unsafe { mlirBlockArgumentSetLocation(self.handle, location.to_c_api()) } + } } impl<'v, 'b: 'v, 'c, 't> Value<'v, 'c, 't> for BlockArgumentRef<'b, 'c, 't> { @@ -551,6 +560,10 @@ mod tests { assert!(block_argument.is::()); assert!(!block_argument.is::()); assert_eq!(block_argument.block(), block); + let new_location = context.file_location("foo.mlir", 2, 3); + let mut block_argument = block_argument; + block_argument.set_location(new_location); + assert_eq!(block_argument.location(), new_location); block_argument.set_type(f64_type); assert_eq!(block_argument.r#type(), f64_type); } From 9988a7ec1c7225c48c96cba384ab5a8525859dd0 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Sat, 2 May 2026 09:08:37 -0700 Subject: [PATCH 11/12] . --- crates/ryft-mlir/src/values.rs | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/crates/ryft-mlir/src/values.rs b/crates/ryft-mlir/src/values.rs index 62cb5ebc..f8786422 100644 --- a/crates/ryft-mlir/src/values.rs +++ b/crates/ryft-mlir/src/values.rs @@ -5,10 +5,11 @@ use std::marker::PhantomData; use ryft_xla_sys::bindings::{ MlirContext, MlirOpOperand, MlirValue, mlirBlockArgumentGetArgNumber, mlirBlockArgumentGetOwner, - mlirBlockArgumentSetLocation, mlirOpOperandGetNextUse, mlirOpOperandGetOperandNumber, mlirOpOperandGetOwner, - mlirOpOperandGetValue, mlirOpOperandIsNull, mlirOpResultGetOwner, mlirOpResultGetResultNumber, mlirValueDump, - mlirValueGetFirstUse, mlirValueGetLocation, mlirValueGetType, mlirValueIsABlockArgument, mlirValueIsAOpResult, - mlirValuePrintAsOperand, mlirValueReplaceAllUsesExcept, mlirValueReplaceAllUsesOfWith, mlirValueSetType, + mlirBlockArgumentSetLocation, mlirBlockArgumentSetType, mlirOpOperandGetNextUse, mlirOpOperandGetOperandNumber, + mlirOpOperandGetOwner, mlirOpOperandGetValue, mlirOpOperandIsNull, mlirOpResultGetOwner, + mlirOpResultGetResultNumber, mlirValueDump, mlirValueGetFirstUse, mlirValueGetLocation, mlirValueGetType, + mlirValueIsABlockArgument, mlirValueIsAOpResult, mlirValuePrintAsOperand, mlirValueReplaceAllUsesExcept, + mlirValueReplaceAllUsesOfWith, mlirValueSetType, }; use crate::support::write_to_string_callback; @@ -260,6 +261,15 @@ impl<'b, 'c, 't> BlockArgumentRef<'b, 'c, 't> { unsafe { mlirBlockArgumentGetArgNumber(self.handle).cast_unsigned() } } + /// Sets the [`Type`] of this [`BlockArgumentRef`]. + pub fn set_type>(&mut self, r#type: T) { + // The following context borrow ensures that access to the underlying MLIR data structures is done safely from + // Rust. It is maybe more conservative than would be ideal, but that is due to the limited exposure to MLIR + // internals that we have when working with the MLIR C API. + let _guard = self.context.borrow_mut(); + unsafe { mlirBlockArgumentSetType(self.handle, r#type.to_c_api()) } + } + /// Sets the [`Location`] of this [`BlockArgumentRef`]. pub fn set_location>(&mut self, location: L) { // The following context borrow ensures that access to the underlying MLIR data structures is done safely from @@ -547,7 +557,7 @@ mod tests { let index_type = context.index_type(); let f64_type = context.float64_type(); let block = context.block(&[(index_type, location)]); - let block_argument = block.argument(0).unwrap(); + let mut block_argument = block.argument(0).unwrap(); assert_eq!(block_argument.context(), &context); assert_eq!(block_argument.name(false, false).ok().flatten(), None); assert_eq!(block_argument.name(true, false).ok().flatten(), None); @@ -560,12 +570,11 @@ mod tests { assert!(block_argument.is::()); assert!(!block_argument.is::()); assert_eq!(block_argument.block(), block); + block_argument.set_type(f64_type); + assert_eq!(block_argument.r#type(), f64_type); let new_location = context.file_location("foo.mlir", 2, 3); - let mut block_argument = block_argument; block_argument.set_location(new_location); assert_eq!(block_argument.location(), new_location); - block_argument.set_type(f64_type); - assert_eq!(block_argument.r#type(), f64_type); } #[test] From dc57ec648ec706054906ed86d6f813b7ad9d9623 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Sat, 2 May 2026 09:23:22 -0700 Subject: [PATCH 12/12] . --- .../stable_hlo/operations/communication.rs | 244 ++++++++++++------ 1 file changed, 172 insertions(+), 72 deletions(-) diff --git a/crates/ryft-mlir/src/dialects/stable_hlo/operations/communication.rs b/crates/ryft-mlir/src/dialects/stable_hlo/operations/communication.rs index 15a5ef7a..7cf34b53 100644 --- a/crates/ryft-mlir/src/dialects/stable_hlo/operations/communication.rs +++ b/crates/ryft-mlir/src/dialects/stable_hlo/operations/communication.rs @@ -4,11 +4,14 @@ use ryft_xla_sys::bindings::{ }; use crate::{ - Attribute, BooleanAttributeRef, Context, DenseIntegerElementsAttributeRef, DetachedOp, DetachedRegion, - DialectHandle, IntegerAttributeRef, IntoWithContext, Location, OneRegion, Operation, OperationBuilder, RegionRef, - Size, StringAttributeRef, StringRef, TensorTypeRef, Type, Value, mlir_op, mlir_op_trait, mlir_subtype_trait_impls, + Attribute, AttributeRef, BooleanAttributeRef, Context, DenseIntegerElementsAttributeRef, DetachedOp, + DetachedRegion, DialectHandle, IntegerAttributeRef, IntoWithContext, Location, OneRegion, Operation, + OperationBuilder, RegionRef, Size, StringAttributeRef, StringRef, TensorTypeRef, Type, Value, mlir_op, + mlir_op_trait, mlir_subtype_trait_impls, }; +use crate::dialects::stable_hlo::ReplicaGroupMeshAxesAttributeRef; + /// Represents the type of a StableHLO communication channel. #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(i64)] @@ -632,42 +635,94 @@ pub const COLLECTIVE_REPLICA_GROUPS_ATTRIBUTE: &str = "replica_groups"; /// Name of the [`Attribute`] that is used to store [`HasReplicaGroups::use_global_device_ids`]. pub const COLLECTIVE_USE_GLOBAL_DEVICE_IDS_ATTRIBUTE: &str = "use_global_device_ids"; +/// StableHLO collective replica groups. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ReplicaGroups<'c, 't> { + /// Replica groups represented explicitly as device IDs. + Dense(Vec>), + + /// Replica groups represented by a mesh and mesh-axis references. + MeshAxes(ReplicaGroupMeshAxesAttributeRef<'c, 't>), +} + +impl<'c, 't> ReplicaGroups<'c, 't> { + /// Creates dense replica groups from a slice of device-id groups. + pub fn dense(replica_groups: &[&[usize]]) -> Self { + Self::Dense(replica_groups.iter().map(|group| group.to_vec()).collect()) + } + + /// Creates mesh-axis replica groups from the corresponding StableHLO attribute. + pub fn mesh_axes(replica_groups: ReplicaGroupMeshAxesAttributeRef<'c, 't>) -> Self { + Self::MeshAxes(replica_groups) + } + + /// Returns the MLIR [`Attribute`] representation of these replica groups. + fn to_attribute>(&self, context: &'c Context<'t>, location: L) -> AttributeRef<'c, 't> { + match self { + Self::Dense(replica_groups) => { + context.stable_hlo_replica_groups_attribute(replica_groups, location).as_ref() + } + Self::MeshAxes(replica_groups) => replica_groups.as_ref(), + } + } +} + +impl<'c, 't> From> for ReplicaGroups<'c, 't> { + fn from(replica_groups: ReplicaGroupMeshAxesAttributeRef<'c, 't>) -> Self { + Self::mesh_axes(replica_groups) + } +} + /// Trait that represents collective [`Operation`]s that support specifying and operating over replica groups. pub trait HasReplicaGroups<'o, 'c: 'o, 't: 'c>: Operation<'o, 'c, 't> { - /// Returns the optional replica groups of this [`Operation`]. This must be non-empty if + /// Returns the replica groups of this [`Operation`]. This must be non-empty if /// [`HasReplicaGroups::use_global_device_ids`] is `true`. /// - /// This a [`Vec`] over replica groups where each group is represented as a [`Vec`] of device IDs. In most cases, - /// the groups must all have the same size. The only exception is [`all_reduce`] which supports non-uniform replica - /// groups. These groups determine the order in which the gather operation is performed and which devices - /// communicate with which other devices during the gather operation. - fn replica_groups(&self) -> Vec> { + /// The returned [`ReplicaGroups`] may contain either explicit device-id groups or mesh-axis replica groups. + fn replica_groups(&self) -> ReplicaGroups<'c, 't> { self.attribute(COLLECTIVE_REPLICA_GROUPS_ATTRIBUTE) - .and_then(|attribute| attribute.cast::()) - .and_then(|attribute| { - let attribute_type = attribute.r#type().cast::(); - let device_ids = unsafe { attribute.i64_elements() }; - attribute_type.and_then(|tensor_type| { - tensor_type.dimension(1).value().map(|max_group_size| { - let mut groups = Vec::new(); - let mut device_ids = device_ids.peekable(); - while device_ids.peek().is_some() { - // We filter for non-negative values after the chunking below because the value `-1` is - // used as a "null" device padding value for when dealing with non-uniform replica groups. - groups.push( - device_ids - .by_ref() - .take(max_group_size) - .filter(|id| *id >= 0) - .map(|id| id as usize) - .collect(), - ); - } - groups - }) - }) + .map(|attribute| { + if let Some(attribute) = attribute.cast::() { + let attribute_type = attribute.r#type().cast::(); + let device_ids = unsafe { attribute.i64_elements() }; + ReplicaGroups::Dense( + attribute_type + .and_then(|tensor_type| { + tensor_type.dimension(1).value().map(|max_group_size| { + let mut groups = Vec::new(); + let mut device_ids = device_ids.peekable(); + while device_ids.peek().is_some() { + // We filter for non-negative values after the chunking below because the value + // `-1` is used as a "null" device padding value for when dealing with + // non-uniform replica groups. + groups.push( + device_ids + .by_ref() + .take(max_group_size) + .filter(|id| *id >= 0) + .map(|id| id as usize) + .collect(), + ); + } + groups + }) + }) + .unwrap_or_else(|| { + panic!( + "invalid '{COLLECTIVE_REPLICA_GROUPS_ATTRIBUTE}' dense attribute in StableHLO \ + collective operation", + ) + }), + ) + } else if let Some(attribute) = attribute.cast::() { + ReplicaGroups::MeshAxes(attribute) + } else { + panic!( + "invalid '{COLLECTIVE_REPLICA_GROUPS_ATTRIBUTE}' attribute in StableHLO collective operation", + ) + } }) - .unwrap_or(Vec::new()) + .unwrap_or_else(|| ReplicaGroups::Dense(Vec::new())) } /// Returns `true` if this [`Operation`] uses global device IDs. Defaults to `false` if not specified. @@ -678,11 +733,11 @@ pub trait HasReplicaGroups<'o, 'c: 'o, 't: 'c>: Operation<'o, 'c, 't> { } impl<'t> Context<'t> { - /// Internal helper for constructing the [`DenseIntegerElementsAttributeRef`] that is used to store + /// Internal helper for constructing the [`DenseIntegerElementsAttributeRef`] that is used to store dense /// [`HasReplicaGroups::replica_groups`]. fn stable_hlo_replica_groups_attribute<'c, L: Location<'c, 't>>( &'c self, - replica_groups: &[&[usize]], + replica_groups: &[Vec], location: L, ) -> DenseIntegerElementsAttributeRef<'c, 't> { let i64_type = self.signless_integer_type(64); @@ -694,7 +749,7 @@ impl<'t> Context<'t> { let mut attribute_values = Vec::with_capacity(group_count * max_group_size); for group in replica_groups { let group_size = group.len(); - for id in *group { + for id in group { attribute_values.push(self.integer_attribute(i64_type, *id as i64)); } if group_size < max_group_size { @@ -768,7 +823,7 @@ mlir_op_trait!(AllGather, @local SupportsChannelHandle); pub fn all_gather<'v, 'c: 'v, 't: 'c, V: Value<'v, 'c, 't>, T: Type<'c, 't>, L: Location<'c, 't>>( inputs: &[V], all_gather_dimension: usize, - replica_groups: &[&[usize]], + replica_groups: ReplicaGroups<'c, 't>, channel_id: Option, channel_type: Option, use_global_device_ids: bool, @@ -781,10 +836,7 @@ pub fn all_gather<'v, 'c: 'v, 't: 'c, V: Value<'v, 'c, 't>, T: Type<'c, 't>, L: let mut builder = OperationBuilder::new("stablehlo.all_gather", location) .add_operands(inputs) .add_attribute(ALL_GATHER_DIMENSION_ATTRIBUTE, context.integer_attribute(i64_type, all_gather_dimension as i64)) - .add_attribute( - COLLECTIVE_REPLICA_GROUPS_ATTRIBUTE, - context.stable_hlo_replica_groups_attribute(replica_groups, location), - ); + .add_attribute(COLLECTIVE_REPLICA_GROUPS_ATTRIBUTE, replica_groups.to_attribute(context, location)); if let Some(channel_id) = channel_id { builder = builder.add_attribute( COLLECTIVE_CHANNEL_HANDLE_ATTRIBUTE, @@ -861,7 +913,7 @@ mlir_op_trait!(AllReduce, @local SupportsChannelHandle); /// Note that if any of the inputs to this function are invalid, it will panic! pub fn all_reduce<'v, 'c: 'v, 't: 'c, V: Value<'v, 'c, 't>, L: Location<'c, 't>>( inputs: &[V], - replica_groups: &[&[usize]], + replica_groups: ReplicaGroups<'c, 't>, channel_id: Option, channel_type: Option, use_global_device_ids: bool, @@ -870,10 +922,9 @@ pub fn all_reduce<'v, 'c: 'v, 't: 'c, V: Value<'v, 'c, 't>, L: Location<'c, 't>> ) -> DetachedAllReduceOperation<'c, 't> { let context = location.context(); context.load_dialect(DialectHandle::stable_hlo()); - let mut builder = OperationBuilder::new("stablehlo.all_reduce", location).add_operands(inputs).add_attribute( - COLLECTIVE_REPLICA_GROUPS_ATTRIBUTE, - context.stable_hlo_replica_groups_attribute(replica_groups, location), - ); + let mut builder = OperationBuilder::new("stablehlo.all_reduce", location) + .add_operands(inputs) + .add_attribute(COLLECTIVE_REPLICA_GROUPS_ATTRIBUTE, replica_groups.to_attribute(context, location)); if let Some(channel_id) = channel_id { builder = builder.add_attribute( COLLECTIVE_CHANNEL_HANDLE_ATTRIBUTE, @@ -990,7 +1041,7 @@ pub fn all_to_all<'v, 'c: 'v, 't: 'c, V: Value<'v, 'c, 't>, L: Location<'c, 't>> split_dimension: usize, split_count: usize, concatenation_dimension: usize, - replica_groups: &[&[usize]], + replica_groups: ReplicaGroups<'c, 't>, channel_id: Option, channel_type: Option, use_global_device_ids: bool, @@ -1010,10 +1061,7 @@ pub fn all_to_all<'v, 'c: 'v, 't: 'c, V: Value<'v, 'c, 't>, L: Location<'c, 't>> ALL_TO_ALL_CONCATENATION_DIMENSION_ATTRIBUTE, context.integer_attribute(i64_type, concatenation_dimension as i64), ) - .add_attribute( - COLLECTIVE_REPLICA_GROUPS_ATTRIBUTE, - context.stable_hlo_replica_groups_attribute(replica_groups, location), - ); + .add_attribute(COLLECTIVE_REPLICA_GROUPS_ATTRIBUTE, replica_groups.to_attribute(context, location)); if let Some(channel_id) = channel_id { builder = builder.add_attribute( COLLECTIVE_CHANNEL_HANDLE_ATTRIBUTE, @@ -1080,18 +1128,16 @@ mlir_op_trait!(CollectiveBroadcast, @local SupportsChannelHandle); /// Note that if any of the inputs to this function are invalid, it will panic! pub fn collective_broadcast<'v, 'c: 'v, 't: 'c, V: Value<'v, 'c, 't>, L: Location<'c, 't>>( input: V, - replica_groups: &[&[usize]], + replica_groups: ReplicaGroups<'c, 't>, channel_id: Option, channel_type: Option, location: L, ) -> DetachedCollectiveBroadcastOperation<'c, 't> { let context = location.context(); context.load_dialect(DialectHandle::stable_hlo()); - let mut builder = - OperationBuilder::new("stablehlo.collective_broadcast", location).add_operand(input).add_attribute( - COLLECTIVE_REPLICA_GROUPS_ATTRIBUTE, - context.stable_hlo_replica_groups_attribute(replica_groups, location), - ); + let mut builder = OperationBuilder::new("stablehlo.collective_broadcast", location) + .add_operand(input) + .add_attribute(COLLECTIVE_REPLICA_GROUPS_ATTRIBUTE, replica_groups.to_attribute(context, location)); if let Some(channel_id) = channel_id { builder = builder.add_attribute( COLLECTIVE_CHANNEL_HANDLE_ATTRIBUTE, @@ -1256,7 +1302,7 @@ mlir_op_trait!(ReduceScatter, @local SupportsChannelHandle); pub fn reduce_scatter<'v, 'c: 'v, 't: 'c, V: Value<'v, 'c, 't>, T: Type<'c, 't>, L: Location<'c, 't>>( operand: V, dimension: usize, - replica_groups: &[&[usize]], + replica_groups: ReplicaGroups<'c, 't>, channel_id: Option, channel_type: Option, use_global_device_ids: bool, @@ -1272,10 +1318,7 @@ pub fn reduce_scatter<'v, 'c: 'v, 't: 'c, V: Value<'v, 'c, 't>, T: Type<'c, 't>, REDUCE_SCATTER_DIMENSION_ATTRIBUTE, location.context().integer_attribute(context.signless_integer_type(64), dimension as i64), ) - .add_attribute( - COLLECTIVE_REPLICA_GROUPS_ATTRIBUTE, - context.stable_hlo_replica_groups_attribute(replica_groups, location), - ); + .add_attribute(COLLECTIVE_REPLICA_GROUPS_ATTRIBUTE, replica_groups.to_attribute(context, location)); if let Some(channel_id) = channel_id { builder = builder.add_attribute( COLLECTIVE_CHANNEL_HANDLE_ATTRIBUTE, @@ -1658,7 +1701,7 @@ mod tests { let op = all_gather( &[block.argument(0).unwrap()], 1, - &[&[0, 2], &[1, 3]], + ReplicaGroups::dense(&[&[0, 2], &[1, 3]]), Some(0), Some(ChannelHandleType::DeviceToDevice), true, @@ -1666,7 +1709,7 @@ mod tests { location, ); assert_eq!(op.all_gather_dimension(), 1); - assert_eq!(op.replica_groups(), vec![vec![0, 2], vec![1, 3]]); + assert_eq!(op.replica_groups(), ReplicaGroups::Dense(vec![vec![0, 2], vec![1, 3]])); assert_eq!(op.channel_id(), Some(0)); assert_eq!(op.channel_type(), Some(ChannelHandleType::DeviceToDevice)); assert!(op.use_global_device_ids()); @@ -1700,6 +1743,63 @@ mod tests { } "}, ); + + // Test using mesh axis replica groups. + let module = context.module(location); + let i64_type = context.signless_integer_type(64); + let input_tensor_type = + context.tensor_type(i64_type, &[Size::Static(2), Size::Static(2)], None, location).unwrap(); + let output_tensor_type = + context.tensor_type(i64_type, &[Size::Static(2), Size::Static(4)], None, location).unwrap(); + let mesh_axis_x = context.stable_hlo_mesh_axis("x", 2); + let mesh_axis_y = context.stable_hlo_mesh_axis("y", 2); + let mesh = context.stable_hlo_mesh(context.array_attribute(&[mesh_axis_x, mesh_axis_y]), None); + let replica_axis = context.stable_hlo_axis_ref("x", None); + let replica_groups = context.stable_hlo_replica_group_mesh_axes(mesh, context.array_attribute(&[replica_axis])); + module.body().append_operation({ + let mut block = context.block(&[(input_tensor_type, location)]); + let op = all_gather( + &[block.argument(0).unwrap()], + 1, + replica_groups.into(), + None, + None, + false, + &[output_tensor_type], + location, + ); + assert_eq!(op.replica_groups(), ReplicaGroups::MeshAxes(replica_groups)); + let op = block.append_operation(op); + block.append_operation(func::r#return(&[op.result(0).unwrap()], location)); + func::func( + "test_all_gather_with_mesh_axis_replica_groups", + func::FuncAttributes { + arguments: vec![input_tensor_type.into()], + results: vec![output_tensor_type.into()], + ..Default::default() + }, + block.into(), + location, + ) + }); + assert!(module.verify()); + assert_eq!( + module.to_string(), + indoc! {" + module { + func.func @test_all_gather_with_mesh_axis_replica_groups(%arg0: tensor<2x2xi64>) -> tensor<2x4xi64> { + %0 = \"stablehlo.all_gather\"(%arg0) <{\ + all_gather_dim = 1 : i64, \ + replica_groups = #stablehlo.replica_group_mesh_axes<\ + mesh = #stablehlo.mesh, ]>, \ + axes = [#stablehlo.axis_ref]\ + >\ + }> : (tensor<2x2xi64>) -> tensor<2x4xi64> + return %0 : tensor<2x4xi64> + } + } + "}, + ); } #[test] @@ -1726,14 +1826,14 @@ mod tests { let computation = computation_region.into(); let op = all_reduce( &[block.argument(0).unwrap()], - &[&[0, 2], &[1]], + ReplicaGroups::dense(&[&[0, 2], &[1]]), Some(1), Some(ChannelHandleType::DeviceToDevice), true, computation, location, ); - assert_eq!(op.replica_groups(), vec![vec![0, 2], vec![1]]); + assert_eq!(op.replica_groups(), ReplicaGroups::Dense(vec![vec![0, 2], vec![1]])); assert_eq!(op.channel_id(), Some(1)); assert_eq!(op.channel_type(), Some(ChannelHandleType::DeviceToDevice)); assert!(op.use_global_device_ids()); @@ -1789,7 +1889,7 @@ mod tests { 1, 2, 0, - &[&[0, 2], &[1, 3]], + ReplicaGroups::dense(&[&[0, 2], &[1, 3]]), Some(1), Some(ChannelHandleType::DeviceToDevice), true, @@ -1798,7 +1898,7 @@ mod tests { assert_eq!(op.split_dimension(), 1); assert_eq!(op.split_count(), 2); assert_eq!(op.concatenation_dimension(), 0); - assert_eq!(op.replica_groups(), vec![vec![0, 2], vec![1, 3]]); + assert_eq!(op.replica_groups(), ReplicaGroups::Dense(vec![vec![0, 2], vec![1, 3]])); assert_eq!(op.channel_id(), Some(1)); assert_eq!(op.channel_type(), Some(ChannelHandleType::DeviceToDevice)); let op = block.append_operation(op); @@ -1845,12 +1945,12 @@ mod tests { let mut block = context.block(&[(tensor_type, location)]); let op = collective_broadcast( block.argument(0).unwrap(), - &[&[0, 2], &[1, 3]], + ReplicaGroups::dense(&[&[0, 2], &[1, 3]]), Some(1), Some(ChannelHandleType::DeviceToDevice), location, ); - assert_eq!(op.replica_groups(), vec![vec![0, 2], vec![1, 3]]); + assert_eq!(op.replica_groups(), ReplicaGroups::Dense(vec![vec![0, 2], vec![1, 3]])); assert_eq!(op.channel_id(), Some(1)); assert_eq!(op.channel_type(), Some(ChannelHandleType::DeviceToDevice)); let op = block.append_operation(op); @@ -1956,7 +2056,7 @@ mod tests { let reduce_scatter_op = reduce_scatter( input, 1, - &[&[0, 2], &[1, 3]], + ReplicaGroups::dense(&[&[0, 2], &[1, 3]]), Some(1), Some(ChannelHandleType::DeviceToDevice), false, @@ -1965,7 +2065,7 @@ mod tests { location, ); assert_eq!(reduce_scatter_op.dimension(), 1); - assert_eq!(reduce_scatter_op.replica_groups(), vec![vec![0, 2], vec![1, 3]]); + assert_eq!(reduce_scatter_op.replica_groups(), ReplicaGroups::Dense(vec![vec![0, 2], vec![1, 3]])); assert_eq!(reduce_scatter_op.channel_id(), Some(1)); assert_eq!(reduce_scatter_op.channel_type(), Some(ChannelHandleType::DeviceToDevice)); assert_eq!(reduce_scatter_op.use_global_device_ids(), false);