From a85f7026efd8d5402b00e3c3ce5c2c798bf00736 Mon Sep 17 00:00:00 2001 From: sayantn Date: Wed, 22 Apr 2026 04:33:19 +0530 Subject: [PATCH 1/2] Add remaining AMX intrinsics --- crates/core_arch/missing-x86.md | 37 - crates/core_arch/src/x86_64/amx.rs | 906 ++++++++++++++++++++++- crates/core_arch/src/x86_64/mod.rs | 14 + crates/stdarch-test/src/lib.rs | 4 + crates/stdarch-verify/src/lib.rs | 1 + crates/stdarch-verify/tests/x86-intel.rs | 4 + 6 files changed, 922 insertions(+), 44 deletions(-) diff --git a/crates/core_arch/missing-x86.md b/crates/core_arch/missing-x86.md index e9f68eb9e6..3a82f9761f 100644 --- a/crates/core_arch/missing-x86.md +++ b/crates/core_arch/missing-x86.md @@ -1,41 +1,4 @@ -
["AMX-BF16"]

- - * [ ] [`__tile_dpbf16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbf16ps) -

- - -
["AMX-COMPLEX"]

- - * [ ] [`__tile_cmmimfp16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_cmmimfp16ps) - * [ ] [`__tile_cmmrlfp16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_cmmrlfp16ps) -

- - -
["AMX-FP16"]

- - * [ ] [`__tile_dpfp16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpfp16ps) -

- - -
["AMX-INT8"]

- - * [ ] [`__tile_dpbssd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbssd) - * [ ] [`__tile_dpbsud`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbsud) - * [ ] [`__tile_dpbusd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbusd) - * [ ] [`__tile_dpbuud`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbuud) -

- - -
["AMX-TILE"]

- - * [ ] [`__tile_loadd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_loadd) - * [ ] [`__tile_stored`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_stored) - * [ ] [`__tile_stream_loadd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_stream_loadd) - * [ ] [`__tile_zero`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_zero) -

- -
["AVX512_FP16"]

