Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,17 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
auto labels = std::make_unique<int64_t[]>(nq * topk);
auto distances = std::make_unique<float[]>(nq * topk);

const std::string metric_str = cfg.metric_type.value();
if ((base_dataset->GetDim() != query_dataset->GetDim()) &&
(IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IP metrics could also work with sparse vectors.
Can we skip all the query dimension checks by identifying whether the base_dataset and query_dataset are sparse?

IsMetricType(metric_str, metric::L2))) {
const std::string msg_e =
fmt::format("dimensionalities of the base dataset ({}) and the query ({}) do not match",
base_dataset->GetDim(), query_dataset->GetDim());
LOG_KNOWHERE_ERROR_ << msg_e;
return expected<DataSetPtr>::Err(Status::invalid_args, msg_e);
}

auto search_status =
SearchWithBuf<DataType>(base_dataset, query_dataset, labels.get(), distances.get(), config, bitset_);
if (search_status != Status::success) {
Expand Down Expand Up @@ -170,9 +181,17 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
// LCOV_EXCL_STOP
#endif

std::string metric_str = cfg.metric_type.value();
const std::string metric_str = cfg.metric_type.value();
auto topk = cfg.k.value();

if ((base_dataset->GetDim() != query_dataset->GetDim()) &&
(IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) ||
IsMetricType(metric_str, metric::L2))) {
LOG_KNOWHERE_ERROR_ << "dimensionalities of the base dataset (" << base_dataset->GetDim() << ") and the query ("
<< query_dataset->GetDim() << ") do not match";
return Status::invalid_args;
}

if (is_emb_list) {
if (!IsMetricType(metric_str, metric::MAX_SIM) && !IsMetricType(metric_str, metric::ORDERED_MAX_SIM) &&
!IsMetricType(metric_str, metric::ORDERED_MAX_SIM_WITH_WINDOW) && !IsMetricType(metric_str, metric::DTW)) {
Expand Down Expand Up @@ -441,7 +460,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
// LCOV_EXCL_STOP
#endif

std::string metric_str = cfg.metric_type.value();
const std::string metric_str = cfg.metric_type.value();
const bool is_bm25 = IsMetricType(metric_str, metric::BM25);

faiss::MetricType faiss_metric_type;
Expand All @@ -452,6 +471,17 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
return expected<DataSetPtr>::Err(result.error(), result.what());
}
faiss_metric_type = result.value();

// additionally, perform a check for dimensionalities
if ((base_dataset->GetDim() != query_dataset->GetDim()) &&
(IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) ||
IsMetricType(metric_str, metric::L2))) {
const std::string msg_e =
fmt::format("dimensionalities of the base dataset ({}) and the query ({}) do not match",
base_dataset->GetDim(), query_dataset->GetDim());
LOG_KNOWHERE_ERROR_ << msg_e;
return expected<DataSetPtr>::Err(Status::invalid_args, msg_e);
}
} else {
auto computer_or = GetDocValueComputer<float>(cfg);
if (!computer_or.has_value()) {
Expand Down
19 changes: 19 additions & 0 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,15 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
const auto hnsw_cfg = static_cast<const FaissHnswConfig&>(*cfg);
const auto k = hnsw_cfg.k.value();

const std::string metric_str = hnsw_cfg.metric_type.value();
if ((dim != Dim()) && (IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) ||
IsMetricType(metric_str, metric::L2))) {
const std::string msg_e =
fmt::format("dimensionalities of the base dataset ({}) and the query ({}) do not match", Dim(), dim);
LOG_KNOWHERE_ERROR_ << msg_e;
return expected<DataSetPtr>::Err(Status::invalid_args, msg_e);
}

BitsetView bitset(bitset_);
if (!internal_offset_to_most_external_id.empty()) {
bitset.set_out_ids(internal_offset_to_most_external_id.data(), internal_offset_to_most_external_id.size());
Expand Down Expand Up @@ -1473,6 +1482,16 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
const auto* data = dataset->GetTensor();

const auto hnsw_cfg = static_cast<const FaissHnswConfig&>(*cfg);

const std::string metric_str = hnsw_cfg.metric_type.value();
if ((dim != Dim()) && (IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) ||
IsMetricType(metric_str, metric::L2))) {
const std::string msg_e =
fmt::format("dimensionalities of the base dataset ({}) and the query ({}) do not match", Dim(), dim);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put these repeated codes into a util function

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, this should have already been intercepted in Milvus. What issue are you encountering here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there was a problem in KIOXIA unit test that triggered a problem

LOG_KNOWHERE_ERROR_ << msg_e;
return expected<DataSetPtr>::Err(Status::invalid_args, msg_e);
}

BitsetView bitset(bitset_);
if (!internal_offset_to_most_external_id.empty()) {
bitset.set_out_ids(internal_offset_to_most_external_id.data(), internal_offset_to_most_external_id.size());
Expand Down
18 changes: 18 additions & 0 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,15 @@ IvfIndexNode<DataType, IndexType>::Search(const DataSetPtr dataset, std::unique_
const IvfConfig& ivf_cfg = static_cast<const IvfConfig&>(*cfg);
bool is_cosine = IsMetricType(ivf_cfg.metric_type.value(), knowhere::metric::COSINE);

const std::string metric_str = ivf_cfg.metric_type.value();
if ((dim != Dim()) && (IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) ||
IsMetricType(metric_str, metric::L2))) {
const std::string msg_e =
fmt::format("dimensionalities of the base dataset ({}) and the query ({}) do not match", Dim(), dim);
LOG_KNOWHERE_ERROR_ << msg_e;
return expected<DataSetPtr>::Err(Status::invalid_args, msg_e);
}

auto k = ivf_cfg.k.value();
auto nprobe = ivf_cfg.nprobe.value();

Expand Down Expand Up @@ -928,6 +937,15 @@ IvfIndexNode<DataType, IndexType>::RangeSearch(const DataSetPtr dataset, std::un
const IvfConfig& ivf_cfg = static_cast<const IvfConfig&>(*cfg);
bool is_cosine = IsMetricType(ivf_cfg.metric_type.value(), knowhere::metric::COSINE);

const std::string metric_str = ivf_cfg.metric_type.value();
if ((dim != Dim()) && (IsMetricType(metric_str, metric::COSINE) || IsMetricType(metric_str, metric::IP) ||
IsMetricType(metric_str, metric::L2))) {
const std::string msg_e =
fmt::format("dimensionalities of the base dataset ({}) and the query ({}) do not match", Dim(), dim);
LOG_KNOWHERE_ERROR_ << msg_e;
return expected<DataSetPtr>::Err(Status::invalid_args, msg_e);
}

float radius = ivf_cfg.radius.value();
float range_filter = ivf_cfg.range_filter.value();
bool is_ip = (index_->metric_type == faiss::METRIC_INNER_PRODUCT);
Expand Down
Loading