@@ -119,6 +119,17 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
119
119
auto labels = std::make_unique<int64_t []>(nq * topk);
120
120
auto distances = std::make_unique<float []>(nq * topk);
121
121
122
+ const std::string metric_str = cfg.metric_type .value ();
123
+ if ((base_dataset->GetDim () != query_dataset->GetDim ()) &&
124
+ (IsMetricType (metric_str, metric::COSINE) || IsMetricType (metric_str, metric::IP) ||
125
+ IsMetricType (metric_str, metric::L2))) {
126
+ const std::string msg_e =
127
+ fmt::format (" dimensionalities of the base dataset ({}) and the query ({}) do not match" ,
128
+ base_dataset->GetDim (), query_dataset->GetDim ());
129
+ LOG_KNOWHERE_ERROR_ << msg_e;
130
+ return expected<DataSetPtr>::Err (Status::invalid_args, msg_e);
131
+ }
132
+
122
133
auto search_status =
123
134
SearchWithBuf<DataType>(base_dataset, query_dataset, labels.get (), distances.get (), config, bitset_);
124
135
if (search_status != Status::success) {
@@ -170,9 +181,17 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
170
181
// LCOV_EXCL_STOP
171
182
#endif
172
183
173
- std::string metric_str = cfg.metric_type .value ();
184
+ const std::string metric_str = cfg.metric_type .value ();
174
185
auto topk = cfg.k .value ();
175
186
187
+ if ((base_dataset->GetDim () != query_dataset->GetDim ()) &&
188
+ (IsMetricType (metric_str, metric::COSINE) || IsMetricType (metric_str, metric::IP) ||
189
+ IsMetricType (metric_str, metric::L2))) {
190
+ LOG_KNOWHERE_ERROR_ << " dimensionalities of the base dataset (" << base_dataset->GetDim () << " ) and the query ("
191
+ << query_dataset->GetDim () << " ) do not match" ;
192
+ return Status::invalid_args;
193
+ }
194
+
176
195
if (is_emb_list) {
177
196
if (!IsMetricType (metric_str, metric::MAX_SIM) && !IsMetricType (metric_str, metric::ORDERED_MAX_SIM) &&
178
197
!IsMetricType (metric_str, metric::ORDERED_MAX_SIM_WITH_WINDOW) && !IsMetricType (metric_str, metric::DTW)) {
@@ -441,7 +460,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
441
460
// LCOV_EXCL_STOP
442
461
#endif
443
462
444
- std::string metric_str = cfg.metric_type .value ();
463
+ const std::string metric_str = cfg.metric_type .value ();
445
464
const bool is_bm25 = IsMetricType (metric_str, metric::BM25);
446
465
447
466
faiss::MetricType faiss_metric_type;
@@ -452,6 +471,17 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
452
471
return expected<DataSetPtr>::Err (result.error (), result.what ());
453
472
}
454
473
faiss_metric_type = result.value ();
474
+
475
+ // additionally, perform a check for dimensionalities
476
+ if ((base_dataset->GetDim () != query_dataset->GetDim ()) &&
477
+ (IsMetricType (metric_str, metric::COSINE) || IsMetricType (metric_str, metric::IP) ||
478
+ IsMetricType (metric_str, metric::L2))) {
479
+ const std::string msg_e =
480
+ fmt::format (" dimensionalities of the base dataset ({}) and the query ({}) do not match" ,
481
+ base_dataset->GetDim (), query_dataset->GetDim ());
482
+ LOG_KNOWHERE_ERROR_ << msg_e;
483
+ return expected<DataSetPtr>::Err (Status::invalid_args, msg_e);
484
+ }
455
485
} else {
456
486
auto computer_or = GetDocValueComputer<float >(cfg);
457
487
if (!computer_or.has_value ()) {
0 commit comments