Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions src/ailego/math/matrix_utility.i
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <zvec/ailego/internal/platform.h>

namespace zvec {
Expand Down Expand Up @@ -46,7 +48,7 @@ static inline float HorizontalAdd_FP32_V128(__m128 v) {
return _mm_cvtss_f32(x4);
#endif
}
#endif // __SSE__
#endif // __SSE__

#if defined(__SSE2__)
static inline int32_t HorizontalAdd_INT32_V128(__m128i v) {
Expand All @@ -71,7 +73,7 @@ static inline int64_t HorizontalAdd_INT64_V128(__m128i v) {
_mm_add_epi64(_mm_shuffle_epi32(v, _MM_SHUFFLE(0, 0, 3, 2)), v));
#endif
}
#endif // __SSE2__
#endif // __SSE2__

#if defined(__SSSE3__)
static const __m128i POPCNT_LOOKUP_SSE =
Expand All @@ -86,7 +88,7 @@ static inline __m128i VerticalPopCount_INT8_V128(__m128i v) {
__m128i hi = _mm_shuffle_epi8(POPCNT_LOOKUP_SSE,
_mm_and_si128(_mm_srli_epi32(v, 4), low_mask));
return _mm_add_epi8(lo, hi);
#endif // __AVX512VL__ && __AVX512BITALG__
#endif // __AVX512VL__ && __AVX512BITALG__
}

static inline __m128i VerticalPopCount_INT16_V128(__m128i v) {
Expand All @@ -96,7 +98,7 @@ static inline __m128i VerticalPopCount_INT16_V128(__m128i v) {
__m128i total = VerticalPopCount_INT8_V128(v);
return _mm_add_epi16(_mm_srli_epi16(total, 8),
_mm_and_si128(total, _mm_set1_epi16(0xff)));
#endif // __AVX512VL__ && __AVX512BITALG__
#endif // __AVX512VL__ && __AVX512BITALG__
}

static inline __m128i VerticalPopCount_INT32_V128(__m128i v) {
Expand All @@ -107,17 +109,17 @@ static inline __m128i VerticalPopCount_INT32_V128(__m128i v) {
_mm_madd_epi16(VerticalPopCount_INT8_V128(v), _mm_set1_epi16(1));
return _mm_add_epi32(_mm_srli_epi32(total, 8),
_mm_and_si128(total, _mm_set1_epi32(0xff)));
#endif // __AVX512VL__ && __AVX512VPOPCNTDQ__
#endif // __AVX512VL__ && __AVX512VPOPCNTDQ__
}

static inline __m128i VerticalPopCount_INT64_V128(__m128i v) {
#if defined(__AVX512VL__) && defined(__AVX512VPOPCNTDQ__)
return _mm_popcnt_epi64(v);
#else
return _mm_sad_epu8(VerticalPopCount_INT8_V128(v), _mm_setzero_si128());
#endif // __AVX512VL__ && __AVX512VPOPCNTDQ__
#endif // __AVX512VL__ && __AVX512VPOPCNTDQ__
}
#endif // __SSSE3__
#endif // __SSSE3__

#if defined(__SSE4_1__)
static inline int16_t HorizontalMax_UINT8_V128(__m128i v) {
Expand All @@ -127,7 +129,7 @@ static inline int16_t HorizontalMax_UINT8_V128(__m128i v) {
v = _mm_max_epu8(v, _mm_srli_epi16(v, 8));
return static_cast<uint8_t>(_mm_cvtsi128_si32(v));
}
#endif // __SSE4_1__
#endif // __SSE4_1__

#if defined(__AVX__)
static inline float HorizontalMax_FP32_V256(__m256 v) {
Expand All @@ -147,7 +149,7 @@ static inline float HorizontalAdd_FP32_V256(__m256 v) {
__m128 x4 = _mm_add_ss(_mm256_castps256_ps128(x2), x3);
return _mm_cvtss_f32(x4);
}
#endif // __AVX__
#endif // __AVX__

#if defined(__AVX2__)
static const __m256i POPCNT_MASK1_INT8_AVX = _mm256_set1_epi8(0x0f);
Expand All @@ -169,7 +171,7 @@ static inline __m256i VerticalPopCount_INT8_V256(__m256i v) {
POPCNT_LOOKUP_AVX,
_mm256_and_si256(_mm256_srli_epi32(v, 4), POPCNT_MASK1_INT8_AVX));
return _mm256_add_epi8(lo, hi);
#endif // __AVX512VL__ && __AVX512BITALG__
#endif // __AVX512VL__ && __AVX512BITALG__
}

static inline __m256i VerticalPopCount_INT16_V256(__m256i v) {
Expand All @@ -179,7 +181,7 @@ static inline __m256i VerticalPopCount_INT16_V256(__m256i v) {
__m256i total = VerticalPopCount_INT8_V256(v);
return _mm256_add_epi16(_mm256_srli_epi16(total, 8),
_mm256_and_si256(total, POPCNT_MASK2_INT16_AVX));
#endif // __AVX512VL__ && __AVX512BITALG__
#endif // __AVX512VL__ && __AVX512BITALG__
}

static inline __m256i VerticalPopCount_INT32_V256(__m256i v) {
Expand All @@ -190,15 +192,15 @@ static inline __m256i VerticalPopCount_INT32_V256(__m256i v) {
_mm256_madd_epi16(VerticalPopCount_INT8_V256(v), POPCNT_MASK1_INT16_AVX);
return _mm256_add_epi32(_mm256_srli_epi32(total, 8),
_mm256_and_si256(total, POPCNT_MASK1_INT32_AVX));
#endif // __AVX512VL__ && __AVX512VPOPCNTDQ__
#endif // __AVX512VL__ && __AVX512VPOPCNTDQ__
}

static inline __m256i VerticalPopCount_INT64_V256(__m256i v) {
#if defined(__AVX512VL__) && defined(__AVX512VPOPCNTDQ__)
return _mm256_popcnt_epi64(v);
#else
return _mm256_sad_epu8(VerticalPopCount_INT8_V256(v), POPCNT_ZERO_AVX);
#endif // __AVX512VL__ && __AVX512VPOPCNTDQ__
#endif // __AVX512VL__ && __AVX512VPOPCNTDQ__
}

static inline int16_t HorizontalMax_UINT8_V256(__m256i v) {
Expand Down Expand Up @@ -226,7 +228,7 @@ static inline int64_t HorizontalAdd_INT64_V256(__m256i v) {
__m128i x4 = _mm_add_epi64(_mm256_extractf128_si256(x2, 0), x3);
return _mm_cvtsi128_si64(x4);
}
#endif // __AVX2__
#endif // __AVX2__

#if defined(__AVX512F__)
static inline float HorizontalMax_FP32_V512(__m512 v) {
Expand All @@ -242,7 +244,7 @@ static inline float HorizontalAdd_FP32_V512(__m512 v) {
_mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(v), 1));
return HorizontalAdd_FP32_V256(_mm256_add_ps(low, high));
}
#endif // __AVX512F__
#endif // __AVX512F__

#if defined(__AVX512FP16__)
static inline float HorizontalMax_FP16_V512(__m512h v) {
Expand All @@ -259,7 +261,7 @@ static inline float HorizontalAdd_FP16_V512(__m512h v) {

return HorizontalAdd_FP32_V512(_mm512_add_ps(low, high));
}
#endif // __AVX512FP16__
#endif // __AVX512FP16__

} // namespace ailego
} // namespace zvec
} // namespace ailego
} // namespace zvec
18 changes: 16 additions & 2 deletions src/ailego/math_batch/distance_batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
#include <zvec/ailego/math_batch/utils.h>
#include "ailego/math/distance_matrix.h"
#include "cosine_distance_batch.h"
#include "euclidean_distance_batch.h"
#include "inner_product_distance_batch.h"


namespace zvec::ailego {


template <
template <typename, size_t, size_t, typename = void> class DistanceType,
typename ValueType, size_t BatchSize, size_t PrefetchStep, typename = void>
Expand All @@ -44,6 +43,21 @@ struct BaseDistance {
out);
}

if constexpr (std::is_same_v<DistanceType<ValueType, 1, 1>,
EuclideanDistanceMatrix<ValueType, 1, 1>>) {
return DistanceBatch::EuclideanDistanceBatch<
ValueType, BatchSize, PrefetchStep>::ComputeBatch(m, q, num, dim,
out);
}

if constexpr (std::is_same_v<
DistanceType<ValueType, 1, 1>,
SquaredEuclideanDistanceMatrix<ValueType, 1, 1>>) {
return DistanceBatch::SquaredEuclideanDistanceBatch<
ValueType, BatchSize, PrefetchStep>::ComputeBatch(m, q, num, dim,
out);
}

_ComputeBatch(m, q, num, dim, out);
}
};
Expand Down
30 changes: 30 additions & 0 deletions src/ailego/math_batch/distance_batch_math.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright 2025-present the zvec project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#if defined(__AVX2__)

inline float sum4(__m128 v) {
v = _mm_add_ps(v, _mm_castsi128_ps(_mm_srli_si128(_mm_castps_si128(v), 8)));
return v[0] + v[1];
}

inline __m128 sum_top_bottom_avx(__m256 v) {
const __m128 high = _mm256_extractf128_ps(v, 1);
const __m128 low = _mm256_castps256_ps128(v);
return _mm_add_ps(high, low);
}

#endif
176 changes: 176 additions & 0 deletions src/ailego/math_batch/euclidean_distance_batch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Copyright 2025-present the zvec project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <vector>
#include <ailego/internal/cpu_features.h>
#include <ailego/utility/math_helper.h>
#include <zvec/ailego/internal/platform.h>
#include <zvec/ailego/math_batch/utils.h>
#include <zvec/ailego/utility/type_helper.h>
#include "euclidean_distance_batch_impl.h"
#include "euclidean_distance_batch_impl_fp16.h"
#include "euclidean_distance_batch_impl_int8.h"

namespace zvec::ailego::DistanceBatch {

// SquaredEuclideanDistanceBatch
template <typename T, size_t BatchSize, size_t PrefetchStep, typename = void>
struct SquaredEuclideanDistanceBatch;

// Function template partial specialization is not allowed,
// therefore the wrapper struct is required.
template <typename T, size_t BatchSize>
struct SquaredEuclideanDistanceBatchImpl {
using ValueType = typename std::remove_cv<T>::type;
static void compute_one_to_many(
const ValueType *query, const ValueType **ptrs,
std::array<const ValueType *, BatchSize> &prefetch_ptrs, size_t dim,
float *sums) {
return compute_one_to_many_squared_euclidean_fallback(
query, ptrs, prefetch_ptrs, dim, sums);
}
};

template <size_t BatchSize>
struct SquaredEuclideanDistanceBatchImpl<float, BatchSize> {
using ValueType = float;
static void compute_one_to_many(
const ValueType *query, const ValueType **ptrs,
std::array<const ValueType *, BatchSize> &prefetch_ptrs, size_t dim,
float *sums) {
#if defined(__AVX512F__)
if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) {
return compute_one_to_many_squared_euclidean_avx512f_fp32<ValueType,
BatchSize>(
query, ptrs, prefetch_ptrs, dim, sums);
}
#endif

#if defined(__AVX2__)
if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) {
return compute_one_to_many_squared_euclidean_avx2_fp32<ValueType,
BatchSize>(
query, ptrs, prefetch_ptrs, dim, sums);
}
#endif
return compute_one_to_many_squared_euclidean_fallback(
query, ptrs, prefetch_ptrs, dim, sums);
}
};

