["AVX512_FP16"]
* [ ] [`_mm256_set1_pch`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_set1_pch)
diff --git a/crates/core_arch/src/x86_64/amx.rs b/crates/core_arch/src/x86_64/amx.rs
index b3b3e86750..7d693ec1a5 100644
--- a/crates/core_arch/src/x86_64/amx.rs
+++ b/crates/core_arch/src/x86_64/amx.rs
@@ -1,14 +1,34 @@
+use crate::core_arch::x86_64::{__tile1024i, Tile};
use crate::core_arch::{simd::*, x86::*};
#[cfg(test)]
use stdarch_test::assert_instr;
-/// Load tile configuration from a 64-byte memory location specified by mem_addr.
+/// Load tile configuration from a 64-byte memory location specified by `mem_addr`.
/// The tile configuration format is specified below, and includes the tile type pallette,
/// the number of bytes per row, and the number of rows. If the specified pallette_id is zero,
/// that signifies the init state for both the tile config and the tile data, and the tiles are zeroed.
/// Any invalid configurations will result in #GP fault.
///
+/// ```intel
+/// // format of memory payload. each field is a byte.
+/// 0: palette
+/// 1: start_row
+/// 2-15: reserved, must be zero
+/// 16-17: tile0.colsb
+/// 18-19: tile1.colsb
+/// 20-21: tile2.colsb
+/// ...
+/// 30-31: tile7.colsb
+/// 32-47: reserved, must be zero
+/// 48: tile0.rows
+/// 49: tile1.rows
+/// 50: tile2.rows
+/// ...
+/// 55: tile7.rows
+/// 56-63: reserved, must be zero
+/// ```
+///
/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadconfig&ig_expand=6875)
#[inline]
#[target_feature(enable = "amx-tile")]
@@ -18,8 +38,8 @@ pub unsafe fn _tile_loadconfig(mem_addr: *const u8) {
ldtilecfg(mem_addr);
}
-/// Stores the current tile configuration to a 64-byte memory location specified by mem_addr.
-/// The tile configuration format is specified below, and includes the tile type pallette,
+/// Stores the current tile configuration to a 64-byte memory location specified by `mem_addr`.
+/// The tile configuration format is as specified in [`_tile_loadconfig`], and includes the tile type pallette,
/// the number of bytes per row, and the number of rows. If tiles are not configured, all zeroes will be stored to memory.
///
/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_storeconfig&ig_expand=6879)
@@ -31,7 +51,7 @@ pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) {
sttilecfg(mem_addr);
}
-/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration previously configured via _tile_loadconfig.
+/// Load tile rows from memory specified by base address and stride into destination tile dst using the tile configuration previously configured via [`_tile_loadconfig`].
///
/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadd&ig_expand=6877)
#[inline]
@@ -41,7 +61,19 @@ pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) {
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_loadd(base: *const u8, stride: usize) {
static_assert_uimm_bits!(DST, 3);
- tileloadd64(DST as i8, base, stride);
+ tileloadd64(DST as i8, base, stride as u64);
+}
+
+/// Load tile rows from memory specified by base address and stride into destination tile dst. The shape
+/// of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+///
+/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_loadd&ig_expand=6877)
+#[inline]
+#[target_feature(enable = "amx-tile")]
+#[cfg_attr(test, assert_instr(tileloadd))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_loadd(dst: *mut __tile1024i, base: *const u8, stride: usize) {
+ (*dst).tile = tileloadd64_internal((*dst).rows, (*dst).cols, base, stride as u64);
}
/// Release the tile configuration to return to the init state, which releases all storage it currently holds.
@@ -55,7 +87,7 @@ pub unsafe fn _tile_release() {
tilerelease();
}
-/// Store the tile specified by src to memory specifieid by base address and stride using the tile configuration previously configured via _tile_loadconfig.
+/// Store the tile specified by src to memory specified by base address and stride using the tile configuration previously configured via [`_tile_loadconfig`].
///
/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stored&ig_expand=6881)
#[inline]
@@ -65,11 +97,23 @@ pub unsafe fn _tile_release() {
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_stored(base: *mut u8, stride: usize) {
static_assert_uimm_bits!(DST, 3);
- tilestored64(DST as i8, base, stride);
+ tilestored64(DST as i8, base, stride as u64);
}
-/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration
-/// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will
+/// Store the tile specified by src to memory specified by base address and stride. The shape of the tile
+/// is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+///
+/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_stored&ig_expand=6881)
+#[inline]
+#[target_feature(enable = "amx-tile")]
+#[cfg_attr(test, assert_instr(tilestored))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_stored(base: *mut u8, stride: usize, src: __tile1024i) {
+ tilestored64_internal(src.rows, src.cols, base, stride as u64, src.tile);
+}
+
+/// Load tile rows from memory specified by base address and stride into destination tile dst using the tile configuration
+/// previously configured via [`_tile_loadconfig`]. This intrinsic provides a hint to the implementation that the data will
/// likely not be reused in the near future and the data caching can be optimized accordingly.
///
/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stream_loadd&ig_expand=6883)
@@ -80,10 +124,24 @@ pub unsafe fn _tile_stored(base: *mut u8, stride: usize) {
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_stream_loadd(base: *const u8, stride: usize) {
static_assert_uimm_bits!(DST, 3);
- tileloaddt164(DST as i8, base, stride);
+ tileloaddt164(DST as i8, base, stride as u64);
+}
+
+/// Load tile rows from memory specified by base address and stride into destination tile dst. The shape
+/// of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+/// This intrinsic provides a hint to the implementation that the data will likely not be reused in the
+/// near future and the data caching can be optimized accordingly.
+///
+/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_stream_loadd&ig_expand=6883)
+#[inline]
+#[target_feature(enable = "amx-tile")]
+#[cfg_attr(test, assert_instr(tileloaddt1))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_stream_loadd(dst: *mut __tile1024i, base: *const u8, stride: usize) {
+ (*dst).tile = tileloaddt164_internal((*dst).rows, (*dst).cols, base, stride as u64);
}
-/// Zero the tile specified by tdest.
+/// Zero the tile specified by `tdest`.
///
/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_zero&ig_expand=6885)
#[inline]
@@ -96,6 +154,18 @@ pub unsafe fn _tile_zero() {
tilezero(DST as i8);
}
+/// Zero the tile specified by `dst`. The shape of the tile is specified in the struct of [`__tile1024i`].
+/// The register of the tile is allocated by the compiler.
+///
+/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_zero&ig_expand=6885)
+#[inline]
+#[target_feature(enable = "amx-tile")]
+#[cfg_attr(test, assert_instr(tilezero))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_zero(dst: *mut __tile1024i) {
+ (*dst).tile = tilezero_internal((*dst).rows, (*dst).cols);
+}
+
/// Compute dot-product of BF16 (16-bit) floating-point pairs in tiles a and b,
/// accumulating the intermediate single-precision (32-bit) floating-point elements
/// with elements in dst, and store the 32-bit result back to tile dst.
@@ -113,6 +183,20 @@ pub unsafe fn _tile_dpbf16ps() {
tdpbf16ps(DST as i8, A as i8, B as i8);
}
+/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b,
+/// accumulating the intermediate single-precision (32-bit) floating-point elements
+/// with elements in dst, and store the 32-bit result back to tile dst. The shape of the tile
+/// is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+///
+/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbf16ps&ig_expand=6864)
+#[inline]
+#[target_feature(enable = "amx-bf16")]
+#[cfg_attr(test, assert_instr(tdpbf16ps))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_dpbf16ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) {
+ (*dst).tile = tdpbf16ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile);
+}
+
/// Compute dot-product of bytes in tiles with a source/destination accumulator.
/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding
/// signed 8-bit integers in b, producing 4 intermediate 32-bit results.
@@ -131,6 +215,21 @@ pub unsafe fn _tile_dpbssd() {
tdpbssd(DST as i8, A as i8, B as i8);
}
+/// Compute dot-product of bytes in tiles with a source/destination accumulator.
+/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding
+/// signed 8-bit integers in b, producing 4 intermediate 32-bit results.
+/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+///
+/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbssd&ig_expand=6866)
+#[inline]
+#[target_feature(enable = "amx-int8")]
+#[cfg_attr(test, assert_instr(tdpbssd))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_dpbssd(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) {
+ (*dst).tile = tdpbssd_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile);
+}
+
/// Compute dot-product of bytes in tiles with a source/destination accumulator.
/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding
/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results.
@@ -149,6 +248,21 @@ pub unsafe fn _tile_dpbsud() {
tdpbsud(DST as i8, A as i8, B as i8);
}
+/// Compute dot-product of bytes in tiles with a source/destination accumulator.
+/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding
+/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results.
+/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+///
+/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbsud&ig_expand=6868)
+#[inline]
+#[target_feature(enable = "amx-int8")]
+#[cfg_attr(test, assert_instr(tdpbsud))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_dpbsud(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) {
+ (*dst).tile = tdpbsud_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile);
+}
+
/// Compute dot-product of bytes in tiles with a source/destination accumulator.
/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding
/// signed 8-bit integers in b, producing 4 intermediate 32-bit results.
@@ -167,6 +281,21 @@ pub unsafe fn _tile_dpbusd() {
tdpbusd(DST as i8, A as i8, B as i8);
}
+/// Compute dot-product of bytes in tiles with a source/destination accumulator.
+/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding
+/// signed 8-bit integers in b, producing 4 intermediate 32-bit results.
+/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+///
+/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbusd&ig_expand=6870)
+#[inline]
+#[target_feature(enable = "amx-int8")]
+#[cfg_attr(test, assert_instr(tdpbusd))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_dpbusd(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) {
+ (*dst).tile = tdpbusd_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile);
+}
+
/// Compute dot-product of bytes in tiles with a source/destination accumulator.
/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding
/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results.
@@ -185,6 +314,21 @@ pub unsafe fn _tile_dpbuud() {
tdpbuud(DST as i8, A as i8, B as i8);
}
+/// Compute dot-product of bytes in tiles with a source/destination accumulator.
+/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding
+/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results.
+/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+///
+/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbuud&ig_expand=6872)
+#[inline]
+#[target_feature(enable = "amx-int8")]
+#[cfg_attr(test, assert_instr(tdpbuud))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_dpbuud(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) {
+ (*dst).tile = tdpbuud_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile);
+}
+
/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b,
/// accumulating the intermediate single-precision (32-bit) floating-point elements
/// with elements in dst, and store the 32-bit result back to tile dst.
@@ -202,6 +346,20 @@ pub unsafe fn _tile_dpfp16ps() {
tdpfp16ps(DST as i8, A as i8, B as i8);
}
+/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b,
+/// accumulating the intermediate single-precision (32-bit) floating-point elements
+/// with elements in dst, and store the 32-bit result back to tile dst.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+///
+/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpfp16ps&ig_expand=6874)
+#[inline]
+#[target_feature(enable = "amx-fp16")]
+#[cfg_attr(test, assert_instr(tdpfp16ps))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_dpfp16ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) {
+ (*dst).tile = tdpfp16ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile);
+}
+
/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile.
/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part.
/// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b),
@@ -223,6 +381,24 @@ pub unsafe fn _tile_cmmimfp16ps() {
tcmmimfp16ps(DST as i8, A as i8, B as i8);
}
+/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile.
+/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part.
+/// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b),
+/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b).
+/// The imaginary part of the a element is multiplied with the real part of the corresponding b element, and the real part of
+/// the a element is multiplied with the imaginary part of the corresponding b elements. The two accumulated results are added,
+/// and then accumulated into the corresponding row and column of dst.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+///
+/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_cmmimfp16ps&ig_expand=6860)
+#[inline]
+#[target_feature(enable = "amx-complex")]
+#[cfg_attr(test, assert_instr(tcmmimfp16ps))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_cmmimfp16ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) {
+ (*dst).tile = tcmmimfp16ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile);
+}
+
/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile.
/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part.
/// Calculates the real part of the result. For each possible combination of (row of a, column of b),
@@ -244,6 +420,24 @@ pub unsafe fn _tile_cmmrlfp16ps() {
tcmmrlfp16ps(DST as i8, A as i8, B as i8);
}
+/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile.
+/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part.
+/// Calculates the real part of the result. For each possible combination of (row of a, column of b),
+/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b).
+/// The real part of the a element is multiplied with the real part of the corresponding b element, and the negated imaginary part of
+/// the a element is multiplied with the imaginary part of the corresponding b elements.
+/// The two accumulated results are added, and then accumulated into the corresponding row and column of dst.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+///
+/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_cmmrlfp16ps&ig_expand=6862)
+#[inline]
+#[target_feature(enable = "amx-complex")]
+#[cfg_attr(test, assert_instr(tcmmrlfp16ps))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_cmmrlfp16ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) {
+ (*dst).tile = tcmmrlfp16ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile);
+}
+
/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and BF8 (8-bit E5M2)
/// floating-point elements in tile b, accumulating the intermediate single-precision
/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
@@ -263,6 +457,19 @@ pub unsafe fn _tile_dpbf8ps() {
tdpbf8ps(DST as i8, A as i8, B as i8);
}
+/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and BF8 (8-bit E5M2)
+/// floating-point elements in tile b, accumulating the intermediate single-precision
+/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
+/// back to tile dst.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+#[inline]
+#[target_feature(enable = "amx-fp8")]
+#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdpbf8ps))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_dpbf8ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) {
+ (*dst).tile = tdpbf8ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile);
+}
+
/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and HF8
/// (8-bit E4M3) floating-point elements in tile b, accumulating the intermediate single-precision
/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
@@ -282,6 +489,19 @@ pub unsafe fn _tile_dpbhf8ps() {
tdpbhf8ps(DST as i8, A as i8, B as i8);
}
+/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and HF8
+/// (8-bit E4M3) floating-point elements in tile b, accumulating the intermediate single-precision
+/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
+/// back to tile dst.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+#[inline]
+#[target_feature(enable = "amx-fp8")]
+#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdpbhf8ps))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_dpbhf8ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) {
+ (*dst).tile = tdpbhf8ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile);
+}
+
/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and BF8
/// (8-bit E5M2) floating-point elements in tile b, accumulating the intermediate single-precision
/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
@@ -301,6 +521,19 @@ pub unsafe fn _tile_dphbf8ps() {
tdphbf8ps(DST as i8, A as i8, B as i8);
}
+/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and BF8
+/// (8-bit E5M2) floating-point elements in tile b, accumulating the intermediate single-precision
+/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
+/// back to tile dst.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+#[inline]
+#[target_feature(enable = "amx-fp8")]
+#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdphbf8ps))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_dphbf8ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) {
+ (*dst).tile = tdphbf8ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile);
+}
+
/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and HF8 (8-bit E4M3)
/// floating-point elements in tile b, accumulating the intermediate single-precision
/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
@@ -320,8 +553,21 @@ pub unsafe fn _tile_dphf8ps() {
tdphf8ps(DST as i8, A as i8, B as i8);
}
+/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and HF8 (8-bit E4M3)
+/// floating-point elements in tile b, accumulating the intermediate single-precision
+/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
+/// back to tile dst.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+#[inline]
+#[target_feature(enable = "amx-fp8")]
+#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdphf8ps))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_dphf8ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) {
+ (*dst).tile = tdphf8ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile);
+}
+
/// Load tile rows from memory specified by base address and stride into destination tile dst
-/// using the tile configuration previously configured via _tile_loadconfig.
+/// using the tile configuration previously configured via [`_tile_loadconfig`].
/// Additionally, this intrinsic indicates the source memory location is likely to become
/// read-shared by multiple processors, i.e., read in the future by at least one other processor
/// before it is written, assuming it is ever written in the future.
@@ -335,11 +581,24 @@ pub unsafe fn _tile_dphf8ps() {
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_loaddrs(base: *const u8, stride: usize) {
static_assert_uimm_bits!(DST, 3);
- tileloaddrs64(DST as i8, base, stride);
+ tileloaddrs64(DST as i8, base, stride as u64);
+}
+
+/// Load tile rows from memory specified by base address and stride into destination tile dst.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+/// Additionally, this intrinsic indicates the source memory location is likely to become
+/// read-shared by multiple processors, i.e., read in the future by at least one other processor
+/// before it is written, assuming it is ever written in the future.
+#[inline]
+#[target_feature(enable = "amx-movrs")]
+#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tileloaddrs))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_loaddrs(dst: *mut __tile1024i, base: *const u8, stride: usize) {
+ (*dst).tile = tileloaddrs64_internal((*dst).rows, (*dst).cols, base, stride as u64);
}
/// Load tile rows from memory specified by base address and stride into destination tile dst
-/// using the tile configuration previously configured via _tile_loadconfig.
+/// using the tile configuration previously configured via [`_tile_loadconfig`].
/// Provides a hint to the implementation that the data would be reused but does not need
/// to be resident in the nearest cache levels.
/// Additionally, this intrinsic indicates the source memory location is likely to become
@@ -355,7 +614,22 @@ pub unsafe fn _tile_loaddrs(base: *const u8, stride: usize) {
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_stream_loaddrs(base: *const u8, stride: usize) {
static_assert_uimm_bits!(DST, 3);
- tileloaddrst164(DST as i8, base, stride);
+ tileloaddrst164(DST as i8, base, stride as u64);
+}
+
+/// Load tile rows from memory specified by base address and stride into destination tile dst.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+/// Provides a hint to the implementation that the data would be reused but does not need
+/// to be resident in the nearest cache levels.
+/// Additionally, this intrinsic indicates the source memory location is likely to become
+/// read-shared by multiple processors, i.e., read in the future by at least one other processor
+/// before it is written, assuming it is ever written in the future.
+#[inline]
+#[target_feature(enable = "amx-movrs")]
+#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tileloaddrst1))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_stream_loaddrs(dst: *mut __tile1024i, base: *const u8, stride: usize) {
+ (*dst).tile = tileloaddrst164_internal((*dst).rows, (*dst).cols, base, stride as u64);
}
/// Perform matrix multiplication of two tiles a and b, containing packed single precision (32-bit)
@@ -383,6 +657,25 @@ pub unsafe fn _tile_mmultf32ps() {
tmmultf32ps(DST as i8, A as i8, B as i8);
}
+/// Perform matrix multiplication of two tiles a and b, containing packed single precision (32-bit)
+/// floating-point elements, which are converted to TF32 (tensor-float32) format, and accumulate the
+/// results into a packed single precision tile.
+/// For each possible combination of (row of a, column of b), it performs
+/// - convert to TF32
+/// - multiply the corresponding elements of a and b
+/// - accumulate the results into the corresponding row and column of dst using round-to-nearest-even
+/// rounding mode.
+/// Output FP32 denormals are always flushed to zero, input single precision denormals are always
+/// handled and *not* treated as zero.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+#[inline]
+#[target_feature(enable = "amx-tf32")]
+#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tmmultf32ps))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_mmultf32ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) {
+ (*dst).tile = tmmultf32ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile);
+}
+
/// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer
/// elements to packed single-precision (32-bit) floating-point elements.
#[inline]
@@ -414,6 +707,17 @@ pub unsafe fn _tile_cvtrowd2psi() -> __m512 {
tcvtrowd2psi(TILE as i8, ROW as u32).as_m512()
}
+/// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer
+/// elements to packed single-precision (32-bit) floating-point elements.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+#[inline]
+#[target_feature(enable = "amx-avx512,avx10.2")]
+#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowd2ps))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_cvtrowd2ps(src: __tile1024i, row: u32) -> __m512 {
+ tcvtrowd2ps_internal(src.rows, src.cols, src.tile, row).as_m512()
+}
+
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
@@ -447,6 +751,18 @@ pub unsafe fn _tile_cvtrowps2phhi() -> __m512h
tcvtrowps2phhi(TILE as i8, ROW as u32).as_m512h()
}
+/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
+/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
+/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+#[inline]
+#[target_feature(enable = "amx-avx512,avx10.2")]
+#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2phh))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_cvtrowps2phh(src: __tile1024i, row: u32) -> __m512h {
+ tcvtrowps2phh_internal(src.rows, src.cols, src.tile, row).as_m512h()
+}
+
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
@@ -480,6 +796,18 @@ pub unsafe fn _tile_cvtrowps2phli() -> __m512h
tcvtrowps2phli(TILE as i8, ROW as u32).as_m512h()
}
+/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
+/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
+/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+#[inline]
+#[target_feature(enable = "amx-avx512,avx10.2")]
+#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2phl))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_cvtrowps2phl(src: __tile1024i, row: u32) -> __m512h {
+ tcvtrowps2phl_internal(src.rows, src.cols, src.tile, row).as_m512h()
+}
+
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
@@ -513,6 +841,18 @@ pub unsafe fn _tile_cvtrowps2bf16hi() -> __m512
tcvtrowps2bf16hi(TILE as i8, ROW as u32).as_m512bh()
}
+/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
+/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
+/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+#[inline]
+#[target_feature(enable = "amx-avx512,avx10.2")]
+#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2bf16h))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_cvtrowps2bf16h(src: __tile1024i, row: u32) -> __m512bh {
+ tcvtrowps2bf16h_internal(src.rows, src.cols, src.tile, row).as_m512bh()
+}
+
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
@@ -546,6 +886,18 @@ pub unsafe fn _tile_cvtrowps2bf16li() -> __m512
tcvtrowps2bf16li(TILE as i8, ROW as u32).as_m512bh()
}
+/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
+/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
+/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+#[inline]
+#[target_feature(enable = "amx-avx512,avx10.2")]
+#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2bf16l))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_cvtrowps2bf16l(src: __tile1024i, row: u32) -> __m512bh {
+ tcvtrowps2bf16l_internal(src.rows, src.cols, src.tile, row).as_m512bh()
+}
+
/// Moves one row of tile data into a zmm vector register
#[inline]
#[rustc_legacy_const_generics(0)]
@@ -575,83 +927,170 @@ pub unsafe fn _tile_movrowi() -> __m512i {
tilemovrowi(TILE as i8, ROW as u32).as_m512i()
}
+/// Moves one row of tile data into a zmm vector register
+/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler.
+#[inline]
+#[target_feature(enable = "amx-avx512,avx10.2")]
+#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tilemovrow))]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub unsafe fn __tile_movrow(src: __tile1024i, row: u32) -> __m512i {
+ tilemovrow_internal(src.rows, src.cols, src.tile, row).as_m512i()
+}
+
#[allow(improper_ctypes)]
-unsafe extern "C" {
+unsafe extern "unadjusted" {
#[link_name = "llvm.x86.ldtilecfg"]
fn ldtilecfg(mem_addr: *const u8);
#[link_name = "llvm.x86.sttilecfg"]
fn sttilecfg(mem_addr: *mut u8);
+
#[link_name = "llvm.x86.tileloadd64"]
- fn tileloadd64(dst: i8, base: *const u8, stride: usize);
+ fn tileloadd64(dst: i8, base: *const u8, stride: u64);
+ #[link_name = "llvm.x86.tileloadd64.internal"]
+ fn tileloadd64_internal(rows: u16, cols: u16, base: *const u8, stride: u64) -> Tile;
+
#[link_name = "llvm.x86.tileloaddt164"]
- fn tileloaddt164(dst: i8, base: *const u8, stride: usize);
+ fn tileloaddt164(dst: i8, base: *const u8, stride: u64);
+ #[link_name = "llvm.x86.tileloaddt164.internal"]
+ fn tileloaddt164_internal(rows: u16, cols: u16, base: *const u8, stride: u64) -> Tile;
+
#[link_name = "llvm.x86.tilerelease"]
fn tilerelease();
+
#[link_name = "llvm.x86.tilestored64"]
- fn tilestored64(dst: i8, base: *mut u8, stride: usize);
+ fn tilestored64(dst: i8, base: *mut u8, stride: u64);
+ #[link_name = "llvm.x86.tilestored64.internal"]
+ fn tilestored64_internal(rows: u16, cols: u16, base: *mut u8, stride: u64, src: Tile);
+
#[link_name = "llvm.x86.tilezero"]
fn tilezero(dst: i8);
+ #[link_name = "llvm.x86.tilezero.internal"]
+ fn tilezero_internal(rows: u16, cols: u16) -> Tile;
+
#[link_name = "llvm.x86.tdpbf16ps"]
fn tdpbf16ps(dst: i8, a: i8, b: i8);
+ #[link_name = "llvm.x86.tdpbf16ps.internal"]
+ fn tdpbf16ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile;
+
#[link_name = "llvm.x86.tdpbuud"]
fn tdpbuud(dst: i8, a: i8, b: i8);
+ #[link_name = "llvm.x86.tdpbuud.internal"]
+ fn tdpbuud_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile;
+
#[link_name = "llvm.x86.tdpbusd"]
fn tdpbusd(dst: i8, a: i8, b: i8);
+ #[link_name = "llvm.x86.tdpbusd.internal"]
+ fn tdpbusd_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile;
+
#[link_name = "llvm.x86.tdpbsud"]
fn tdpbsud(dst: i8, a: i8, b: i8);
+ #[link_name = "llvm.x86.tdpbsud.internal"]
+ fn tdpbsud_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile;
+
#[link_name = "llvm.x86.tdpbssd"]
fn tdpbssd(dst: i8, a: i8, b: i8);
+ #[link_name = "llvm.x86.tdpbssd.internal"]
+ fn tdpbssd_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile;
+
#[link_name = "llvm.x86.tdpfp16ps"]
fn tdpfp16ps(dst: i8, a: i8, b: i8);
+ #[link_name = "llvm.x86.tdpfp16ps.internal"]
+ fn tdpfp16ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile;
+
#[link_name = "llvm.x86.tcmmimfp16ps"]
fn tcmmimfp16ps(dst: i8, a: i8, b: i8);
+ #[link_name = "llvm.x86.tcmmimfp16ps.internal"]
+ fn tcmmimfp16ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile;
+
#[link_name = "llvm.x86.tcmmrlfp16ps"]
fn tcmmrlfp16ps(dst: i8, a: i8, b: i8);
+ #[link_name = "llvm.x86.tcmmrlfp16ps.internal"]
+ fn tcmmrlfp16ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile;
+
#[link_name = "llvm.x86.tdpbf8ps"]
fn tdpbf8ps(dst: i8, a: i8, b: i8);
+ #[link_name = "llvm.x86.tdpbf8ps.internal"]
+ fn tdpbf8ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile;
+
#[link_name = "llvm.x86.tdpbhf8ps"]
fn tdpbhf8ps(dst: i8, a: i8, b: i8);
+ #[link_name = "llvm.x86.tdpbhf8ps.internal"]
+ fn tdpbhf8ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile;
+
#[link_name = "llvm.x86.tdphbf8ps"]
fn tdphbf8ps(dst: i8, a: i8, b: i8);
+ #[link_name = "llvm.x86.tdphbf8ps.internal"]
+ fn tdphbf8ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile;
+
#[link_name = "llvm.x86.tdphf8ps"]
fn tdphf8ps(dst: i8, a: i8, b: i8);
+ #[link_name = "llvm.x86.tdphf8ps.internal"]
+ fn tdphf8ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile;
+
#[link_name = "llvm.x86.tileloaddrs64"]
- fn tileloaddrs64(dst: i8, base: *const u8, stride: usize);
+ fn tileloaddrs64(dst: i8, base: *const u8, stride: u64);
+ #[link_name = "llvm.x86.tileloaddrs64.internal"]
+ fn tileloaddrs64_internal(rows: u16, cols: u16, base: *const u8, stride: u64) -> Tile;
+
#[link_name = "llvm.x86.tileloaddrst164"]
- fn tileloaddrst164(dst: i8, base: *const u8, stride: usize);
+ fn tileloaddrst164(dst: i8, base: *const u8, stride: u64);
+ #[link_name = "llvm.x86.tileloaddrst164.internal"]
+ fn tileloaddrst164_internal(rows: u16, cols: u16, base: *const u8, stride: u64) -> Tile;
+
#[link_name = "llvm.x86.tmmultf32ps"]
fn tmmultf32ps(dst: i8, a: i8, b: i8);
+ #[link_name = "llvm.x86.tmmultf32ps.internal"]
+ fn tmmultf32ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile;
+
#[link_name = "llvm.x86.tcvtrowd2ps"]
fn tcvtrowd2ps(tile: i8, row: u32) -> f32x16;
#[link_name = "llvm.x86.tcvtrowd2psi"]
fn tcvtrowd2psi(tile: i8, row: u32) -> f32x16;
+ #[link_name = "llvm.x86.tcvtrowd2ps.internal"]
+ fn tcvtrowd2ps_internal(rows: u16, cols: u16, src: Tile, row: u32) -> f32x16;
+
#[link_name = "llvm.x86.tcvtrowps2phh"]
fn tcvtrowps2phh(tile: i8, row: u32) -> f16x32;
#[link_name = "llvm.x86.tcvtrowps2phhi"]
fn tcvtrowps2phhi(tile: i8, row: u32) -> f16x32;
+ #[link_name = "llvm.x86.tcvtrowps2phh.internal"]
+ fn tcvtrowps2phh_internal(rows: u16, cols: u16, src: Tile, row: u32) -> f16x32;
+
#[link_name = "llvm.x86.tcvtrowps2phl"]
fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32;
#[link_name = "llvm.x86.tcvtrowps2phli"]
fn tcvtrowps2phli(tile: i8, row: u32) -> f16x32;
+ #[link_name = "llvm.x86.tcvtrowps2phl.internal"]
+ fn tcvtrowps2phl_internal(rows: u16, cols: u16, src: Tile, row: u32) -> f16x32;
+
#[link_name = "llvm.x86.tcvtrowps2bf16h"]
fn tcvtrowps2bf16h(tile: i8, row: u32) -> u16x32;
#[link_name = "llvm.x86.tcvtrowps2bf16hi"]
fn tcvtrowps2bf16hi(tile: i8, row: u32) -> u16x32;
+ #[link_name = "llvm.x86.tcvtrowps2bf16h.internal"]
+ fn tcvtrowps2bf16h_internal(rows: u16, cols: u16, src: Tile, row: u32) -> u16x32;
+
#[link_name = "llvm.x86.tcvtrowps2bf16l"]
fn tcvtrowps2bf16l(tile: i8, row: u32) -> u16x32;
#[link_name = "llvm.x86.tcvtrowps2bf16li"]
fn tcvtrowps2bf16li(tile: i8, row: u32) -> u16x32;
+ #[link_name = "llvm.x86.tcvtrowps2bf16l.internal"]
+ fn tcvtrowps2bf16l_internal(rows: u16, cols: u16, src: Tile, row: u32) -> u16x32;
+
#[link_name = "llvm.x86.tilemovrow"]
fn tilemovrow(tile: i8, row: u32) -> i32x16;
#[link_name = "llvm.x86.tilemovrowi"]
fn tilemovrowi(tile: i8, row: u32) -> i32x16;
+ #[link_name = "llvm.x86.tilemovrow.internal"]
+ fn tilemovrow_internal(rows: u16, cols: u16, src: Tile, row: u32) -> i32x16;
}
#[cfg(test)]
mod tests {
use crate::core_arch::x86::_mm_cvtness_sbh;
use crate::core_arch::x86_64::*;
- use core::{array, mem::transmute};
+ use core::array;
+ use core::mem::MaybeUninit;
use stdarch_test::simd_test;
#[cfg(target_os = "linux")]
use syscalls::{Sysno, syscall};
@@ -704,19 +1143,23 @@ mod tests {
#[cfg(target_os = "linux")]
#[target_feature(enable = "amx-tile")]
#[inline]
- unsafe fn _init_amx() {
+ fn _init_amx() {
let mut ret: usize;
let mut xfeatures: usize = 0;
- ret = syscall!(Sysno::arch_prctl, 0x1022, &mut xfeatures as *mut usize)
- .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed");
+ ret = unsafe {
+ syscall!(Sysno::arch_prctl, 0x1022, &raw mut xfeatures)
+ .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed")
+ };
if ret != 0 {
panic!("Failed to get XFEATURES");
} else {
match 0b11 & (xfeatures >> 17) {
0 => panic!("AMX is not available"),
1 => {
- ret = syscall!(Sysno::arch_prctl, 0x1023, 18)
- .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed");
+ ret = unsafe {
+ syscall!(Sysno::arch_prctl, 0x1023, 18)
+ .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed")
+ };
if ret != 0 {
panic!("Failed to enable AMX");
}
@@ -727,6 +1170,18 @@ mod tests {
}
}
+ impl __tile1024i {
+ #[inline]
+ #[target_feature(enable = "amx-tile")]
+ fn zeroed(rows: u16, cols: u16) -> Self {
+ Self {
+ rows,
+ cols,
+ tile: unsafe { super::tilezero_internal(rows, cols) },
+ }
+ }
+ }
+
#[simd_test(enable = "amx-tile")]
fn test_tile_loadconfig() {
unsafe {
@@ -759,12 +1214,26 @@ mod tests {
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
let mut out = [[1_i8; 64]; 16];
- _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
+ _tile_stored::<0>(out.as_mut_ptr().cast(), 64);
_tile_release();
assert_eq!(out, [[0; 64]; 16]);
}
}
+ #[simd_test(enable = "amx-tile")]
+ fn test__tile_zero() {
+ unsafe {
+ _init_amx();
+
+ let tile = __tile1024i::zeroed(16, 64);
+
+ let mut out = [[1_i8; 64]; 16];
+ __tile_stored(out.as_mut_ptr().cast(), 64, tile);
+
+ assert_eq!(out, [[0; 64]; 16]);
+ }
+ }
+
#[simd_test(enable = "amx-tile")]
fn test_tile_stored() {
unsafe {
@@ -776,12 +1245,26 @@ mod tests {
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
let mut out = [[1_i8; 64]; 16];
- _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
+ _tile_stored::<0>(out.as_mut_ptr().cast(), 64);
_tile_release();
assert_eq!(out, [[0; 64]; 16]);
}
}
+ #[simd_test(enable = "amx-tile")]
+ fn test__tile_stored() {
+ unsafe {
+ _init_amx();
+
+ let tile = __tile1024i::zeroed(16, 64);
+
+ let mut out = [[1_i8; 64]; 16];
+ __tile_stored(out.as_mut_ptr().cast(), 64, tile);
+
+ assert_eq!(out, [[0; 64]; 16]);
+ }
+ }
+
#[simd_test(enable = "amx-tile")]
fn test_tile_loadd() {
unsafe {
@@ -793,14 +1276,30 @@ mod tests {
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
let mat = [1_i8; 1024];
- _tile_loadd::<0>(&mat as *const i8 as *const u8, 64);
+ _tile_loadd::<0>(mat.as_ptr().cast(), 64);
let mut out = [[0_i8; 64]; 16];
- _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
+ _tile_stored::<0>(out.as_mut_ptr().cast(), 64);
_tile_release();
assert_eq!(out, [[1; 64]; 16]);
}
}
+ #[simd_test(enable = "amx-tile")]
+ fn test__tile_loadd() {
+ unsafe {
+ _init_amx();
+
+ let mut tile = __tile1024i::zeroed(16, 64);
+
+ let mat = [1_i8; 1024];
+ __tile_loadd(&mut tile, mat.as_ptr().cast(), 64);
+ let mut out = [[0_i8; 64]; 16];
+ __tile_stored(out.as_mut_ptr().cast(), 64, tile);
+
+ assert_eq!(out, [[1; 64]; 16]);
+ }
+ }
+
#[simd_test(enable = "amx-tile")]
fn test_tile_stream_loadd() {
unsafe {
@@ -812,14 +1311,30 @@ mod tests {
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
let mat = [1_i8; 1024];
- _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64);
+ _tile_stream_loadd::<0>(mat.as_ptr().cast(), 64);
let mut out = [[0_i8; 64]; 16];
- _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
+ _tile_stored::<0>(out.as_mut_ptr().cast(), 64);
_tile_release();
assert_eq!(out, [[1; 64]; 16]);
}
}
+ #[simd_test(enable = "amx-tile")]
+ fn test__tile_stream_loadd() {
+ unsafe {
+ _init_amx();
+
+ let mut tile = __tile1024i::zeroed(16, 64);
+
+ let mat = [1_i8; 1024];
+ __tile_stream_loadd(&mut tile, mat.as_ptr().cast(), 64);
+ let mut out = [[0_i8; 64]; 16];
+ __tile_stored(out.as_mut_ptr().cast(), 64, tile);
+
+ assert_eq!(out, [[1; 64]; 16]);
+ }
+ }
+
#[simd_test(enable = "amx-tile")]
fn test_tile_release() {
unsafe {
@@ -827,14 +1342,15 @@ mod tests {
}
}
- #[simd_test(enable = "amx-bf16,avx512f")]
+ const BF16_1: u16 = 0x3f80;
+ const BF16_2: u16 = 0x4000;
+
+ #[simd_test(enable = "amx-bf16")]
fn test_tile_dpbf16ps() {
unsafe {
_init_amx();
- let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits();
- let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits();
- let ones: [u8; 1024] = transmute([bf16_1; 512]);
- let twos: [u8; 1024] = transmute([bf16_2; 512]);
+ let ones = [BF16_1; 512];
+ let twos = [BF16_2; 512];
let mut res = [[0f32; 16]; 16];
let mut config = __tilecfg::default();
config.palette = 1;
@@ -844,15 +1360,36 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
- _tile_loadd::<1>(&ones as *const u8, 64);
- _tile_loadd::<2>(&twos as *const u8, 64);
+ _tile_loadd::<1>(ones.as_ptr().cast(), 64);
+ _tile_loadd::<2>(twos.as_ptr().cast(), 64);
_tile_dpbf16ps::<0, 1, 2>();
- _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
+ _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
_tile_release();
assert_eq!(res, [[64f32; 16]; 16]);
}
}
+ #[simd_test(enable = "amx-bf16,avx512f")]
+ fn test__tile_dpbf16ps() {
+ unsafe {
+ _init_amx();
+ let ones = [BF16_1; 512];
+ let twos = [BF16_2; 512];
+ let mut res = [[0f32; 16]; 16];
+
+ let mut a = __tile1024i::zeroed(16, 64);
+ let mut b = __tile1024i::zeroed(16, 64);
+ let mut c = __tile1024i::zeroed(16, 64);
+
+ __tile_loadd(&mut a, ones.as_ptr().cast(), 64);
+ __tile_loadd(&mut b, twos.as_ptr().cast(), 64);
+ __tile_dpbf16ps(&mut c, a, b);
+ __tile_stored(res.as_mut_ptr().cast(), 64, c);
+
+ assert_eq!(res, [[64f32; 16]; 16]);
+ }
+ }
+
#[simd_test(enable = "amx-int8")]
fn test_tile_dpbssd() {
unsafe {
@@ -868,15 +1405,36 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
- _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
- _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
+ _tile_loadd::<1>(ones.as_ptr().cast(), 64);
+ _tile_loadd::<2>(twos.as_ptr().cast(), 64);
_tile_dpbssd::<0, 1, 2>();
- _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
+ _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
_tile_release();
assert_eq!(res, [[128_i32; 16]; 16]);
}
}
+ #[simd_test(enable = "amx-int8")]
+ fn test__tile_dpbssd() {
+ unsafe {
+ _init_amx();
+ let ones = [-1_i8; 1024];
+ let twos = [-2_i8; 1024];
+ let mut res = [[0_i32; 16]; 16];
+
+ let mut a = __tile1024i::zeroed(16, 64);
+ let mut b = __tile1024i::zeroed(16, 64);
+ let mut c = __tile1024i::zeroed(16, 64);
+
+ __tile_loadd(&mut a, ones.as_ptr().cast(), 64);
+ __tile_loadd(&mut b, twos.as_ptr().cast(), 64);
+ __tile_dpbssd(&mut c, a, b);
+ __tile_stored(res.as_mut_ptr().cast(), 64, c);
+
+ assert_eq!(res, [[128_i32; 16]; 16]);
+ }
+ }
+
#[simd_test(enable = "amx-int8")]
fn test_tile_dpbsud() {
unsafe {
@@ -892,15 +1450,36 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
- _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
- _tile_loadd::<2>(&twos as *const u8, 64);
+ _tile_loadd::<1>(ones.as_ptr().cast(), 64);
+ _tile_loadd::<2>(twos.as_ptr(), 64);
_tile_dpbsud::<0, 1, 2>();
- _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
+ _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
_tile_release();
assert_eq!(res, [[-128_i32; 16]; 16]);
}
}
+ #[simd_test(enable = "amx-int8")]
+ fn test__tile_dpbsud() {
+ unsafe {
+ _init_amx();
+ let ones = [-1_i8; 1024];
+ let twos = [2_u8; 1024];
+ let mut res = [[0_i32; 16]; 16];
+
+ let mut a = __tile1024i::zeroed(16, 64);
+ let mut b = __tile1024i::zeroed(16, 64);
+ let mut c = __tile1024i::zeroed(16, 64);
+
+ __tile_loadd(&mut a, ones.as_ptr().cast(), 64);
+ __tile_loadd(&mut b, twos.as_ptr(), 64);
+ __tile_dpbsud(&mut c, a, b);
+ __tile_stored(res.as_mut_ptr().cast(), 64, c);
+
+ assert_eq!(res, [[-128_i32; 16]; 16]);
+ }
+ }
+
#[simd_test(enable = "amx-int8")]
fn test_tile_dpbusd() {
unsafe {
@@ -916,15 +1495,36 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
- _tile_loadd::<1>(&ones as *const u8, 64);
- _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
+ _tile_loadd::<1>(ones.as_ptr(), 64);
+ _tile_loadd::<2>(twos.as_ptr().cast(), 64);
_tile_dpbusd::<0, 1, 2>();
- _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
+ _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
_tile_release();
assert_eq!(res, [[-128_i32; 16]; 16]);
}
}
+ #[simd_test(enable = "amx-int8")]
+ fn test__tile_dpbusd() {
+ unsafe {
+ _init_amx();
+ let ones = [1_u8; 1024];
+ let twos = [-2_i8; 1024];
+ let mut res = [[0_i32; 16]; 16];
+
+ let mut a = __tile1024i::zeroed(16, 64);
+ let mut b = __tile1024i::zeroed(16, 64);
+ let mut c = __tile1024i::zeroed(16, 64);
+
+ __tile_loadd(&mut a, ones.as_ptr(), 64);
+ __tile_loadd(&mut b, twos.as_ptr().cast(), 64);
+ __tile_dpbusd(&mut c, a, b);
+ __tile_stored(res.as_mut_ptr().cast(), 64, c);
+
+ assert_eq!(res, [[-128_i32; 16]; 16]);
+ }
+ }
+
#[simd_test(enable = "amx-int8")]
fn test_tile_dpbuud() {
unsafe {
@@ -940,15 +1540,36 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
- _tile_loadd::<1>(&ones as *const u8, 64);
- _tile_loadd::<2>(&twos as *const u8, 64);
+ _tile_loadd::<1>(ones.as_ptr(), 64);
+ _tile_loadd::<2>(twos.as_ptr(), 64);
_tile_dpbuud::<0, 1, 2>();
- _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
+ _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
_tile_release();
assert_eq!(res, [[128_i32; 16]; 16]);
}
}
+ #[simd_test(enable = "amx-int8")]
+ fn test__tile_dpbuud() {
+ unsafe {
+ _init_amx();
+ let ones = [1_u8; 1024];
+ let twos = [2_u8; 1024];
+ let mut res = [[0_i32; 16]; 16];
+
+ let mut a = __tile1024i::zeroed(16, 64);
+ let mut b = __tile1024i::zeroed(16, 64);
+ let mut c = __tile1024i::zeroed(16, 64);
+
+ __tile_loadd(&mut a, ones.as_ptr(), 64);
+ __tile_loadd(&mut b, twos.as_ptr(), 64);
+ __tile_dpbuud(&mut c, a, b);
+ __tile_stored(res.as_mut_ptr().cast(), 64, c);
+
+ assert_eq!(res, [[128_i32; 16]; 16]);
+ }
+ }
+
#[simd_test(enable = "amx-fp16")]
fn test_tile_dpfp16ps() {
unsafe {
@@ -964,15 +1585,36 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
- _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
- _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
+ _tile_loadd::<1>(ones.as_ptr().cast(), 64);
+ _tile_loadd::<2>(twos.as_ptr().cast(), 64);
_tile_dpfp16ps::<0, 1, 2>();
- _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
+ _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
_tile_release();
assert_eq!(res, [[64f32; 16]; 16]);
}
}
+ #[simd_test(enable = "amx-fp16")]
+ fn test__tile_dpfp16ps() {
+ unsafe {
+ _init_amx();
+ let ones = [1f16; 512];
+ let twos = [2f16; 512];
+ let mut res = [[0f32; 16]; 16];
+
+ let mut a = __tile1024i::zeroed(16, 64);
+ let mut b = __tile1024i::zeroed(16, 64);
+ let mut c = __tile1024i::zeroed(16, 64);
+
+ __tile_loadd(&mut a, ones.as_ptr().cast(), 64);
+ __tile_loadd(&mut b, twos.as_ptr().cast(), 64);
+ __tile_dpfp16ps(&mut c, a, b);
+ __tile_stored(res.as_mut_ptr().cast(), 64, c);
+
+ assert_eq!(res, [[64f32; 16]; 16]);
+ }
+ }
+
#[simd_test(enable = "amx-complex")]
fn test_tile_cmmimfp16ps() {
unsafe {
@@ -988,15 +1630,36 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
- _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
- _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
+ _tile_loadd::<1>(ones.as_ptr().cast(), 64);
+ _tile_loadd::<2>(twos.as_ptr().cast(), 64);
_tile_cmmimfp16ps::<0, 1, 2>();
- _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
+ _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
_tile_release();
assert_eq!(res, [[64f32; 16]; 16]);
}
}
+ #[simd_test(enable = "amx-complex")]
+ fn test__tile_cmmimfp16ps() {
+ unsafe {
+ _init_amx();
+ let ones = [1f16; 512];
+ let twos = [2f16; 512];
+ let mut res = [[0f32; 16]; 16];
+
+ let mut a = __tile1024i::zeroed(16, 64);
+ let mut b = __tile1024i::zeroed(16, 64);
+ let mut c = __tile1024i::zeroed(16, 64);
+
+ __tile_loadd(&mut a, ones.as_ptr().cast(), 64);
+ __tile_loadd(&mut b, twos.as_ptr().cast(), 64);
+ __tile_cmmimfp16ps(&mut c, a, b);
+ __tile_stored(res.as_mut_ptr().cast(), 64, c);
+
+ assert_eq!(res, [[64f32; 16]; 16]);
+ }
+ }
+
#[simd_test(enable = "amx-complex")]
fn test_tile_cmmrlfp16ps() {
unsafe {
@@ -1012,15 +1675,36 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
- _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
- _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
+ _tile_loadd::<1>(ones.as_ptr().cast(), 64);
+ _tile_loadd::<2>(twos.as_ptr().cast(), 64);
_tile_cmmrlfp16ps::<0, 1, 2>();
- _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
+ _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
_tile_release();
assert_eq!(res, [[0f32; 16]; 16]);
}
}
+ #[simd_test(enable = "amx-complex")]
+ fn test__tile_cmmrlfp16ps() {
+ unsafe {
+ _init_amx();
+ let ones = [1f16; 512];
+ let twos = [2f16; 512];
+ let mut res = [[0f32; 16]; 16];
+
+ let mut a = __tile1024i::zeroed(16, 64);
+ let mut b = __tile1024i::zeroed(16, 64);
+ let mut c = __tile1024i::zeroed(16, 64);
+
+ __tile_loadd(&mut a, ones.as_ptr().cast(), 64);
+ __tile_loadd(&mut b, twos.as_ptr().cast(), 64);
+ __tile_cmmrlfp16ps(&mut c, a, b);
+ __tile_stored(res.as_mut_ptr().cast(), 64, c);
+
+ assert_eq!(res, [[0f32; 16]; 16]);
+ }
+ }
+
const BF8_ONE: u8 = 0x3c;
const BF8_TWO: u8 = 0x40;
const HF8_ONE: u8 = 0x38;
@@ -1041,8 +1725,8 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
- _tile_loadd::<1>(&ones as *const u8, 64);
- _tile_loadd::<2>(&twos as *const u8, 64);
+ _tile_loadd::<1>(ones.as_ptr(), 64);
+ _tile_loadd::<2>(twos.as_ptr(), 64);
_tile_dpbf8ps::<0, 1, 2>();
_tile_stored::<0>(res.as_mut_ptr().cast(), 64);
_tile_release();
@@ -1050,6 +1734,27 @@ mod tests {
}
}
+ #[simd_test(enable = "amx-fp8")]
+ fn test__tile_dpbf8ps() {
+ unsafe {
+ _init_amx();
+ let ones = [BF8_ONE; 1024];
+ let twos = [BF8_TWO; 1024];
+ let mut res = [[0.0_f32; 16]; 16];
+
+ let mut a = __tile1024i::zeroed(16, 64);
+ let mut b = __tile1024i::zeroed(16, 64);
+ let mut c = __tile1024i::zeroed(16, 64);
+
+ __tile_loadd(&mut a, ones.as_ptr(), 64);
+ __tile_loadd(&mut b, twos.as_ptr(), 64);
+ __tile_dpbf8ps(&mut c, a, b);
+ __tile_stored(res.as_mut_ptr().cast(), 64, c);
+
+ assert_eq!(res, [[128.0_f32; 16]; 16]);
+ }
+ }
+
#[simd_test(enable = "amx-fp8")]
fn test_tile_dpbhf8ps() {
unsafe {
@@ -1065,8 +1770,8 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
- _tile_loadd::<1>(&ones as *const u8, 64);
- _tile_loadd::<2>(&twos as *const u8, 64);
+ _tile_loadd::<1>(ones.as_ptr(), 64);
+ _tile_loadd::<2>(twos.as_ptr(), 64);
_tile_dpbhf8ps::<0, 1, 2>();
_tile_stored::<0>(res.as_mut_ptr().cast(), 64);
_tile_release();
@@ -1074,6 +1779,27 @@ mod tests {
}
}
+ #[simd_test(enable = "amx-fp8")]
+ fn test__tile_dpbhf8ps() {
+ unsafe {
+ _init_amx();
+ let ones = [BF8_ONE; 1024];
+ let twos = [HF8_TWO; 1024];
+ let mut res = [[0.0_f32; 16]; 16];
+
+ let mut a = __tile1024i::zeroed(16, 64);
+ let mut b = __tile1024i::zeroed(16, 64);
+ let mut c = __tile1024i::zeroed(16, 64);
+
+ __tile_loadd(&mut a, ones.as_ptr(), 64);
+ __tile_loadd(&mut b, twos.as_ptr(), 64);
+ __tile_dpbhf8ps(&mut c, a, b);
+ __tile_stored(res.as_mut_ptr().cast(), 64, c);
+
+ assert_eq!(res, [[128.0_f32; 16]; 16]);
+ }
+ }
+
#[simd_test(enable = "amx-fp8")]
fn test_tile_dphbf8ps() {
unsafe {
@@ -1089,8 +1815,8 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
- _tile_loadd::<1>(&ones as *const u8, 64);
- _tile_loadd::<2>(&twos as *const u8, 64);
+ _tile_loadd::<1>(ones.as_ptr(), 64);
+ _tile_loadd::<2>(twos.as_ptr(), 64);
_tile_dphbf8ps::<0, 1, 2>();
_tile_stored::<0>(res.as_mut_ptr().cast(), 64);
_tile_release();
@@ -1098,6 +1824,27 @@ mod tests {
}
}
+ #[simd_test(enable = "amx-fp8")]
+ fn test__tile_dphbf8ps() {
+ unsafe {
+ _init_amx();
+ let ones = [HF8_ONE; 1024];
+ let twos = [BF8_TWO; 1024];
+ let mut res = [[0.0_f32; 16]; 16];
+
+ let mut a = __tile1024i::zeroed(16, 64);
+ let mut b = __tile1024i::zeroed(16, 64);
+ let mut c = __tile1024i::zeroed(16, 64);
+
+ __tile_loadd(&mut a, ones.as_ptr(), 64);
+ __tile_loadd(&mut b, twos.as_ptr(), 64);
+ __tile_dphbf8ps(&mut c, a, b);
+ __tile_stored(res.as_mut_ptr().cast(), 64, c);
+
+ assert_eq!(res, [[128.0_f32; 16]; 16]);
+ }
+ }
+
#[simd_test(enable = "amx-fp8")]
fn test_tile_dphf8ps() {
unsafe {
@@ -1113,8 +1860,8 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
- _tile_loadd::<1>(&ones as *const u8, 64);
- _tile_loadd::<2>(&twos as *const u8, 64);
+ _tile_loadd::<1>(ones.as_ptr(), 64);
+ _tile_loadd::<2>(twos.as_ptr(), 64);
_tile_dphf8ps::<0, 1, 2>();
_tile_stored::<0>(res.as_mut_ptr().cast(), 64);
_tile_release();
@@ -1122,6 +1869,27 @@ mod tests {
}
}
+ #[simd_test(enable = "amx-fp8")]
+ fn test__tile_dphf8ps() {
+ unsafe {
+ _init_amx();
+ let ones = [HF8_ONE; 1024];
+ let twos = [HF8_TWO; 1024];
+ let mut res = [[0.0_f32; 16]; 16];
+
+ let mut a = __tile1024i::zeroed(16, 64);
+ let mut b = __tile1024i::zeroed(16, 64);
+ let mut c = __tile1024i::zeroed(16, 64);
+
+ __tile_loadd(&mut a, ones.as_ptr(), 64);
+ __tile_loadd(&mut b, twos.as_ptr(), 64);
+ __tile_dphf8ps(&mut c, a, b);
+ __tile_stored(res.as_mut_ptr().cast(), 64, c);
+
+ assert_eq!(res, [[128.0_f32; 16]; 16]);
+ }
+ }
+
#[simd_test(enable = "amx-movrs")]
fn test_tile_loaddrs() {
unsafe {
@@ -1133,14 +1901,30 @@ mod tests {
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
let mat = [1_i8; 1024];
- _tile_loaddrs::<0>(&mat as *const i8 as *const u8, 64);
+ _tile_loaddrs::<0>(mat.as_ptr().cast(), 64);
let mut out = [[0_i8; 64]; 16];
- _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
+ _tile_stored::<0>(out.as_mut_ptr().cast(), 64);
_tile_release();
assert_eq!(out, [[1; 64]; 16]);
}
}
+ #[simd_test(enable = "amx-movrs")]
+ fn test__tile_loaddrs() {
+ unsafe {
+ _init_amx();
+
+ let mut tile = __tile1024i::zeroed(16, 64);
+
+ let mat = [1_i8; 1024];
+ __tile_loaddrs(&mut tile, mat.as_ptr().cast(), 64);
+ let mut out = [[0_i8; 64]; 16];
+ __tile_stored(out.as_mut_ptr().cast(), 64, tile);
+
+ assert_eq!(out, [[1; 64]; 16]);
+ }
+ }
+
#[simd_test(enable = "amx-movrs")]
fn test_tile_stream_loaddrs() {
unsafe {
@@ -1152,14 +1936,30 @@ mod tests {
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
let mat = [1_i8; 1024];
- _tile_stream_loaddrs::<0>(&mat as *const i8 as *const u8, 64);
+ _tile_stream_loaddrs::<0>(mat.as_ptr().cast(), 64);
let mut out = [[0_i8; 64]; 16];
- _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
+ _tile_stored::<0>(out.as_mut_ptr().cast(), 64);
_tile_release();
assert_eq!(out, [[1; 64]; 16]);
}
}
+ #[simd_test(enable = "amx-movrs")]
+ fn test__tile_stream_loaddrs() {
+ unsafe {
+ _init_amx();
+
+ let mut tile = __tile1024i::zeroed(16, 64);
+
+ let mat = [1_i8; 1024];
+ __tile_stream_loaddrs(&mut tile, mat.as_ptr().cast(), 64);
+ let mut out = [[0_i8; 64]; 16];
+ __tile_stored(out.as_mut_ptr().cast(), 64, tile);
+
+ assert_eq!(out, [[1; 64]; 16]);
+ }
+ }
+
#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_movrow() {
unsafe {
@@ -1223,6 +2023,22 @@ mod tests {
}
}
+ #[simd_test(enable = "amx-avx512,avx10.2")]
+ fn test__tile_movrow() {
+ unsafe {
+ _init_amx();
+ let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]);
+
+ let mut tile = __tile1024i::zeroed(16, 64);
+ __tile_loadd(&mut tile, array.as_ptr().cast(), 64);
+
+ for i in 0..16 {
+ let row = __tile_movrow(tile, i);
+ assert_eq!(*row.as_u8x64().as_array(), [i as _; _]);
+ }
+ }
+ }
+
#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowd2ps() {
unsafe {
@@ -1262,6 +2078,22 @@ mod tests {
}
}
+ #[simd_test(enable = "amx-avx512,avx10.2")]
+ fn test__tile_cvtrowd2ps() {
+ unsafe {
+ _init_amx();
+ let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]);
+
+ let mut tile = __tile1024i::zeroed(16, 64);
+ __tile_loadd(&mut tile, array.as_ptr().cast(), 64);
+
+ for i in 0..16 {
+ let row = __tile_cvtrowd2ps(tile, i);
+ assert_eq!(*row.as_f32x16().as_array(), [i as _; _]);
+ }
+ }
+ }
+
#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2phh() {
unsafe {
@@ -1306,6 +2138,25 @@ mod tests {
}
}
+ #[simd_test(enable = "amx-avx512,avx10.2")]
+ fn test__tile_cvtrowps2phh() {
+ unsafe {
+ _init_amx();
+ let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
+
+ let mut tile = __tile1024i::zeroed(16, 64);
+ __tile_loadd(&mut tile, array.as_ptr().cast(), 64);
+
+ for i in 0..16 {
+ let row = __tile_cvtrowps2phh(tile, i);
+ assert_eq!(
+ *row.as_f16x32().as_array(),
+ array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ })
+ );
+ }
+ }
+ }
+
#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2phl() {
unsafe {
@@ -1350,6 +2201,25 @@ mod tests {
}
}
+ #[simd_test(enable = "amx-avx512,avx10.2")]
+ fn test__tile_cvtrowps2phl() {
+ unsafe {
+ _init_amx();
+ let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
+
+ let mut tile = __tile1024i::zeroed(16, 64);
+ __tile_loadd(&mut tile, array.as_ptr().cast(), 64);
+
+ for i in 0..16 {
+ let row = __tile_cvtrowps2phl(tile, i);
+ assert_eq!(
+ *row.as_f16x32().as_array(),
+ array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 })
+ );
+ }
+ }
+ }
+
#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2bf16h() {
unsafe {
@@ -1402,6 +2272,29 @@ mod tests {
}
}
+ #[simd_test(enable = "amx-avx512,avx10.2")]
+ fn test__tile_cvtrowps2bf16h() {
+ unsafe {
+ _init_amx();
+ let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
+
+ let mut tile = __tile1024i::zeroed(16, 64);
+ __tile_loadd(&mut tile, array.as_ptr().cast(), 64);
+
+ for i in 0..16 {
+ let row = __tile_cvtrowps2bf16h(tile, i);
+ assert_eq!(
+ *row.as_u16x32().as_array(),
+ array::from_fn(|j| if j & 1 == 0 {
+ 0
+ } else {
+ _mm_cvtness_sbh(i as _).to_bits()
+ })
+ );
+ }
+ }
+ }
+
#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2bf16l() {
unsafe {
@@ -1454,6 +2347,29 @@ mod tests {
}
}
+ #[simd_test(enable = "amx-avx512,avx10.2")]
+ fn test__tile_cvtrowps2bf16l() {
+ unsafe {
+ _init_amx();
+ let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
+
+ let mut tile = __tile1024i::zeroed(16, 64);
+ __tile_loadd(&mut tile, array.as_ptr().cast(), 64);
+
+ for i in 0..16 {
+ let row = __tile_cvtrowps2bf16l(tile, i);
+ assert_eq!(
+ *row.as_u16x32().as_array(),
+ array::from_fn(|j| if j & 1 == 0 {
+ _mm_cvtness_sbh(i as _).to_bits()
+ } else {
+ 0
+ })
+ );
+ }
+ }
+ }
+
#[simd_test(enable = "amx-tf32")]
fn test_tile_mmultf32ps() {
unsafe {
@@ -1480,4 +2396,26 @@ mod tests {
assert_eq!(res, expected);
}
}
+
+ #[simd_test(enable = "amx-tf32")]
+ fn test__tile_mmultf32ps() {
+ unsafe {
+ _init_amx();
+ let a: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
+ let b: [[f32; 16]; 16] = [array::from_fn(|j| j as _); _];
+ let mut res = [[0.0; 16]; 16];
+
+ let mut tile_a = __tile1024i::zeroed(16, 64);
+ let mut tile_b = __tile1024i::zeroed(16, 64);
+ let mut tile_c = __tile1024i::zeroed(16, 64);
+
+ __tile_loadd(&mut tile_a, a.as_ptr().cast(), 64);
+ __tile_loadd(&mut tile_b, b.as_ptr().cast(), 64);
+ __tile_mmultf32ps(&mut tile_c, tile_a, tile_b);
+ __tile_stored(res.as_mut_ptr().cast(), 64, tile_c);
+
+ let expected = array::from_fn(|i| array::from_fn(|j| 16.0 * i as f32 * j as f32));
+ assert_eq!(res, expected);
+ }
+ }
}
diff --git a/crates/core_arch/src/x86_64/mod.rs b/crates/core_arch/src/x86_64/mod.rs
index 46384176e0..ffc2daaefa 100644
--- a/crates/core_arch/src/x86_64/mod.rs
+++ b/crates/core_arch/src/x86_64/mod.rs
@@ -3,6 +3,20 @@
#[macro_use]
mod macros;
+// Any 1024-byte vector should work
+type Tile = crate::core_arch::simd::Simd;
+
+/// A tile register, used by AMX instructions.
+// TODO: add more docs
+#[derive(Copy, Clone, Debug)]
+#[allow(non_camel_case_types)]
+#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
+pub struct __tile1024i {
+ pub rows: u16,
+ pub cols: u16,
+ tile: Tile,
+}
+
mod fxsr;
#[stable(feature = "simd_x86", since = "1.27.0")]
pub use self::fxsr::*;
diff --git a/crates/stdarch-test/src/lib.rs b/crates/stdarch-test/src/lib.rs
index ecaf95f617..c468ebd12b 100644
--- a/crates/stdarch-test/src/lib.rs
+++ b/crates/stdarch-test/src/lib.rs
@@ -172,6 +172,10 @@ pub fn assert(shim_addr: usize, fnname: &str, expected: &str) {
// vst1q_p64_x4_nop : #instructions = 33 >= 22 (limit)
"nop" if fnname.contains("vst1q_p64") => 34,
+ // AMX intrinsics generate a lot of move instructions to load/store the tile registers
+ // due to Rust ABI
+ _ if fnname.contains("___tile") => 165,
+
// Original limit was 20 instructions, but ARM DSP Intrinsics
// are exactly 20 instructions long. So, bump the limit to 22
// instead of adding here a long list of exceptions.
diff --git a/crates/stdarch-verify/src/lib.rs b/crates/stdarch-verify/src/lib.rs
index f7304ab326..5412ab466a 100644
--- a/crates/stdarch-verify/src/lib.rs
+++ b/crates/stdarch-verify/src/lib.rs
@@ -202,6 +202,7 @@ fn to_type(t: &syn::Type) -> proc_macro2::TokenStream {
"_MM_MANTISSA_NORM_ENUM" => quote! { &MM_MANTISSA_NORM_ENUM },
"_MM_MANTISSA_SIGN_ENUM" => quote! { &MM_MANTISSA_SIGN_ENUM },
"_MM_PERM_ENUM" => quote! { &MM_PERM_ENUM },
+ "__tile1024i" => quote! { &TILE1024I },
"bool" => quote! { &BOOL },
"bf16" => quote! { &BF16 },
"f16" => quote! { &F16 },
diff --git a/crates/stdarch-verify/tests/x86-intel.rs b/crates/stdarch-verify/tests/x86-intel.rs
index 024a873de1..aad19ca55a 100644
--- a/crates/stdarch-verify/tests/x86-intel.rs
+++ b/crates/stdarch-verify/tests/x86-intel.rs
@@ -62,6 +62,7 @@ static MM_CMPINT_ENUM: Type = Type::MM_CMPINT_ENUM;
static MM_MANTISSA_NORM_ENUM: Type = Type::MM_MANTISSA_NORM_ENUM;
static MM_MANTISSA_SIGN_ENUM: Type = Type::MM_MANTISSA_SIGN_ENUM;
static MM_PERM_ENUM: Type = Type::MM_PERM_ENUM;
+static TILE1024I: Type = Type::TILE1024I;
static TUPLE: Type = Type::Tuple;
static CPUID: Type = Type::CpuidResult;
@@ -102,6 +103,7 @@ enum Type {
CpuidResult,
Never,
Ordering,
+ TILE1024I,
}
stdarch_verify::x86_functions!(static FUNCTIONS);
@@ -774,6 +776,7 @@ fn equate(
(&Type::MMASK32, "__mmask32") => {}
(&Type::MMASK16, "__mmask16") => {}
(&Type::MMASK8, "__mmask8") => {}
+ (&Type::TILE1024I, "__tile1024i") => {}
(&Type::MutPtr(_type), "void*") | (&Type::ConstPtr(_type), "void const*") => {
let pointed_type = pointed_type(intrinsic)?;
@@ -812,6 +815,7 @@ fn equate(
(&Type::MutPtr(&Type::M512BH), "__m512bh*") => {}
(&Type::MutPtr(&Type::M512I), "__m512i*") => {}
(&Type::MutPtr(&Type::M512D), "__m512d*") => {}
+ (&Type::MutPtr(&Type::TILE1024I), "__tile1024i*") => {}
(&Type::ConstPtr(&Type::PrimFloat(16)), "_Float16 const*") => {}
(&Type::ConstPtr(&Type::PrimFloat(32)), "float const*") => {}