From 09236a62545b3ce1b0720c0d00232fdf1ac0a5cf Mon Sep 17 00:00:00 2001 From: ray Date: Wed, 11 Feb 2026 16:11:13 +0800 Subject: [PATCH 1/6] add euclidean one2many implementation --- src/ailego/math/matrix_utility.i | 2 + src/ailego/math_batch/distance_batch.h | 3 +- src/ailego/math_batch/distance_batch_math.h | 26 ++ .../math_batch/euclidean_distance_batch.h | 163 +++++++++++ .../euclidean_distance_batch_impl.h | 150 ++++++++++ .../euclidean_distance_batch_impl_fp16.h | 256 ++++++++++++++++++ .../euclidean_distance_batch_impl_int8.h | 92 +++++++ .../math_batch/inner_product_distance_batch.h | 22 +- .../inner_product_distance_batch_impl.h | 16 +- .../inner_product_distance_batch_impl_fp16.h | 6 +- .../inner_product_distance_batch_impl_int8.h | 6 +- src/core/metric/euclidean_metric.cc | 28 ++ 12 files changed, 738 insertions(+), 32 deletions(-) create mode 100644 src/ailego/math_batch/distance_batch_math.h create mode 100644 src/ailego/math_batch/euclidean_distance_batch.h create mode 100644 src/ailego/math_batch/euclidean_distance_batch_impl.h create mode 100644 src/ailego/math_batch/euclidean_distance_batch_impl_fp16.h create mode 100644 src/ailego/math_batch/euclidean_distance_batch_impl_int8.h diff --git a/src/ailego/math/matrix_utility.i b/src/ailego/math/matrix_utility.i index 34951478..18d5140e 100644 --- a/src/ailego/math/matrix_utility.i +++ b/src/ailego/math/matrix_utility.i @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#pragma once + #include namespace zvec { diff --git a/src/ailego/math_batch/distance_batch.h b/src/ailego/math_batch/distance_batch.h index c762a258..92ed65f1 100644 --- a/src/ailego/math_batch/distance_batch.h +++ b/src/ailego/math_batch/distance_batch.h @@ -18,11 +18,10 @@ #include "ailego/math/distance_matrix.h" #include "cosine_distance_batch.h" #include "inner_product_distance_batch.h" - +#include "euclidean_distance_batch.h" namespace zvec::ailego { - template < template class DistanceType, typename ValueType, size_t BatchSize, size_t PrefetchStep, typename = void> diff --git a/src/ailego/math_batch/distance_batch_math.h b/src/ailego/math_batch/distance_batch_math.h new file mode 100644 index 00000000..63ee6c18 --- /dev/null +++ b/src/ailego/math_batch/distance_batch_math.h @@ -0,0 +1,26 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +inline float sum4(__m128 v) { + v = _mm_add_ps(v, _mm_castsi128_ps(_mm_srli_si128(_mm_castps_si128(v), 8))); + return v[0] + v[1]; +} + +inline __m128 sum_top_bottom_avx(__m256 v) { + const __m128 high = _mm256_extractf128_ps(v, 1); + const __m128 low = _mm256_castps256_ps128(v); + return _mm_add_ps(high, low); +} diff --git a/src/ailego/math_batch/euclidean_distance_batch.h b/src/ailego/math_batch/euclidean_distance_batch.h new file mode 100644 index 00000000..d09d6a68 --- /dev/null +++ b/src/ailego/math_batch/euclidean_distance_batch.h @@ -0,0 +1,163 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "euclidean_distance_batch_impl.h" +#include "euclidean_distance_batch_impl_fp16.h" +#include "euclidean_distance_batch_impl_int8.h" + +namespace zvec::ailego::DistanceBatch { + +//SquaredEuclideanDistanceBatch +template +struct SquaredEuclideanDistanceBatch; + +// Function template partial specialization is not allowed, +// therefore the wrapper struct is required. +template +struct SquaredEuclideanDistanceBatchImpl { + using ValueType = typename std::remove_cv::type; + static void compute_one_to_many( + const ValueType *query, const ValueType **ptrs, + std::array &prefetch_ptrs, size_t dim, + float *sums) { + return compute_one_to_many_squared_euclidean_fallback(query, ptrs, prefetch_ptrs, dim, sums); + } +}; + +template +struct SquaredEuclideanDistanceBatchImpl { + using ValueType = float; + static void compute_one_to_many( + const ValueType *query, const ValueType **ptrs, + std::array &prefetch_ptrs, size_t dim, + float *sums) { +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { + return compute_one_to_many_squared_euclidean_avx512f_fp32( + query, ptrs, prefetch_ptrs, dim, sums); + } +#endif + +#if defined(__AVX2__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { + return compute_one_to_many_squared_euclidean_avx2_fp32( + query, ptrs, prefetch_ptrs, dim, sums); + } +#endif + return compute_one_to_many_squared_euclidean_fallback(query, ptrs, prefetch_ptrs, dim, sums); + } +}; + +template +struct SquaredEuclideanDistanceBatchImpl { + using ValueType = int8_t; + static void compute_one_to_many( + const int8_t *query, const int8_t **ptrs, + std::array &prefetch_ptrs, size_t dim, + float *sums) { +#if defined(__AVX2__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { + return compute_one_to_many_squared_euclidean_avx2_int8( + query, ptrs, prefetch_ptrs, dim, sums); + } +#endif + return compute_one_to_many_squared_euclidean_fallback(query, ptrs, prefetch_ptrs, dim, sums); + } +}; + +template +struct SquaredEuclideanDistanceBatchImpl { + using ValueType = ailego::Float16; + static void compute_one_to_many( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, size_t dim, + float *sums) { +#if defined(__AVX512FP16__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { + return compute_one_to_many_squared_euclidean_avx512fp16_fp16( + query, ptrs, prefetch_ptrs, dim, sums); + } +#endif +#if defined(__AVX512F__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { + return compute_one_to_many_squared_euclidean_avx512f_fp16( + query, ptrs, prefetch_ptrs, dim, sums); + } +#endif +#if defined(__AVX2__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { + return compute_one_to_many_squared_euclidean_avx2_fp16( + query, ptrs, prefetch_ptrs, dim, sums); + } +#endif + return compute_one_to_many_squared_euclidean_fallback(query, ptrs, prefetch_ptrs, dim, sums); + } +}; + +template +struct SquaredEuclideanDistanceBatch { + using ValueType = typename std::remove_cv::type; + + static inline void ComputeBatch(const ValueType **vecs, + const ValueType *query, size_t num_vecs, + size_t dim, float *results) { + size_t i = 0; + for (; i + BatchSize <= num_vecs; i += BatchSize) { + std::array prefetch_ptrs; + for (size_t j = 0; j < BatchSize; ++j) { + if (i + j + BatchSize * PrefetchStep < num_vecs) { + prefetch_ptrs[j] = vecs[i + j + BatchSize * PrefetchStep]; + } else { + prefetch_ptrs[j] = nullptr; + } + } + SquaredEuclideanDistanceBatchImpl::compute_one_to_many( + query, &vecs[i], prefetch_ptrs, dim, &results[i]); + } + for (; i < num_vecs; ++i) { // TODO: unroll by 1, 2, 4, 8, etc. + std::array prefetch_ptrs{nullptr}; + SquaredEuclideanDistanceBatchImpl::compute_one_to_many( + query, &vecs[i], prefetch_ptrs, dim, &results[i]); + } + } +}; + +//EuclideanDistanceBatch +template +struct EuclideanDistanceBatch; + +template +struct EuclideanDistanceBatch { + using ValueType = typename std::remove_cv::type; + + static inline void ComputeBatch(const ValueType **vecs, + const ValueType *query, size_t num_vecs, + size_t dim, float *results) { + SquaredEuclideanDistanceBatch::ComputeBatch(vecs, query, num_vecs, dim, results); + + for (size_t i=0; i +#include +#include +#include +#include +#include "distance_batch_math.h" + +namespace zvec::ailego::DistanceBatch { + +template +static void compute_one_to_many_squared_euclidean_fallback( + const ValueType *query, const ValueType **ptrs, + std::array &prefetch_ptrs, size_t dim, + float *sums) { + for (size_t j = 0; j < BatchSize; ++j) { + sums[j] = 0.0; + SquaredEuclideanDistanceMatrix::Compute(ptrs[j], query, dim, sums + j); + ailego_prefetch(&prefetch_ptrs[j]); + } +} + +#if defined(__AVX512F__) + +template +static std::enable_if_t, void> +compute_one_to_many_squared_euclidean_avx512f_fp32( + const ValueType *query, const ValueType **ptrs, + std::array &prefetch_ptrs, + size_t dimensionality, float *results) { + __m512 accs[dp_batch]; + for (size_t i = 0; i < dp_batch; ++i) { + accs[i] = _mm512_setzero_ps(); + } + size_t dim = 0; + for (; dim + 16 <= dimensionality; dim += 16) { + __m512 q = _mm512_loadu_ps(query + dim); + __m512 data_regs[dp_batch]; + for (size_t i = 0; i < dp_batch; ++i) { + data_regs[i] = _mm512_loadu_ps(ptrs[i] + dim); + } + if (prefetch_ptrs[0]) { + for (size_t i = 0; i < dp_batch; ++i) { + ailego_prefetch(prefetch_ptrs[i] + dim); + } + } + for (size_t i = 0; i < dp_batch; ++i) { + __m512 diff = _mm512_sub_ps(q, data_regs[i]); + accs[i] = _mm512_fmadd_ps(diff, diff, accs[i]); + } + } + + if (dim < dimensionality) { + __mmask32 mask = (__mmask32)((1 << (dimensionality - dim)) - 1); + + for (size_t i = 0; i < dp_batch; ++i) { + __m512 zmm_undefined = _mm512_undefined_ps(); + + accs[i] = + _mm512_mask3_fmadd_ps(_mm512_mask_loadu_ps( + zmm_undefined, mask, query + dim), + _mm512_mask_loadu_ps( + zmm_undefined, mask, ptrs[i] + dim), + accs[i], mask); + } + } + + for (size_t i = 0; i < dp_batch; ++i) { + results[i] = HorizontalAdd_FP32_V512(accs[i]); + } +} + +#endif + +#if defined(__AVX2__) + +template +static std::enable_if_t, void> +compute_one_to_many_squared_euclidean_avx2_fp32( + const ValueType *query, const ValueType **ptrs, + std::array &prefetch_ptrs, + size_t dimensionality, float *results) { + __m256 accs[dp_batch]; + for (size_t i = 0; i < dp_batch; ++i) { + accs[i] = _mm256_setzero_ps(); + } + size_t dim = 0; + for (; dim + 8 <= dimensionality; dim += 8) { + __m256 q = _mm256_loadu_ps(query + dim); + __m256 data_regs[dp_batch]; + for (size_t i = 0; i < dp_batch; ++i) { + data_regs[i] = _mm256_loadu_ps(ptrs[i] + dim); + } + if (prefetch_ptrs[0]) { + for (size_t i = 0; i < dp_batch; ++i) { + ailego_prefetch(prefetch_ptrs[i] + dim); + } + } + for (size_t i = 0; i < dp_batch; ++i) { + __m256 diff = _mm256_sub_ps(q, data_regs[i]); + accs[i] = _mm256_fmadd_ps(diff, diff, accs[i]); + } + } + + for (size_t i = 0; i < dp_batch; ++i) { + results[i] = HorizontalAdd_FP32_V256(_mm256_add_ps(accs[i])); + + switch (dimensionality - dim) { + case 7: + SSD_FP32_GENERAL(query[6], ptrs[i][6], results[i]); + /* FALLTHRU */ + case 6: + SSD_FP32_GENERAL(query[5], ptrs[i][5], results[i]); + /* FALLTHRU */ + case 5: + SSD_FP32_GENERAL(query[4], ptrs[i][4], results[i]); + /* FALLTHRU */ + case 4: + SSD_FP32_GENERAL(query[3], ptrs[i][3], results[i]); + /* FALLTHRU */ + case 3: + SSD_FP32_GENERAL(query[2], ptrs[i][2], results[i]); + /* FALLTHRU */ + case 2: + SSD_FP32_GENERAL(query[1], ptrs[i][1], results[i]); + /* FALLTHRU */ + case 1: + SSD_FP32_GENERAL(query[0], ptrs[i][0], results[i]); + } + } +} +#endif + + +} // namespace zvec::ailego::DistanceBatch \ No newline at end of file diff --git a/src/ailego/math_batch/euclidean_distance_batch_impl_fp16.h b/src/ailego/math_batch/euclidean_distance_batch_impl_fp16.h new file mode 100644 index 00000000..23c117e6 --- /dev/null +++ b/src/ailego/math_batch/euclidean_distance_batch_impl_fp16.h @@ -0,0 +1,256 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +namespace zvec::ailego::DistanceBatch { + +#if defined(__AVX512FP16__) + +template +static std::enable_if_t, void> +compute_one_to_many_squared_euclidean_avx512fp16_fp16( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, + size_t dimensionality, float *results) { + __m512h accs[dp_batch]; + + for (size_t i = 0; i < dp_batch; ++i) { + accs[i] = _mm512_setzero_ph(); + } + + size_t dim = 0; + for (; dim + 32 <= dimensionality; dim += 32) { + __m512h q = _mm512_loadu_ph(query + dim); + + __m512h data_regs[dp_batch]; + for (size_t i = 0; i < dp_batch; ++i) { + data_regs[i] = _mm512_loadu_ph(ptrs[i] + dim); + } + + if (prefetch_ptrs[0]) { + for (size_t i = 0; i < dp_batch; ++i) { + ailego_prefetch(prefetch_ptrs[i] + dim); + } + } + + for (size_t i = 0; i < dp_batch; ++i) { + __m512h diff = _mm512_sub_ph(data_regs[i], q); + accs[i] = _mm512_fmadd_ph(diff, diff, accs[i]); + } + } + + if (dim < dimensionality) { + __mmask32 mask = (__mmask32)((1 << (dimensionality - dim)) - 1); + + for (size_t i = 0; i < dp_batch; ++i) { + __m512i zmm_undefined = _mm512_undefined_epi32(); + __m512h zmm_undefined_ph = _mm512_undefined_ph(); + __m512h zmm_d = _mm512_mask_sub_ph( + zmm_undefined_ph, mask, + _mm512_castsi512_ph(_mm512_mask_loadu_epi16(zmm_undefined, mask, query + dim)), + _mm512_castsi512_ph(_mm512_mask_loadu_epi16(zmm_undefined, mask, ptrs[i] + dim))); + + accs[i] = _mm512_mask3_fmadd_ph(zmm_d, zmm_d, accs[i], mask); + } + } + + for (size_t i = 0; i < dp_batch; ++i) { + results[i] = HorizontalAdd_FP16_V512(accs[i]); + } +} + +#endif + +#if defined(__AVX512F__) + +template +static std::enable_if_t, void> +compute_one_to_many_squared_euclidean_avx512f_fp16( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, + size_t dimensionality, float *results) { + __m512 accs[dp_batch]; + + for (size_t i = 0; i < dp_batch; ++i) { + accs[i] = _mm512_setzero_ps(); + } + + size_t dim = 0; + for (; dim + 32 <= dimensionality; dim += 32) { + __m512i q = + _mm512_loadu_si512(reinterpret_cast(query + dim)); + + __m512 q1 = _mm512_cvtph_ps(_mm512_castsi512_si256(q)); + __m512 q2 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(q, 1)); + + __m512 data_regs_1[dp_batch]; + __m512 data_regs_2[dp_batch]; + for (size_t i = 0; i < dp_batch; ++i) { + __m512i m = + _mm512_loadu_si512(reinterpret_cast(ptrs[i] + dim)); + + data_regs_1[i] = _mm512_cvtph_ps(_mm512_castsi512_si256(m)); + data_regs_2[i] = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(m, 1)); + } + + if (prefetch_ptrs[0]) { + for (size_t i = 0; i < dp_batch; ++i) { + ailego_prefetch(prefetch_ptrs[i] + dim); + } + } + + for (size_t i = 0; i < dp_batch; ++i) { + __m512 diff1 = _mm512_sub_ps(q1, data_regs_1[i]); + accs[i] = _mm512_fmadd_ps(diff1, diff1, accs[i]); + + __m512 diff2 = _mm512_sub_ps(q2, data_regs_2[i]); + accs[i] = _mm512_fmadd_ps(diff2,diff2, accs[i]); + } + } + + if (dim + 16 < dimensionality) { + __m512 q = _mm512_cvtph_ps( + _mm256_loadu_si256(reinterpret_cast(query + dim))); + + __m512 data_regs[dp_batch]; + for (size_t i = 0; i < dp_batch; ++i) { + data_regs[i] = _mm512_cvtph_ps( + _mm256_loadu_si256(reinterpret_cast(ptrs[i] + dim))); + accs[i] = _mm512_fmadd_ps(q, data_regs[i], accs[i]); + } + + dim += 16; + } + + __m256 acc_new[dp_batch]; + for (size_t i = 0; i < dp_batch; ++i) { + acc_new[i] = _mm256_add_ps( + _mm512_castps512_ps256(accs[i]), + _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(accs[i]), 1))); + } + + if (dim + 8 < dimensionality) { + __m256 q = _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(query + dim))); + + for (size_t i = 0; i < dp_batch; ++i) { + __m256 m = _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ptrs[i] + dim))); + + __m256 diff = _mm256_sub_ps(m, q); + acc_new[i] = _mm256_fmadd_ps(diff, diff, acc_new[i]); + } + + dim += 8; + } + + for (size_t i = 0; i < dp_batch; ++i) { + results[i] = HorizontalAdd_FP32_V256(acc_new[i]); + } + + for (; dim < dimensionality; ++dim) { + for (size_t i = 0; i < dp_batch; ++i) { + float diff = (*(query + dim)) - (*(ptrs[i] + dim)); + results[i] += diff * diff; + } + } +} +#endif + +#if defined(__AVX2__) + +template +static std::enable_if_t, void> +compute_one_to_many_squared_euclidean_avx2_fp16( + const ailego::Float16 *query, const ailego::Float16 **ptrs, + std::array &prefetch_ptrs, + size_t dimensionality, float *results) { + __m256 accs[dp_batch]; + + for (size_t i = 0; i < dp_batch; ++i) { + accs[i] = _mm256_setzero_ps(); + } + + size_t dim = 0; + for (; dim + 16 <= dimensionality; dim += 16) { + __m256i q = + _mm256_loadu_si256(reinterpret_cast(query + dim)); + + __m256 q1 = _mm256_cvtph_ps(_mm256_castsi256_si128(q)); + __m256 q2 = _mm256_cvtph_ps(_mm256_extractf128_si256(q, 1)); + + __m256 data_regs_1[dp_batch]; + __m256 data_regs_2[dp_batch]; + for (size_t i = 0; i < dp_batch; ++i) { + __m256i m = + _mm256_loadu_si256(reinterpret_cast(ptrs[i] + dim)); + + data_regs_1[i] = _mm256_cvtph_ps(_mm256_castsi256_si128(m)); + data_regs_2[i] = _mm256_cvtph_ps(_mm256_extractf128_si256(m, 1)); + } + + if (prefetch_ptrs[0]) { + for (size_t i = 0; i < dp_batch; ++i) { + ailego_prefetch(prefetch_ptrs[i] + dim); + } + } + + for (size_t i = 0; i < dp_batch; ++i) { + __m256 diff1 = _mm256_sub_ps(q1, data_regs_1[i]); + accs[i] = _mm256_fmadd_ps(diff1, diff1, accs[i]); + + __m256 diff2 = _mm256_sub_ps(q1, data_regs_2[i]); + accs[i] = _mm256_fmadd_ps(diff2, diff2, accs[i]); + } + } + + if (dim + 8 < dimensionality) { + __m256 q = _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(query + dim))); + + __m256 data_regs[dp_batch]; + for (size_t i = 0; i < dp_batch; ++i) { + data_regs[i] = _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ptrs[i] + dim))); + + __m256 diff = _mm256_sub_ps(q, data_regs[i]); + accs[i] = _mm256_fmadd_ps(diff, diff, accs[i]); + } + + dim += 8; + } + + for (size_t i = 0; i < dp_batch; ++i) { + results[i] = HorizontalAdd_FP32_V256(accs[i]); + } + + for (; dim < dimensionality; ++dim) { + for (size_t i = 0; i < dp_batch; ++i) { + float diff = (*(query + dim)) - (*(ptrs[i] + dim)); + results[i] += diff * diff; + } + } +} + +#endif + + +} // namespace zvec::ailego::DistanceBatch \ No newline at end of file diff --git a/src/ailego/math_batch/euclidean_distance_batch_impl_int8.h b/src/ailego/math_batch/euclidean_distance_batch_impl_int8.h new file mode 100644 index 00000000..5ff7ee4c --- /dev/null +++ b/src/ailego/math_batch/euclidean_distance_batch_impl_int8.h @@ -0,0 +1,92 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +namespace zvec::ailego::DistanceBatch { + +#if defined(__AVX2__) + +template +static std::enable_if_t, void> +compute_one_to_many_squared_euclidean_avx2_int8( + const int8_t *query, const int8_t **ptrs, + std::array &prefetch_ptrs, size_t dimensionality, + float *results) { + __m256i accs[dp_batch]; + for (size_t i = 0; i < dp_batch; ++i) { + accs[i] = _mm256_setzero_si256(); + } + + size_t dim = 0; + for (; dim + 32 <= dimensionality; dim += 32) { + __m256i q = _mm256_loadu_si256((const __m256i *)(query + dim)); + __m256i data_regs[dp_batch]; + for (size_t i = 0; i < dp_batch; ++i) { + data_regs[i] = _mm256_loadu_si256((const __m256i *)(ptrs[i] + dim)); + } + if (prefetch_ptrs[0]) { + for (size_t i = 0; i < dp_batch; ++i) { + ailego_prefetch(prefetch_ptrs[i] + dim); + } + } + __m256i q_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q)); + __m256i q_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q, 1)); + __m256i data_lo[dp_batch]; + __m256i data_hi[dp_batch]; + for (size_t i = 0; i < dp_batch; ++i) { + data_lo[i] = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(data_regs[i])); + data_hi[i] = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(data_regs[i], 1)); + } + __m256i prod_lo[dp_batch]; + __m256i prod_hi[dp_batch]; + for (size_t i = 0; i < dp_batch; ++i) { + prod_lo[i] = _mm256_madd_epi16(q_lo, data_lo[i]); + prod_hi[i] = _mm256_madd_epi16(q_hi, data_hi[i]); + } + for (size_t i = 0; i < dp_batch; ++i) { + accs[i] = + _mm256_add_epi32(accs[i], _mm256_add_epi32(prod_lo[i], prod_hi[i])); + } + } + std::array temp_results; + for (size_t i = 0; i < dp_batch; ++i) { + __m128i lo = _mm256_castsi256_si128(accs[i]); + __m128i hi = _mm256_extracti128_si256(accs[i], 1); + __m128i sum128 = _mm_add_epi32(lo, hi); + sum128 = _mm_hadd_epi32(sum128, sum128); + sum128 = _mm_hadd_epi32(sum128, sum128); + temp_results[i] = _mm_cvtsi128_si32(sum128); + } + for (; dim < dimensionality; ++dim) { + int8_t q = query[dim]; + for (size_t i = 0; i < dp_batch; ++i) { + temp_results[i] += q * static_cast(ptrs[i][dim]); + } + } + for (size_t i = 0; i < dp_batch; ++i) { + results[i] = static_cast(temp_results[i]); + } +} + +#endif + + +} // namespace zvec::ailego::DistanceBatch \ No newline at end of file diff --git a/src/ailego/math_batch/inner_product_distance_batch.h b/src/ailego/math_batch/inner_product_distance_batch.h index f5799497..02203016 100644 --- a/src/ailego/math_batch/inner_product_distance_batch.h +++ b/src/ailego/math_batch/inner_product_distance_batch.h @@ -38,7 +38,7 @@ struct InnerProductDistanceBatchImpl { const ValueType *query, const ValueType **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { - return compute_one_to_many_fallback(query, ptrs, prefetch_ptrs, dim, sums); + return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, dim, sums); } static DistanceBatchQueryPreprocessFunc GetQueryPreprocessFunc() { return nullptr; @@ -54,11 +54,11 @@ struct InnerProductDistanceBatchImpl { float *sums) { #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { - return compute_one_to_many_avx2_fp32( + return compute_one_to_many_inner_product_avx2_fp32( query, ptrs, prefetch_ptrs, dim, sums); } #endif - return compute_one_to_many_fallback(query, ptrs, prefetch_ptrs, dim, sums); + return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, dim, sums); } static DistanceBatchQueryPreprocessFunc GetQueryPreprocessFunc() { @@ -78,23 +78,23 @@ struct InnerProductDistanceBatchImpl { // query, ptrs, prefetch_ptrs, dim, sums); #if defined(__AVX512VNNI__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_VNNI) { - return compute_one_to_many_avx512_vnni_int8( + return compute_one_to_many_inner_product_avx512_vnni_int8( query, ptrs, prefetch_ptrs, dim, sums); } #endif #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { - return compute_one_to_many_avx2_int8( + return compute_one_to_many_inner_product_avx2_int8( query, ptrs, prefetch_ptrs, dim, sums); } #endif - return compute_one_to_many_fallback(query, ptrs, prefetch_ptrs, dim, sums); + return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, dim, sums); } static DistanceBatchQueryPreprocessFunc GetQueryPreprocessFunc() { #if defined(__AVX512VNNI__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_VNNI) { - return compute_one_to_many_avx512_vnni_int8_query_preprocess; + return compute_one_to_many_inner_product_avx512_vnni_int8_query_preprocess; } #endif return nullptr; @@ -110,23 +110,23 @@ struct InnerProductDistanceBatchImpl { float *sums) { #if defined(__AVX512FP16__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { - return compute_one_to_many_avx512fp16_fp16( + return compute_one_to_many_inner_product_avx512fp16_fp16( query, ptrs, prefetch_ptrs, dim, sums); } #endif #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - return compute_one_to_many_avx512f_fp16( + return compute_one_to_many_inner_product_avx512f_fp16( query, ptrs, prefetch_ptrs, dim, sums); } #endif #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { - return compute_one_to_many_avx2_fp16( + return compute_one_to_many_inner_product_avx2_fp16( query, ptrs, prefetch_ptrs, dim, sums); } #endif - return compute_one_to_many_fallback(query, ptrs, prefetch_ptrs, dim, sums); + return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, dim, sums); } }; diff --git a/src/ailego/math_batch/inner_product_distance_batch_impl.h b/src/ailego/math_batch/inner_product_distance_batch_impl.h index 1554426f..55e3e6af 100644 --- a/src/ailego/math_batch/inner_product_distance_batch_impl.h +++ b/src/ailego/math_batch/inner_product_distance_batch_impl.h @@ -19,11 +19,12 @@ #include #include #include +#include "distance_batch_math.h" namespace zvec::ailego::DistanceBatch { template -static void compute_one_to_many_fallback( +static void compute_one_to_many_inner_product_fallback( const ValueType *query, const ValueType **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { @@ -36,20 +37,9 @@ static void compute_one_to_many_fallback( #if defined(__AVX2__) -inline float sum4(__m128 v) { - v = _mm_add_ps(v, _mm_castsi128_ps(_mm_srli_si128(_mm_castps_si128(v), 8))); - return v[0] + v[1]; -} - -inline __m128 sum_top_bottom_avx(__m256 v) { - const __m128 high = _mm256_extractf128_ps(v, 1); - const __m128 low = _mm256_castps256_ps128(v); - return _mm_add_ps(high, low); -} - template static std::enable_if_t, void> -compute_one_to_many_avx2_fp32( +compute_one_to_many_inner_product_avx2_fp32( const ValueType *query, const ValueType **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { diff --git a/src/ailego/math_batch/inner_product_distance_batch_impl_fp16.h b/src/ailego/math_batch/inner_product_distance_batch_impl_fp16.h index db9e81e0..30486cea 100644 --- a/src/ailego/math_batch/inner_product_distance_batch_impl_fp16.h +++ b/src/ailego/math_batch/inner_product_distance_batch_impl_fp16.h @@ -26,7 +26,7 @@ namespace zvec::ailego::DistanceBatch { template static std::enable_if_t, void> -compute_one_to_many_avx512fp16_fp16( +compute_one_to_many_inner_product_avx512fp16_fp16( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { @@ -82,7 +82,7 @@ compute_one_to_many_avx512fp16_fp16( template static std::enable_if_t, void> -compute_one_to_many_avx512f_fp16( +compute_one_to_many_inner_product_avx512f_fp16( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { @@ -172,7 +172,7 @@ compute_one_to_many_avx512f_fp16( template static std::enable_if_t, void> -compute_one_to_many_avx2_fp16( +compute_one_to_many_inner_product_avx2_fp16( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { diff --git a/src/ailego/math_batch/inner_product_distance_batch_impl_int8.h b/src/ailego/math_batch/inner_product_distance_batch_impl_int8.h index 9f49effd..0c514836 100644 --- a/src/ailego/math_batch/inner_product_distance_batch_impl_int8.h +++ b/src/ailego/math_batch/inner_product_distance_batch_impl_int8.h @@ -23,7 +23,7 @@ namespace zvec::ailego::DistanceBatch { #if defined(__AVX512VNNI__) -static void compute_one_to_many_avx512_vnni_int8_query_preprocess(void *query, +static void compute_one_to_many_inner_product_avx512_vnni_int8_query_preprocess(void *query, size_t dim) { const int8_t *input = reinterpret_cast(query); uint8_t *output = reinterpret_cast(query); @@ -51,7 +51,7 @@ static void compute_one_to_many_avx512_vnni_int8_query_preprocess(void *query, // query is unsigned template -static void compute_one_to_many_avx512_vnni_int8( +static void compute_one_to_many_inner_product_avx512_vnni_int8( const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { @@ -159,7 +159,7 @@ static void compute_one_to_many_avx512_vnni_int8( template static std::enable_if_t, void> -compute_one_to_many_avx2_int8( +compute_one_to_many_inner_product_avx2_int8( const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { diff --git a/src/core/metric/euclidean_metric.cc b/src/core/metric/euclidean_metric.cc index a1a8d598..bcd2ab23 100644 --- a/src/core/metric/euclidean_metric.cc +++ b/src/core/metric/euclidean_metric.cc @@ -1009,6 +1009,34 @@ class EuclideanMetric : public IndexMetric { } } + //! Retrieve distance function for query + MatrixBatchDistance batch_distance(void) const override { + switch (data_type_) { + case IndexMeta::DataType::DT_FP16: + return reinterpret_cast( + ailego::BaseDistance::ComputeBatch); + + case IndexMeta::DataType::DT_FP32: + return reinterpret_cast( + ailego::BaseDistance::ComputeBatch); + + case IndexMeta::DataType::DT_INT8: + return reinterpret_cast( + ailego::BaseDistance::ComputeBatch); + + case IndexMeta::DataType::DT_INT4: + return reinterpret_cast( + ailego::BaseDistance::ComputeBatch); + + default: + return nullptr; + } + } + //! Retrieve params of Metric const ailego::Params ¶ms(void) const override { return params_; From 6370754f083550d6a8ff6e60b4637a7b6a5c13b9 Mon Sep 17 00:00:00 2001 From: ray Date: Wed, 11 Feb 2026 17:18:22 +0800 Subject: [PATCH 2/6] add euclidean one2many implementation --- .../euclidean_distance_batch_impl.h | 6 + .../euclidean_distance_batch_impl_int8.h | 107 +++++++++++++----- 2 files changed, 85 insertions(+), 28 deletions(-) diff --git a/src/ailego/math_batch/euclidean_distance_batch_impl.h b/src/ailego/math_batch/euclidean_distance_batch_impl.h index 83dbb217..539c4ad1 100644 --- a/src/ailego/math_batch/euclidean_distance_batch_impl.h +++ b/src/ailego/math_batch/euclidean_distance_batch_impl.h @@ -21,6 +21,12 @@ #include #include "distance_batch_math.h" +#define SSD_FP32_GENERAL(m, q, sum) \ + { \ + float x = m - q; \ + sum += (x * x); \ + } + namespace zvec::ailego::DistanceBatch { template diff --git a/src/ailego/math_batch/euclidean_distance_batch_impl_int8.h b/src/ailego/math_batch/euclidean_distance_batch_impl_int8.h index 5ff7ee4c..51c6fd56 100644 --- a/src/ailego/math_batch/euclidean_distance_batch_impl_int8.h +++ b/src/ailego/math_batch/euclidean_distance_batch_impl_int8.h @@ -19,6 +19,12 @@ #include #include +#define SSD_INT8_GENERAL(m, q, sum) \ + { \ + int32_t x = m - q; \ + sum += static_cast(x * x); \ + } + namespace zvec::ailego::DistanceBatch { #if defined(__AVX2__) @@ -46,43 +52,88 @@ compute_one_to_many_squared_euclidean_avx2_int8( ailego_prefetch(prefetch_ptrs[i] + dim); } } - __m256i q_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q)); - __m256i q_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q, 1)); - __m256i data_lo[dp_batch]; - __m256i data_hi[dp_batch]; - for (size_t i = 0; i < dp_batch; ++i) { - data_lo[i] = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(data_regs[i])); - data_hi[i] = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(data_regs[i], 1)); - } - __m256i prod_lo[dp_batch]; - __m256i prod_hi[dp_batch]; - for (size_t i = 0; i < dp_batch; ++i) { - prod_lo[i] = _mm256_madd_epi16(q_lo, data_lo[i]); - prod_hi[i] = _mm256_madd_epi16(q_hi, data_hi[i]); - } + for (size_t i = 0; i < dp_batch; ++i) { - accs[i] = - _mm256_add_epi32(accs[i], _mm256_add_epi32(prod_lo[i], prod_hi[i])); + __m256i data_diff = _mm256_sub_epi8(_mm256_max_epi8(q, data_regs[i]), + _mm256_min_epi8(q, data_regs[i])); + + __m256i diff0 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(data_diff)); + __m256i diff1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(data_diff, 1)); + accs[i] = _mm256_add_epi32(_mm256_madd_epi16(diff0, diff0), accs[i]); + accs[i] = _mm256_add_epi32(_mm256_madd_epi16(diff1, diff1), accs[i]); } } - std::array temp_results; + for (size_t i = 0; i < dp_batch; ++i) { - __m128i lo = _mm256_castsi256_si128(accs[i]); - __m128i hi = _mm256_extracti128_si256(accs[i], 1); - __m128i sum128 = _mm_add_epi32(lo, hi); - sum128 = _mm_hadd_epi32(sum128, sum128); - sum128 = _mm_hadd_epi32(sum128, sum128); - temp_results[i] = _mm_cvtsi128_si32(sum128); + results[i] = HorizontalAdd_INT32_V256(accs[i]); } - for (; dim < dimensionality; ++dim) { - int8_t q = query[dim]; + + if (dimensionality >= dim + 16) { for (size_t i = 0; i < dp_batch; ++i) { - temp_results[i] += q * static_cast(ptrs[i][dim]); + __m128i q = _mm_loadu_si128((const __m128i *)query + dim); + __m128i data_regs = _mm_loadu_si128((const __m128i *)(ptrs[i]+dim)); + + __m128i diff = _mm_sub_epi8(_mm_max_epi8(q, data_regs), + _mm_min_epi8(q, data_regs)); + + __m128i diff0 = _mm_cvtepu8_epi16(diff); + __m128i diff1 = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(diff, diff)); + __m128i sum = _mm_add_epi32(_mm_madd_epi16(diff0, diff0), + _mm_madd_epi16(diff1, diff1)); + + results[i] += static_cast(HorizontalAdd_INT32_V128(sum)); } + + dim += 16; } + for (size_t i = 0; i < dp_batch; ++i) { - results[i] = static_cast(temp_results[i]); + switch (dimensionality - dim) { + case 15: + SSD_INT8_GENERAL(query+dim, ptrs[14]+dim, results[i]); + /* FALLTHRU */ + case 14: + SSD_INT8_GENERAL(query+dim, ptrs[13+dim], results[i]); + /* FALLTHRU */ + case 13: + SSD_INT8_GENERAL(query+dim, ptrs[12]+dim, results[i]); + /* FALLTHRU */ + case 12: + SSD_INT8_GENERAL(query+dim, ptrs[11]+dim, results[i]); + /* FALLTHRU */ + case 11: + SSD_INT8_GENERAL(query+dim, ptrs[10+dim], results[i]); + /* FALLTHRU */ + case 10: + SSD_INT8_GENERAL(query+dim, ptrs[9]+dim, results[i]); + /* FALLTHRU */ + case 9: + SSD_INT8_GENERAL(query+dim, ptrs[8]+dim, results[i]); + /* FALLTHRU */ + case 8: + SSD_INT8_GENERAL(query+dim, ptrs[7]+dim, results[i]); + /* FALLTHRU */ + case 7: + SSD_INT8_GENERAL(query+dim, ptrs[6]+dim, results[i]); + /* FALLTHRU */ + case 6: + SSD_INT8_GENERAL(query+dim, ptrs[5]+dim, results[i]); + /* FALLTHRU */ + case 5: + SSD_INT8_GENERAL(query+dim, ptrs[4]+dim, results[i]); + /* FALLTHRU */ + case 4: + SSD_INT8_GENERAL(query+dim, ptrs[3]+dim, results[i]); + /* FALLTHRU */ + case 3: + SSD_INT8_GENERAL(query+dim, ptrs[2]+dim, results[i]); + /* FALLTHRU */ + case 2: + SSD_INT8_GENERAL(query+dim, ptrs[1]+dim, results[i]); + /* FALLTHRU */ + case 1: + SSD_INT8_GENERAL(query+dim, ptrs[0]+dim, results[i]); + } } } From 200cc06b1bb5b717fcc1befd77df2bdb32ef6821 Mon Sep 17 00:00:00 2001 From: ray Date: Fri, 13 Feb 2026 17:26:38 +0800 Subject: [PATCH 3/6] add euclidean one2many implementation --- src/ailego/math_batch/distance_batch.h | 14 ++++++++++++++ .../math_batch/euclidean_distance_batch_impl.h | 2 +- .../euclidean_distance_batch_impl_fp16.h | 2 +- src/core/metric/euclidean_metric.cc | 12 ------------ src/core/metric/hamming_metric.cc | 16 ---------------- 5 files changed, 16 insertions(+), 30 deletions(-) diff --git a/src/ailego/math_batch/distance_batch.h b/src/ailego/math_batch/distance_batch.h index 92ed65f1..0588b619 100644 --- a/src/ailego/math_batch/distance_batch.h +++ b/src/ailego/math_batch/distance_batch.h @@ -43,6 +43,20 @@ struct BaseDistance { out); } + if constexpr (std::is_same_v, + EuclideanDistanceMatrix>) { + return DistanceBatch::EuclideanDistanceBatch< + ValueType, BatchSize, PrefetchStep>::ComputeBatch(m, q, num, dim, + out); + } + + if constexpr (std::is_same_v, + SquaredEuclideanDistanceMatrix>) { + return DistanceBatch::SquaredEuclideanDistanceBatch< + ValueType, BatchSize, PrefetchStep>::ComputeBatch(m, q, num, dim, + out); + } + _ComputeBatch(m, q, num, dim, out); } }; diff --git a/src/ailego/math_batch/euclidean_distance_batch_impl.h b/src/ailego/math_batch/euclidean_distance_batch_impl.h index 539c4ad1..94b19f28 100644 --- a/src/ailego/math_batch/euclidean_distance_batch_impl.h +++ b/src/ailego/math_batch/euclidean_distance_batch_impl.h @@ -124,7 +124,7 @@ compute_one_to_many_squared_euclidean_avx2_fp32( } for (size_t i = 0; i < dp_batch; ++i) { - results[i] = HorizontalAdd_FP32_V256(_mm256_add_ps(accs[i])); + results[i] = HorizontalAdd_FP32_V256(accs[i]); switch (dimensionality - dim) { case 7: diff --git a/src/ailego/math_batch/euclidean_distance_batch_impl_fp16.h b/src/ailego/math_batch/euclidean_distance_batch_impl_fp16.h index 23c117e6..bf61858f 100644 --- a/src/ailego/math_batch/euclidean_distance_batch_impl_fp16.h +++ b/src/ailego/math_batch/euclidean_distance_batch_impl_fp16.h @@ -217,7 +217,7 @@ compute_one_to_many_squared_euclidean_avx2_fp16( __m256 diff1 = _mm256_sub_ps(q1, data_regs_1[i]); accs[i] = _mm256_fmadd_ps(diff1, diff1, accs[i]); - __m256 diff2 = _mm256_sub_ps(q1, data_regs_2[i]); + __m256 diff2 = _mm256_sub_ps(q2, data_regs_2[i]); accs[i] = _mm256_fmadd_ps(diff2, diff2, accs[i]); } } diff --git a/src/core/metric/euclidean_metric.cc b/src/core/metric/euclidean_metric.cc index bcd2ab23..93eb936e 100644 --- a/src/core/metric/euclidean_metric.cc +++ b/src/core/metric/euclidean_metric.cc @@ -853,18 +853,6 @@ class SquaredEuclideanMetric : public IndexMetric { //! Retrieve distance function for query MatrixBatchDistance batch_distance(void) const override { switch (data_type_) { - case IndexMeta::DataType::DT_BINARY32: - return reinterpret_cast( - ailego::BaseDistance::ComputeBatch); - -#if defined(AILEGO_M64) - case IndexMeta::DataType::DT_BINARY64: - return reinterpret_cast( - ailego::BaseDistance::ComputeBatch); -#endif // AILEGO_M64 - case IndexMeta::DataType::DT_FP16: return reinterpret_cast( ailego::BaseDistance( - ailego::BaseDistance::ComputeBatch); - } -#endif - if (feature_type_ == IndexMeta::DataType::DT_BINARY32) { - return reinterpret_cast( - ailego::BaseDistance::ComputeBatch); - } - return nullptr; - } - //! Retrieve distance function for index features MatrixDistance distance_matrix(size_t m, size_t n) const override { #if defined(AILEGO_M64) From 9c07ff9036acb03e714b945354d9680756b54762 Mon Sep 17 00:00:00 2001 From: ray Date: Thu, 26 Feb 2026 20:13:30 +0800 Subject: [PATCH 4/6] format codes --- src/ailego/math/matrix_utility.i | 36 +++++++------- .../euclidean_distance_batch_impl.h | 15 +++--- .../euclidean_distance_batch_impl_fp16.h | 16 ++++--- .../euclidean_distance_batch_impl_int8.h | 47 ++++++++++--------- .../math_batch/inner_product_distance_batch.h | 18 ++++--- .../inner_product_distance_batch_impl_int8.h | 4 +- src/core/metric/euclidean_metric.cc | 12 ++--- 7 files changed, 78 insertions(+), 70 deletions(-) diff --git a/src/ailego/math/matrix_utility.i b/src/ailego/math/matrix_utility.i index 18d5140e..8ae3e245 100644 --- a/src/ailego/math/matrix_utility.i +++ b/src/ailego/math/matrix_utility.i @@ -48,7 +48,7 @@ static inline float HorizontalAdd_FP32_V128(__m128 v) { return _mm_cvtss_f32(x4); #endif } -#endif // __SSE__ +#endif // __SSE__ #if defined(__SSE2__) static inline int32_t HorizontalAdd_INT32_V128(__m128i v) { @@ -73,7 +73,7 @@ static inline int64_t HorizontalAdd_INT64_V128(__m128i v) { _mm_add_epi64(_mm_shuffle_epi32(v, _MM_SHUFFLE(0, 0, 3, 2)), v)); #endif } -#endif // __SSE2__ +#endif // __SSE2__ #if defined(__SSSE3__) static const __m128i POPCNT_LOOKUP_SSE = @@ -88,7 +88,7 @@ static inline __m128i VerticalPopCount_INT8_V128(__m128i v) { __m128i hi = _mm_shuffle_epi8(POPCNT_LOOKUP_SSE, _mm_and_si128(_mm_srli_epi32(v, 4), low_mask)); return _mm_add_epi8(lo, hi); -#endif // __AVX512VL__ && __AVX512BITALG__ +#endif // __AVX512VL__ && __AVX512BITALG__ } static inline __m128i VerticalPopCount_INT16_V128(__m128i v) { @@ -98,7 +98,7 @@ static inline __m128i VerticalPopCount_INT16_V128(__m128i v) { __m128i total = VerticalPopCount_INT8_V128(v); return _mm_add_epi16(_mm_srli_epi16(total, 8), _mm_and_si128(total, _mm_set1_epi16(0xff))); -#endif // __AVX512VL__ && __AVX512BITALG__ +#endif // __AVX512VL__ && __AVX512BITALG__ } static inline __m128i VerticalPopCount_INT32_V128(__m128i v) { @@ -109,7 +109,7 @@ static inline __m128i VerticalPopCount_INT32_V128(__m128i v) { _mm_madd_epi16(VerticalPopCount_INT8_V128(v), _mm_set1_epi16(1)); return _mm_add_epi32(_mm_srli_epi32(total, 8), _mm_and_si128(total, _mm_set1_epi32(0xff))); -#endif // __AVX512VL__ && __AVX512VPOPCNTDQ__ +#endif // __AVX512VL__ && __AVX512VPOPCNTDQ__ } static inline __m128i VerticalPopCount_INT64_V128(__m128i v) { @@ -117,9 +117,9 @@ static inline __m128i VerticalPopCount_INT64_V128(__m128i v) { return _mm_popcnt_epi64(v); #else return _mm_sad_epu8(VerticalPopCount_INT8_V128(v), _mm_setzero_si128()); -#endif // __AVX512VL__ && __AVX512VPOPCNTDQ__ +#endif // __AVX512VL__ && __AVX512VPOPCNTDQ__ } -#endif // __SSSE3__ +#endif // __SSSE3__ #if defined(__SSE4_1__) static inline int16_t HorizontalMax_UINT8_V128(__m128i v) { @@ -129,7 +129,7 @@ static inline int16_t HorizontalMax_UINT8_V128(__m128i v) { v = _mm_max_epu8(v, _mm_srli_epi16(v, 8)); return static_cast(_mm_cvtsi128_si32(v)); } -#endif // __SSE4_1__ +#endif // __SSE4_1__ #if defined(__AVX__) static inline float HorizontalMax_FP32_V256(__m256 v) { @@ -149,7 +149,7 @@ static inline float HorizontalAdd_FP32_V256(__m256 v) { __m128 x4 = _mm_add_ss(_mm256_castps256_ps128(x2), x3); return _mm_cvtss_f32(x4); } -#endif // __AVX__ +#endif // __AVX__ #if defined(__AVX2__) static const __m256i POPCNT_MASK1_INT8_AVX = _mm256_set1_epi8(0x0f); @@ -171,7 +171,7 @@ static inline __m256i VerticalPopCount_INT8_V256(__m256i v) { POPCNT_LOOKUP_AVX, _mm256_and_si256(_mm256_srli_epi32(v, 4), POPCNT_MASK1_INT8_AVX)); return _mm256_add_epi8(lo, hi); -#endif // __AVX512VL__ && __AVX512BITALG__ +#endif // __AVX512VL__ && __AVX512BITALG__ } static inline __m256i VerticalPopCount_INT16_V256(__m256i v) { @@ -181,7 +181,7 @@ static inline __m256i VerticalPopCount_INT16_V256(__m256i v) { __m256i total = VerticalPopCount_INT8_V256(v); return _mm256_add_epi16(_mm256_srli_epi16(total, 8), _mm256_and_si256(total, POPCNT_MASK2_INT16_AVX)); -#endif // __AVX512VL__ && __AVX512BITALG__ +#endif // __AVX512VL__ && __AVX512BITALG__ } static inline __m256i VerticalPopCount_INT32_V256(__m256i v) { @@ -192,7 +192,7 @@ static inline __m256i VerticalPopCount_INT32_V256(__m256i v) { _mm256_madd_epi16(VerticalPopCount_INT8_V256(v), POPCNT_MASK1_INT16_AVX); return _mm256_add_epi32(_mm256_srli_epi32(total, 8), _mm256_and_si256(total, POPCNT_MASK1_INT32_AVX)); -#endif // __AVX512VL__ && __AVX512VPOPCNTDQ__ +#endif // __AVX512VL__ && __AVX512VPOPCNTDQ__ } static inline __m256i VerticalPopCount_INT64_V256(__m256i v) { @@ -200,7 +200,7 @@ static inline __m256i VerticalPopCount_INT64_V256(__m256i v) { return _mm256_popcnt_epi64(v); #else return _mm256_sad_epu8(VerticalPopCount_INT8_V256(v), POPCNT_ZERO_AVX); -#endif // __AVX512VL__ && __AVX512VPOPCNTDQ__ +#endif // __AVX512VL__ && __AVX512VPOPCNTDQ__ } static inline int16_t HorizontalMax_UINT8_V256(__m256i v) { @@ -228,7 +228,7 @@ static inline int64_t HorizontalAdd_INT64_V256(__m256i v) { __m128i x4 = _mm_add_epi64(_mm256_extractf128_si256(x2, 0), x3); return _mm_cvtsi128_si64(x4); } -#endif // __AVX2__ +#endif // __AVX2__ #if defined(__AVX512F__) static inline float HorizontalMax_FP32_V512(__m512 v) { @@ -244,7 +244,7 @@ static inline float HorizontalAdd_FP32_V512(__m512 v) { _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(v), 1)); return HorizontalAdd_FP32_V256(_mm256_add_ps(low, high)); } -#endif // __AVX512F__ +#endif // __AVX512F__ #if defined(__AVX512FP16__) static inline float HorizontalMax_FP16_V512(__m512h v) { @@ -261,7 +261,7 @@ static inline float HorizontalAdd_FP16_V512(__m512h v) { return HorizontalAdd_FP32_V512(_mm512_add_ps(low, high)); } -#endif // __AVX512FP16__ +#endif // __AVX512FP16__ -} // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace ailego +} // namespace zvec \ No newline at end of file diff --git a/src/ailego/math_batch/euclidean_distance_batch_impl.h b/src/ailego/math_batch/euclidean_distance_batch_impl.h index 94b19f28..19be4ead 100644 --- a/src/ailego/math_batch/euclidean_distance_batch_impl.h +++ b/src/ailego/math_batch/euclidean_distance_batch_impl.h @@ -36,12 +36,13 @@ static void compute_one_to_many_squared_euclidean_fallback( float *sums) { for (size_t j = 0; j < BatchSize; ++j) { sums[j] = 0.0; - SquaredEuclideanDistanceMatrix::Compute(ptrs[j], query, dim, sums + j); + SquaredEuclideanDistanceMatrix::Compute(ptrs[j], query, + dim, sums + j); ailego_prefetch(&prefetch_ptrs[j]); } } -#if defined(__AVX512F__) +#if defined(__AVX512F__) template static std::enable_if_t, void> @@ -77,12 +78,10 @@ compute_one_to_many_squared_euclidean_avx512f_fp32( for (size_t i = 0; i < dp_batch; ++i) { __m512 zmm_undefined = _mm512_undefined_ps(); - accs[i] = - _mm512_mask3_fmadd_ps(_mm512_mask_loadu_ps( - zmm_undefined, mask, query + dim), - _mm512_mask_loadu_ps( - zmm_undefined, mask, ptrs[i] + dim), - accs[i], mask); + accs[i] = _mm512_mask3_fmadd_ps( + _mm512_mask_loadu_ps(zmm_undefined, mask, query + dim), + _mm512_mask_loadu_ps(zmm_undefined, mask, ptrs[i] + dim), accs[i], + mask); } } diff --git a/src/ailego/math_batch/euclidean_distance_batch_impl_fp16.h b/src/ailego/math_batch/euclidean_distance_batch_impl_fp16.h index bf61858f..23196d52 100644 --- a/src/ailego/math_batch/euclidean_distance_batch_impl_fp16.h +++ b/src/ailego/math_batch/euclidean_distance_batch_impl_fp16.h @@ -63,10 +63,12 @@ compute_one_to_many_squared_euclidean_avx512fp16_fp16( for (size_t i = 0; i < dp_batch; ++i) { __m512i zmm_undefined = _mm512_undefined_epi32(); __m512h zmm_undefined_ph = _mm512_undefined_ph(); - __m512h zmm_d = _mm512_mask_sub_ph( - zmm_undefined_ph, mask, - _mm512_castsi512_ph(_mm512_mask_loadu_epi16(zmm_undefined, mask, query + dim)), - _mm512_castsi512_ph(_mm512_mask_loadu_epi16(zmm_undefined, mask, ptrs[i] + dim))); + __m512h zmm_d = + _mm512_mask_sub_ph(zmm_undefined_ph, mask, + _mm512_castsi512_ph(_mm512_mask_loadu_epi16( + zmm_undefined, mask, query + dim)), + _mm512_castsi512_ph(_mm512_mask_loadu_epi16( + zmm_undefined, mask, ptrs[i] + dim))); accs[i] = _mm512_mask3_fmadd_ph(zmm_d, zmm_d, accs[i], mask); } @@ -118,11 +120,11 @@ compute_one_to_many_squared_euclidean_avx512f_fp16( } for (size_t i = 0; i < dp_batch; ++i) { - __m512 diff1 = _mm512_sub_ps(q1, data_regs_1[i]); + __m512 diff1 = _mm512_sub_ps(q1, data_regs_1[i]); accs[i] = _mm512_fmadd_ps(diff1, diff1, accs[i]); - __m512 diff2 = _mm512_sub_ps(q2, data_regs_2[i]); - accs[i] = _mm512_fmadd_ps(diff2,diff2, accs[i]); + __m512 diff2 = _mm512_sub_ps(q2, data_regs_2[i]); + accs[i] = _mm512_fmadd_ps(diff2, diff2, accs[i]); } } diff --git a/src/ailego/math_batch/euclidean_distance_batch_impl_int8.h b/src/ailego/math_batch/euclidean_distance_batch_impl_int8.h index 51c6fd56..e69fafa0 100644 --- a/src/ailego/math_batch/euclidean_distance_batch_impl_int8.h +++ b/src/ailego/math_batch/euclidean_distance_batch_impl_int8.h @@ -39,7 +39,7 @@ compute_one_to_many_squared_euclidean_avx2_int8( for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm256_setzero_si256(); } - + size_t dim = 0; for (; dim + 32 <= dimensionality; dim += 32) { __m256i q = _mm256_loadu_si256((const __m256i *)(query + dim)); @@ -52,13 +52,14 @@ compute_one_to_many_squared_euclidean_avx2_int8( ailego_prefetch(prefetch_ptrs[i] + dim); } } - + for (size_t i = 0; i < dp_batch; ++i) { __m256i data_diff = _mm256_sub_epi8(_mm256_max_epi8(q, data_regs[i]), - _mm256_min_epi8(q, data_regs[i])); + _mm256_min_epi8(q, data_regs[i])); __m256i diff0 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(data_diff)); - __m256i diff1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(data_diff, 1)); + __m256i diff1 = + _mm256_cvtepu8_epi16(_mm256_extractf128_si256(data_diff, 1)); accs[i] = _mm256_add_epi32(_mm256_madd_epi16(diff0, diff0), accs[i]); accs[i] = _mm256_add_epi32(_mm256_madd_epi16(diff1, diff1), accs[i]); } @@ -71,15 +72,15 @@ compute_one_to_many_squared_euclidean_avx2_int8( if (dimensionality >= dim + 16) { for (size_t i = 0; i < dp_batch; ++i) { __m128i q = _mm_loadu_si128((const __m128i *)query + dim); - __m128i data_regs = _mm_loadu_si128((const __m128i *)(ptrs[i]+dim)); + __m128i data_regs = _mm_loadu_si128((const __m128i *)(ptrs[i] + dim)); - __m128i diff = _mm_sub_epi8(_mm_max_epi8(q, data_regs), - _mm_min_epi8(q, data_regs)); + __m128i diff = + _mm_sub_epi8(_mm_max_epi8(q, data_regs), _mm_min_epi8(q, data_regs)); __m128i diff0 = _mm_cvtepu8_epi16(diff); __m128i diff1 = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(diff, diff)); __m128i sum = _mm_add_epi32(_mm_madd_epi16(diff0, diff0), - _mm_madd_epi16(diff1, diff1)); + _mm_madd_epi16(diff1, diff1)); results[i] += static_cast(HorizontalAdd_INT32_V128(sum)); } @@ -90,49 +91,49 @@ compute_one_to_many_squared_euclidean_avx2_int8( for (size_t i = 0; i < dp_batch; ++i) { switch (dimensionality - dim) { case 15: - SSD_INT8_GENERAL(query+dim, ptrs[14]+dim, results[i]); + SSD_INT8_GENERAL(query + dim, ptrs[14] + dim, results[i]); /* FALLTHRU */ case 14: - SSD_INT8_GENERAL(query+dim, ptrs[13+dim], results[i]); + SSD_INT8_GENERAL(query + dim, ptrs[13 + dim], results[i]); /* FALLTHRU */ case 13: - SSD_INT8_GENERAL(query+dim, ptrs[12]+dim, results[i]); + SSD_INT8_GENERAL(query + dim, ptrs[12] + dim, results[i]); /* FALLTHRU */ case 12: - SSD_INT8_GENERAL(query+dim, ptrs[11]+dim, results[i]); + SSD_INT8_GENERAL(query + dim, ptrs[11] + dim, results[i]); /* FALLTHRU */ case 11: - SSD_INT8_GENERAL(query+dim, ptrs[10+dim], results[i]); + SSD_INT8_GENERAL(query + dim, ptrs[10 + dim], results[i]); /* FALLTHRU */ case 10: - SSD_INT8_GENERAL(query+dim, ptrs[9]+dim, results[i]); + SSD_INT8_GENERAL(query + dim, ptrs[9] + dim, results[i]); /* FALLTHRU */ case 9: - SSD_INT8_GENERAL(query+dim, ptrs[8]+dim, results[i]); + SSD_INT8_GENERAL(query + dim, ptrs[8] + dim, results[i]); /* FALLTHRU */ case 8: - SSD_INT8_GENERAL(query+dim, ptrs[7]+dim, results[i]); + SSD_INT8_GENERAL(query + dim, ptrs[7] + dim, results[i]); /* FALLTHRU */ case 7: - SSD_INT8_GENERAL(query+dim, ptrs[6]+dim, results[i]); + SSD_INT8_GENERAL(query + dim, ptrs[6] + dim, results[i]); /* FALLTHRU */ case 6: - SSD_INT8_GENERAL(query+dim, ptrs[5]+dim, results[i]); + SSD_INT8_GENERAL(query + dim, ptrs[5] + dim, results[i]); /* FALLTHRU */ case 5: - SSD_INT8_GENERAL(query+dim, ptrs[4]+dim, results[i]); + SSD_INT8_GENERAL(query + dim, ptrs[4] + dim, results[i]); /* FALLTHRU */ case 4: - SSD_INT8_GENERAL(query+dim, ptrs[3]+dim, results[i]); + SSD_INT8_GENERAL(query + dim, ptrs[3] + dim, results[i]); /* FALLTHRU */ case 3: - SSD_INT8_GENERAL(query+dim, ptrs[2]+dim, results[i]); + SSD_INT8_GENERAL(query + dim, ptrs[2] + dim, results[i]); /* FALLTHRU */ case 2: - SSD_INT8_GENERAL(query+dim, ptrs[1]+dim, results[i]); + SSD_INT8_GENERAL(query + dim, ptrs[1] + dim, results[i]); /* FALLTHRU */ case 1: - SSD_INT8_GENERAL(query+dim, ptrs[0]+dim, results[i]); + SSD_INT8_GENERAL(query + dim, ptrs[0] + dim, results[i]); } } } diff --git a/src/ailego/math_batch/inner_product_distance_batch.h b/src/ailego/math_batch/inner_product_distance_batch.h index 02203016..e0447f23 100644 --- a/src/ailego/math_batch/inner_product_distance_batch.h +++ b/src/ailego/math_batch/inner_product_distance_batch.h @@ -38,7 +38,8 @@ struct InnerProductDistanceBatchImpl { const ValueType *query, const ValueType **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { - return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, dim, sums); + return compute_one_to_many_inner_product_fallback(query, ptrs, + prefetch_ptrs, dim, sums); } static DistanceBatchQueryPreprocessFunc GetQueryPreprocessFunc() { return nullptr; @@ -58,7 +59,8 @@ struct InnerProductDistanceBatchImpl { query, ptrs, prefetch_ptrs, dim, sums); } #endif - return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, dim, sums); + return compute_one_to_many_inner_product_fallback(query, ptrs, + prefetch_ptrs, dim, sums); } static DistanceBatchQueryPreprocessFunc GetQueryPreprocessFunc() { @@ -88,7 +90,8 @@ struct InnerProductDistanceBatchImpl { query, ptrs, prefetch_ptrs, dim, sums); } #endif - return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, dim, sums); + return compute_one_to_many_inner_product_fallback(query, ptrs, + prefetch_ptrs, dim, sums); } static DistanceBatchQueryPreprocessFunc GetQueryPreprocessFunc() { @@ -110,13 +113,15 @@ struct InnerProductDistanceBatchImpl { float *sums) { #if defined(__AVX512FP16__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { - return compute_one_to_many_inner_product_avx512fp16_fp16( + return compute_one_to_many_inner_product_avx512fp16_fp16( query, ptrs, prefetch_ptrs, dim, sums); } #endif #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - return compute_one_to_many_inner_product_avx512f_fp16( + return compute_one_to_many_inner_product_avx512f_fp16( query, ptrs, prefetch_ptrs, dim, sums); } #endif @@ -126,7 +131,8 @@ struct InnerProductDistanceBatchImpl { query, ptrs, prefetch_ptrs, dim, sums); } #endif - return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, dim, sums); + return compute_one_to_many_inner_product_fallback(query, ptrs, + prefetch_ptrs, dim, sums); } }; diff --git a/src/ailego/math_batch/inner_product_distance_batch_impl_int8.h b/src/ailego/math_batch/inner_product_distance_batch_impl_int8.h index 0c514836..976eeb8b 100644 --- a/src/ailego/math_batch/inner_product_distance_batch_impl_int8.h +++ b/src/ailego/math_batch/inner_product_distance_batch_impl_int8.h @@ -23,8 +23,8 @@ namespace zvec::ailego::DistanceBatch { #if defined(__AVX512VNNI__) -static void compute_one_to_many_inner_product_avx512_vnni_int8_query_preprocess(void *query, - size_t dim) { +static void compute_one_to_many_inner_product_avx512_vnni_int8_query_preprocess( + void *query, size_t dim) { const int8_t *input = reinterpret_cast(query); uint8_t *output = reinterpret_cast(query); diff --git a/src/core/metric/euclidean_metric.cc b/src/core/metric/euclidean_metric.cc index 93eb936e..dd37315e 100644 --- a/src/core/metric/euclidean_metric.cc +++ b/src/core/metric/euclidean_metric.cc @@ -1007,18 +1007,18 @@ class EuclideanMetric : public IndexMetric { case IndexMeta::DataType::DT_FP32: return reinterpret_cast( - ailego::BaseDistance::ComputeBatch); + ailego::BaseDistance::ComputeBatch); case IndexMeta::DataType::DT_INT8: return reinterpret_cast( - ailego::BaseDistance::ComputeBatch); + ailego::BaseDistance::ComputeBatch); case IndexMeta::DataType::DT_INT4: return reinterpret_cast( - ailego::BaseDistance::ComputeBatch); + ailego::BaseDistance::ComputeBatch); default: return nullptr; From 9545fbab3963de72d380219f0034e1ab01a8589f Mon Sep 17 00:00:00 2001 From: raymond Date: Mon, 2 Mar 2026 16:22:35 +0800 Subject: [PATCH 5/6] format codes --- src/ailego/math_batch/distance_batch.h | 7 +-- .../math_batch/euclidean_distance_batch.h | 45 ++++++++++++------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/src/ailego/math_batch/distance_batch.h b/src/ailego/math_batch/distance_batch.h index 0588b619..9494be85 100644 --- a/src/ailego/math_batch/distance_batch.h +++ b/src/ailego/math_batch/distance_batch.h @@ -17,8 +17,8 @@ #include #include "ailego/math/distance_matrix.h" #include "cosine_distance_batch.h" -#include "inner_product_distance_batch.h" #include "euclidean_distance_batch.h" +#include "inner_product_distance_batch.h" namespace zvec::ailego { @@ -50,8 +50,9 @@ struct BaseDistance { out); } - if constexpr (std::is_same_v, - SquaredEuclideanDistanceMatrix>) { + if constexpr (std::is_same_v< + DistanceType, + SquaredEuclideanDistanceMatrix>) { return DistanceBatch::SquaredEuclideanDistanceBatch< ValueType, BatchSize, PrefetchStep>::ComputeBatch(m, q, num, dim, out); diff --git a/src/ailego/math_batch/euclidean_distance_batch.h b/src/ailego/math_batch/euclidean_distance_batch.h index d09d6a68..96705d4e 100644 --- a/src/ailego/math_batch/euclidean_distance_batch.h +++ b/src/ailego/math_batch/euclidean_distance_batch.h @@ -26,7 +26,7 @@ namespace zvec::ailego::DistanceBatch { -//SquaredEuclideanDistanceBatch +// SquaredEuclideanDistanceBatch template struct SquaredEuclideanDistanceBatch; @@ -39,7 +39,8 @@ struct SquaredEuclideanDistanceBatchImpl { const ValueType *query, const ValueType **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { - return compute_one_to_many_squared_euclidean_fallback(query, ptrs, prefetch_ptrs, dim, sums); + return compute_one_to_many_squared_euclidean_fallback( + query, ptrs, prefetch_ptrs, dim, sums); } }; @@ -52,18 +53,21 @@ struct SquaredEuclideanDistanceBatchImpl { float *sums) { #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { - return compute_one_to_many_squared_euclidean_avx512f_fp32( + return compute_one_to_many_squared_euclidean_avx512f_fp32( query, ptrs, prefetch_ptrs, dim, sums); } #endif #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { - return compute_one_to_many_squared_euclidean_avx2_fp32( + return compute_one_to_many_squared_euclidean_avx2_fp32( query, ptrs, prefetch_ptrs, dim, sums); } #endif - return compute_one_to_many_squared_euclidean_fallback(query, ptrs, prefetch_ptrs, dim, sums); + return compute_one_to_many_squared_euclidean_fallback( + query, ptrs, prefetch_ptrs, dim, sums); } }; @@ -76,11 +80,13 @@ struct SquaredEuclideanDistanceBatchImpl { float *sums) { #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { - return compute_one_to_many_squared_euclidean_avx2_int8( + return compute_one_to_many_squared_euclidean_avx2_int8( query, ptrs, prefetch_ptrs, dim, sums); } #endif - return compute_one_to_many_squared_euclidean_fallback(query, ptrs, prefetch_ptrs, dim, sums); + return compute_one_to_many_squared_euclidean_fallback( + query, ptrs, prefetch_ptrs, dim, sums); } }; @@ -93,23 +99,27 @@ struct SquaredEuclideanDistanceBatchImpl { float *sums) { #if defined(__AVX512FP16__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { - return compute_one_to_many_squared_euclidean_avx512fp16_fp16( + return compute_one_to_many_squared_euclidean_avx512fp16_fp16( query, ptrs, prefetch_ptrs, dim, sums); } #endif #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { - return compute_one_to_many_squared_euclidean_avx512f_fp16( + return compute_one_to_many_squared_euclidean_avx512f_fp16( query, ptrs, prefetch_ptrs, dim, sums); } #endif #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { - return compute_one_to_many_squared_euclidean_avx2_fp16( + return compute_one_to_many_squared_euclidean_avx2_fp16( query, ptrs, prefetch_ptrs, dim, sums); } #endif - return compute_one_to_many_squared_euclidean_fallback(query, ptrs, prefetch_ptrs, dim, sums); + return compute_one_to_many_squared_euclidean_fallback( + query, ptrs, prefetch_ptrs, dim, sums); } }; @@ -130,8 +140,10 @@ struct SquaredEuclideanDistanceBatch { prefetch_ptrs[j] = nullptr; } } - SquaredEuclideanDistanceBatchImpl::compute_one_to_many( - query, &vecs[i], prefetch_ptrs, dim, &results[i]); + SquaredEuclideanDistanceBatchImpl< + ValueType, BatchSize>::compute_one_to_many(query, &vecs[i], + prefetch_ptrs, dim, + &results[i]); } for (; i < num_vecs; ++i) { // TODO: unroll by 1, 2, 4, 8, etc. std::array prefetch_ptrs{nullptr}; @@ -141,7 +153,7 @@ struct SquaredEuclideanDistanceBatch { } }; -//EuclideanDistanceBatch +// EuclideanDistanceBatch template struct EuclideanDistanceBatch; @@ -152,9 +164,10 @@ struct EuclideanDistanceBatch { static inline void ComputeBatch(const ValueType **vecs, const ValueType *query, size_t num_vecs, size_t dim, float *results) { - SquaredEuclideanDistanceBatch::ComputeBatch(vecs, query, num_vecs, dim, results); + SquaredEuclideanDistanceBatch::ComputeBatch( + vecs, query, num_vecs, dim, results); - for (size_t i=0; i Date: Mon, 2 Mar 2026 17:07:53 +0800 Subject: [PATCH 6/6] fix macro missing --- src/ailego/math_batch/distance_batch_math.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/ailego/math_batch/distance_batch_math.h b/src/ailego/math_batch/distance_batch_math.h index 63ee6c18..a2672cd8 100644 --- a/src/ailego/math_batch/distance_batch_math.h +++ b/src/ailego/math_batch/distance_batch_math.h @@ -14,6 +14,8 @@ #pragma once +#if defined(__AVX2__) + inline float sum4(__m128 v) { v = _mm_add_ps(v, _mm_castsi128_ps(_mm_srli_si128(_mm_castps_si128(v), 8))); return v[0] + v[1]; @@ -24,3 +26,5 @@ inline __m128 sum_top_bottom_avx(__m256 v) { const __m128 low = _mm256_castps256_ps128(v); return _mm_add_ps(high, low); } + +#endif