diff --git a/.github/mergify.yml b/.github/mergify.yml
index c5a415a2a..cf7fe8397 100644
--- a/.github/mergify.yml
+++ b/.github/mergify.yml
@@ -35,7 +35,7 @@ pull_request_rules:
- base=1.x
- and:
- -body~=\#[0-9]{1,6}(\s+|$)
- - -body~=https://github.com/milvus-io/knowhere/issues/[0-9]{1,6}(\s+|$)
+ - -body~=https://github.com/zilliztech/Knowhere/issues/[0-9]{1,6}(\s+|$)
- -label=kind/improvement
- -title~=\[automated\]
actions:
@@ -55,7 +55,7 @@ pull_request_rules:
- or:
- or:
- body~=\#[0-9]{1,6}(\s+|$)
- - body~=https://github.com/milvus-io/knowhere/issues/[0-9]{1,6}(\s+|$)
+ - body~=https://github.com/zilliztech/Knowhere/issues/[0-9]{1,6}(\s+|$)
- label=kind/improvement
actions:
label:
diff --git a/README.md b/README.md
index 3e605b696..b591d174a 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
-This document will help you to build the Knowhere repository from source code and to run unit tests. Please [file an issue](https://github.com/milvus-io/knowhere/issues/new) if there's a problem.
+This document will help you to build the Knowhere repository from source code and to run unit tests. Please [file an issue](https://github.com/zilliztech/knowhere/issues/new) if there's a problem.
## Introduction
diff --git a/src/common/comp/brute_force.cc b/src/common/comp/brute_force.cc
index a1f283a89..7810af423 100644
--- a/src/common/comp/brute_force.cc
+++ b/src/common/comp/brute_force.cc
@@ -33,11 +33,6 @@ expected
BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
const BitsetView& bitset) {
std::string metric_str = config[meta::METRIC_TYPE].get();
- bool is_cosine = IsMetricType(metric_str, metric::COSINE);
- if (is_cosine) {
- Normalize(*base_dataset);
- }
-
auto xb = base_dataset->GetTensor();
auto nb = base_dataset->GetRows();
auto dim = base_dataset->GetDim();
@@ -71,11 +66,13 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (float*)xq + dim * index;
- if (is_cosine) {
+ faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
+ if (IsMetricType(metric_str, metric::COSINE)) {
NormalizeVec(cur_query, dim);
+ faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
+ } else {
+ faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
}
- faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
- faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
break;
}
case faiss::METRIC_Jaccard: {
@@ -123,11 +120,6 @@ Status
BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis,
const Json& config, const BitsetView& bitset) {
std::string metric_str = config[meta::METRIC_TYPE].get();
- bool is_cosine = IsMetricType(metric_str, metric::COSINE);
- if (is_cosine) {
- Normalize(*base_dataset);
- }
-
auto xb = base_dataset->GetTensor();
auto nb = base_dataset->GetRows();
auto dim = base_dataset->GetDim();
@@ -167,11 +159,13 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (float*)xq + dim * index;
- if (is_cosine) {
+ faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
+ if (IsMetricType(metric_str, metric::COSINE)) {
NormalizeVec(cur_query, dim);
+ faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
+ } else {
+ faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
}
- faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
- faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
break;
}
case faiss::METRIC_Jaccard: {
@@ -262,10 +256,13 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
case faiss::METRIC_INNER_PRODUCT: {
is_ip = true;
auto cur_query = (float*)xq + dim * index;
- if (is_cosine) {
+ if (IsMetricType(metric_str, metric::COSINE)) {
NormalizeVec(cur_query, dim);
+ faiss::range_search_cosine(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset);
+ } else {
+ faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res,
+ bitset);
}
- faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset);
break;
}
case faiss::METRIC_Jaccard: {
diff --git a/thirdparty/faiss/faiss/utils/distances.cpp b/thirdparty/faiss/faiss/utils/distances.cpp
index b8c5e2dbb..793419227 100644
--- a/thirdparty/faiss/faiss/utils/distances.cpp
+++ b/thirdparty/faiss/faiss/utils/distances.cpp
@@ -14,6 +14,7 @@
#include
#include
#include
+#include "simd/hook.h"
#include
@@ -284,6 +285,44 @@ void exhaustive_L2sqr_seq(
}
}
+namespace {
+float fvec_cosine(const float* x, const float* y, size_t d) {
+ return fvec_inner_product(x, y, d) / sqrtf(fvec_norm_L2sqr(y, d));
+}
+} // namespace
+
+template
+void exhaustive_cosine_seq(
+ const float* x,
+ const float* y,
+ size_t d,
+ size_t nx,
+ size_t ny,
+ ResultHandler& res,
+ const BitsetView bitset) {
+ using SingleResultHandler = typename ResultHandler::SingleResultHandler;
+ int nt = std::min(int(nx), omp_get_max_threads());
+
+#pragma omp parallel num_threads(nt)
+ {
+ SingleResultHandler resi(res);
+#pragma omp for
+ for (int64_t i = 0; i < nx; i++) {
+ const float* x_i = x + i * d;
+ const float* y_j = y;
+ resi.begin(i);
+ for (size_t j = 0; j < ny; j++) {
+ if (bitset.empty() || !bitset.test(j)) {
+ float disij = fvec_cosine(x_i, y_j, d);
+ resi.add_result(disij, j);
+ }
+ y_j += d;
+ }
+ resi.end();
+ }
+ }
+}
+
/** Find the nearest neighbors for nx queries in a set of ny vectors */
template
void exhaustive_inner_product_blas(
@@ -426,6 +465,76 @@ void exhaustive_L2sqr_blas(
}
}
+template
+void exhaustive_cosine_blas(
+ const float* x,
+ const float* y,
+ size_t d,
+ size_t nx,
+ size_t ny,
+ ResultHandler& res,
+ const BitsetView bitset = nullptr) {
+ // BLAS does not like empty matrices
+ if (nx == 0 || ny == 0)
+ return;
+
+ /* block sizes */
+ const size_t bs_x = distance_compute_blas_query_bs;
+ const size_t bs_y = distance_compute_blas_database_bs;
+ // const size_t bs_x = 16, bs_y = 16;
+ std::unique_ptr ip_block(new float[bs_x * bs_y]);
+ std::unique_ptr y_norms(new float[nx]);
+ std::unique_ptr del2;
+
+ fvec_norms_L2(y_norms.get(), x, d, nx);
+
+ for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
+ size_t i1 = i0 + bs_x;
+ if (i1 > nx)
+ i1 = nx;
+
+ res.begin_multiple(i0, i1);
+
+ for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
+ size_t j1 = j0 + bs_y;
+ if (j1 > ny)
+ j1 = ny;
+ /* compute the actual dot products */
+ {
+ float one = 1, zero = 0;
+ FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
+ sgemm_("Transpose",
+ "Not transpose",
+ &nyi,
+ &nxi,
+ &di,
+ &one,
+ y + j0 * d,
+ &di,
+ x + i0 * d,
+ &di,
+ &zero,
+ ip_block.get(),
+ &nyi);
+ }
+#pragma omp parallel for
+ for (int64_t i = i0; i < i1; i++) {
+ float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
+
+ for (size_t j = j0; j < j1; j++) {
+ float ip = *ip_line;
+ float dis = ip / y_norms[j];
+ *ip_line = dis;
+ ip_line++;
+ }
+ }
+ res.add_results(j0, j1, ip_block.get(), bitset);
+ }
+ res.end_multiple();
+ InterruptCallback::check();
+ }
+}
+
template
static void knn_jaccard_blas(
const float* x,
@@ -577,6 +686,34 @@ void knn_L2sqr(
}
}
+void knn_cosine(
+ const float* x,
+ const float* y,
+ size_t d,
+ size_t nx,
+ size_t ny,
+ float_minheap_array_t* ha,
+ const BitsetView bitset) {
+ if (ha->k < distance_compute_min_k_reservoir) {
+ HeapResultHandler> res(
+ ha->nh, ha->val, ha->ids, ha->k);
+ if (nx < distance_compute_blas_threshold) {
+ exhaustive_L2sqr_IP_seq(x, y, d, nx, ny, res, fvec_cosine, bitset);
+ } else {
+ exhaustive_cosine_blas(x, y, d, nx, ny, res, bitset);
+ }
+ } else {
+ ReservoirResultHandler> res(
+ ha->nh, ha->val, ha->ids, ha->k);
+ if (nx < distance_compute_blas_threshold) {
+ exhaustive_L2sqr_IP_seq(
+ x, y, d, nx, ny, res, fvec_inner_product, bitset);
+ } else {
+ exhaustive_cosine_blas(x, y, d, nx, ny, res, bitset);
+ }
+ }
+}
+
struct NopDistanceCorrection {
float operator()(float dis, size_t /*qno*/, size_t /*bno*/) const {
return dis;
@@ -640,6 +777,23 @@ void range_search_inner_product(
}
}
+void range_search_cosine(
+ const float* x,
+ const float* y,
+ size_t d,
+ size_t nx,
+ size_t ny,
+ float radius,
+ RangeSearchResult* res,
+ const BitsetView bitset) {
+ RangeSearchResultHandler> resh(res, radius);
+ if (nx < distance_compute_blas_threshold) {
+ exhaustive_cosine_seq(x, y, d, nx, ny, resh, bitset);
+ } else {
+ exhaustive_cosine_blas(x, y, d, nx, ny, resh, bitset);
+ }
+}
+
/***************************************************************************
* compute a subset of distances
***************************************************************************/
diff --git a/thirdparty/faiss/faiss/utils/distances.h b/thirdparty/faiss/faiss/utils/distances.h
index ebc51f7f2..2d015a3ef 100644
--- a/thirdparty/faiss/faiss/utils/distances.h
+++ b/thirdparty/faiss/faiss/utils/distances.h
@@ -199,6 +199,15 @@ void knn_L2sqr(
const float* y_norm2 = nullptr,
const BitsetView bitset = nullptr);
+void knn_cosine(
+ const float* x,
+ const float* y,
+ size_t d,
+ size_t nx,
+ size_t ny,
+ float_minheap_array_t* ha,
+ const BitsetView bitset);
+
void knn_jaccard(
const float* x,
const float* y,
@@ -265,6 +274,16 @@ void range_search_inner_product(
RangeSearchResult* result,
const BitsetView bitset = nullptr);
+void range_search_cosine(
+ const float* x,
+ const float* y,
+ size_t d,
+ size_t nx,
+ size_t ny,
+ float radius,
+ RangeSearchResult* result,
+ const BitsetView bitset = nullptr);
+
/***************************************************************************
* PQ tables computations
***************************************************************************/