From 8660081c63e7f21b81e5ebbc67620e6160964095 Mon Sep 17 00:00:00 2001 From: Jeroen Van Der Donckt Date: Wed, 28 Dec 2022 12:02:00 +0100 Subject: [PATCH 1/8] :recycle: use different index SIMD dtype --- src/simd/generic.rs | 80 +++++++++++++++----------- src/simd/simd_f16.rs | 78 ++++++++++++++++++++------ src/simd/simd_f32.rs | 131 +++++++++++++++++++++++++++---------------- src/simd/simd_f64.rs | 102 ++++++++++++++++++--------------- src/simd/simd_i16.rs | 78 +++++++++++++++++++------- src/simd/simd_i32.rs | 74 ++++++++++++++++++------ src/simd/simd_i64.rs | 48 +++++++++++++--- src/simd/simd_i8.rs | 114 +++++++++++++++++++++++++------------ src/simd/simd_u16.rs | 66 ++++++++++++++++++---- src/simd/simd_u32.rs | 66 ++++++++++++++++++---- src/simd/simd_u64.rs | 50 ++++++++++++++--- src/simd/simd_u8.rs | 68 ++++++++++++++++++---- src/utils.rs | 12 +++- 13 files changed, 696 insertions(+), 271 deletions(-) diff --git a/src/simd/generic.rs b/src/simd/generic.rs index b317b70..1bfe9e2 100644 --- a/src/simd/generic.rs +++ b/src/simd/generic.rs @@ -8,14 +8,16 @@ use crate::scalar::{ScalarArgMinMax, SCALAR}; // TODO: other potential generic SIMDIndexDtype: Copy #[allow(clippy::missing_safety_doc)] // TODO: add safety docs? pub trait SIMD< - ScalarDType: Copy + PartialOrd + AsPrimitive, - SIMDVecDtype: Copy, + ValueDType: Copy + PartialOrd, + SIMDValueDtype: Copy, + IndexDtype: Copy + PartialOrd + AsPrimitive, + SIMDIndexDtype: Copy, SIMDMaskDtype: Copy, const LANE_SIZE: usize, > { - const INITIAL_INDEX: SIMDVecDtype; - const MAX_INDEX: usize; // Integers > this value **cannot** be accurately represented in SIMDVecDtype + const INITIAL_INDEX: SIMDIndexDtype; + const MAX_INDEX: usize; // Integers > this value **cannot** be accurately represented in SIMDIndexDtype #[inline(always)] fn _find_largest_lower_multiple_of_lane_size(n: usize) -> usize { @@ -24,22 +26,34 @@ pub trait SIMD< // ------------------------------------ SIMD HELPERS -------------------------------------- - unsafe fn _reg_to_arr(reg: SIMDVecDtype) -> [ScalarDType; LANE_SIZE]; + unsafe fn _reg_to_arr_values(reg: SIMDValueDtype) -> [ValueDType; LANE_SIZE]; - unsafe fn _mm_loadu(data: *const ScalarDType) -> SIMDVecDtype; + unsafe fn _reg_to_arr_indices(reg: SIMDIndexDtype) -> [IndexDtype; LANE_SIZE]; - unsafe fn _mm_set1(a: usize) -> SIMDVecDtype; + unsafe fn _mm_loadu(data: *const ValueDType) -> SIMDValueDtype; - unsafe fn _mm_add(a: SIMDVecDtype, b: SIMDVecDtype) -> SIMDVecDtype; + unsafe fn _mm_set1(a: usize) -> SIMDIndexDtype; - unsafe fn _mm_cmpgt(a: SIMDVecDtype, b: SIMDVecDtype) -> SIMDMaskDtype; + unsafe fn _mm_add(a: SIMDIndexDtype, b: SIMDIndexDtype) -> SIMDIndexDtype; - unsafe fn _mm_cmplt(a: SIMDVecDtype, b: SIMDVecDtype) -> SIMDMaskDtype; + unsafe fn _mm_cmpgt(a: SIMDValueDtype, b: SIMDValueDtype) -> SIMDMaskDtype; - unsafe fn _mm_blendv(a: SIMDVecDtype, b: SIMDVecDtype, mask: SIMDMaskDtype) -> SIMDVecDtype; + unsafe fn _mm_cmplt(a: SIMDValueDtype, b: SIMDValueDtype) -> SIMDMaskDtype; + + unsafe fn _mm_blendv_values( + a: SIMDValueDtype, + b: SIMDValueDtype, + mask: SIMDMaskDtype, + ) -> SIMDValueDtype; + + unsafe fn _mm_blendv_indices( + a: SIMDIndexDtype, + b: SIMDIndexDtype, + mask: SIMDMaskDtype, + ) -> SIMDIndexDtype; #[inline(always)] - unsafe fn _horiz_min(index: SIMDVecDtype, value: SIMDVecDtype) -> (usize, ScalarDType) { + unsafe fn _horiz_min(index: SIMDIndexDtype, value: SIMDValueDtype) -> (usize, ValueDType) { // This becomes the bottleneck when using 8-bit data types, as for every 2**7 // or 2**8 elements, the SIMD inner loop is executed (& thus also terminated) // to avoid overflow. @@ -48,14 +62,14 @@ pub trait SIMD< // see: https://stackoverflow.com/a/9798369 // Note: this is not a bottleneck for 16-bit data types, as the termination of // the SIMD inner loop is 2**8 times less frequent. - let index_arr = Self::_reg_to_arr(index); - let value_arr = Self::_reg_to_arr(value); + let index_arr = Self::_reg_to_arr_indices(index); + let value_arr = Self::_reg_to_arr_values(value); let (min_index, min_value) = min_index_value(&index_arr, &value_arr); (min_index.as_(), min_value) } #[inline(always)] - unsafe fn _horiz_max(index: SIMDVecDtype, value: SIMDVecDtype) -> (usize, ScalarDType) { + unsafe fn _horiz_max(index: SIMDIndexDtype, value: SIMDValueDtype) -> (usize, ValueDType) { // This becomes the bottleneck when using 8-bit data types, as for every 2**7 // or 2**8 elements, the SIMD inner loop is executed (& thus also terminated) // to avoid overflow. @@ -64,28 +78,28 @@ pub trait SIMD< // see: https://stackoverflow.com/a/9798369 // Note: this is not a bottleneck for 16-bit data types, as the termination of // the SIMD inner loop is 2**8 times less frequent. - let index_arr = Self::_reg_to_arr(index); - let value_arr = Self::_reg_to_arr(value); + let index_arr = Self::_reg_to_arr_indices(index); + let value_arr = Self::_reg_to_arr_values(value); let (max_index, max_value) = max_index_value(&index_arr, &value_arr); (max_index.as_(), max_value) } // ------------------------------------ ARGMINMAX -------------------------------------- - unsafe fn argminmax(data: ArrayView1) -> (usize, usize); + unsafe fn argminmax(data: ArrayView1) -> (usize, usize); #[inline(always)] - unsafe fn _argminmax(data: ArrayView1) -> (usize, usize) + unsafe fn _argminmax(data: ArrayView1) -> (usize, usize) where - SCALAR: ScalarArgMinMax, + SCALAR: ScalarArgMinMax, { argminmax_generic(data, LANE_SIZE, Self::_overflow_safe_core_argminmax) } #[inline(always)] unsafe fn _overflow_safe_core_argminmax( - arr: ArrayView1, - ) -> (usize, ScalarDType, usize, ScalarDType) { + arr: ArrayView1, + ) -> (usize, ValueDType, usize, ValueDType) { // 0. Get the max value of the data type - which needs to be divided by LANE_SIZE let dtype_max = Self::_find_largest_lower_multiple_of_lane_size(Self::MAX_INDEX); @@ -129,11 +143,11 @@ pub trait SIMD< // TODO: can be cleaner (perhaps?) #[inline(always)] unsafe fn _get_min_max_index_value( - index_low: SIMDVecDtype, - values_low: SIMDVecDtype, - index_high: SIMDVecDtype, - values_high: SIMDVecDtype, - ) -> (usize, ScalarDType, usize, ScalarDType) { + index_low: SIMDIndexDtype, + values_low: SIMDValueDtype, + index_high: SIMDIndexDtype, + values_high: SIMDValueDtype, + ) -> (usize, ValueDType, usize, ValueDType) { let (min_index, min_value) = Self::_horiz_min(index_low, values_low); let (max_index, max_value) = Self::_horiz_max(index_high, values_high); (min_index, min_value, max_index, max_value) @@ -141,8 +155,8 @@ pub trait SIMD< #[inline(always)] unsafe fn _core_argminmax( - arr: ArrayView1, - ) -> (usize, ScalarDType, usize, ScalarDType) { + arr: ArrayView1, + ) -> (usize, ValueDType, usize, ValueDType) { assert_eq!(arr.len() % LANE_SIZE, 0); // Efficient calculation of argmin and argmax together let mut new_index = Self::INITIAL_INDEX; @@ -166,11 +180,11 @@ pub trait SIMD< let lt_mask = Self::_mm_cmplt(new_values, values_low); let gt_mask = Self::_mm_cmpgt(new_values, values_high); - index_low = Self::_mm_blendv(index_low, new_index, lt_mask); - index_high = Self::_mm_blendv(index_high, new_index, gt_mask); + index_low = Self::_mm_blendv_indices(index_low, new_index, lt_mask); + index_high = Self::_mm_blendv_indices(index_high, new_index, gt_mask); - values_low = Self::_mm_blendv(values_low, new_values, lt_mask); - values_high = Self::_mm_blendv(values_high, new_values, gt_mask); + values_low = Self::_mm_blendv_values(values_low, new_values, lt_mask); + values_high = Self::_mm_blendv_values(values_high, new_values, gt_mask); }); Self::_get_min_max_index_value(index_low, values_low, index_high, values_high) diff --git a/src/simd/simd_f16.rs b/src/simd/simd_f16.rs index a2c0b76..5caab80 100644 --- a/src/simd/simd_f16.rs +++ b/src/simd/simd_f16.rs @@ -57,7 +57,7 @@ mod avx2 { std::mem::transmute::<__m256i, [i16; LANE_SIZE]>(reg) } - impl SIMD for AVX2 { + impl SIMD for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([ 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16, @@ -67,7 +67,13 @@ mod avx2 { const MAX_INDEX: usize = i16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m256i) -> [f16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m256i) -> [f16; LANE_SIZE] { + // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m256i) -> [i16; LANE_SIZE] { // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -98,7 +104,12 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + unsafe fn _mm_blendv_values(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } @@ -260,13 +271,19 @@ mod sse { std::mem::transmute::<__m128i, [i16; LANE_SIZE]>(reg) } - impl SIMD for SSE { + impl SIMD for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; const MAX_INDEX: usize = i16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m128i) -> [f16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m128i) -> [f16; LANE_SIZE] { + // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m128i) -> [i16; LANE_SIZE] { // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -297,7 +314,12 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + unsafe fn _mm_blendv_values(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } @@ -426,7 +448,7 @@ mod avx512 { std::mem::transmute::<__m512i, [i16; LANE_SIZE]>(reg) } - impl SIMD for AVX512 { + impl SIMD for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16, @@ -437,7 +459,13 @@ mod avx512 { const MAX_INDEX: usize = i16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m512i) -> [f16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m512i) -> [f16; LANE_SIZE] { + // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m512i) -> [i16; LANE_SIZE] { // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -468,7 +496,12 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u32) -> __m512i { + unsafe fn _mm_blendv_values(a: __m512i, b: __m512i, mask: u32) -> __m512i { + _mm512_mask_blend_epi16(mask, a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u32) -> __m512i { _mm512_mask_blend_epi16(mask, a, b) } @@ -614,13 +647,19 @@ mod neon { std::mem::transmute::(reg) } - impl SIMD for NEON { + impl SIMD for NEON { const INITIAL_INDEX: int16x8_t = unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; - const MAX_INDEX: usize = i16::MAX as usize; + const MAX_INDEX: usize = u16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: int16x8_t) -> [f16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: int16x8_t) -> [f16; LANE_SIZE] { + // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: uint16x8_t) -> [u16; LANE_SIZE] { // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -633,13 +672,13 @@ mod neon { } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> int16x8_t { - vdupq_n_s16(a as i16) + unsafe fn _mm_set1(a: usize) -> uint16x8_t { + vdupq_n_u16(a as u16) } #[inline(always)] - unsafe fn _mm_add(a: int16x8_t, b: int16x8_t) -> int16x8_t { - vaddq_s16(a, b) + unsafe fn _mm_add(a: uint16x8_t, b: uint16x8_t) -> uint16x8_t { + vaddq_u16(a, b) } #[inline(always)] @@ -653,10 +692,15 @@ mod neon { } #[inline(always)] - unsafe fn _mm_blendv(a: int16x8_t, b: int16x8_t, mask: uint16x8_t) -> int16x8_t { + unsafe fn _mm_blendv_values(a: int16x8_t, b: int16x8_t, mask: uint16x8_t) -> int16x8_t { vbslq_s16(mask, b, a) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: uint16x8_t, b: uint16x8_t, mask: uint16x8_t) -> uint16x8_t { + vbslq_u16(mask, b, a) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "neon")] diff --git a/src/simd/simd_f32.rs b/src/simd/simd_f32.rs index 35aa7c3..8208d11 100644 --- a/src/simd/simd_f32.rs +++ b/src/simd/simd_f32.rs @@ -19,33 +19,34 @@ mod avx2 { const LANE_SIZE: usize = AVX2::LANE_SIZE_32; - impl SIMD for AVX2 { - const INITIAL_INDEX: __m256 = unsafe { - std::mem::transmute([ - 0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, - ]) - }; - // https://stackoverflow.com/a/3793950 - const MAX_INDEX: usize = 1 << f32::MANTISSA_DIGITS; + impl SIMD for AVX2 { + const INITIAL_INDEX: __m256i = + unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32]) }; + const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m256) -> [f32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m256) -> [f32; LANE_SIZE] { std::mem::transmute::<__m256, [f32; LANE_SIZE]>(reg) } + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m256i) -> [i32; LANE_SIZE] { + std::mem::transmute::<__m256i, [i32; LANE_SIZE]>(reg) + } + #[inline(always)] unsafe fn _mm_loadu(data: *const f32) -> __m256 { _mm256_loadu_ps(data as *const f32) } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m256 { - _mm256_set1_ps(a as f32) + unsafe fn _mm_set1(a: usize) -> __m256i { + _mm256_set1_epi32(a as i32) } #[inline(always)] - unsafe fn _mm_add(a: __m256, b: __m256) -> __m256 { - _mm256_add_ps(a, b) + unsafe fn _mm_add(a: __m256i, b: __m256i) -> __m256i { + _mm256_add_epi32(a, b) } #[inline(always)] @@ -59,10 +60,15 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256, b: __m256, mask: __m256) -> __m256 { + unsafe fn _mm_blendv_values(a: __m256, b: __m256, mask: __m256) -> __m256 { _mm256_blendv_ps(a, b, mask) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256) -> __m256i { + _mm256_blendv_epi8(a, b, _mm256_castps_si256(mask)) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "avx")] @@ -171,30 +177,33 @@ mod sse { const LANE_SIZE: usize = SSE::LANE_SIZE_32; - impl SIMD for SSE { - const INITIAL_INDEX: __m128 = - unsafe { std::mem::transmute([0.0f32, 1.0f32, 2.0f32, 3.0f32]) }; - // https://stackoverflow.com/a/3793950 - const MAX_INDEX: usize = 1 << f32::MANTISSA_DIGITS; + impl SIMD for SSE { + const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32]) }; + const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m128) -> [f32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m128) -> [f32; LANE_SIZE] { std::mem::transmute::<__m128, [f32; LANE_SIZE]>(reg) } + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m128i) -> [i32; LANE_SIZE] { + std::mem::transmute::<__m128i, [i32; LANE_SIZE]>(reg) + } + #[inline(always)] unsafe fn _mm_loadu(data: *const f32) -> __m128 { _mm_loadu_ps(data as *const f32) } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m128 { - _mm_set1_ps(a as f32) + unsafe fn _mm_set1(a: usize) -> __m128i { + _mm_set1_epi32(a as i32) } #[inline(always)] - unsafe fn _mm_add(a: __m128, b: __m128) -> __m128 { - _mm_add_ps(a, b) + unsafe fn _mm_add(a: __m128i, b: __m128i) -> __m128i { + _mm_add_epi32(a, b) } #[inline(always)] @@ -208,10 +217,15 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128, b: __m128, mask: __m128) -> __m128 { + unsafe fn _mm_blendv_values(a: __m128, b: __m128, mask: __m128) -> __m128 { _mm_blendv_ps(a, b, mask) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128) -> __m128i { + _mm_blendv_epi8(a, b, _mm_castps_si128(mask)) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "sse4.1")] @@ -303,34 +317,38 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_32; - impl SIMD for AVX512 { - const INITIAL_INDEX: __m512 = unsafe { + impl SIMD for AVX512 { + const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ - 0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, 8.0f32, 9.0f32, - 10.0f32, 11.0f32, 12.0f32, 13.0f32, 14.0f32, 15.0f32, + 0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32, 8i32, 9i32, 10i32, 11i32, 12i32, + 13i32, 14i32, 15i32, ]) }; - // https://stackoverflow.com/a/3793950 - const MAX_INDEX: usize = 1 << f32::MANTISSA_DIGITS; + const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m512) -> [f32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m512) -> [f32; LANE_SIZE] { std::mem::transmute::<__m512, [f32; LANE_SIZE]>(reg) } + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m512i) -> [i32; LANE_SIZE] { + std::mem::transmute::<__m512i, [i32; LANE_SIZE]>(reg) + } + #[inline(always)] unsafe fn _mm_loadu(data: *const f32) -> __m512 { _mm512_loadu_ps(data as *const f32) } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m512 { - _mm512_set1_ps(a as f32) + unsafe fn _mm_set1(a: usize) -> __m512i { + _mm512_set1_epi32(a as i32) } #[inline(always)] - unsafe fn _mm_add(a: __m512, b: __m512) -> __m512 { - _mm512_add_ps(a, b) + unsafe fn _mm_add(a: __m512i, b: __m512i) -> __m512i { + _mm512_add_epi32(a, b) } #[inline(always)] @@ -354,7 +372,7 @@ mod avx512 { // { _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ) } #[inline(always)] - unsafe fn _mm_blendv(a: __m512, b: __m512, mask: u16) -> __m512 { + unsafe fn _mm_blendv_values(a: __m512, b: __m512, mask: u16) -> __m512 { _mm512_mask_blend_ps(mask, a, b) } // unimplemented!("AVX512 blendv instructions for ps require a u16 mask.") @@ -365,6 +383,11 @@ mod avx512 { // _mm512_mask_mov_ps(a, _mm512_castps_si512(mask), b) // } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u16) -> __m512i { + _mm512_mask_blend_epi32(mask, a, b) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "avx512f")] @@ -473,16 +496,19 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_32; - impl SIMD for NEON { - const INITIAL_INDEX: float32x4_t = - unsafe { std::mem::transmute([0.0f32, 1.0f32, 2.0f32, 3.0f32]) }; + impl SIMD for NEON { + const INITIAL_INDEX: uint32x4_t = unsafe { std::mem::transmute([0u32, 1u32, 2u32, 3u32]) }; + const MAX_INDEX: usize = u32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: float32x4_t) -> [f32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: float32x4_t) -> [f32; LANE_SIZE] { std::mem::transmute::(reg) } - // https://stackoverflow.com/a/3793950 - const MAX_INDEX: usize = 1 << f32::MANTISSA_DIGITS; + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: uint32x4_t) -> [u32; LANE_SIZE] { + std::mem::transmute::(reg) + } #[inline(always)] unsafe fn _mm_loadu(data: *const f32) -> float32x4_t { @@ -490,13 +516,13 @@ mod neon { } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> float32x4_t { - vdupq_n_f32(a as f32) + unsafe fn _mm_set1(a: usize) -> uint32x4_t { + vdupq_n_u32(a as u32) } #[inline(always)] - unsafe fn _mm_add(a: float32x4_t, b: float32x4_t) -> float32x4_t { - vaddq_f32(a, b) + unsafe fn _mm_add(a: uint32x4_t, b: uint32x4_t) -> uint32x4_t { + vaddq_u32(a, b) } #[inline(always)] @@ -510,10 +536,19 @@ mod neon { } #[inline(always)] - unsafe fn _mm_blendv(a: float32x4_t, b: float32x4_t, mask: uint32x4_t) -> float32x4_t { + unsafe fn _mm_blendv_values( + a: float32x4_t, + b: float32x4_t, + mask: uint32x4_t, + ) -> float32x4_t { vbslq_f32(mask, b, a) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: uint32x4_t, b: uint32x4_t, mask: uint32x4_t) -> uint32x4_t { + vbslq_u32(mask, b, a) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "neon")] diff --git a/src/simd/simd_f64.rs b/src/simd/simd_f64.rs index e252cde..d63ed98 100644 --- a/src/simd/simd_f64.rs +++ b/src/simd/simd_f64.rs @@ -16,33 +16,33 @@ mod avx2 { const LANE_SIZE: usize = AVX2::LANE_SIZE_64; - impl SIMD for AVX2 { - const INITIAL_INDEX: __m256d = - unsafe { std::mem::transmute([0.0f64, 1.0f64, 2.0f64, 3.0f64]) }; - // https://stackoverflow.com/a/3793950 - #[cfg(target_arch = "x86_64")] - const MAX_INDEX: usize = 1 << f64::MANTISSA_DIGITS; - #[cfg(target_arch = "x86")] // https://stackoverflow.com/a/29592369 - const MAX_INDEX: usize = u32::MAX as usize; + impl SIMD for AVX2 { + const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([0i64, 1i64, 2i64, 3i64]) }; + const MAX_INDEX: usize = i64::MAX as usize; // TODO overflow on x86? #[inline(always)] - unsafe fn _reg_to_arr(reg: __m256d) -> [f64; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m256d) -> [f64; LANE_SIZE] { std::mem::transmute::<__m256d, [f64; LANE_SIZE]>(reg) } + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m256i) -> [i64; LANE_SIZE] { + std::mem::transmute::<__m256i, [i64; LANE_SIZE]>(reg) + } + #[inline(always)] unsafe fn _mm_loadu(data: *const f64) -> __m256d { _mm256_loadu_pd(data as *const f64) } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m256d { - _mm256_set1_pd(a as f64) + unsafe fn _mm_set1(a: usize) -> __m256i { + _mm256_set1_epi64x(a as i64) } #[inline(always)] - unsafe fn _mm_add(a: __m256d, b: __m256d) -> __m256d { - _mm256_add_pd(a, b) + unsafe fn _mm_add(a: __m256i, b: __m256i) -> __m256i { + _mm256_add_epi64(a, b) } #[inline(always)] @@ -56,10 +56,15 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256d, b: __m256d, mask: __m256d) -> __m256d { + unsafe fn _mm_blendv_values(a: __m256d, b: __m256d, mask: __m256d) -> __m256d { _mm256_blendv_pd(a, b, mask) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256d) -> __m256i { + _mm256_blendv_epi8(a, b, _mm256_castpd_si256(mask)) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "avx")] @@ -153,32 +158,33 @@ mod sse { const LANE_SIZE: usize = SSE::LANE_SIZE_64; - impl SIMD for SSE { - const INITIAL_INDEX: __m128d = unsafe { std::mem::transmute([0.0f64, 1.0f64]) }; - // https://stackoverflow.com/a/3793950 - #[cfg(target_arch = "x86_64")] - const MAX_INDEX: usize = 1 << f64::MANTISSA_DIGITS; - #[cfg(target_arch = "x86")] // https://stackoverflow.com/a/29592369 - const MAX_INDEX: usize = u32::MAX as usize; + impl SIMD for SSE { + const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i64, 1i64]) }; + const MAX_INDEX: usize = i64::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m128d) -> [f64; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m128d) -> [f64; LANE_SIZE] { std::mem::transmute::<__m128d, [f64; LANE_SIZE]>(reg) } + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m128i) -> [i64; LANE_SIZE] { + std::mem::transmute::<__m128i, [i64; LANE_SIZE]>(reg) + } + #[inline(always)] unsafe fn _mm_loadu(data: *const f64) -> __m128d { _mm_loadu_pd(data as *const f64) } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m128d { - _mm_set1_pd(a as f64) + unsafe fn _mm_set1(a: usize) -> __m128i { + _mm_set1_epi64x(a as i64) } #[inline(always)] - unsafe fn _mm_add(a: __m128d, b: __m128d) -> __m128d { - _mm_add_pd(a, b) + unsafe fn _mm_add(a: __m128i, b: __m128i) -> __m128i { + _mm_add_epi64(a, b) } #[inline(always)] @@ -192,10 +198,15 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128d, b: __m128d, mask: __m128d) -> __m128d { + unsafe fn _mm_blendv_values(a: __m128d, b: __m128d, mask: __m128d) -> __m128d { _mm_blendv_pd(a, b, mask) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128d) -> __m128i { + _mm_blendv_epi8(a, b, _mm_castpd_si128(mask)) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "sse4.1")] @@ -276,36 +287,34 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_64; - impl SIMD for AVX512 { - const INITIAL_INDEX: __m512d = unsafe { - std::mem::transmute([ - 0.0f64, 1.0f64, 2.0f64, 3.0f64, 4.0f64, 5.0f64, 6.0f64, 7.0f64, - ]) - }; - // https://stackoverflow.com/a/3793950 - #[cfg(target_arch = "x86_64")] - const MAX_INDEX: usize = 1 << f64::MANTISSA_DIGITS; - #[cfg(target_arch = "x86")] // https://stackoverflow.com/a/29592369 - const MAX_INDEX: usize = u32::MAX as usize; + impl SIMD for AVX512 { + const INITIAL_INDEX: __m512i = + unsafe { std::mem::transmute([0i64, 1i64, 2i64, 3i64, 4i64, 5i64, 6i64, 7i64]) }; + const MAX_INDEX: usize = i64::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m512d) -> [f64; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m512d) -> [f64; LANE_SIZE] { std::mem::transmute::<__m512d, [f64; LANE_SIZE]>(reg) } + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m512i) -> [i64; LANE_SIZE] { + std::mem::transmute::<__m512i, [i64; LANE_SIZE]>(reg) + } + #[inline(always)] unsafe fn _mm_loadu(data: *const f64) -> __m512d { _mm512_loadu_pd(data as *const f64) } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m512d { - _mm512_set1_pd(a as f64) + unsafe fn _mm_set1(a: usize) -> __m512i { + _mm512_set1_epi64(a as i64) } #[inline(always)] - unsafe fn _mm_add(a: __m512d, b: __m512d) -> __m512d { - _mm512_add_pd(a, b) + unsafe fn _mm_add(a: __m512i, b: __m512i) -> __m512i { + _mm512_add_epi64(a, b) } #[inline(always)] @@ -319,10 +328,15 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512d, b: __m512d, mask: u8) -> __m512d { + unsafe fn _mm_blendv_values(a: __m512d, b: __m512d, mask: u8) -> __m512d { _mm512_mask_blend_pd(mask, a, b) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u8) -> __m512i { + _mm512_mask_blend_epi64(mask, a, b) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "avx512f")] diff --git a/src/simd/simd_i16.rs b/src/simd/simd_i16.rs index 56ce25c..4768959 100644 --- a/src/simd/simd_i16.rs +++ b/src/simd/simd_i16.rs @@ -19,7 +19,7 @@ mod avx2 { const LANE_SIZE: usize = AVX2::LANE_SIZE_16; - impl SIMD for AVX2 { + impl SIMD for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([ 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16, @@ -29,7 +29,12 @@ mod avx2 { const MAX_INDEX: usize = i16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m256i) -> [i16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m256i) -> [i16; LANE_SIZE] { + std::mem::transmute::<__m256i, [i16; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m256i) -> [i16; LANE_SIZE] { std::mem::transmute::<__m256i, [i16; LANE_SIZE]>(reg) } @@ -59,7 +64,12 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + unsafe fn _mm_blendv_values(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } @@ -173,13 +183,18 @@ mod sse { const LANE_SIZE: usize = SSE::LANE_SIZE_16; - impl SIMD for SSE { + impl SIMD for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; const MAX_INDEX: usize = i16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m128i) -> [i16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m128i) -> [i16; LANE_SIZE] { + std::mem::transmute::<__m128i, [i16; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m128i) -> [i16; LANE_SIZE] { std::mem::transmute::<__m128i, [i16; LANE_SIZE]>(reg) } @@ -209,7 +224,12 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + unsafe fn _mm_blendv_values(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } @@ -306,7 +326,7 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_16; - impl SIMD for AVX512 { + impl SIMD for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16, @@ -317,7 +337,12 @@ mod avx512 { const MAX_INDEX: usize = i16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m512i) -> [i16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m512i) -> [i16; LANE_SIZE] { + std::mem::transmute::<__m512i, [i16; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m512i) -> [i16; LANE_SIZE] { std::mem::transmute::<__m512i, [i16; LANE_SIZE]>(reg) } @@ -347,7 +372,12 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u32) -> __m512i { + unsafe fn _mm_blendv_values(a: __m512i, b: __m512i, mask: u32) -> __m512i { + _mm512_mask_blend_epi16(mask, a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u32) -> __m512i { _mm512_mask_blend_epi16(mask, a, b) } @@ -461,29 +491,34 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_16; - impl SIMD for NEON { - const INITIAL_INDEX: int16x8_t = - unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; - const MAX_INDEX: usize = i16::MAX as usize; + impl SIMD for NEON { + const INITIAL_INDEX: uint16x8_t = + unsafe { std::mem::transmute([0u16, 1u16, 2u16, 3u16, 4u16, 5u16, 6u16, 7u16]) }; + const MAX_INDEX: usize = u16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: int16x8_t) -> [i16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: int16x8_t) -> [i16; LANE_SIZE] { std::mem::transmute::(reg) } + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: uint16x8_t) -> [u16; LANE_SIZE] { + std::mem::transmute::(reg) + } + #[inline(always)] unsafe fn _mm_loadu(data: *const i16) -> int16x8_t { vld1q_s16(data as *const i16) } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> int16x8_t { - vdupq_n_s16(a as i16) + unsafe fn _mm_set1(a: usize) -> uint16x8_t { + vdupq_n_u16(a as u16) } #[inline(always)] - unsafe fn _mm_add(a: int16x8_t, b: int16x8_t) -> int16x8_t { - vaddq_s16(a, b) + unsafe fn _mm_add(a: uint16x8_t, b: uint16x8_t) -> uint16x8_t { + vaddq_u16(a, b) } #[inline(always)] @@ -497,10 +532,15 @@ mod neon { } #[inline(always)] - unsafe fn _mm_blendv(a: int16x8_t, b: int16x8_t, mask: uint16x8_t) -> int16x8_t { + unsafe fn _mm_blendv_values(a: int16x8_t, b: int16x8_t, mask: uint16x8_t) -> int16x8_t { vbslq_s16(mask, b, a) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: uint16x8_t, b: uint16x8_t, mask: uint16x8_t) -> uint16x8_t { + vbslq_u16(mask, b, a) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "neon")] diff --git a/src/simd/simd_i32.rs b/src/simd/simd_i32.rs index 54bbb96..8b11209 100644 --- a/src/simd/simd_i32.rs +++ b/src/simd/simd_i32.rs @@ -19,13 +19,18 @@ mod avx2 { const LANE_SIZE: usize = AVX2::LANE_SIZE_32; - impl SIMD for AVX2 { + impl SIMD for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32]) }; const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m256i) -> [i32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m256i) -> [i32; LANE_SIZE] { + std::mem::transmute::<__m256i, [i32; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m256i) -> [i32; LANE_SIZE] { std::mem::transmute::<__m256i, [i32; LANE_SIZE]>(reg) } @@ -55,7 +60,12 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + unsafe fn _mm_blendv_values(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } @@ -153,12 +163,17 @@ mod sse { const LANE_SIZE: usize = SSE::LANE_SIZE_32; - impl SIMD for SSE { + impl SIMD for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32]) }; const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m128i) -> [i32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m128i) -> [i32; LANE_SIZE] { + std::mem::transmute::<__m128i, [i32; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m128i) -> [i32; LANE_SIZE] { std::mem::transmute::<__m128i, [i32; LANE_SIZE]>(reg) } @@ -188,7 +203,12 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + unsafe fn _mm_blendv_values(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } @@ -273,7 +293,7 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_32; - impl SIMD for AVX512 { + impl SIMD for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32, 8i32, 9i32, 10i32, 11i32, 12i32, @@ -283,7 +303,12 @@ mod avx512 { const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m512i) -> [i32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m512i) -> [i32; LANE_SIZE] { + std::mem::transmute::<__m512i, [i32; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m512i) -> [i32; LANE_SIZE] { std::mem::transmute::<__m512i, [i32; LANE_SIZE]>(reg) } @@ -313,7 +338,12 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u16) -> __m512i { + unsafe fn _mm_blendv_values(a: __m512i, b: __m512i, mask: u16) -> __m512i { + _mm512_mask_blend_epi32(mask, a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u16) -> __m512i { _mm512_mask_blend_epi32(mask, a, b) } @@ -411,28 +441,33 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_32; - impl SIMD for NEON { + impl SIMD for NEON { const INITIAL_INDEX: int32x4_t = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32]) }; - const MAX_INDEX: usize = i32::MAX as usize; + const MAX_INDEX: usize = u32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: int32x4_t) -> [i32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: int32x4_t) -> [i32; LANE_SIZE] { std::mem::transmute::(reg) } + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: uint32x4_t) -> [u32; LANE_SIZE] { + std::mem::transmute::(reg) + } + #[inline(always)] unsafe fn _mm_loadu(data: *const i32) -> int32x4_t { vld1q_s32(data) } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> int32x4_t { - vdupq_n_s32(a as i32) + unsafe fn _mm_set1(a: usize) -> uint32x4_t { + vdupq_n_u32(a as u32) } #[inline(always)] - unsafe fn _mm_add(a: int32x4_t, b: int32x4_t) -> int32x4_t { - vaddq_s32(a, b) + unsafe fn _mm_add(a: uint32x4_t, b: uint32x4_t) -> uint32x4_t { + vaddq_u32(a, b) } #[inline(always)] @@ -446,10 +481,15 @@ mod neon { } #[inline(always)] - unsafe fn _mm_blendv(a: int32x4_t, b: int32x4_t, mask: uint32x4_t) -> int32x4_t { + unsafe fn _mm_blendv_values(a: int32x4_t, b: int32x4_t, mask: uint32x4_t) -> int32x4_t { vbslq_s32(mask, b, a) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: uint32x4_t, b: uint32x4_t, mask: uint32x4_t) -> uint32x4_t { + vbslq_u32(mask, b, a) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "neon")] diff --git a/src/simd/simd_i64.rs b/src/simd/simd_i64.rs index 277ecff..d1846c4 100644 --- a/src/simd/simd_i64.rs +++ b/src/simd/simd_i64.rs @@ -16,12 +16,17 @@ mod avx2 { const LANE_SIZE: usize = AVX2::LANE_SIZE_64; - impl SIMD for AVX2 { + impl SIMD for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([0i64, 1i64, 2i64, 3i64]) }; const MAX_INDEX: usize = i64::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m256i) -> [i64; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m256i) -> [i64; LANE_SIZE] { + std::mem::transmute::<__m256i, [i64; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m256i) -> [i64; LANE_SIZE] { std::mem::transmute::<__m256i, [i64; LANE_SIZE]>(reg) } @@ -51,7 +56,12 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + unsafe fn _mm_blendv_values(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } @@ -149,12 +159,17 @@ mod sse { const LANE_SIZE: usize = SSE::LANE_SIZE_64; - impl SIMD for SSE { + impl SIMD for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i64, 1i64]) }; const MAX_INDEX: usize = i64::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m128i) -> [i64; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m128i) -> [i64; LANE_SIZE] { + std::mem::transmute::<__m128i, [i64; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m128i) -> [i64; LANE_SIZE] { std::mem::transmute::<__m128i, [i64; LANE_SIZE]>(reg) } @@ -184,7 +199,12 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + unsafe fn _mm_blendv_values(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } @@ -269,13 +289,18 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_64; - impl SIMD for AVX512 { + impl SIMD for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([0i64, 1i64, 2i64, 3i64, 4i64, 5i64, 6i64, 7i64]) }; const MAX_INDEX: usize = i64::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m512i) -> [i64; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m512i) -> [i64; LANE_SIZE] { + std::mem::transmute::<__m512i, [i64; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m512i) -> [i64; LANE_SIZE] { std::mem::transmute::<__m512i, [i64; LANE_SIZE]>(reg) } @@ -305,7 +330,12 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u8) -> __m512i { + unsafe fn _mm_blendv_values(a: __m512i, b: __m512i, mask: u8) -> __m512i { + _mm512_mask_blend_epi64(mask, a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u8) -> __m512i { _mm512_mask_blend_epi64(mask, a, b) } diff --git a/src/simd/simd_i8.rs b/src/simd/simd_i8.rs index a2eefbe..fd709bd 100644 --- a/src/simd/simd_i8.rs +++ b/src/simd/simd_i8.rs @@ -19,7 +19,7 @@ mod avx2 { const LANE_SIZE: usize = AVX2::LANE_SIZE_8; - impl SIMD for AVX2 { + impl SIMD for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, @@ -30,7 +30,12 @@ mod avx2 { const MAX_INDEX: usize = i8::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m256i) -> [i8; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m256i) -> [i8; LANE_SIZE] { + std::mem::transmute::<__m256i, [i8; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m256i) -> [i8; LANE_SIZE] { std::mem::transmute::<__m256i, [i8; LANE_SIZE]>(reg) } @@ -60,7 +65,12 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + unsafe fn _mm_blendv_values(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } @@ -228,7 +238,7 @@ mod sse { const LANE_SIZE: usize = SSE::LANE_SIZE_8; - impl SIMD for SSE { + impl SIMD for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, @@ -238,7 +248,12 @@ mod sse { const MAX_INDEX: usize = i8::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m128i) -> [i8; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m128i) -> [i8; LANE_SIZE] { + std::mem::transmute::<__m128i, [i8; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m128i) -> [i8; LANE_SIZE] { std::mem::transmute::<__m128i, [i8; LANE_SIZE]>(reg) } @@ -268,7 +283,12 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + unsafe fn _mm_blendv_values(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } @@ -415,7 +435,7 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_8; - impl SIMD for AVX512 { + impl SIMD for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, @@ -428,7 +448,12 @@ mod avx512 { const MAX_INDEX: usize = i8::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m512i) -> [i8; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: __m512i) -> [i8; LANE_SIZE] { + std::mem::transmute::<__m512i, [i8; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m512i) -> [i8; LANE_SIZE] { std::mem::transmute::<__m512i, [i8; LANE_SIZE]>(reg) } @@ -458,7 +483,12 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u64) -> __m512i { + unsafe fn _mm_blendv_values(a: __m512i, b: __m512i, mask: u64) -> __m512i { + _mm512_mask_blend_epi8(mask, a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u64) -> __m512i { _mm512_mask_blend_epi8(mask, a, b) } @@ -630,20 +660,25 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_8; - impl SIMD for NEON { - const INITIAL_INDEX: int8x16_t = unsafe { + impl SIMD for NEON { + const INITIAL_INDEX: uint8x16_t = unsafe { std::mem::transmute([ - 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, - 15i8, + 0u8, 1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8, 9u8, 10u8, 11u8, 12u8, 13u8, 14u8, + 15u8, ]) }; - const MAX_INDEX: usize = i8::MAX as usize; + const MAX_INDEX: usize = u8::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: int8x16_t) -> [i8; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: int8x16_t) -> [i8; LANE_SIZE] { std::mem::transmute::(reg) } + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: uint8x16_t) -> [u8; LANE_SIZE] { + std::mem::transmute::(reg) + } + #[inline(always)] unsafe fn _mm_loadu(data: *const i8) -> int8x16_t { // TODO: requires v7 @@ -651,13 +686,13 @@ mod neon { } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> int8x16_t { - vdupq_n_s8(a as i8) + unsafe fn _mm_set1(a: usize) -> uint8x16_t { + vdupq_n_u8(a as u8) } #[inline(always)] - unsafe fn _mm_add(a: int8x16_t, b: int8x16_t) -> int8x16_t { - vaddq_s8(a, b) + unsafe fn _mm_add(a: uint8x16_t, b: uint8x16_t) -> uint8x16_t { + vaddq_u8(a, b) } #[inline(always)] @@ -671,10 +706,15 @@ mod neon { } #[inline(always)] - unsafe fn _mm_blendv(a: int8x16_t, b: int8x16_t, mask: uint8x16_t) -> int8x16_t { + unsafe fn _mm_blendv_values(a: int8x16_t, b: int8x16_t, mask: uint8x16_t) -> int8x16_t { vbslq_s8(mask, b, a) } + #[inline(always)] + unsafe fn _mm_blendv_indices(a: uint8x16_t, b: uint8x16_t, mask: uint8x16_t) -> uint8x16_t { + vbslq_u8(mask, b, a) + } + // ------------------------------------ ARGMINMAX -------------------------------------- #[target_feature(enable = "neon")] @@ -683,7 +723,7 @@ mod neon { } #[inline(always)] - unsafe fn _horiz_min(index: int8x16_t, value: int8x16_t) -> (usize, i8) { + unsafe fn _horiz_min(index: uint8x16_t, value: int8x16_t) -> (usize, i8) { // 0. Find the minimum value let mut vmin: int8x16_t = value; vmin = vminq_s8(vmin, vextq_s8(vmin, vmin, 8)); @@ -696,24 +736,24 @@ mod neon { // 1. Create a mask with the index of the minimum value let mask = vceqq_s8(value, vmin); // 2. Blend the mask with the index - let search_index = vbslq_s8( + let search_index = vbslq_u8( mask, index, // if mask is 1, use index - vdupq_n_s8(i8::MAX), // if mask is 0, use i8::MAX + vdupq_n_u8(u8::MAX), // if mask is 0, use u8::MAX ); // 3. Find the minimum index - let mut imin: int8x16_t = search_index; - imin = vminq_s8(imin, vextq_s8(imin, imin, 8)); - imin = vminq_s8(imin, vextq_s8(imin, imin, 4)); - imin = vminq_s8(imin, vextq_s8(imin, imin, 2)); - imin = vminq_s8(imin, vextq_s8(imin, imin, 1)); - let min_index: usize = vgetq_lane_s8(imin, 0) as usize; + let mut imin: uint8x16_t = search_index; + imin = vminq_u8(imin, vextq_u8(imin, imin, 8)); + imin = vminq_u8(imin, vextq_u8(imin, imin, 4)); + imin = vminq_u8(imin, vextq_u8(imin, imin, 2)); + imin = vminq_u8(imin, vextq_u8(imin, imin, 1)); + let min_index: usize = vgetq_lane_u8(imin, 0) as usize; (min_index, min_value) } #[inline(always)] - unsafe fn _horiz_max(index: int8x16_t, value: int8x16_t) -> (usize, i8) { + unsafe fn _horiz_max(index: uint8x16_t, value: int8x16_t) -> (usize, i8) { // 0. Find the maximum value let mut vmax: int8x16_t = value; vmax = vmaxq_s8(vmax, vextq_s8(vmax, vmax, 8)); @@ -726,18 +766,18 @@ mod neon { // 1. Create a mask with the index of the maximum value let mask = vceqq_s8(value, vmax); // 2. Blend the mask with the index - let search_index = vbslq_s8( + let search_index = vbslq_u8( mask, index, // if mask is 1, use index - vdupq_n_s8(i8::MAX), // if mask is 0, use i8::MAX + vdupq_n_u8(u8::MAX), // if mask is 0, use u8::MAX ); // 3. Find the maximum index let mut imin: int8x16_t = search_index; - imin = vminq_s8(imin, vextq_s8(imin, imin, 8)); - imin = vminq_s8(imin, vextq_s8(imin, imin, 4)); - imin = vminq_s8(imin, vextq_s8(imin, imin, 2)); - imin = vminq_s8(imin, vextq_s8(imin, imin, 1)); - let max_index: usize = vgetq_lane_s8(imin, 0) as usize; + imin = vminq_u8(imin, vextq_u8(imin, imin, 8)); + imin = vminq_u8(imin, vextq_u8(imin, imin, 4)); + imin = vminq_u8(imin, vextq_u8(imin, imin, 2)); + imin = vminq_u8(imin, vextq_u8(imin, imin, 1)); + let max_index: usize = vgetq_lane_u8(imin, 0) as usize; (max_index, max_value) } diff --git a/src/simd/simd_u16.rs b/src/simd/simd_u16.rs index 916a8d1..bc5d7e8 100644 --- a/src/simd/simd_u16.rs +++ b/src/simd/simd_u16.rs @@ -60,7 +60,7 @@ mod avx2 { std::mem::transmute::<__m256i, [i16; LANE_SIZE]>(reg) } - impl SIMD for AVX2 { + impl SIMD for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([ 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16, @@ -70,7 +70,13 @@ mod avx2 { const MAX_INDEX: usize = i16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m256i) -> [u16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m256i) -> [u16; LANE_SIZE] { + // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m256i) -> [i16; LANE_SIZE] { // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -102,7 +108,12 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + unsafe fn _mm_blendv_values(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } @@ -247,13 +258,19 @@ mod sse { std::mem::transmute::<__m128i, [i16; LANE_SIZE]>(reg) } - impl SIMD for SSE { + impl SIMD for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; const MAX_INDEX: usize = i16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m128i) -> [u16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m128i) -> [u16; LANE_SIZE] { + // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m128i) -> [i16; LANE_SIZE] { // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -284,7 +301,12 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + unsafe fn _mm_blendv_values(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } @@ -415,7 +437,7 @@ mod avx512 { std::mem::transmute::<__m512i, [i16; LANE_SIZE]>(reg) } - impl SIMD for AVX512 { + impl SIMD for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16, @@ -426,7 +448,12 @@ mod avx512 { const MAX_INDEX: usize = i16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m512i) -> [u16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m512i) -> [u16; LANE_SIZE] { + unimplemented!("We work with decrordi16 and override _get_min_index_value and _get_max_index_value") + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m512i) -> [i16; LANE_SIZE] { unimplemented!("We work with decrordi16 and override _get_min_index_value and _get_max_index_value") } @@ -456,7 +483,12 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u32) -> __m512i { + unsafe fn _mm_blendv_values(a: __m512i, b: __m512i, mask: u32) -> __m512i { + _mm512_mask_blend_epi16(mask, a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u32) -> __m512i { _mm512_mask_blend_epi16(mask, a, b) } @@ -589,13 +621,18 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_16; - impl SIMD for NEON { + impl SIMD for NEON { const INITIAL_INDEX: uint16x8_t = unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; const MAX_INDEX: usize = u16::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: uint16x8_t) -> [u16; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: uint16x8_t) -> [u16; LANE_SIZE] { + std::mem::transmute::(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: uint16x8_t) -> [u16; LANE_SIZE] { std::mem::transmute::(reg) } @@ -625,7 +662,12 @@ mod neon { } #[inline(always)] - unsafe fn _mm_blendv(a: uint16x8_t, b: uint16x8_t, mask: uint16x8_t) -> uint16x8_t { + unsafe fn _mm_blendv_values(a: uint16x8_t, b: uint16x8_t, mask: uint16x8_t) -> uint16x8_t { + vbslq_u16(mask, b, a) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: uint16x8_t, b: uint16x8_t, mask: uint16x8_t) -> uint16x8_t { vbslq_u16(mask, b, a) } diff --git a/src/simd/simd_u32.rs b/src/simd/simd_u32.rs index 9eaf16a..78e2bae 100644 --- a/src/simd/simd_u32.rs +++ b/src/simd/simd_u32.rs @@ -42,13 +42,19 @@ mod avx2 { std::mem::transmute::<__m256i, [i32; LANE_SIZE]>(reg) } - impl SIMD for AVX2 { + impl SIMD for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32]) }; const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m256i) -> [u32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m256i) -> [u32; LANE_SIZE] { + // Not used because we work with i32ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m256i) -> [i32; LANE_SIZE] { // Not used because we work with i32ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -79,7 +85,12 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + unsafe fn _mm_blendv_values(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } @@ -209,12 +220,18 @@ mod sse { std::mem::transmute::<__m128i, [i32; LANE_SIZE]>(reg) } - impl SIMD for SSE { + impl SIMD for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32]) }; const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m128i) -> [u32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m128i) -> [u32; LANE_SIZE] { + // Not used because we work with i32ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m128i) -> [i32; LANE_SIZE] { // Not used because we work with i32ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -245,7 +262,12 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + unsafe fn _mm_blendv_values(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } @@ -365,7 +387,7 @@ mod avx512 { std::mem::transmute::<__m512i, [i32; LANE_SIZE]>(reg) } - impl SIMD for AVX512 { + impl SIMD for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32, 8i32, 9i32, 10i32, 11i32, 12i32, @@ -375,7 +397,12 @@ mod avx512 { const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m512i) -> [u32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m512i) -> [u32; LANE_SIZE] { + unimplemented!("We work with decrordu32 and override _get_min_index_value and _get_max_index_value") + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m512i) -> [i32; LANE_SIZE] { unimplemented!("We work with decrordu32 and override _get_min_index_value and _get_max_index_value") } @@ -405,7 +432,12 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u16) -> __m512i { + unsafe fn _mm_blendv_values(a: __m512i, b: __m512i, mask: u16) -> __m512i { + _mm512_mask_blend_epi32(mask, a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u16) -> __m512i { _mm512_mask_blend_epi32(mask, a, b) } @@ -522,12 +554,17 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_32; - impl SIMD for NEON { + impl SIMD for NEON { const INITIAL_INDEX: uint32x4_t = unsafe { std::mem::transmute([0u32, 1u32, 2u32, 3u32]) }; const MAX_INDEX: usize = u32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: uint32x4_t) -> [u32; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: uint32x4_t) -> [u32; LANE_SIZE] { + std::mem::transmute::(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: uint32x4_t) -> [u32; LANE_SIZE] { std::mem::transmute::(reg) } @@ -557,7 +594,12 @@ mod neon { } #[inline(always)] - unsafe fn _mm_blendv(a: uint32x4_t, b: uint32x4_t, mask: uint32x4_t) -> uint32x4_t { + unsafe fn _mm_blendv_values(a: uint32x4_t, b: uint32x4_t, mask: uint32x4_t) -> uint32x4_t { + vbslq_u32(mask, b, a) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: uint32x4_t, b: uint32x4_t, mask: uint32x4_t) -> uint32x4_t { vbslq_u32(mask, b, a) } diff --git a/src/simd/simd_u64.rs b/src/simd/simd_u64.rs index 39b40df..a9bd54f 100644 --- a/src/simd/simd_u64.rs +++ b/src/simd/simd_u64.rs @@ -39,12 +39,18 @@ mod avx2 { std::mem::transmute::<__m256i, [i64; LANE_SIZE]>(reg) } - impl SIMD for AVX2 { + impl SIMD for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([0i64, 1i64, 2i64, 3i64]) }; const MAX_INDEX: usize = i64::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m256i) -> [u64; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m256i) -> [u64; LANE_SIZE] { + // Not used because we work with i64ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m256i) -> [i64; LANE_SIZE] { // Not used because we work with i64ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -75,7 +81,12 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + unsafe fn _mm_blendv_values(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } @@ -205,12 +216,18 @@ mod sse { std::mem::transmute::<__m128i, [i64; LANE_SIZE]>(reg) } - impl SIMD for SSE { + impl SIMD for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i64, 1i64]) }; const MAX_INDEX: usize = i64::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m128i) -> [u64; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m128i) -> [u64; LANE_SIZE] { + // Not used because we work with i64ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m128i) -> [i64; LANE_SIZE] { // Not used because we work with i64ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -241,7 +258,12 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + unsafe fn _mm_blendv_values(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } @@ -361,13 +383,18 @@ mod avx512 { std::mem::transmute::<__m512i, [i64; LANE_SIZE]>(reg) } - impl SIMD for AVX512 { + impl SIMD for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([0i64, 1i64, 2i64, 3i64, 4i64, 5i64, 6i64, 7i64]) }; const MAX_INDEX: usize = i64::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m512i) -> [u64; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m512i) -> [u64; LANE_SIZE] { + unimplemented!("We work with decrordi64 and override _get_min_index_value and _get_max_index_value") + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: __m512i) -> [i64; LANE_SIZE] { unimplemented!("We work with decrordi64 and override _get_min_index_value and _get_max_index_value") } @@ -397,7 +424,12 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u8) -> __m512i { + unsafe fn _mm_blendv_values(a: __m512i, b: __m512i, mask: u8) -> __m512i { + _mm512_mask_blend_epi64(mask, a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u8) -> __m512i { _mm512_mask_blend_epi64(mask, a, b) } diff --git a/src/simd/simd_u8.rs b/src/simd/simd_u8.rs index a79cced..1dff0d7 100644 --- a/src/simd/simd_u8.rs +++ b/src/simd/simd_u8.rs @@ -46,7 +46,7 @@ mod avx2 { std::mem::transmute::<__m256i, [i8; LANE_SIZE]>(reg) } - impl SIMD for AVX2 { + impl SIMD for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, @@ -57,7 +57,13 @@ mod avx2 { const MAX_INDEX: usize = i8::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m256i) -> [u8; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m256i) -> [u8; LANE_SIZE] { + // Not used because we work with i8ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m256i) -> [i8; LANE_SIZE] { // Not used because we work with i8ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -88,7 +94,12 @@ mod avx2 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + unsafe fn _mm_blendv_values(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } @@ -282,7 +293,7 @@ mod sse { std::mem::transmute::<__m128i, [i8; LANE_SIZE]>(reg) } - impl SIMD for SSE { + impl SIMD for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, @@ -292,7 +303,13 @@ mod sse { const MAX_INDEX: usize = i8::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m128i) -> [u8; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m128i) -> [u8; LANE_SIZE] { + // Not used because we work with i8ord and override _get_min_index_value and _get_max_index_value + unimplemented!() + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m128i) -> [i8; LANE_SIZE] { // Not used because we work with i8ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -323,7 +340,12 @@ mod sse { } #[inline(always)] - unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + unsafe fn _mm_blendv_values(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } @@ -499,7 +521,7 @@ mod avx512 { std::mem::transmute::<__m512i, [i8; LANE_SIZE]>(reg) } - impl SIMD for AVX512 { + impl SIMD for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, @@ -512,7 +534,14 @@ mod avx512 { const MAX_INDEX: usize = i8::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(_: __m512i) -> [u8; LANE_SIZE] { + unsafe fn _reg_to_arr_values(_: __m512i) -> [u8; LANE_SIZE] { + unimplemented!( + "We work with decrordi8 and override _get_min_index_value and _get_max_index_value" + ) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(_: __m512i) -> [i8; LANE_SIZE] { unimplemented!( "We work with decrordi8 and override _get_min_index_value and _get_max_index_value" ) @@ -544,7 +573,12 @@ mod avx512 { } #[inline(always)] - unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u64) -> __m512i { + unsafe fn _mm_blendv_values(a: __m512i, b: __m512i, mask: u64) -> __m512i { + _mm512_mask_blend_epi8(mask, a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: __m512i, b: __m512i, mask: u64) -> __m512i { _mm512_mask_blend_epi8(mask, a, b) } @@ -730,7 +764,7 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_8; - impl SIMD for NEON { + impl SIMD for NEON { const INITIAL_INDEX: uint8x16_t = unsafe { std::mem::transmute([ 0u8, 1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8, 9u8, 10u8, 11u8, 12u8, 13u8, 14u8, @@ -740,7 +774,12 @@ mod neon { const MAX_INDEX: usize = u8::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: uint8x16_t) -> [u8; LANE_SIZE] { + unsafe fn _reg_to_arr_values(reg: uint8x16_t) -> [u8; LANE_SIZE] { + std::mem::transmute::(reg) + } + + #[inline(always)] + unsafe fn _reg_to_arr_indices(reg: uint8x16_t) -> [u8; LANE_SIZE] { std::mem::transmute::(reg) } @@ -770,7 +809,12 @@ mod neon { } #[inline(always)] - unsafe fn _mm_blendv(a: uint8x16_t, b: uint8x16_t, mask: uint8x16_t) -> uint8x16_t { + unsafe fn _mm_blendv_values(a: uint8x16_t, b: uint8x16_t, mask: uint8x16_t) -> uint8x16_t { + vbslq_u8(mask, b, a) + } + + #[inline(always)] + unsafe fn _mm_blendv_indices(a: uint8x16_t, b: uint8x16_t, mask: uint8x16_t) -> uint8x16_t { vbslq_u8(mask, b, a) } diff --git a/src/utils.rs b/src/utils.rs index 55fd93a..d91a177 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -23,7 +23,11 @@ // } #[inline(always)] -pub(crate) fn min_index_value(index: &[T], values: &[T]) -> (T, T) { +pub(crate) fn min_index_value(index: &[Tidx], values: &[Tval]) -> (Tidx, Tval) +where + Tidx: Copy + PartialOrd, + Tval: Copy + PartialOrd, +{ assert_eq!(index.len(), values.len()); values .iter() @@ -62,7 +66,11 @@ pub(crate) fn min_index_value(index: &[T], values: &[T]) - // } #[inline(always)] -pub(crate) fn max_index_value(index: &[T], values: &[T]) -> (T, T) { +pub(crate) fn max_index_value(index: &[Tidx], values: &[Tval]) -> (Tidx, Tval) +where + Tidx: Copy + PartialOrd, + Tval: Copy + PartialOrd, +{ assert_eq!(index.len(), values.len()); values .iter() From a7e8e10091c111859daa80f7e9590224f8ba69f3 Mon Sep 17 00:00:00 2001 From: Jeroen Van Der Donckt Date: Wed, 28 Dec 2022 13:57:57 +0100 Subject: [PATCH 2/8] :bug: update unimplement macro --- src/simd/generic.rs | 14 +++++++++++--- src/simd/simd_u16.rs | 2 +- src/simd/simd_u32.rs | 6 +++--- src/simd/simd_u64.rs | 4 ++-- 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/simd/generic.rs b/src/simd/generic.rs index 1bfe9e2..f84b8f8 100644 --- a/src/simd/generic.rs +++ b/src/simd/generic.rs @@ -194,11 +194,15 @@ pub trait SIMD< #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] macro_rules! unimplement_simd { ($scalar_type:ty, $reg:ty, $simd_type:ident) => { - impl SIMD<$scalar_type, $reg, $reg, 0> for $simd_type { + impl SIMD<$scalar_type, $reg, $scalar_type, $reg, $reg, 0> for $simd_type { const INITIAL_INDEX: $reg = 0; const MAX_INDEX: usize = 0; - unsafe fn _reg_to_arr(_reg: $reg) -> [$scalar_type; 0] { + unsafe fn _reg_to_arr_values(_reg: $reg) -> [$scalar_type; 0] { + unimplemented!() + } + + unsafe fn _reg_to_arr_indices(_reg: $reg) -> [$scalar_type; 0] { unimplemented!() } @@ -222,7 +226,11 @@ macro_rules! unimplement_simd { unimplemented!() } - unsafe fn _mm_blendv(_a: $reg, _b: $reg, _mask: $reg) -> $reg { + unsafe fn _mm_blendv_values(_a: $reg, _b: $reg, _mask: $reg) -> $reg { + unimplemented!() + } + + unsafe fn _mm_blendv_indices(_a: $reg, _b: $reg, _mask: $reg) -> $reg { unimplemented!() } diff --git a/src/simd/simd_u16.rs b/src/simd/simd_u16.rs index bc5d7e8..350f866 100644 --- a/src/simd/simd_u16.rs +++ b/src/simd/simd_u16.rs @@ -270,7 +270,7 @@ mod sse { } #[inline(always)] - unsafe fn _reg_to_arr_indices(reg: __m128i) -> [i16; LANE_SIZE] { + unsafe fn _reg_to_arr_indices(_: __m128i) -> [i16; LANE_SIZE] { // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value unimplemented!() } diff --git a/src/simd/simd_u32.rs b/src/simd/simd_u32.rs index 78e2bae..ab68bc8 100644 --- a/src/simd/simd_u32.rs +++ b/src/simd/simd_u32.rs @@ -54,7 +54,7 @@ mod avx2 { } #[inline(always)] - unsafe fn _reg_to_arr_indices(reg: __m256i) -> [i32; LANE_SIZE] { + unsafe fn _reg_to_arr_indices(_: __m256i) -> [i32; LANE_SIZE] { // Not used because we work with i32ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -231,7 +231,7 @@ mod sse { } #[inline(always)] - unsafe fn _reg_to_arr_indices(reg: __m128i) -> [i32; LANE_SIZE] { + unsafe fn _reg_to_arr_indices(_: __m128i) -> [i32; LANE_SIZE] { // Not used because we work with i32ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -402,7 +402,7 @@ mod avx512 { } #[inline(always)] - unsafe fn _reg_to_arr_indices(reg: __m512i) -> [i32; LANE_SIZE] { + unsafe fn _reg_to_arr_indices(_: __m512i) -> [i32; LANE_SIZE] { unimplemented!("We work with decrordu32 and override _get_min_index_value and _get_max_index_value") } diff --git a/src/simd/simd_u64.rs b/src/simd/simd_u64.rs index a9bd54f..5eca73b 100644 --- a/src/simd/simd_u64.rs +++ b/src/simd/simd_u64.rs @@ -50,7 +50,7 @@ mod avx2 { } #[inline(always)] - unsafe fn _reg_to_arr_indices(reg: __m256i) -> [i64; LANE_SIZE] { + unsafe fn _reg_to_arr_indices(_: __m256i) -> [i64; LANE_SIZE] { // Not used because we work with i64ord and override _get_min_index_value and _get_max_index_value unimplemented!() } @@ -394,7 +394,7 @@ mod avx512 { } #[inline(always)] - unsafe fn _reg_to_arr_indices(reg: __m512i) -> [i64; LANE_SIZE] { + unsafe fn _reg_to_arr_indices(_: __m512i) -> [i64; LANE_SIZE] { unimplemented!("We work with decrordi64 and override _get_min_index_value and _get_max_index_value") } From 0b3cb3d1882b84eec1ba31e3d35efada54098a30 Mon Sep 17 00:00:00 2001 From: Jeroen Van Der Donckt Date: Wed, 28 Dec 2022 14:00:25 +0100 Subject: [PATCH 3/8] :see_no_evil: --- src/simd/simd_i32.rs | 2 +- src/simd/simd_i8.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/simd/simd_i32.rs b/src/simd/simd_i32.rs index 8b11209..3748bf2 100644 --- a/src/simd/simd_i32.rs +++ b/src/simd/simd_i32.rs @@ -442,7 +442,7 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_32; impl SIMD for NEON { - const INITIAL_INDEX: int32x4_t = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32]) }; + const INITIAL_INDEX: uint32x4_t = unsafe { std::mem::transmute([0u32, 1u32, 2u32, 3u32]) }; const MAX_INDEX: usize = u32::MAX as usize; #[inline(always)] diff --git a/src/simd/simd_i8.rs b/src/simd/simd_i8.rs index fd709bd..4636665 100644 --- a/src/simd/simd_i8.rs +++ b/src/simd/simd_i8.rs @@ -772,7 +772,7 @@ mod neon { vdupq_n_u8(u8::MAX), // if mask is 0, use u8::MAX ); // 3. Find the maximum index - let mut imin: int8x16_t = search_index; + let mut imin: uint8x16_t = search_index; imin = vminq_u8(imin, vextq_u8(imin, imin, 8)); imin = vminq_u8(imin, vextq_u8(imin, imin, 4)); imin = vminq_u8(imin, vextq_u8(imin, imin, 2)); From 9be790f11a446bdf51382ae082098538eec35b18 Mon Sep 17 00:00:00 2001 From: Jeroen Van Der Donckt Date: Wed, 28 Dec 2022 14:06:58 +0100 Subject: [PATCH 4/8] :bug: --- src/simd/simd_f16.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/simd/simd_f16.rs b/src/simd/simd_f16.rs index 5caab80..8347b9f 100644 --- a/src/simd/simd_f16.rs +++ b/src/simd/simd_f16.rs @@ -648,8 +648,8 @@ mod neon { } impl SIMD for NEON { - const INITIAL_INDEX: int16x8_t = - unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; + const INITIAL_INDEX: uint16x8_t = + unsafe { std::mem::transmute([0u16, 1u16, 2u16, 3u16, 4u16, 5u16, 6u16, 7u16]) }; const MAX_INDEX: usize = u16::MAX as usize; #[inline(always)] @@ -710,9 +710,9 @@ mod neon { #[inline(always)] unsafe fn _get_min_max_index_value( - index_low: int16x8_t, + index_low: uint16x8_t, values_low: int16x8_t, - index_high: int16x8_t, + index_high: uint16x8_t, values_high: int16x8_t, ) -> (usize, f16, usize, f16) { let (min_index, min_value) = Self::_horiz_min(index_low, values_low); From 6a4a97cd17ba33f504fcfd3f29cfa354d565d57c Mon Sep 17 00:00:00 2001 From: jvdd Date: Wed, 28 Dec 2022 14:33:40 +0100 Subject: [PATCH 5/8] :thinking: f32 and f64 avx -> avx2 --- src/lib.rs | 3 --- src/simd/simd_f32.rs | 10 +++++----- src/simd/simd_f64.rs | 8 ++++---- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 52c998f..fc6a552 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -117,9 +117,6 @@ macro_rules! impl_argminmax { return unsafe { AVX512::argminmax(self) } } else if is_x86_feature_detected!("avx2") { return unsafe { AVX2::argminmax(self) } - } else if is_x86_feature_detected!("avx") & (<$t>::NB_BITS >= 32) & (<$t>::IS_FLOAT == true) { - // f32 and f64 do not require avx2 - return unsafe { AVX2::argminmax(self) } // SKIP SSE4.2 bc scalar is faster or equivalent for 64 bit numbers // // } else if is_x86_feature_detected!("sse4.2") & (<$t>::NB_BITS == 64) & (<$t>::IS_FLOAT == false) { // // SSE4.2 is needed for comparing 64-bit integers diff --git a/src/simd/simd_f32.rs b/src/simd/simd_f32.rs index 8208d11..e7546d4 100644 --- a/src/simd/simd_f32.rs +++ b/src/simd/simd_f32.rs @@ -71,7 +71,7 @@ mod avx2 { // ------------------------------------ ARGMINMAX -------------------------------------- - #[target_feature(enable = "avx")] + #[target_feature(enable = "avx2")] unsafe fn argminmax(data: ArrayView1) -> (usize, usize) { Self::_argminmax(data) } @@ -95,7 +95,7 @@ mod avx2 { #[test] fn test_both_versions_return_the_same_results() { - if !is_x86_feature_detected!("avx") { + if !is_x86_feature_detected!("avx2") { return; } @@ -110,7 +110,7 @@ mod avx2 { #[test] fn test_first_index_is_returned_when_identical_values_found() { - if !is_x86_feature_detected!("avx") { + if !is_x86_feature_detected!("avx2") { return; } @@ -137,7 +137,7 @@ mod avx2 { #[test] fn test_no_overflow() { - if !is_x86_feature_detected!("avx") { + if !is_x86_feature_detected!("avx2") { return; } @@ -152,7 +152,7 @@ mod avx2 { #[test] fn test_many_random_runs() { - if !is_x86_feature_detected!("avx") { + if !is_x86_feature_detected!("avx2") { return; } diff --git a/src/simd/simd_f64.rs b/src/simd/simd_f64.rs index d63ed98..ce14232 100644 --- a/src/simd/simd_f64.rs +++ b/src/simd/simd_f64.rs @@ -67,7 +67,7 @@ mod avx2 { // ------------------------------------ ARGMINMAX -------------------------------------- - #[target_feature(enable = "avx")] + #[target_feature(enable = "avx2")] unsafe fn argminmax(data: ArrayView1) -> (usize, usize) { Self::_argminmax(data) } @@ -91,7 +91,7 @@ mod avx2 { #[test] fn test_both_versions_return_the_same_results() { - if !is_x86_feature_detected!("avx") { + if !is_x86_feature_detected!("avx2") { return; } @@ -106,7 +106,7 @@ mod avx2 { #[test] fn test_first_index_is_returned_when_identical_values_found() { - if !is_x86_feature_detected!("avx") { + if !is_x86_feature_detected!("avx2") { return; } @@ -133,7 +133,7 @@ mod avx2 { #[test] fn test_many_random_runs() { - if !is_x86_feature_detected!("avx") { + if !is_x86_feature_detected!("avx2") { return; } From f927296a54ef4b85e5dd7982131708ce95d42899 Mon Sep 17 00:00:00 2001 From: Jeroen Van Der Donckt Date: Wed, 28 Dec 2022 14:38:37 +0100 Subject: [PATCH 6/8] :see_no_evil: --- benches/bench_f32.rs | 8 ++++---- benches/bench_f64.rs | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/benches/bench_f32.rs b/benches/bench_f32.rs index dc86827..70d4c8b 100644 --- a/benches/bench_f32.rs +++ b/benches/bench_f32.rs @@ -27,7 +27,7 @@ fn minmax_f32_random_array_long(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_random_long_f32", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) }) }); @@ -68,7 +68,7 @@ fn minmax_f32_random_array_short(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_random_short_f32", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) }) }); @@ -109,7 +109,7 @@ fn minmax_f32_worst_case_array_long(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_worst_long_f32", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) }) }); @@ -150,7 +150,7 @@ fn minmax_f32_worst_case_array_short(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_worst_short_f32", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) }) }); diff --git a/benches/bench_f64.rs b/benches/bench_f64.rs index a876d8e..820f515 100644 --- a/benches/bench_f64.rs +++ b/benches/bench_f64.rs @@ -25,7 +25,7 @@ fn minmax_f64_random_array_long(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_random_long_f64", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) }) }); @@ -54,7 +54,7 @@ fn minmax_f64_random_array_short(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_random_short_f64", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) }) }); @@ -83,7 +83,7 @@ fn minmax_f64_worst_case_array_long(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_worst_long_f64", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) }) }); @@ -112,7 +112,7 @@ fn minmax_f64_worst_case_array_short(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_worst_short_f64", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data.view())) }) }); From 22f54f700657cb6efc2d91420b350dd22bda9a12 Mon Sep 17 00:00:00 2001 From: Jeroen Van Der Donckt Date: Thu, 5 Jan 2023 13:01:05 +0100 Subject: [PATCH 7/8] :see_no_evil: fix merge conflict issues --- src/simd/generic.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/simd/generic.rs b/src/simd/generic.rs index cd938f0..a16718b 100644 --- a/src/simd/generic.rs +++ b/src/simd/generic.rs @@ -85,7 +85,7 @@ pub trait SIMD< } #[inline(always)] - unsafe fn _mm_prefetch(data: *const ScalarDType) { + unsafe fn _mm_prefetch(data: *const ValueDType) { #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { #[cfg(target_arch = "x86")] @@ -242,12 +242,12 @@ pub trait SIMD< let gt_mask = Self::_mm_cmpgt(new_values, values_high); // Update the highest and lowest values - values_low = Self::_mm_blendv(values_low, new_values, lt_mask); - values_high = Self::_mm_blendv(values_high, new_values, gt_mask); + values_low = Self::_mm_blendv_values(values_low, new_values, lt_mask); + values_high = Self::_mm_blendv_values(values_high, new_values, gt_mask); // Update the index if the new value is lower/higher - index_low = Self::_mm_blendv(index_low, new_index, lt_mask); - index_high = Self::_mm_blendv(index_high, new_index, gt_mask); + index_low = Self::_mm_blendv_indices(index_low, new_index, lt_mask); + index_high = Self::_mm_blendv_indices(index_high, new_index, gt_mask); // 25 is a non-scientific number, but seems to work overall // => TODO: probably this should be in function of the data type From 1cead8b03abd97030416068422c3325830ee65fe Mon Sep 17 00:00:00 2001 From: Jeroen Van Der Donckt Date: Thu, 5 Jan 2023 13:41:50 +0100 Subject: [PATCH 8/8] :bug: fix NEON 16-bit horiz SIMD --- src/simd/simd_f16.rs | 32 ++++++++++++++++---------------- src/simd/simd_i16.rs | 32 ++++++++++++++++---------------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/src/simd/simd_f16.rs b/src/simd/simd_f16.rs index a0c954b..ae30f9a 100644 --- a/src/simd/simd_f16.rs +++ b/src/simd/simd_f16.rs @@ -822,7 +822,7 @@ mod neon { } #[inline(always)] - unsafe fn _horiz_min(index: int16x8_t, value: int16x8_t) -> (usize, f16) { + unsafe fn _horiz_min(index: uint16x8_t, value: int16x8_t) -> (usize, f16) { // 0. Find the minimum value let mut vmin: int16x8_t = value; vmin = vminq_s16(vmin, vextq_s16(vmin, vmin, 4)); @@ -834,23 +834,23 @@ mod neon { // 1. Create a mask with the index of the minimum value let mask = vceqq_s16(value, vmin); // 2. Blend the mask with the index - let search_index = vbslq_s16( + let search_index = vbslq_u16( mask, index, // if mask is 1, use index - vdupq_n_s16(i16::MAX), // if mask is 0, use i16::MAX + vdupq_n_u16(u16::MAX), // if mask is 0, use u16::MAX ); // 3. Find the minimum index - let mut imin: int16x8_t = search_index; - imin = vminq_s16(imin, vextq_s16(imin, imin, 4)); - imin = vminq_s16(imin, vextq_s16(imin, imin, 2)); - imin = vminq_s16(imin, vextq_s16(imin, imin, 1)); - let min_index: usize = vgetq_lane_s16(imin, 0) as usize; + let mut imin: uint16x8_t = search_index; + imin = vminq_u16(imin, vextq_u16(imin, imin, 4)); + imin = vminq_u16(imin, vextq_u16(imin, imin, 2)); + imin = vminq_u16(imin, vextq_u16(imin, imin, 1)); + let min_index: usize = vgetq_lane_u16(imin, 0) as usize; (min_index, _ord_i16_to_f16(min_value)) } #[inline(always)] - unsafe fn _horiz_max(index: int16x8_t, value: int16x8_t) -> (usize, f16) { + unsafe fn _horiz_max(index: uint16x8_t, value: int16x8_t) -> (usize, f16) { // 0. Find the maximum value let mut vmax: int16x8_t = value; vmax = vmaxq_s16(vmax, vextq_s16(vmax, vmax, 4)); @@ -862,17 +862,17 @@ mod neon { // 1. Create a mask with the index of the maximum value let mask = vceqq_s16(value, vmax); // 2. Blend the mask with the index - let search_index = vbslq_s16( + let search_index = vbslq_u16( mask, index, // if mask is 1, use index - vdupq_n_s16(i16::MAX), // if mask is 0, use i16::MAX + vdupq_n_u16(u16::MAX), // if mask is 0, use u16::MAX ); // 3. Find the maximum index - let mut imin: int16x8_t = search_index; - imin = vminq_s16(imin, vextq_s16(imin, imin, 4)); - imin = vminq_s16(imin, vextq_s16(imin, imin, 2)); - imin = vminq_s16(imin, vextq_s16(imin, imin, 1)); - let max_index: usize = vgetq_lane_s16(imin, 0) as usize; + let mut imin: uint16x8_t = search_index; + imin = vminq_u16(imin, vextq_u16(imin, imin, 4)); + imin = vminq_u16(imin, vextq_u16(imin, imin, 2)); + imin = vminq_u16(imin, vextq_u16(imin, imin, 1)); + let max_index: usize = vgetq_lane_u16(imin, 0) as usize; (max_index, _ord_i16_to_f16(max_value)) } diff --git a/src/simd/simd_i16.rs b/src/simd/simd_i16.rs index 37e2a92..8471253 100644 --- a/src/simd/simd_i16.rs +++ b/src/simd/simd_i16.rs @@ -729,7 +729,7 @@ mod neon { } #[inline(always)] - unsafe fn _horiz_min(index: int16x8_t, value: int16x8_t) -> (usize, i16) { + unsafe fn _horiz_min(index: uint16x8_t, value: int16x8_t) -> (usize, i16) { // 0. Find the minimum value let mut vmin: int16x8_t = value; vmin = vminq_s16(vmin, vextq_s16(vmin, vmin, 4)); @@ -741,23 +741,23 @@ mod neon { // 1. Create a mask with the index of the minimum value let mask = vceqq_s16(value, vmin); // 2. Blend the mask with the index - let search_index = vbslq_s16( + let search_index = vbslq_u16( mask, index, // if mask is 1, use index - vdupq_n_s16(i16::MAX), // if mask is 0, use i16::MAX + vdupq_n_u16(u16::MAX), // if mask is 0, use u16::MAX ); // 3. Find the minimum index - let mut imin: int16x8_t = search_index; - imin = vminq_s16(imin, vextq_s16(imin, imin, 4)); - imin = vminq_s16(imin, vextq_s16(imin, imin, 2)); - imin = vminq_s16(imin, vextq_s16(imin, imin, 1)); - let min_index: usize = vgetq_lane_s16(imin, 0) as usize; + let mut imin: uint16x8_t = search_index; + imin = vminq_u16(imin, vextq_u16(imin, imin, 4)); + imin = vminq_u16(imin, vextq_u16(imin, imin, 2)); + imin = vminq_u16(imin, vextq_u16(imin, imin, 1)); + let min_index: usize = vgetq_lane_u16(imin, 0) as usize; (min_index, min_value) } #[inline(always)] - unsafe fn _horiz_max(index: int16x8_t, value: int16x8_t) -> (usize, i16) { + unsafe fn _horiz_max(index: uint16x8_t, value: int16x8_t) -> (usize, i16) { // 0. Find the maximum value let mut vmax: int16x8_t = value; vmax = vmaxq_s16(vmax, vextq_s16(vmax, vmax, 4)); @@ -769,17 +769,17 @@ mod neon { // 1. Create a mask with the index of the maximum value let mask = vceqq_s16(value, vmax); // 2. Blend the mask with the index - let search_index = vbslq_s16( + let search_index = vbslq_u16( mask, index, // if mask is 1, use index - vdupq_n_s16(i16::MAX), // if mask is 0, use i16::MAX + vdupq_n_u16(u16::MAX), // if mask is 0, use u16::MAX ); // 3. Find the maximum index - let mut imin: int16x8_t = search_index; - imin = vminq_s16(imin, vextq_s16(imin, imin, 4)); - imin = vminq_s16(imin, vextq_s16(imin, imin, 2)); - imin = vminq_s16(imin, vextq_s16(imin, imin, 1)); - let max_index: usize = vgetq_lane_s16(imin, 0) as usize; + let mut imin: uint16x8_t = search_index; + imin = vminq_u16(imin, vextq_u16(imin, imin, 4)); + imin = vminq_u16(imin, vextq_u16(imin, imin, 2)); + imin = vminq_u16(imin, vextq_u16(imin, imin, 1)); + let max_index: usize = vgetq_lane_u16(imin, 0) as usize; (max_index, max_value) }