template <size_t BatchSize>
struct SquaredEuclideanDistanceBatchImpl<int8_t, BatchSize> {
using ValueType = int8_t;
static void compute_one_to_many(
const int8_t *query, const int8_t **ptrs,
std::array<const int8_t *, BatchSize> &prefetch_ptrs, size_t dim,
float *sums) {
#if defined(__AVX2__)
if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) {
return compute_one_to_many_squared_euclidean_avx2_int8<ValueType,
BatchSize>(
query, ptrs, prefetch_ptrs, dim, sums);
}
#endif
return compute_one_to_many_squared_euclidean_fallback(
query, ptrs, prefetch_ptrs, dim, sums);
}
};

template <size_t BatchSize>
struct SquaredEuclideanDistanceBatchImpl<ailego::Float16, BatchSize> {
using ValueType = ailego::Float16;
static void compute_one_to_many(
const ailego::Float16 *query, const ailego::Float16 **ptrs,
std::array<const ailego::Float16 *, BatchSize> &prefetch_ptrs, size_t dim,
float *sums) {
#if defined(__AVX512FP16__)
if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) {
return compute_one_to_many_squared_euclidean_avx512fp16_fp16<ValueType,
BatchSize>(
query, ptrs, prefetch_ptrs, dim, sums);
}
#endif
#if defined(__AVX512F__)
if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) {
return compute_one_to_many_squared_euclidean_avx512f_fp16<ValueType,
BatchSize>(
query, ptrs, prefetch_ptrs, dim, sums);
}
#endif
#if defined(__AVX2__)
if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) {
return compute_one_to_many_squared_euclidean_avx2_fp16<ValueType,
BatchSize>(
query, ptrs, prefetch_ptrs, dim, sums);
}
#endif
return compute_one_to_many_squared_euclidean_fallback(
query, ptrs, prefetch_ptrs, dim, sums);
}
};

