From 0d61ee304df41d9d0c82ba2dcaf0edfd39253238 Mon Sep 17 00:00:00 2001 From: sayantn Date: Sun, 19 Apr 2026 22:53:43 +0530 Subject: [PATCH 1/2] Add AMX-AVX512 BF16 intrinsics --- crates/core_arch/src/x86_64/amx.rs | 178 +++++++++++++++++++++++++++++ 1 file changed, 178 insertions(+) diff --git a/crates/core_arch/src/x86_64/amx.rs b/crates/core_arch/src/x86_64/amx.rs index 03bbe3e449..62e46097e6 100644 --- a/crates/core_arch/src/x86_64/amx.rs +++ b/crates/core_arch/src/x86_64/amx.rs @@ -480,6 +480,72 @@ pub unsafe fn _tile_cvtrowps2phli() -> __m512h tcvtrowps2phli(TILE as i8, ROW as u32).as_m512h() } +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) +/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting +/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. +#[inline] +#[rustc_legacy_const_generics(0)] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr( + all(test, any(target_os = "linux", target_env = "msvc")), + assert_instr(tcvtrowps2bf16h, TILE = 0) +)] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_cvtrowps2bf16h(row: u32) -> __m512bh { + static_assert_uimm_bits!(TILE, 3); + tcvtrowps2bf16h(TILE as i8, row).as_m512bh() +} + +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) +/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting +/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. +#[inline] +#[rustc_legacy_const_generics(0, 1)] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr( + all(test, any(target_os = "linux", target_env = "msvc")), + assert_instr(tcvtrowps2bf16h, TILE = 0, ROW = 0) +)] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_cvtrowps2bf16hi() -> __m512bh { + static_assert_uimm_bits!(TILE, 3); + static_assert_uimm_bits!(ROW, 6); + tcvtrowps2bf16hi(TILE as i8, ROW as u32).as_m512bh() +} + +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) +/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting +/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. +#[inline] +#[rustc_legacy_const_generics(0)] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr( + all(test, any(target_os = "linux", target_env = "msvc")), + assert_instr(tcvtrowps2bf16l, TILE = 0) +)] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_cvtrowps2bf16l(row: u32) -> __m512bh { + static_assert_uimm_bits!(TILE, 3); + tcvtrowps2bf16l(TILE as i8, row).as_m512bh() +} + +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) +/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting +/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. +#[inline] +#[rustc_legacy_const_generics(0, 1)] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr( + all(test, any(target_os = "linux", target_env = "msvc")), + assert_instr(tcvtrowps2bf16l, TILE = 0, ROW = 0) +)] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_cvtrowps2bf16li() -> __m512bh { + static_assert_uimm_bits!(TILE, 3); + static_assert_uimm_bits!(ROW, 6); + tcvtrowps2bf16li(TILE as i8, ROW as u32).as_m512bh() +} + /// Moves one row of tile data into a zmm vector register #[inline] #[rustc_legacy_const_generics(0)] @@ -567,6 +633,14 @@ unsafe extern "C" { fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32; #[link_name = "llvm.x86.tcvtrowps2phli"] fn tcvtrowps2phli(tile: i8, row: u32) -> f16x32; + #[link_name = "llvm.x86.tcvtrowps2bf16h"] + fn tcvtrowps2bf16h(tile: i8, row: u32) -> u16x32; + #[link_name = "llvm.x86.tcvtrowps2bf16hi"] + fn tcvtrowps2bf16hi(tile: i8, row: u32) -> u16x32; + #[link_name = "llvm.x86.tcvtrowps2bf16l"] + fn tcvtrowps2bf16l(tile: i8, row: u32) -> u16x32; + #[link_name = "llvm.x86.tcvtrowps2bf16li"] + fn tcvtrowps2bf16li(tile: i8, row: u32) -> u16x32; #[link_name = "llvm.x86.tilemovrow"] fn tilemovrow(tile: i8, row: u32) -> i32x16; #[link_name = "llvm.x86.tilemovrowi"] @@ -1276,6 +1350,110 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test_tile_cvtrowps2bf16h() { + unsafe { + _init_amx(); + let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut config = __tilecfg::default(); + config.palette = 1; + config.colsb[0] = 64; + config.rows[0] = 16; + _tile_loadconfig(config.as_ptr()); + _tile_loadd::<0>(array.as_ptr().cast(), 64); + for i in 0..16 { + let row = _tile_cvtrowps2bf16h::<0>(i); + assert_eq!( + *row.as_u16x32().as_array(), + array::from_fn(|j| if j & 1 == 0 { + 0 + } else { + _mm_cvtness_sbh(i as _).to_bits() + }) + ); + } + } + } + + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test_tile_cvtrowps2bf16hi() { + unsafe { + _init_amx(); + let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut config = __tilecfg::default(); + config.palette = 1; + config.colsb[0] = 64; + config.rows[0] = 16; + _tile_loadconfig(config.as_ptr()); + _tile_loadd::<0>(array.as_ptr().cast(), 64); + for i in 0..16 { + let row = wrap_imm4!(_tile_cvtrowps2bf16hi::<0>, i); + assert_eq!( + *row.as_u16x32().as_array(), + array::from_fn(|j| if j & 1 == 0 { + 0 + } else { + _mm_cvtness_sbh(i as _).to_bits() + }) + ); + } + } + } + + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test_tile_cvtrowps2bf16l() { + unsafe { + _init_amx(); + let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut config = __tilecfg::default(); + config.palette = 1; + config.colsb[0] = 64; + config.rows[0] = 16; + _tile_loadconfig(config.as_ptr()); + _tile_loadd::<0>(array.as_ptr().cast(), 64); + for i in 0..16 { + let row = _tile_cvtrowps2bf16l::<0>(i); + assert_eq!( + *row.as_u16x32().as_array(), + array::from_fn(|j| if j & 1 == 0 { + _mm_cvtness_sbh(i as _).to_bits() + } else { + 0 + }) + ); + } + } + } + + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test_tile_cvtrowps2bf16li() { + unsafe { + _init_amx(); + let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut config = __tilecfg::default(); + config.palette = 1; + config.colsb[0] = 64; + config.rows[0] = 16; + _tile_loadconfig(config.as_ptr()); + _tile_loadd::<0>(array.as_ptr().cast(), 64); + for i in 0..16 { + let row = wrap_imm4!(_tile_cvtrowps2bf16li::<0>, i); + assert_eq!( + *row.as_u16x32().as_array(), + array::from_fn(|j| if j & 1 == 0 { + _mm_cvtness_sbh(i as _).to_bits() + } else { + 0 + }) + ); + } + } + } + #[simd_test(enable = "amx-tf32")] fn test_tile_mmultf32ps() { unsafe { From 7eca5b6dd97d7a71b50d5ee8ee3a7ff374d9a69a Mon Sep 17 00:00:00 2001 From: sayantn Date: Tue, 21 Apr 2026 16:49:09 +0530 Subject: [PATCH 2/2] enable AMX instruction tests in windows-gnu --- crates/core_arch/src/x86_64/amx.rs | 38 +++++++++++++++--------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/crates/core_arch/src/x86_64/amx.rs b/crates/core_arch/src/x86_64/amx.rs index 62e46097e6..b3b3e86750 100644 --- a/crates/core_arch/src/x86_64/amx.rs +++ b/crates/core_arch/src/x86_64/amx.rs @@ -252,7 +252,7 @@ pub unsafe fn _tile_cmmrlfp16ps() { #[rustc_legacy_const_generics(0, 1, 2)] #[target_feature(enable = "amx-fp8")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tdpbf8ps, DST = 0, A = 1, B = 2) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] @@ -271,7 +271,7 @@ pub unsafe fn _tile_dpbf8ps() { #[rustc_legacy_const_generics(0, 1, 2)] #[target_feature(enable = "amx-fp8")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tdpbhf8ps, DST = 0, A = 1, B = 2) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] @@ -290,7 +290,7 @@ pub unsafe fn _tile_dpbhf8ps() { #[rustc_legacy_const_generics(0, 1, 2)] #[target_feature(enable = "amx-fp8")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tdphbf8ps, DST = 0, A = 1, B = 2) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] @@ -309,7 +309,7 @@ pub unsafe fn _tile_dphbf8ps() { #[rustc_legacy_const_generics(0, 1, 2)] #[target_feature(enable = "amx-fp8")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tdphf8ps, DST = 0, A = 1, B = 2) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] @@ -329,7 +329,7 @@ pub unsafe fn _tile_dphf8ps() { #[rustc_legacy_const_generics(0)] #[target_feature(enable = "amx-movrs")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tileloaddrs, DST = 0) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] @@ -349,7 +349,7 @@ pub unsafe fn _tile_loaddrs(base: *const u8, stride: usize) { #[rustc_legacy_const_generics(0)] #[target_feature(enable = "amx-movrs")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tileloaddrst1, DST = 0) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] @@ -372,7 +372,7 @@ pub unsafe fn _tile_stream_loaddrs(base: *const u8, stride: usiz #[rustc_legacy_const_generics(0, 1, 2)] #[target_feature(enable = "amx-tf32")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tmmultf32ps, DST = 0, A = 1, B = 2) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] @@ -389,7 +389,7 @@ pub unsafe fn _tile_mmultf32ps() { #[rustc_legacy_const_generics(0)] #[target_feature(enable = "amx-avx512,avx10.2")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tcvtrowd2ps, TILE = 0) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] @@ -404,7 +404,7 @@ pub unsafe fn _tile_cvtrowd2ps(row: u32) -> __m512 { #[rustc_legacy_const_generics(0, 1)] #[target_feature(enable = "amx-avx512,avx10.2")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tcvtrowd2ps, TILE = 0, ROW = 0) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] @@ -421,7 +421,7 @@ pub unsafe fn _tile_cvtrowd2psi() -> __m512 { #[rustc_legacy_const_generics(0)] #[target_feature(enable = "amx-avx512,avx10.2")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2phh, TILE = 0) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] @@ -437,7 +437,7 @@ pub unsafe fn _tile_cvtrowps2phh(row: u32) -> __m512h { #[rustc_legacy_const_generics(0, 1)] #[target_feature(enable = "amx-avx512,avx10.2")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2phh, TILE = 0, ROW = 0) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] @@ -454,7 +454,7 @@ pub unsafe fn _tile_cvtrowps2phhi() -> __m512h #[rustc_legacy_const_generics(0)] #[target_feature(enable = "amx-avx512,avx10.2")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2phl, TILE = 0) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] @@ -470,7 +470,7 @@ pub unsafe fn _tile_cvtrowps2phl(row: u32) -> __m512h { #[rustc_legacy_const_generics(0, 1)] #[target_feature(enable = "amx-avx512,avx10.2")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2phl, TILE = 0, ROW = 0) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] @@ -487,7 +487,7 @@ pub unsafe fn _tile_cvtrowps2phli() -> __m512h #[rustc_legacy_const_generics(0)] #[target_feature(enable = "amx-avx512,avx10.2")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2bf16h, TILE = 0) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] @@ -503,7 +503,7 @@ pub unsafe fn _tile_cvtrowps2bf16h(row: u32) -> __m512bh { #[rustc_legacy_const_generics(0, 1)] #[target_feature(enable = "amx-avx512,avx10.2")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2bf16h, TILE = 0, ROW = 0) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] @@ -520,7 +520,7 @@ pub unsafe fn _tile_cvtrowps2bf16hi() -> __m512 #[rustc_legacy_const_generics(0)] #[target_feature(enable = "amx-avx512,avx10.2")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2bf16l, TILE = 0) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] @@ -536,7 +536,7 @@ pub unsafe fn _tile_cvtrowps2bf16l(row: u32) -> __m512bh { #[rustc_legacy_const_generics(0, 1)] #[target_feature(enable = "amx-avx512,avx10.2")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2bf16l, TILE = 0, ROW = 0) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] @@ -551,7 +551,7 @@ pub unsafe fn _tile_cvtrowps2bf16li() -> __m512 #[rustc_legacy_const_generics(0)] #[target_feature(enable = "amx-avx512,avx10.2")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tilemovrow, TILE = 0) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] @@ -565,7 +565,7 @@ pub unsafe fn _tile_movrow(row: u32) -> __m512i { #[rustc_legacy_const_generics(0, 1)] #[target_feature(enable = "amx-avx512,avx10.2")] #[cfg_attr( - all(test, any(target_os = "linux", target_env = "msvc")), + all(test, not(target_vendor = "apple")), assert_instr(tilemovrow, TILE = 0, ROW = 0) )] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")]