* [ ] [`_mm256_set1_pch`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_set1_pch) diff --git a/crates/core_arch/src/x86_64/amx.rs b/crates/core_arch/src/x86_64/amx.rs index b3b3e86750..5fc47c5cfd 100644 --- a/crates/core_arch/src/x86_64/amx.rs +++ b/crates/core_arch/src/x86_64/amx.rs @@ -1,3 +1,4 @@ +use crate::core_arch::x86_64::{__tile1024i, Tile}; use crate::core_arch::{simd::*, x86::*}; #[cfg(test)] @@ -44,6 +45,17 @@ pub unsafe fn _tile_loadd(base: *const u8, stride: usize) { tileloadd64(DST as i8, base, stride); } +/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration previously configured via _tile_loadconfig. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_loadd&ig_expand=6877) +#[inline] +#[target_feature(enable = "amx-tile")] +#[cfg_attr(test, assert_instr(tileloadd))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_loadd(dst: *mut __tile1024i, base: *const u8, stride: usize) { + (*dst).tile = tileloadd64_internal((*dst).rows, (*dst).cols, base, stride as u64); +} + /// Release the tile configuration to return to the init state, which releases all storage it currently holds. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_release&ig_expand=6878) @@ -68,6 +80,17 @@ pub unsafe fn _tile_stored(base: *mut u8, stride: usize) { tilestored64(DST as i8, base, stride); } +/// Store the tile specified by src to memory specifieid by base address and stride using the tile configuration previously configured via _tile_loadconfig. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_stored&ig_expand=6881) +#[inline] +#[target_feature(enable = "amx-tile")] +#[cfg_attr(test, assert_instr(tilestored))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_stored(base: *mut u8, stride: usize, src: __tile1024i) { + tilestored64_internal(src.rows, src.cols, base, stride as u64, src.tile); +} + /// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration /// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will /// likely not be reused in the near future and the data caching can be optimized accordingly. @@ -83,6 +106,19 @@ pub unsafe fn _tile_stream_loadd(base: *const u8, stride: usize) tileloaddt164(DST as i8, base, stride); } +/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration +/// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will +/// likely not be reused in the near future and the data caching can be optimized accordingly. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_stream_loadd&ig_expand=6883) +#[inline] +#[target_feature(enable = "amx-tile")] +#[cfg_attr(test, assert_instr(tileloaddt1))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_stream_loadd(dst: *mut __tile1024i, base: *const u8, stride: usize) { + (*dst).tile = tileloaddt164_internal((*dst).rows, (*dst).cols, base, stride as u64); +} + /// Zero the tile specified by tdest. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_zero&ig_expand=6885) @@ -96,6 +132,17 @@ pub unsafe fn _tile_zero() { tilezero(DST as i8); } +/// Zero the tile specified by dst. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_zero&ig_expand=6885) +#[inline] +#[target_feature(enable = "amx-tile")] +#[cfg_attr(test, assert_instr(tilezero))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_zero(dst: *mut __tile1024i) { + (*dst).tile = tilezero_internal((*dst).rows, (*dst).cols); +} + /// Compute dot-product of BF16 (16-bit) floating-point pairs in tiles a and b, /// accumulating the intermediate single-precision (32-bit) floating-point elements /// with elements in dst, and store the 32-bit result back to tile dst. @@ -113,6 +160,19 @@ pub unsafe fn _tile_dpbf16ps() { tdpbf16ps(DST as i8, A as i8, B as i8); } +/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b, +/// accumulating the intermediate single-precision (32-bit) floating-point elements +/// with elements in dst, and store the 32-bit result back to tile dst. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbf16ps&ig_expand=6864) +#[inline] +#[target_feature(enable = "amx-bf16")] +#[cfg_attr(test, assert_instr(tdpbf16ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbf16ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbf16ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of bytes in tiles with a source/destination accumulator. /// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding /// signed 8-bit integers in b, producing 4 intermediate 32-bit results. @@ -131,6 +191,20 @@ pub unsafe fn _tile_dpbssd() { tdpbssd(DST as i8, A as i8, B as i8); } +/// Compute dot-product of bytes in tiles with a source/destination accumulator. +/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding +/// signed 8-bit integers in b, producing 4 intermediate 32-bit results. +/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbssd&ig_expand=6866) +#[inline] +#[target_feature(enable = "amx-int8")] +#[cfg_attr(test, assert_instr(tdpbssd))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbssd(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbssd_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of bytes in tiles with a source/destination accumulator. /// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding /// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. @@ -149,6 +223,20 @@ pub unsafe fn _tile_dpbsud() { tdpbsud(DST as i8, A as i8, B as i8); } +/// Compute dot-product of bytes in tiles with a source/destination accumulator. +/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding +/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. +/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbsud&ig_expand=6868) +#[inline] +#[target_feature(enable = "amx-int8")] +#[cfg_attr(test, assert_instr(tdpbsud))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbsud(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbsud_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of bytes in tiles with a source/destination accumulator. /// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding /// signed 8-bit integers in b, producing 4 intermediate 32-bit results. @@ -167,6 +255,20 @@ pub unsafe fn _tile_dpbusd() { tdpbusd(DST as i8, A as i8, B as i8); } +/// Compute dot-product of bytes in tiles with a source/destination accumulator. +/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding +/// signed 8-bit integers in b, producing 4 intermediate 32-bit results. +/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbusd&ig_expand=6870) +#[inline] +#[target_feature(enable = "amx-int8")] +#[cfg_attr(test, assert_instr(tdpbusd))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbusd(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbusd_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of bytes in tiles with a source/destination accumulator. /// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding /// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. @@ -185,6 +287,20 @@ pub unsafe fn _tile_dpbuud() { tdpbuud(DST as i8, A as i8, B as i8); } +/// Compute dot-product of bytes in tiles with a source/destination accumulator. +/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding +/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. +/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbuud&ig_expand=6872) +#[inline] +#[target_feature(enable = "amx-int8")] +#[cfg_attr(test, assert_instr(tdpbuud))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbuud(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbuud_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b, /// accumulating the intermediate single-precision (32-bit) floating-point elements /// with elements in dst, and store the 32-bit result back to tile dst. @@ -202,6 +318,19 @@ pub unsafe fn _tile_dpfp16ps() { tdpfp16ps(DST as i8, A as i8, B as i8); } +/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b, +/// accumulating the intermediate single-precision (32-bit) floating-point elements +/// with elements in dst, and store the 32-bit result back to tile dst. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpfp16ps&ig_expand=6874) +#[inline] +#[target_feature(enable = "amx-fp16")] +#[cfg_attr(test, assert_instr(tdpfp16ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpfp16ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpfp16ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. /// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. /// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b), @@ -223,6 +352,23 @@ pub unsafe fn _tile_cmmimfp16ps() { tcmmimfp16ps(DST as i8, A as i8, B as i8); } +/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. +/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. +/// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b), +/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b). +/// The imaginary part of the a element is multiplied with the real part of the corresponding b element, and the real part of +/// the a element is multiplied with the imaginary part of the corresponding b elements. The two accumulated results are added, +/// and then accumulated into the corresponding row and column of dst. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_cmmimfp16ps&ig_expand=6860) +#[inline] +#[target_feature(enable = "amx-complex")] +#[cfg_attr(test, assert_instr(tcmmimfp16ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cmmimfp16ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tcmmimfp16ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. /// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. /// Calculates the real part of the result. For each possible combination of (row of a, column of b), @@ -244,6 +390,23 @@ pub unsafe fn _tile_cmmrlfp16ps() { tcmmrlfp16ps(DST as i8, A as i8, B as i8); } +/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. +/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. +/// Calculates the real part of the result. For each possible combination of (row of a, column of b), +/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b). +/// The real part of the a element is multiplied with the real part of the corresponding b element, and the negated imaginary part of +/// the a element is multiplied with the imaginary part of the corresponding b elements. +/// The two accumulated results are added, and then accumulated into the corresponding row and column of dst. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_cmmrlfp16ps&ig_expand=6862) +#[inline] +#[target_feature(enable = "amx-complex")] +#[cfg_attr(test, assert_instr(tcmmrlfp16ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cmmrlfp16ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tcmmrlfp16ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and BF8 (8-bit E5M2) /// floating-point elements in tile b, accumulating the intermediate single-precision /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result @@ -263,6 +426,18 @@ pub unsafe fn _tile_dpbf8ps() { tdpbf8ps(DST as i8, A as i8, B as i8); } +/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and BF8 (8-bit E5M2) +/// floating-point elements in tile b, accumulating the intermediate single-precision +/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result +/// back to tile dst. +#[inline] +#[target_feature(enable = "amx-fp8")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdpbf8ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbf8ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbf8ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and HF8 /// (8-bit E4M3) floating-point elements in tile b, accumulating the intermediate single-precision /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result @@ -282,6 +457,18 @@ pub unsafe fn _tile_dpbhf8ps() { tdpbhf8ps(DST as i8, A as i8, B as i8); } +/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and HF8 +/// (8-bit E4M3) floating-point elements in tile b, accumulating the intermediate single-precision +/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result +/// back to tile dst. +#[inline] +#[target_feature(enable = "amx-fp8")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdpbhf8ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbhf8ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbhf8ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and BF8 /// (8-bit E5M2) floating-point elements in tile b, accumulating the intermediate single-precision /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result @@ -301,6 +488,18 @@ pub unsafe fn _tile_dphbf8ps() { tdphbf8ps(DST as i8, A as i8, B as i8); } +/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and BF8 +/// (8-bit E5M2) floating-point elements in tile b, accumulating the intermediate single-precision +/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result +/// back to tile dst. +#[inline] +#[target_feature(enable = "amx-fp8")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdphbf8ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dphbf8ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdphbf8ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and HF8 (8-bit E4M3) /// floating-point elements in tile b, accumulating the intermediate single-precision /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result @@ -320,6 +519,18 @@ pub unsafe fn _tile_dphf8ps() { tdphf8ps(DST as i8, A as i8, B as i8); } +/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and HF8 (8-bit E4M3) +/// floating-point elements in tile b, accumulating the intermediate single-precision +/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result +/// back to tile dst. +#[inline] +#[target_feature(enable = "amx-fp8")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdphf8ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dphf8ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdphf8ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Load tile rows from memory specified by base address and stride into destination tile dst /// using the tile configuration previously configured via _tile_loadconfig. /// Additionally, this intrinsic indicates the source memory location is likely to become @@ -338,6 +549,19 @@ pub unsafe fn _tile_loaddrs(base: *const u8, stride: usize) { tileloaddrs64(DST as i8, base, stride); } +/// Load tile rows from memory specified by base address and stride into destination tile dst +/// using the tile configuration previously configured via _tile_loadconfig. +/// Additionally, this intrinsic indicates the source memory location is likely to become +/// read-shared by multiple processors, i.e., read in the future by at least one other processor +/// before it is written, assuming it is ever written in the future. +#[inline] +#[target_feature(enable = "amx-movrs")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tileloaddrs))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_loaddrs(dst: *mut __tile1024i, base: *const u8, stride: usize) { + (*dst).tile = tileloaddrs64_internal((*dst).rows, (*dst).cols, base, stride as u64); +} + /// Load tile rows from memory specified by base address and stride into destination tile dst /// using the tile configuration previously configured via _tile_loadconfig. /// Provides a hint to the implementation that the data would be reused but does not need @@ -358,6 +582,21 @@ pub unsafe fn _tile_stream_loaddrs(base: *const u8, stride: usiz tileloaddrst164(DST as i8, base, stride); } +/// Load tile rows from memory specified by base address and stride into destination tile dst +/// using the tile configuration previously configured via _tile_loadconfig. +/// Provides a hint to the implementation that the data would be reused but does not need +/// to be resident in the nearest cache levels. +/// Additionally, this intrinsic indicates the source memory location is likely to become +/// read-shared by multiple processors, i.e., read in the future by at least one other processor +/// before it is written, assuming it is ever written in the future. +#[inline] +#[target_feature(enable = "amx-movrs")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tileloaddrst1))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_stream_loaddrs(dst: *mut __tile1024i, base: *const u8, stride: usize) { + (*dst).tile = tileloaddrst164_internal((*dst).rows, (*dst).cols, base, stride as u64); +} + /// Perform matrix multiplication of two tiles a and b, containing packed single precision (32-bit) /// floating-point elements, which are converted to TF32 (tensor-float32) format, and accumulate the /// results into a packed single precision tile. @@ -383,6 +622,24 @@ pub unsafe fn _tile_mmultf32ps() { tmmultf32ps(DST as i8, A as i8, B as i8); } +/// Perform matrix multiplication of two tiles a and b, containing packed single precision (32-bit) +/// floating-point elements, which are converted to TF32 (tensor-float32) format, and accumulate the +/// results into a packed single precision tile. +/// For each possible combination of (row of a, column of b), it performs +/// - convert to TF32 +/// - multiply the corresponding elements of a and b +/// - accumulate the results into the corresponding row and column of dst using round-to-nearest-even +/// rounding mode. +/// Output FP32 denormals are always flushed to zero, input single precision denormals are always +/// handled and *not* treated as zero. +#[inline] +#[target_feature(enable = "amx-tf32")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tmmultf32ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_mmultf32ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tmmultf32ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer /// elements to packed single-precision (32-bit) floating-point elements. #[inline] @@ -414,6 +671,16 @@ pub unsafe fn _tile_cvtrowd2psi() -> __m512 { tcvtrowd2psi(TILE as i8, ROW as u32).as_m512() } +/// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer +/// elements to packed single-precision (32-bit) floating-point elements. +#[inline] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowd2ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cvtrowd2ps(src: __tile1024i, row: u32) -> __m512 { + tcvtrowd2ps_internal(src.rows, src.cols, src.tile, row).as_m512() +} + /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) /// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting /// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. @@ -447,6 +714,17 @@ pub unsafe fn _tile_cvtrowps2phhi() -> __m512h tcvtrowps2phhi(TILE as i8, ROW as u32).as_m512h() } +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) +/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting +/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. +#[inline] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2phh))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cvtrowps2phh(src: __tile1024i, row: u32) -> __m512h { + tcvtrowps2phh_internal(src.rows, src.cols, src.tile, row).as_m512h() +} + /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) /// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting /// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. @@ -480,6 +758,17 @@ pub unsafe fn _tile_cvtrowps2phli() -> __m512h tcvtrowps2phli(TILE as i8, ROW as u32).as_m512h() } +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) +/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting +/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. +#[inline] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2phl))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cvtrowps2phl(src: __tile1024i, row: u32) -> __m512h { + tcvtrowps2phl_internal(src.rows, src.cols, src.tile, row).as_m512h() +} + /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) /// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting /// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. @@ -513,6 +802,17 @@ pub unsafe fn _tile_cvtrowps2bf16hi() -> __m512 tcvtrowps2bf16hi(TILE as i8, ROW as u32).as_m512bh() } +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) +/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting +/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. +#[inline] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2bf16h))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cvtrowps2bf16h(src: __tile1024i, row: u32) -> __m512bh { + tcvtrowps2bf16h_internal(src.rows, src.cols, src.tile, row).as_m512bh() +} + /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) /// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting /// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. @@ -546,6 +846,17 @@ pub unsafe fn _tile_cvtrowps2bf16li() -> __m512 tcvtrowps2bf16li(TILE as i8, ROW as u32).as_m512bh() } +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) +/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting +/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. +#[inline] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2bf16l))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cvtrowps2bf16l(src: __tile1024i, row: u32) -> __m512bh { + tcvtrowps2bf16l_internal(src.rows, src.cols, src.tile, row).as_m512bh() +} + /// Moves one row of tile data into a zmm vector register #[inline] #[rustc_legacy_const_generics(0)] @@ -575,83 +886,169 @@ pub unsafe fn _tile_movrowi() -> __m512i { tilemovrowi(TILE as i8, ROW as u32).as_m512i() } +/// Moves one row of tile data into a zmm vector register +#[inline] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tilemovrow))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_movrow(src: __tile1024i, row: u32) -> __m512i { + tilemovrow_internal(src.rows, src.cols, src.tile, row).as_m512i() +} + #[allow(improper_ctypes)] -unsafe extern "C" { +unsafe extern "unadjusted" { #[link_name = "llvm.x86.ldtilecfg"] fn ldtilecfg(mem_addr: *const u8); #[link_name = "llvm.x86.sttilecfg"] fn sttilecfg(mem_addr: *mut u8); + #[link_name = "llvm.x86.tileloadd64"] fn tileloadd64(dst: i8, base: *const u8, stride: usize); + #[link_name = "llvm.x86.tileloadd64.internal"] + fn tileloadd64_internal(rows: u16, cols: u16, base: *const u8, stride: u64) -> Tile; + #[link_name = "llvm.x86.tileloaddt164"] fn tileloaddt164(dst: i8, base: *const u8, stride: usize); + #[link_name = "llvm.x86.tileloaddt164.internal"] + fn tileloaddt164_internal(rows: u16, cols: u16, base: *const u8, stride: u64) -> Tile; + #[link_name = "llvm.x86.tilerelease"] fn tilerelease(); + #[link_name = "llvm.x86.tilestored64"] fn tilestored64(dst: i8, base: *mut u8, stride: usize); + #[link_name = "llvm.x86.tilestored64.internal"] + fn tilestored64_internal(rows: u16, cols: u16, base: *mut u8, stride: u64, src: Tile); + #[link_name = "llvm.x86.tilezero"] fn tilezero(dst: i8); + #[link_name = "llvm.x86.tilezero.internal"] + fn tilezero_internal(rows: u16, cols: u16) -> Tile; + #[link_name = "llvm.x86.tdpbf16ps"] fn tdpbf16ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbf16ps.internal"] + fn tdpbf16ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpbuud"] fn tdpbuud(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbuud.internal"] + fn tdpbuud_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpbusd"] fn tdpbusd(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbusd.internal"] + fn tdpbusd_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpbsud"] fn tdpbsud(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbsud.internal"] + fn tdpbsud_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpbssd"] fn tdpbssd(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbssd.internal"] + fn tdpbssd_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpfp16ps"] fn tdpfp16ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpfp16ps.internal"] + fn tdpfp16ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tcmmimfp16ps"] fn tcmmimfp16ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tcmmimfp16ps.internal"] + fn tcmmimfp16ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tcmmrlfp16ps"] fn tcmmrlfp16ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tcmmrlfp16ps.internal"] + fn tcmmrlfp16ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpbf8ps"] fn tdpbf8ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbf8ps.internal"] + fn tdpbf8ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpbhf8ps"] fn tdpbhf8ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbhf8ps.internal"] + fn tdpbhf8ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdphbf8ps"] fn tdphbf8ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdphbf8ps.internal"] + fn tdphbf8ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdphf8ps"] fn tdphf8ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdphf8ps.internal"] + fn tdphf8ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tileloaddrs64"] fn tileloaddrs64(dst: i8, base: *const u8, stride: usize); + #[link_name = "llvm.x86.tileloaddrs64.internal"] + fn tileloaddrs64_internal(rows: u16, cols: u16, base: *const u8, stride: u64) -> Tile; + #[link_name = "llvm.x86.tileloaddrst164"] fn tileloaddrst164(dst: i8, base: *const u8, stride: usize); + #[link_name = "llvm.x86.tileloaddrst164.internal"] + fn tileloaddrst164_internal(rows: u16, cols: u16, base: *const u8, stride: u64) -> Tile; + #[link_name = "llvm.x86.tmmultf32ps"] fn tmmultf32ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tmmultf32ps.internal"] + fn tmmultf32ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tcvtrowd2ps"] fn tcvtrowd2ps(tile: i8, row: u32) -> f32x16; #[link_name = "llvm.x86.tcvtrowd2psi"] fn tcvtrowd2psi(tile: i8, row: u32) -> f32x16; + #[link_name = "llvm.x86.tcvtrowd2ps.internal"] + fn tcvtrowd2ps_internal(rows: u16, cols: u16, src: Tile, row: u32) -> f32x16; + #[link_name = "llvm.x86.tcvtrowps2phh"] fn tcvtrowps2phh(tile: i8, row: u32) -> f16x32; #[link_name = "llvm.x86.tcvtrowps2phhi"] fn tcvtrowps2phhi(tile: i8, row: u32) -> f16x32; + #[link_name = "llvm.x86.tcvtrowps2phh.internal"] + fn tcvtrowps2phh_internal(rows: u16, cols: u16, src: Tile, row: u32) -> f16x32; + #[link_name = "llvm.x86.tcvtrowps2phl"] fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32; #[link_name = "llvm.x86.tcvtrowps2phli"] fn tcvtrowps2phli(tile: i8, row: u32) -> f16x32; + #[link_name = "llvm.x86.tcvtrowps2phl.internal"] + fn tcvtrowps2phl_internal(rows: u16, cols: u16, src: Tile, row: u32) -> f16x32; + #[link_name = "llvm.x86.tcvtrowps2bf16h"] fn tcvtrowps2bf16h(tile: i8, row: u32) -> u16x32; #[link_name = "llvm.x86.tcvtrowps2bf16hi"] fn tcvtrowps2bf16hi(tile: i8, row: u32) -> u16x32; + #[link_name = "llvm.x86.tcvtrowps2bf16h.internal"] + fn tcvtrowps2bf16h_internal(rows: u16, cols: u16, src: Tile, row: u32) -> u16x32; + #[link_name = "llvm.x86.tcvtrowps2bf16l"] fn tcvtrowps2bf16l(tile: i8, row: u32) -> u16x32; #[link_name = "llvm.x86.tcvtrowps2bf16li"] fn tcvtrowps2bf16li(tile: i8, row: u32) -> u16x32; + #[link_name = "llvm.x86.tcvtrowps2bf16l.internal"] + fn tcvtrowps2bf16l_internal(rows: u16, cols: u16, src: Tile, row: u32) -> u16x32; + #[link_name = "llvm.x86.tilemovrow"] fn tilemovrow(tile: i8, row: u32) -> i32x16; #[link_name = "llvm.x86.tilemovrowi"] fn tilemovrowi(tile: i8, row: u32) -> i32x16; + #[link_name = "llvm.x86.tilemovrow.internal"] + fn tilemovrow_internal(rows: u16, cols: u16, src: Tile, row: u32) -> i32x16; } #[cfg(test)] mod tests { use crate::core_arch::x86::_mm_cvtness_sbh; use crate::core_arch::x86_64::*; - use core::{array, mem::transmute}; + use core::array; + use core::mem::{MaybeUninit, transmute}; use stdarch_test::simd_test; #[cfg(target_os = "linux")] use syscalls::{Sysno, syscall}; @@ -727,6 +1124,18 @@ mod tests { } } + impl __tile1024i { + #[inline] + #[target_feature(enable = "amx-tile")] + fn zeroed(rows: u16, cols: u16) -> Self { + Self { + rows, + cols, + tile: unsafe { super::tilezero_internal(rows, cols) }, + } + } + } + #[simd_test(enable = "amx-tile")] fn test_tile_loadconfig() { unsafe { @@ -765,6 +1174,20 @@ mod tests { } } + #[simd_test(enable = "amx-tile")] + fn test__tile_zero() { + unsafe { + _init_amx(); + + let tile = __tile1024i::zeroed(16, 64); + + let mut out = [[1_i8; 64]; 16]; + __tile_stored(out.as_mut_ptr().cast(), 64, tile); + + assert_eq!(out, [[0; 64]; 16]); + } + } + #[simd_test(enable = "amx-tile")] fn test_tile_stored() { unsafe { @@ -782,6 +1205,20 @@ mod tests { } } + #[simd_test(enable = "amx-tile")] + fn test__tile_stored() { + unsafe { + _init_amx(); + + let tile = __tile1024i::zeroed(16, 64); + + let mut out = [[1_i8; 64]; 16]; + __tile_stored(out.as_mut_ptr().cast(), 64, tile); + + assert_eq!(out, [[0; 64]; 16]); + } + } + #[simd_test(enable = "amx-tile")] fn test_tile_loadd() { unsafe { @@ -801,6 +1238,22 @@ mod tests { } } + #[simd_test(enable = "amx-tile")] + fn test__tile_loadd() { + unsafe { + _init_amx(); + + let mut tile = __tile1024i::zeroed(16, 64); + + let mat = [1_i8; 1024]; + __tile_loadd(&mut tile, mat.as_ptr().cast(), 64); + let mut out = [[0_i8; 64]; 16]; + __tile_stored(out.as_mut_ptr().cast(), 64, tile); + + assert_eq!(out, [[1; 64]; 16]); + } + } + #[simd_test(enable = "amx-tile")] fn test_tile_stream_loadd() { unsafe { @@ -820,6 +1273,22 @@ mod tests { } } + #[simd_test(enable = "amx-tile")] + fn test__tile_stream_loadd() { + unsafe { + _init_amx(); + + let mut tile = __tile1024i::zeroed(16, 64); + + let mat = [1_i8; 1024]; + __tile_stream_loadd(&mut tile, mat.as_ptr().cast(), 64); + let mut out = [[0_i8; 64]; 16]; + __tile_stored(out.as_mut_ptr().cast(), 64, tile); + + assert_eq!(out, [[1; 64]; 16]); + } + } + #[simd_test(enable = "amx-tile")] fn test_tile_release() { unsafe { @@ -827,14 +1296,15 @@ mod tests { } } - #[simd_test(enable = "amx-bf16,avx512f")] + const BF16_1: u16 = 0x3f80; + const BF16_2: u16 = 0x4000; + + #[simd_test(enable = "amx-bf16")] fn test_tile_dpbf16ps() { unsafe { _init_amx(); - let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits(); - let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits(); - let ones: [u8; 1024] = transmute([bf16_1; 512]); - let twos: [u8; 1024] = transmute([bf16_2; 512]); + let ones: [u8; 1024] = transmute([BF16_1; 512]); + let twos: [u8; 1024] = transmute([BF16_2; 512]); let mut res = [[0f32; 16]; 16]; let mut config = __tilecfg::default(); config.palette = 1; @@ -853,6 +1323,27 @@ mod tests { } } + #[simd_test(enable = "amx-bf16,avx512f")] + fn test__tile_dpbf16ps() { + unsafe { + _init_amx(); + let ones = [BF16_1; 512]; + let twos = [BF16_2; 512]; + let mut res = [[0f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr().cast(), 64); + __tile_loadd(&mut b, twos.as_ptr().cast(), 64); + __tile_dpbf16ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[64f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-int8")] fn test_tile_dpbssd() { unsafe { @@ -877,6 +1368,27 @@ mod tests { } } + #[simd_test(enable = "amx-int8")] + fn test__tile_dpbssd() { + unsafe { + _init_amx(); + let ones = [-1_i8; 1024]; + let twos = [-2_i8; 1024]; + let mut res = [[0_i32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr().cast(), 64); + __tile_loadd(&mut b, twos.as_ptr().cast(), 64); + __tile_dpbssd(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[128_i32; 16]; 16]); + } + } + #[simd_test(enable = "amx-int8")] fn test_tile_dpbsud() { unsafe { @@ -901,6 +1413,27 @@ mod tests { } } + #[simd_test(enable = "amx-int8")] + fn test__tile_dpbsud() { + unsafe { + _init_amx(); + let ones = [-1_i8; 1024]; + let twos = [2_u8; 1024]; + let mut res = [[0_i32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr().cast(), 64); + __tile_loadd(&mut b, twos.as_ptr(), 64); + __tile_dpbsud(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[-128_i32; 16]; 16]); + } + } + #[simd_test(enable = "amx-int8")] fn test_tile_dpbusd() { unsafe { @@ -925,6 +1458,27 @@ mod tests { } } + #[simd_test(enable = "amx-int8")] + fn test__tile_dpbusd() { + unsafe { + _init_amx(); + let ones = [1_u8; 1024]; + let twos = [-2_i8; 1024]; + let mut res = [[0_i32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr(), 64); + __tile_loadd(&mut b, twos.as_ptr().cast(), 64); + __tile_dpbusd(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[-128_i32; 16]; 16]); + } + } + #[simd_test(enable = "amx-int8")] fn test_tile_dpbuud() { unsafe { @@ -949,6 +1503,27 @@ mod tests { } } + #[simd_test(enable = "amx-int8")] + fn test__tile_dpbuud() { + unsafe { + _init_amx(); + let ones = [1_u8; 1024]; + let twos = [2_u8; 1024]; + let mut res = [[0_i32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr(), 64); + __tile_loadd(&mut b, twos.as_ptr(), 64); + __tile_dpbuud(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[128_i32; 16]; 16]); + } + } + #[simd_test(enable = "amx-fp16")] fn test_tile_dpfp16ps() { unsafe { @@ -973,6 +1548,27 @@ mod tests { } } + #[simd_test(enable = "amx-fp16")] + fn test__tile_dpfp16ps() { + unsafe { + _init_amx(); + let ones = [1f16; 512]; + let twos = [2f16; 512]; + let mut res = [[0f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr().cast(), 64); + __tile_loadd(&mut b, twos.as_ptr().cast(), 64); + __tile_dpfp16ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[64f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-complex")] fn test_tile_cmmimfp16ps() { unsafe { @@ -997,6 +1593,27 @@ mod tests { } } + #[simd_test(enable = "amx-complex")] + fn test__tile_cmmimfp16ps() { + unsafe { + _init_amx(); + let ones = [1f16; 512]; + let twos = [2f16; 512]; + let mut res = [[0f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr().cast(), 64); + __tile_loadd(&mut b, twos.as_ptr().cast(), 64); + __tile_cmmimfp16ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[64f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-complex")] fn test_tile_cmmrlfp16ps() { unsafe { @@ -1021,6 +1638,27 @@ mod tests { } } + #[simd_test(enable = "amx-complex")] + fn test__tile_cmmrlfp16ps() { + unsafe { + _init_amx(); + let ones = [1f16; 512]; + let twos = [2f16; 512]; + let mut res = [[0f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr().cast(), 64); + __tile_loadd(&mut b, twos.as_ptr().cast(), 64); + __tile_cmmrlfp16ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[0f32; 16]; 16]); + } + } + const BF8_ONE: u8 = 0x3c; const BF8_TWO: u8 = 0x40; const HF8_ONE: u8 = 0x38; @@ -1050,6 +1688,27 @@ mod tests { } } + #[simd_test(enable = "amx-fp8")] + fn test__tile_dpbf8ps() { + unsafe { + _init_amx(); + let ones = [BF8_ONE; 1024]; + let twos = [BF8_TWO; 1024]; + let mut res = [[0.0_f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr(), 64); + __tile_loadd(&mut b, twos.as_ptr(), 64); + __tile_dpbf8ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[128.0_f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-fp8")] fn test_tile_dpbhf8ps() { unsafe { @@ -1074,6 +1733,27 @@ mod tests { } } + #[simd_test(enable = "amx-fp8")] + fn test__tile_dpbhf8ps() { + unsafe { + _init_amx(); + let ones = [BF8_ONE; 1024]; + let twos = [HF8_TWO; 1024]; + let mut res = [[0.0_f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr(), 64); + __tile_loadd(&mut b, twos.as_ptr(), 64); + __tile_dpbhf8ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[128.0_f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-fp8")] fn test_tile_dphbf8ps() { unsafe { @@ -1098,6 +1778,27 @@ mod tests { } } + #[simd_test(enable = "amx-fp8")] + fn test__tile_dphbf8ps() { + unsafe { + _init_amx(); + let ones = [HF8_ONE; 1024]; + let twos = [BF8_TWO; 1024]; + let mut res = [[0.0_f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr(), 64); + __tile_loadd(&mut b, twos.as_ptr(), 64); + __tile_dphbf8ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[128.0_f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-fp8")] fn test_tile_dphf8ps() { unsafe { @@ -1122,6 +1823,27 @@ mod tests { } } + #[simd_test(enable = "amx-fp8")] + fn test__tile_dphf8ps() { + unsafe { + _init_amx(); + let ones = [HF8_ONE; 1024]; + let twos = [HF8_TWO; 1024]; + let mut res = [[0.0_f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr(), 64); + __tile_loadd(&mut b, twos.as_ptr(), 64); + __tile_dphf8ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[128.0_f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-movrs")] fn test_tile_loaddrs() { unsafe { @@ -1141,6 +1863,22 @@ mod tests { } } + #[simd_test(enable = "amx-movrs")] + fn test__tile_loaddrs() { + unsafe { + _init_amx(); + + let mut tile = __tile1024i::zeroed(16, 64); + + let mat = [1_i8; 1024]; + __tile_loaddrs(&mut tile, mat.as_ptr().cast(), 64); + let mut out = [[0_i8; 64]; 16]; + __tile_stored(out.as_mut_ptr().cast(), 64, tile); + + assert_eq!(out, [[1; 64]; 16]); + } + } + #[simd_test(enable = "amx-movrs")] fn test_tile_stream_loaddrs() { unsafe { @@ -1160,6 +1898,22 @@ mod tests { } } + #[simd_test(enable = "amx-movrs")] + fn test__tile_stream_loaddrs() { + unsafe { + _init_amx(); + + let mut tile = __tile1024i::zeroed(16, 64); + + let mat = [1_i8; 1024]; + __tile_stream_loaddrs(&mut tile, mat.as_ptr().cast(), 64); + let mut out = [[0_i8; 64]; 16]; + __tile_stored(out.as_mut_ptr().cast(), 64, tile); + + assert_eq!(out, [[1; 64]; 16]); + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_movrow() { unsafe { @@ -1223,6 +1977,22 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test__tile_movrow() { + unsafe { + _init_amx(); + let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]); + + let mut tile = __tile1024i::zeroed(16, 64); + __tile_loadd(&mut tile, array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = __tile_movrow(tile, i); + assert_eq!(*row.as_u8x64().as_array(), [i as _; _]); + } + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_cvtrowd2ps() { unsafe { @@ -1262,6 +2032,22 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test__tile_cvtrowd2ps() { + unsafe { + _init_amx(); + let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut tile = __tile1024i::zeroed(16, 64); + __tile_loadd(&mut tile, array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = __tile_cvtrowd2ps(tile, i); + assert_eq!(*row.as_f32x16().as_array(), [i as _; _]); + } + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_cvtrowps2phh() { unsafe { @@ -1306,6 +2092,25 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test__tile_cvtrowps2phh() { + unsafe { + _init_amx(); + let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut tile = __tile1024i::zeroed(16, 64); + __tile_loadd(&mut tile, array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = __tile_cvtrowps2phh(tile, i); + assert_eq!( + *row.as_f16x32().as_array(), + array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ }) + ); + } + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_cvtrowps2phl() { unsafe { @@ -1350,6 +2155,25 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test__tile_cvtrowps2phl() { + unsafe { + _init_amx(); + let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut tile = __tile1024i::zeroed(16, 64); + __tile_loadd(&mut tile, array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = __tile_cvtrowps2phl(tile, i); + assert_eq!( + *row.as_f16x32().as_array(), + array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 }) + ); + } + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_cvtrowps2bf16h() { unsafe { @@ -1402,6 +2226,29 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test__tile_cvtrowps2bf16h() { + unsafe { + _init_amx(); + let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut tile = __tile1024i::zeroed(16, 64); + __tile_loadd(&mut tile, array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = __tile_cvtrowps2bf16h(tile, i); + assert_eq!( + *row.as_u16x32().as_array(), + array::from_fn(|j| if j & 1 == 0 { + 0 + } else { + _mm_cvtness_sbh(i as _).to_bits() + }) + ); + } + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_cvtrowps2bf16l() { unsafe { @@ -1454,6 +2301,29 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test__tile_cvtrowps2bf16l() { + unsafe { + _init_amx(); + let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut tile = __tile1024i::zeroed(16, 64); + __tile_loadd(&mut tile, array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = __tile_cvtrowps2bf16l(tile, i); + assert_eq!( + *row.as_u16x32().as_array(), + array::from_fn(|j| if j & 1 == 0 { + _mm_cvtness_sbh(i as _).to_bits() + } else { + 0 + }) + ); + } + } + } + #[simd_test(enable = "amx-tf32")] fn test_tile_mmultf32ps() { unsafe { @@ -1480,4 +2350,26 @@ mod tests { assert_eq!(res, expected); } } + + #[simd_test(enable = "amx-tf32")] + fn test__tile_mmultf32ps() { + unsafe { + _init_amx(); + let a: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + let b: [[f32; 16]; 16] = [array::from_fn(|j| j as _); _]; + let mut res = [[0.0; 16]; 16]; + + let mut tile_a = __tile1024i::zeroed(16, 64); + let mut tile_b = __tile1024i::zeroed(16, 64); + let mut tile_c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut tile_a, a.as_ptr().cast(), 64); + __tile_loadd(&mut tile_b, b.as_ptr().cast(), 64); + __tile_mmultf32ps(&mut tile_c, tile_a, tile_b); + __tile_stored(res.as_mut_ptr().cast(), 64, tile_c); + + let expected = array::from_fn(|i| array::from_fn(|j| 16.0 * i as f32 * j as f32)); + assert_eq!(res, expected); + } + } } diff --git a/crates/core_arch/src/x86_64/mod.rs b/crates/core_arch/src/x86_64/mod.rs index 46384176e0..3309c0fd9b 100644 --- a/crates/core_arch/src/x86_64/mod.rs +++ b/crates/core_arch/src/x86_64/mod.rs @@ -3,6 +3,20 @@ #[macro_use] mod macros; +// Any 1024-byte vector should work +type Tile = crate::core_arch::simd::Simd; + +/// A tile register, used by AMX instructions. +// TODO: add more docs +#[derive(Copy, Clone, Debug)] +#[allow(non_camel_case_types)] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub struct __tile1024i { + rows: u16, + cols: u16, + tile: Tile, +} + mod fxsr; #[stable(feature = "simd_x86", since = "1.27.0")] pub use self::fxsr::*; diff --git a/crates/stdarch-test/src/lib.rs b/crates/stdarch-test/src/lib.rs index ecaf95f617..c468ebd12b 100644 --- a/crates/stdarch-test/src/lib.rs +++ b/crates/stdarch-test/src/lib.rs @@ -172,6 +172,10 @@ pub fn assert(shim_addr: usize, fnname: &str, expected: &str) { // vst1q_p64_x4_nop : #instructions = 33 >= 22 (limit) "nop" if fnname.contains("vst1q_p64") => 34, + // AMX intrinsics generate a lot of move instructions to load/store the tile registers + // due to Rust ABI + _ if fnname.contains("___tile") => 165, + // Original limit was 20 instructions, but ARM DSP Intrinsics // are exactly 20 instructions long. So, bump the limit to 22 // instead of adding here a long list of exceptions. diff --git a/crates/stdarch-verify/src/lib.rs b/crates/stdarch-verify/src/lib.rs index f7304ab326..5412ab466a 100644 --- a/crates/stdarch-verify/src/lib.rs +++ b/crates/stdarch-verify/src/lib.rs @@ -202,6 +202,7 @@ fn to_type(t: &syn::Type) -> proc_macro2::TokenStream { "_MM_MANTISSA_NORM_ENUM" => quote! { &MM_MANTISSA_NORM_ENUM }, "_MM_MANTISSA_SIGN_ENUM" => quote! { &MM_MANTISSA_SIGN_ENUM }, "_MM_PERM_ENUM" => quote! { &MM_PERM_ENUM }, + "__tile1024i" => quote! { &TILE1024I }, "bool" => quote! { &BOOL }, "bf16" => quote! { &BF16 }, "f16" => quote! { &F16 }, diff --git a/crates/stdarch-verify/tests/x86-intel.rs b/crates/stdarch-verify/tests/x86-intel.rs index 024a873de1..aad19ca55a 100644 --- a/crates/stdarch-verify/tests/x86-intel.rs +++ b/crates/stdarch-verify/tests/x86-intel.rs @@ -62,6 +62,7 @@ static MM_CMPINT_ENUM: Type = Type::MM_CMPINT_ENUM; static MM_MANTISSA_NORM_ENUM: Type = Type::MM_MANTISSA_NORM_ENUM; static MM_MANTISSA_SIGN_ENUM: Type = Type::MM_MANTISSA_SIGN_ENUM; static MM_PERM_ENUM: Type = Type::MM_PERM_ENUM; +static TILE1024I: Type = Type::TILE1024I; static TUPLE: Type = Type::Tuple; static CPUID: Type = Type::CpuidResult; @@ -102,6 +103,7 @@ enum Type { CpuidResult, Never, Ordering, + TILE1024I, } stdarch_verify::x86_functions!(static FUNCTIONS); @@ -774,6 +776,7 @@ fn equate( (&Type::MMASK32, "__mmask32") => {} (&Type::MMASK16, "__mmask16") => {} (&Type::MMASK8, "__mmask8") => {} + (&Type::TILE1024I, "__tile1024i") => {} (&Type::MutPtr(_type), "void*") | (&Type::ConstPtr(_type), "void const*") => { let pointed_type = pointed_type(intrinsic)?; @@ -812,6 +815,7 @@ fn equate( (&Type::MutPtr(&Type::M512BH), "__m512bh*") => {} (&Type::MutPtr(&Type::M512I), "__m512i*") => {} (&Type::MutPtr(&Type::M512D), "__m512d*") => {} + (&Type::MutPtr(&Type::TILE1024I), "__tile1024i*") => {} (&Type::ConstPtr(&Type::PrimFloat(16)), "_Float16 const*") => {} (&Type::ConstPtr(&Type::PrimFloat(32)), "float const*") => {} From 5cbce6063a482183e50c1288d7bb44248a6dda11 Mon Sep 17 00:00:00 2001 From: sayantn Date: Wed, 22 Apr 2026 05:19:45 +0530 Subject: [PATCH 2/2] some doc fixes and tidying up --- crates/core_arch/src/x86_64/amx.rs | 210 ++++++++++++++++++----------- crates/core_arch/src/x86_64/mod.rs | 4 +- 2 files changed, 130 insertions(+), 84 deletions(-) diff --git a/crates/core_arch/src/x86_64/amx.rs b/crates/core_arch/src/x86_64/amx.rs index 5fc47c5cfd..7d693ec1a5 100644 --- a/crates/core_arch/src/x86_64/amx.rs +++ b/crates/core_arch/src/x86_64/amx.rs @@ -4,12 +4,31 @@ use crate::core_arch::{simd::*, x86::*}; #[cfg(test)] use stdarch_test::assert_instr; -/// Load tile configuration from a 64-byte memory location specified by mem_addr. +/// Load tile configuration from a 64-byte memory location specified by `mem_addr`. /// The tile configuration format is specified below, and includes the tile type pallette, /// the number of bytes per row, and the number of rows. If the specified pallette_id is zero, /// that signifies the init state for both the tile config and the tile data, and the tiles are zeroed. /// Any invalid configurations will result in #GP fault. /// +/// ```intel +/// // format of memory payload. each field is a byte. +/// 0: palette +/// 1: start_row +/// 2-15: reserved, must be zero +/// 16-17: tile0.colsb +/// 18-19: tile1.colsb +/// 20-21: tile2.colsb +/// ... +/// 30-31: tile7.colsb +/// 32-47: reserved, must be zero +/// 48: tile0.rows +/// 49: tile1.rows +/// 50: tile2.rows +/// ... +/// 55: tile7.rows +/// 56-63: reserved, must be zero +/// ``` +/// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadconfig&ig_expand=6875) #[inline] #[target_feature(enable = "amx-tile")] @@ -19,8 +38,8 @@ pub unsafe fn _tile_loadconfig(mem_addr: *const u8) { ldtilecfg(mem_addr); } -/// Stores the current tile configuration to a 64-byte memory location specified by mem_addr. -/// The tile configuration format is specified below, and includes the tile type pallette, +/// Stores the current tile configuration to a 64-byte memory location specified by `mem_addr`. +/// The tile configuration format is as specified in [`_tile_loadconfig`], and includes the tile type pallette, /// the number of bytes per row, and the number of rows. If tiles are not configured, all zeroes will be stored to memory. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_storeconfig&ig_expand=6879) @@ -32,7 +51,7 @@ pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) { sttilecfg(mem_addr); } -/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration previously configured via _tile_loadconfig. +/// Load tile rows from memory specified by base address and stride into destination tile dst using the tile configuration previously configured via [`_tile_loadconfig`]. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadd&ig_expand=6877) #[inline] @@ -42,10 +61,11 @@ pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) { #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_loadd(base: *const u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tileloadd64(DST as i8, base, stride); + tileloadd64(DST as i8, base, stride as u64); } -/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration previously configured via _tile_loadconfig. +/// Load tile rows from memory specified by base address and stride into destination tile dst. The shape +/// of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_loadd&ig_expand=6877) #[inline] @@ -67,7 +87,7 @@ pub unsafe fn _tile_release() { tilerelease(); } -/// Store the tile specified by src to memory specifieid by base address and stride using the tile configuration previously configured via _tile_loadconfig. +/// Store the tile specified by src to memory specified by base address and stride using the tile configuration previously configured via [`_tile_loadconfig`]. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stored&ig_expand=6881) #[inline] @@ -77,10 +97,11 @@ pub unsafe fn _tile_release() { #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_stored(base: *mut u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tilestored64(DST as i8, base, stride); + tilestored64(DST as i8, base, stride as u64); } -/// Store the tile specified by src to memory specifieid by base address and stride using the tile configuration previously configured via _tile_loadconfig. +/// Store the tile specified by src to memory specified by base address and stride. The shape of the tile +/// is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_stored&ig_expand=6881) #[inline] @@ -91,8 +112,8 @@ pub unsafe fn __tile_stored(base: *mut u8, stride: usize, src: __tile1024i) { tilestored64_internal(src.rows, src.cols, base, stride as u64, src.tile); } -/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration -/// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will +/// Load tile rows from memory specified by base address and stride into destination tile dst using the tile configuration +/// previously configured via [`_tile_loadconfig`]. This intrinsic provides a hint to the implementation that the data will /// likely not be reused in the near future and the data caching can be optimized accordingly. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stream_loadd&ig_expand=6883) @@ -103,12 +124,13 @@ pub unsafe fn __tile_stored(base: *mut u8, stride: usize, src: __tile1024i) { #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_stream_loadd(base: *const u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tileloaddt164(DST as i8, base, stride); + tileloaddt164(DST as i8, base, stride as u64); } -/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration -/// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will -/// likely not be reused in the near future and the data caching can be optimized accordingly. +/// Load tile rows from memory specified by base address and stride into destination tile dst. The shape +/// of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// This intrinsic provides a hint to the implementation that the data will likely not be reused in the +/// near future and the data caching can be optimized accordingly. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_stream_loadd&ig_expand=6883) #[inline] @@ -119,7 +141,7 @@ pub unsafe fn __tile_stream_loadd(dst: *mut __tile1024i, base: *const u8, stride (*dst).tile = tileloaddt164_internal((*dst).rows, (*dst).cols, base, stride as u64); } -/// Zero the tile specified by tdest. +/// Zero the tile specified by `tdest`. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_zero&ig_expand=6885) #[inline] @@ -132,7 +154,8 @@ pub unsafe fn _tile_zero() { tilezero(DST as i8); } -/// Zero the tile specified by dst. +/// Zero the tile specified by `dst`. The shape of the tile is specified in the struct of [`__tile1024i`]. +/// The register of the tile is allocated by the compiler. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_zero&ig_expand=6885) #[inline] @@ -162,7 +185,8 @@ pub unsafe fn _tile_dpbf16ps() { /// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b, /// accumulating the intermediate single-precision (32-bit) floating-point elements -/// with elements in dst, and store the 32-bit result back to tile dst. +/// with elements in dst, and store the 32-bit result back to tile dst. The shape of the tile +/// is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbf16ps&ig_expand=6864) #[inline] @@ -195,6 +219,7 @@ pub unsafe fn _tile_dpbssd() { /// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding /// signed 8-bit integers in b, producing 4 intermediate 32-bit results. /// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbssd&ig_expand=6866) #[inline] @@ -227,6 +252,7 @@ pub unsafe fn _tile_dpbsud() { /// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding /// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. /// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbsud&ig_expand=6868) #[inline] @@ -259,6 +285,7 @@ pub unsafe fn _tile_dpbusd() { /// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding /// signed 8-bit integers in b, producing 4 intermediate 32-bit results. /// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbusd&ig_expand=6870) #[inline] @@ -291,6 +318,7 @@ pub unsafe fn _tile_dpbuud() { /// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding /// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. /// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbuud&ig_expand=6872) #[inline] @@ -321,6 +349,7 @@ pub unsafe fn _tile_dpfp16ps() { /// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b, /// accumulating the intermediate single-precision (32-bit) floating-point elements /// with elements in dst, and store the 32-bit result back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpfp16ps&ig_expand=6874) #[inline] @@ -359,6 +388,7 @@ pub unsafe fn _tile_cmmimfp16ps() { /// The imaginary part of the a element is multiplied with the real part of the corresponding b element, and the real part of /// the a element is multiplied with the imaginary part of the corresponding b elements. The two accumulated results are added, /// and then accumulated into the corresponding row and column of dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_cmmimfp16ps&ig_expand=6860) #[inline] @@ -397,6 +427,7 @@ pub unsafe fn _tile_cmmrlfp16ps() { /// The real part of the a element is multiplied with the real part of the corresponding b element, and the negated imaginary part of /// the a element is multiplied with the imaginary part of the corresponding b elements. /// The two accumulated results are added, and then accumulated into the corresponding row and column of dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_cmmrlfp16ps&ig_expand=6862) #[inline] @@ -430,6 +461,7 @@ pub unsafe fn _tile_dpbf8ps() { /// floating-point elements in tile b, accumulating the intermediate single-precision /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result /// back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. #[inline] #[target_feature(enable = "amx-fp8")] #[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdpbf8ps))] @@ -461,6 +493,7 @@ pub unsafe fn _tile_dpbhf8ps() { /// (8-bit E4M3) floating-point elements in tile b, accumulating the intermediate single-precision /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result /// back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. #[inline] #[target_feature(enable = "amx-fp8")] #[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdpbhf8ps))] @@ -492,6 +525,7 @@ pub unsafe fn _tile_dphbf8ps() { /// (8-bit E5M2) floating-point elements in tile b, accumulating the intermediate single-precision /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result /// back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. #[inline] #[target_feature(enable = "amx-fp8")] #[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdphbf8ps))] @@ -523,6 +557,7 @@ pub unsafe fn _tile_dphf8ps() { /// floating-point elements in tile b, accumulating the intermediate single-precision /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result /// back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. #[inline] #[target_feature(enable = "amx-fp8")] #[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdphf8ps))] @@ -532,7 +567,7 @@ pub unsafe fn __tile_dphf8ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile102 } /// Load tile rows from memory specified by base address and stride into destination tile dst -/// using the tile configuration previously configured via _tile_loadconfig. +/// using the tile configuration previously configured via [`_tile_loadconfig`]. /// Additionally, this intrinsic indicates the source memory location is likely to become /// read-shared by multiple processors, i.e., read in the future by at least one other processor /// before it is written, assuming it is ever written in the future. @@ -546,11 +581,11 @@ pub unsafe fn __tile_dphf8ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile102 #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_loaddrs(base: *const u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tileloaddrs64(DST as i8, base, stride); + tileloaddrs64(DST as i8, base, stride as u64); } -/// Load tile rows from memory specified by base address and stride into destination tile dst -/// using the tile configuration previously configured via _tile_loadconfig. +/// Load tile rows from memory specified by base address and stride into destination tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. /// Additionally, this intrinsic indicates the source memory location is likely to become /// read-shared by multiple processors, i.e., read in the future by at least one other processor /// before it is written, assuming it is ever written in the future. @@ -563,7 +598,7 @@ pub unsafe fn __tile_loaddrs(dst: *mut __tile1024i, base: *const u8, stride: usi } /// Load tile rows from memory specified by base address and stride into destination tile dst -/// using the tile configuration previously configured via _tile_loadconfig. +/// using the tile configuration previously configured via [`_tile_loadconfig`]. /// Provides a hint to the implementation that the data would be reused but does not need /// to be resident in the nearest cache levels. /// Additionally, this intrinsic indicates the source memory location is likely to become @@ -579,11 +614,11 @@ pub unsafe fn __tile_loaddrs(dst: *mut __tile1024i, base: *const u8, stride: usi #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_stream_loaddrs(base: *const u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tileloaddrst164(DST as i8, base, stride); + tileloaddrst164(DST as i8, base, stride as u64); } -/// Load tile rows from memory specified by base address and stride into destination tile dst -/// using the tile configuration previously configured via _tile_loadconfig. +/// Load tile rows from memory specified by base address and stride into destination tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. /// Provides a hint to the implementation that the data would be reused but does not need /// to be resident in the nearest cache levels. /// Additionally, this intrinsic indicates the source memory location is likely to become @@ -632,6 +667,7 @@ pub unsafe fn _tile_mmultf32ps() { /// rounding mode. /// Output FP32 denormals are always flushed to zero, input single precision denormals are always /// handled and *not* treated as zero. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. #[inline] #[target_feature(enable = "amx-tf32")] #[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tmmultf32ps))] @@ -673,6 +709,7 @@ pub unsafe fn _tile_cvtrowd2psi() -> __m512 { /// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer /// elements to packed single-precision (32-bit) floating-point elements. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. #[inline] #[target_feature(enable = "amx-avx512,avx10.2")] #[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowd2ps))] @@ -717,6 +754,7 @@ pub unsafe fn _tile_cvtrowps2phhi() -> __m512h /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) /// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting /// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. #[inline] #[target_feature(enable = "amx-avx512,avx10.2")] #[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2phh))] @@ -761,6 +799,7 @@ pub unsafe fn _tile_cvtrowps2phli() -> __m512h /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) /// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting /// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. #[inline] #[target_feature(enable = "amx-avx512,avx10.2")] #[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2phl))] @@ -805,6 +844,7 @@ pub unsafe fn _tile_cvtrowps2bf16hi() -> __m512 /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) /// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting /// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. #[inline] #[target_feature(enable = "amx-avx512,avx10.2")] #[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2bf16h))] @@ -849,6 +889,7 @@ pub unsafe fn _tile_cvtrowps2bf16li() -> __m512 /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) /// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting /// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. #[inline] #[target_feature(enable = "amx-avx512,avx10.2")] #[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2bf16l))] @@ -887,6 +928,7 @@ pub unsafe fn _tile_movrowi() -> __m512i { } /// Moves one row of tile data into a zmm vector register +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. #[inline] #[target_feature(enable = "amx-avx512,avx10.2")] #[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tilemovrow))] @@ -903,20 +945,20 @@ unsafe extern "unadjusted" { fn sttilecfg(mem_addr: *mut u8); #[link_name = "llvm.x86.tileloadd64"] - fn tileloadd64(dst: i8, base: *const u8, stride: usize); + fn tileloadd64(dst: i8, base: *const u8, stride: u64); #[link_name = "llvm.x86.tileloadd64.internal"] fn tileloadd64_internal(rows: u16, cols: u16, base: *const u8, stride: u64) -> Tile; #[link_name = "llvm.x86.tileloaddt164"] - fn tileloaddt164(dst: i8, base: *const u8, stride: usize); + fn tileloaddt164(dst: i8, base: *const u8, stride: u64); #[link_name = "llvm.x86.tileloaddt164.internal"] fn tileloaddt164_internal(rows: u16, cols: u16, base: *const u8, stride: u64) -> Tile; - + #[link_name = "llvm.x86.tilerelease"] fn tilerelease(); #[link_name = "llvm.x86.tilestored64"] - fn tilestored64(dst: i8, base: *mut u8, stride: usize); + fn tilestored64(dst: i8, base: *mut u8, stride: u64); #[link_name = "llvm.x86.tilestored64.internal"] fn tilestored64_internal(rows: u16, cols: u16, base: *mut u8, stride: u64, src: Tile); @@ -986,12 +1028,12 @@ unsafe extern "unadjusted" { fn tdphf8ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; #[link_name = "llvm.x86.tileloaddrs64"] - fn tileloaddrs64(dst: i8, base: *const u8, stride: usize); + fn tileloaddrs64(dst: i8, base: *const u8, stride: u64); #[link_name = "llvm.x86.tileloaddrs64.internal"] fn tileloaddrs64_internal(rows: u16, cols: u16, base: *const u8, stride: u64) -> Tile; #[link_name = "llvm.x86.tileloaddrst164"] - fn tileloaddrst164(dst: i8, base: *const u8, stride: usize); + fn tileloaddrst164(dst: i8, base: *const u8, stride: u64); #[link_name = "llvm.x86.tileloaddrst164.internal"] fn tileloaddrst164_internal(rows: u16, cols: u16, base: *const u8, stride: u64) -> Tile; @@ -1048,7 +1090,7 @@ mod tests { use crate::core_arch::x86::_mm_cvtness_sbh; use crate::core_arch::x86_64::*; use core::array; - use core::mem::{MaybeUninit, transmute}; + use core::mem::MaybeUninit; use stdarch_test::simd_test; #[cfg(target_os = "linux")] use syscalls::{Sysno, syscall}; @@ -1101,19 +1143,23 @@ mod tests { #[cfg(target_os = "linux")] #[target_feature(enable = "amx-tile")] #[inline] - unsafe fn _init_amx() { + fn _init_amx() { let mut ret: usize; let mut xfeatures: usize = 0; - ret = syscall!(Sysno::arch_prctl, 0x1022, &mut xfeatures as *mut usize) - .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed"); + ret = unsafe { + syscall!(Sysno::arch_prctl, 0x1022, &raw mut xfeatures) + .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed") + }; if ret != 0 { panic!("Failed to get XFEATURES"); } else { match 0b11 & (xfeatures >> 17) { 0 => panic!("AMX is not available"), 1 => { - ret = syscall!(Sysno::arch_prctl, 0x1023, 18) - .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed"); + ret = unsafe { + syscall!(Sysno::arch_prctl, 0x1023, 18) + .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed") + }; if ret != 0 { panic!("Failed to enable AMX"); } @@ -1168,7 +1214,7 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mut out = [[1_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[0; 64]; 16]); } @@ -1199,7 +1245,7 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mut out = [[1_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[0; 64]; 16]); } @@ -1230,9 +1276,9 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mat = [1_i8; 1024]; - _tile_loadd::<0>(&mat as *const i8 as *const u8, 64); + _tile_loadd::<0>(mat.as_ptr().cast(), 64); let mut out = [[0_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[1; 64]; 16]); } @@ -1265,9 +1311,9 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mat = [1_i8; 1024]; - _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64); + _tile_stream_loadd::<0>(mat.as_ptr().cast(), 64); let mut out = [[0_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[1; 64]; 16]); } @@ -1303,8 +1349,8 @@ mod tests { fn test_tile_dpbf16ps() { unsafe { _init_amx(); - let ones: [u8; 1024] = transmute([BF16_1; 512]); - let twos: [u8; 1024] = transmute([BF16_2; 512]); + let ones = [BF16_1; 512]; + let twos = [BF16_2; 512]; let mut res = [[0f32; 16]; 16]; let mut config = __tilecfg::default(); config.palette = 1; @@ -1314,10 +1360,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_dpbf16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[64f32; 16]; 16]); } @@ -1359,10 +1405,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const i8 as *const u8, 64); - _tile_loadd::<2>(&twos as *const i8 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_dpbssd::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[128_i32; 16]; 16]); } @@ -1404,10 +1450,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const i8 as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dpbsud::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[-128_i32; 16]; 16]); } @@ -1449,10 +1495,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const i8 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_dpbusd::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[-128_i32; 16]; 16]); } @@ -1494,10 +1540,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dpbuud::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[128_i32; 16]; 16]); } @@ -1539,10 +1585,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); - _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_dpfp16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[64f32; 16]; 16]); } @@ -1584,10 +1630,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); - _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_cmmimfp16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[64f32; 16]; 16]); } @@ -1629,10 +1675,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); - _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_cmmrlfp16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[0f32; 16]; 16]); } @@ -1679,8 +1725,8 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dpbf8ps::<0, 1, 2>(); _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); @@ -1724,8 +1770,8 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dpbhf8ps::<0, 1, 2>(); _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); @@ -1769,8 +1815,8 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dphbf8ps::<0, 1, 2>(); _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); @@ -1814,8 +1860,8 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dphf8ps::<0, 1, 2>(); _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); @@ -1855,9 +1901,9 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mat = [1_i8; 1024]; - _tile_loaddrs::<0>(&mat as *const i8 as *const u8, 64); + _tile_loaddrs::<0>(mat.as_ptr().cast(), 64); let mut out = [[0_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[1; 64]; 16]); } @@ -1890,9 +1936,9 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mat = [1_i8; 1024]; - _tile_stream_loaddrs::<0>(&mat as *const i8 as *const u8, 64); + _tile_stream_loaddrs::<0>(mat.as_ptr().cast(), 64); let mut out = [[0_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[1; 64]; 16]); } diff --git a/crates/core_arch/src/x86_64/mod.rs b/crates/core_arch/src/x86_64/mod.rs index 3309c0fd9b..ffc2daaefa 100644 --- a/crates/core_arch/src/x86_64/mod.rs +++ b/crates/core_arch/src/x86_64/mod.rs @@ -12,8 +12,8 @@ type Tile = crate::core_arch::simd::Simd; #[allow(non_camel_case_types)] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub struct __tile1024i { - rows: u16, - cols: u16, + pub rows: u16, + pub cols: u16, tile: Tile, }