Skip to content

Commit 40823cc

Browse files
add checks for a dimensionality of a query
Signed-off-by: Alexandr Guzhva <[email protected]>
1 parent e05e2ed commit 40823cc

File tree

3 files changed

+69
-2
lines changed

3 files changed

+69
-2
lines changed

src/common/comp/brute_force.cc

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,17 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
119119
auto labels = std::make_unique<int64_t[]>(nq * topk);
120120
auto distances = std::make_unique<float[]>(nq * topk);
121121

122+
const std::string metric_str = cfg.metric_type.value();
123+
if ((base_dataset->GetDim() != query_dataset->GetDim()) &&
124+
(IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) ||
125+
IsMetricType(metric_str, metric::L2))) {
126+
const std::string msg_e =
127+
fmt::format("dimensionalities of the base dataset ({}) and the query ({}) do not match",
128+
base_dataset->GetDim(), query_dataset->GetDim());
129+
LOG_KNOWHERE_ERROR_ << msg_e;
130+
return expected<DataSetPtr>::Err(Status::invalid_args, msg_e);
131+
}
132+
122133
auto search_status =
123134
SearchWithBuf<DataType>(base_dataset, query_dataset, labels.get(), distances.get(), config, bitset_);
124135
if (search_status != Status::success) {
@@ -170,9 +181,17 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
170181
// LCOV_EXCL_STOP
171182
#endif
172183

173-
std::string metric_str = cfg.metric_type.value();
184+
const std::string metric_str = cfg.metric_type.value();
174185
auto topk = cfg.k.value();
175186

187+
if ((base_dataset->GetDim() != query_dataset->GetDim()) &&
188+
(IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) ||
189+
IsMetricType(metric_str, metric::L2))) {
190+
LOG_KNOWHERE_ERROR_ << "dimensionalities of the base dataset (" << base_dataset->GetDim() << ") and the query ("
191+
<< query_dataset->GetDim() << ") do not match";
192+
return Status::invalid_args;
193+
}
194+
176195
if (is_emb_list) {
177196
if (!IsMetricType(metric_str, metric::MAX_SIM) && !IsMetricType(metric_str, metric::ORDERED_MAX_SIM) &&
178197
!IsMetricType(metric_str, metric::ORDERED_MAX_SIM_WITH_WINDOW) && !IsMetricType(metric_str, metric::DTW)) {
@@ -441,7 +460,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
441460
// LCOV_EXCL_STOP
442461
#endif
443462

444-
std::string metric_str = cfg.metric_type.value();
463+
const std::string metric_str = cfg.metric_type.value();
445464
const bool is_bm25 = IsMetricType(metric_str, metric::BM25);
446465

447466
faiss::MetricType faiss_metric_type;
@@ -452,6 +471,17 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
452471
return expected<DataSetPtr>::Err(result.error(), result.what());
453472
}
454473
faiss_metric_type = result.value();
474+
475+
// additionally, perform a check for dimensionalities
476+
if ((base_dataset->GetDim() != query_dataset->GetDim()) &&
477+
(IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) ||
478+
IsMetricType(metric_str, metric::L2))) {
479+
const std::string msg_e =
480+
fmt::format("dimensionalities of the base dataset ({}) and the query ({}) do not match",
481+
base_dataset->GetDim(), query_dataset->GetDim());
482+
LOG_KNOWHERE_ERROR_ << msg_e;
483+
return expected<DataSetPtr>::Err(Status::invalid_args, msg_e);
484+
}
455485
} else {
456486
auto computer_or = GetDocValueComputer<float>(cfg);
457487
if (!computer_or.has_value()) {

src/index/hnsw/faiss_hnsw.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,6 +1212,15 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
12121212
const auto hnsw_cfg = static_cast<const FaissHnswConfig&>(*cfg);
12131213
const auto k = hnsw_cfg.k.value();
12141214

1215+
const std::string metric_str = hnsw_cfg.metric_type.value();
1216+
if ((dim != Dim()) && (IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) ||
1217+
IsMetricType(metric_str, metric::L2))) {
1218+
const std::string msg_e =
1219+
fmt::format("dimensionalities of the base dataset ({}) and the query ({}) do not match", Dim(), dim);
1220+
LOG_KNOWHERE_ERROR_ << msg_e;
1221+
return expected<DataSetPtr>::Err(Status::invalid_args, msg_e);
1222+
}
1223+
12151224
BitsetView bitset(bitset_);
12161225
if (!internal_offset_to_most_external_id.empty()) {
12171226
bitset.set_out_ids(internal_offset_to_most_external_id.data(), internal_offset_to_most_external_id.size());
@@ -1473,6 +1482,16 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
14731482
const auto* data = dataset->GetTensor();
14741483

14751484
const auto hnsw_cfg = static_cast<const FaissHnswConfig&>(*cfg);
1485+
1486+
const std::string metric_str = hnsw_cfg.metric_type.value();
1487+
if ((dim != Dim()) && (IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) ||
1488+
IsMetricType(metric_str, metric::L2))) {
1489+
const std::string msg_e =
1490+
fmt::format("dimensionalities of the base dataset ({}) and the query ({}) do not match", Dim(), dim);
1491+
LOG_KNOWHERE_ERROR_ << msg_e;
1492+
return expected<DataSetPtr>::Err(Status::invalid_args, msg_e);
1493+
}
1494+
14761495
BitsetView bitset(bitset_);
14771496
if (!internal_offset_to_most_external_id.empty()) {
14781497
bitset.set_out_ids(internal_offset_to_most_external_id.data(), internal_offset_to_most_external_id.size());

src/index/ivf/ivf.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,15 @@ IvfIndexNode<DataType, IndexType>::Search(const DataSetPtr dataset, std::unique_
748748
const IvfConfig& ivf_cfg = static_cast<const IvfConfig&>(*cfg);
749749
bool is_cosine = IsMetricType(ivf_cfg.metric_type.value(), knowhere::metric::COSINE);
750750

751+
const std::string metric_str = ivf_cfg.metric_type.value();
752+
if ((dim != Dim()) && (IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) ||
753+
IsMetricType(metric_str, metric::L2))) {
754+
const std::string msg_e =
755+
fmt::format("dimensionalities of the base dataset ({}) and the query ({}) do not match", Dim(), dim);
756+
LOG_KNOWHERE_ERROR_ << msg_e;
757+
return expected<DataSetPtr>::Err(Status::invalid_args, msg_e);
758+
}
759+
751760
auto k = ivf_cfg.k.value();
752761
auto nprobe = ivf_cfg.nprobe.value();
753762

@@ -928,6 +937,15 @@ IvfIndexNode<DataType, IndexType>::RangeSearch(const DataSetPtr dataset, std::un
928937
const IvfConfig& ivf_cfg = static_cast<const IvfConfig&>(*cfg);
929938
bool is_cosine = IsMetricType(ivf_cfg.metric_type.value(), knowhere::metric::COSINE);
930939

940+
const std::string metric_str = ivf_cfg.metric_type.value();
941+
if ((dim != Dim()) && (IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) ||
942+
IsMetricType(metric_str, metric::L2))) {
943+
const std::string msg_e =
944+
fmt::format("dimensionalities of the base dataset ({}) and the query ({}) do not match", Dim(), dim);
945+
LOG_KNOWHERE_ERROR_ << msg_e;
946+
return expected<DataSetPtr>::Err(Status::invalid_args, msg_e);
947+
}
948+
931949
float radius = ivf_cfg.radius.value();
932950
float range_filter = ivf_cfg.range_filter.value();
933951
bool is_ip = (index_->metric_type == faiss::METRIC_INNER_PRODUCT);

0 commit comments

Comments
 (0)