diff --git a/src/common/comp/brute_force.cc b/src/common/comp/brute_force.cc index ed8b7015e..4cfa9c963 100644 --- a/src/common/comp/brute_force.cc +++ b/src/common/comp/brute_force.cc @@ -119,6 +119,17 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset auto labels = std::make_unique(nq * topk); auto distances = std::make_unique(nq * topk); + const std::string metric_str = cfg.metric_type.value(); + if ((base_dataset->GetDim() != query_dataset->GetDim()) && + (IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) || + IsMetricType(metric_str, metric::L2))) { + const std::string msg_e = + fmt::format("dimensionalities of the base dataset ({}) and the query ({}) do not match", + base_dataset->GetDim(), query_dataset->GetDim()); + LOG_KNOWHERE_ERROR_ << msg_e; + return expected::Err(Status::invalid_args, msg_e); + } + auto search_status = SearchWithBuf(base_dataset, query_dataset, labels.get(), distances.get(), config, bitset_); if (search_status != Status::success) { @@ -170,9 +181,17 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ // LCOV_EXCL_STOP #endif - std::string metric_str = cfg.metric_type.value(); + const std::string metric_str = cfg.metric_type.value(); auto topk = cfg.k.value(); + if ((base_dataset->GetDim() != query_dataset->GetDim()) && + (IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) || + IsMetricType(metric_str, metric::L2))) { + LOG_KNOWHERE_ERROR_ << "dimensionalities of the base dataset (" << base_dataset->GetDim() << ") and the query (" + << query_dataset->GetDim() << ") do not match"; + return Status::invalid_args; + } + if (is_emb_list) { if (!IsMetricType(metric_str, metric::MAX_SIM) && !IsMetricType(metric_str, metric::ORDERED_MAX_SIM) && !IsMetricType(metric_str, metric::ORDERED_MAX_SIM_WITH_WINDOW) && !IsMetricType(metric_str, metric::DTW)) { @@ -441,7 +460,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da // LCOV_EXCL_STOP #endif - std::string metric_str = cfg.metric_type.value(); + const std::string metric_str = cfg.metric_type.value(); const bool is_bm25 = IsMetricType(metric_str, metric::BM25); faiss::MetricType faiss_metric_type; @@ -452,6 +471,17 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da return expected::Err(result.error(), result.what()); } faiss_metric_type = result.value(); + + // additionally, perform a check for dimensionalities + if ((base_dataset->GetDim() != query_dataset->GetDim()) && + (IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) || + IsMetricType(metric_str, metric::L2))) { + const std::string msg_e = + fmt::format("dimensionalities of the base dataset ({}) and the query ({}) do not match", + base_dataset->GetDim(), query_dataset->GetDim()); + LOG_KNOWHERE_ERROR_ << msg_e; + return expected::Err(Status::invalid_args, msg_e); + } } else { auto computer_or = GetDocValueComputer(cfg); if (!computer_or.has_value()) { diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 0adb55368..111a2a506 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -1212,6 +1212,15 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { const auto hnsw_cfg = static_cast(*cfg); const auto k = hnsw_cfg.k.value(); + const std::string metric_str = hnsw_cfg.metric_type.value(); + if ((dim != Dim()) && (IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) || + IsMetricType(metric_str, metric::L2))) { + const std::string msg_e = + fmt::format("dimensionalities of the base dataset ({}) and the query ({}) do not match", Dim(), dim); + LOG_KNOWHERE_ERROR_ << msg_e; + return expected::Err(Status::invalid_args, msg_e); + } + BitsetView bitset(bitset_); if (!internal_offset_to_most_external_id.empty()) { bitset.set_out_ids(internal_offset_to_most_external_id.data(), internal_offset_to_most_external_id.size()); @@ -1473,6 +1482,16 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { const auto* data = dataset->GetTensor(); const auto hnsw_cfg = static_cast(*cfg); + + const std::string metric_str = hnsw_cfg.metric_type.value(); + if ((dim != Dim()) && (IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) || + IsMetricType(metric_str, metric::L2))) { + const std::string msg_e = + fmt::format("dimensionalities of the base dataset ({}) and the query ({}) do not match", Dim(), dim); + LOG_KNOWHERE_ERROR_ << msg_e; + return expected::Err(Status::invalid_args, msg_e); + } + BitsetView bitset(bitset_); if (!internal_offset_to_most_external_id.empty()) { bitset.set_out_ids(internal_offset_to_most_external_id.data(), internal_offset_to_most_external_id.size()); diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index a65f4b05d..7724858c6 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -748,6 +748,15 @@ IvfIndexNode::Search(const DataSetPtr dataset, std::unique_ const IvfConfig& ivf_cfg = static_cast(*cfg); bool is_cosine = IsMetricType(ivf_cfg.metric_type.value(), knowhere::metric::COSINE); + const std::string metric_str = ivf_cfg.metric_type.value(); + if ((dim != Dim()) && (IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) || + IsMetricType(metric_str, metric::L2))) { + const std::string msg_e = + fmt::format("dimensionalities of the base dataset ({}) and the query ({}) do not match", Dim(), dim); + LOG_KNOWHERE_ERROR_ << msg_e; + return expected::Err(Status::invalid_args, msg_e); + } + auto k = ivf_cfg.k.value(); auto nprobe = ivf_cfg.nprobe.value(); @@ -928,6 +937,15 @@ IvfIndexNode::RangeSearch(const DataSetPtr dataset, std::un const IvfConfig& ivf_cfg = static_cast(*cfg); bool is_cosine = IsMetricType(ivf_cfg.metric_type.value(), knowhere::metric::COSINE); + const std::string metric_str = ivf_cfg.metric_type.value(); + if ((dim != Dim()) && (IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) || + IsMetricType(metric_str, metric::L2))) { + const std::string msg_e = + fmt::format("dimensionalities of the base dataset ({}) and the query ({}) do not match", Dim(), dim); + LOG_KNOWHERE_ERROR_ << msg_e; + return expected::Err(Status::invalid_args, msg_e); + } + float radius = ivf_cfg.radius.value(); float range_filter = ivf_cfg.range_filter.value(); bool is_ip = (index_->metric_type == faiss::METRIC_INNER_PRODUCT);