template <typename T, size_t BatchSize, size_t PrefetchStep, typename>
struct SquaredEuclideanDistanceBatch {
using ValueType = typename std::remove_cv<T>::type;

static inline void ComputeBatch(const ValueType **vecs,
const ValueType *query, size_t num_vecs,
size_t dim, float *results) {
size_t i = 0;
for (; i + BatchSize <= num_vecs; i += BatchSize) {
std::array<const ValueType *, BatchSize> prefetch_ptrs;
for (size_t j = 0; j < BatchSize; ++j) {
if (i + j + BatchSize * PrefetchStep < num_vecs) {
prefetch_ptrs[j] = vecs[i + j + BatchSize * PrefetchStep];
} else {
prefetch_ptrs[j] = nullptr;
}
}
SquaredEuclideanDistanceBatchImpl<
ValueType, BatchSize>::compute_one_to_many(query, &vecs[i],
prefetch_ptrs, dim,
&results[i]);
}
for (; i < num_vecs; ++i) { // TODO: unroll by 1, 2, 4, 8, etc.
std::array<const ValueType *, 1> prefetch_ptrs{nullptr};
SquaredEuclideanDistanceBatchImpl<ValueType, 1>::compute_one_to_many(
query, &vecs[i], prefetch_ptrs, dim, &results[i]);
}
}
};

// EuclideanDistanceBatch
template <typename T, size_t BatchSize, size_t PrefetchStep, typename = void>
struct EuclideanDistanceBatch;

template <typename T, size_t BatchSize, size_t PrefetchStep, typename>
struct EuclideanDistanceBatch {
using ValueType = typename std::remove_cv<T>::type;

static inline void ComputeBatch(const ValueType **vecs,
const ValueType *query, size_t num_vecs,
size_t dim, float *results) {
SquaredEuclideanDistanceBatch<T, BatchSize, PrefetchStep>::ComputeBatch(
vecs, query, num_vecs, dim, results);

for (size_t i = 0; i < num_vecs; ++i) {
results[i] = std::sqrt(results[i]);
}
}
};

} // namespace zvec::ailego::DistanceBatch
Loading
Loading