6
6
import traceback
7
7
from contextlib import asynccontextmanager
8
8
from importlib import resources as importlib_resources
9
- from typing import Any , Dict , List , Optional
9
+ from typing import Any , Dict , List , Optional , Union
10
10
11
11
import pandas as pd
12
12
import psutil
@@ -86,10 +86,18 @@ class MaterializeIncrementalRequest(BaseModel):
86
86
class GetOnlineFeaturesRequest (BaseModel ):
87
87
entities : Dict [str , List [Any ]]
88
88
feature_service : Optional [str ] = None
89
- features : Optional [ List [str ]] = None
89
+ features : List [str ] = []
90
90
full_feature_names : bool = False
91
- query_embedding : Optional [List [float ]] = None
91
+
92
+
93
+ class GetOnlineDocumentsRequest (BaseModel ):
94
+ feature_service : Optional [str ] = None
95
+ features : List [str ] = []
96
+ full_feature_names : bool = False
97
+ top_k : Optional [int ] = None
98
+ query : Optional [List [float ]] = None
92
99
query_string : Optional [str ] = None
100
+ api_version : Optional [int ] = 1
93
101
94
102
95
103
class ChatMessage (BaseModel ):
@@ -110,7 +118,10 @@ class SaveDocumentRequest(BaseModel):
110
118
data : dict
111
119
112
120
113
- def _get_features (request : GetOnlineFeaturesRequest , store : "feast.FeatureStore" ):
121
+ def _get_features (
122
+ request : Union [GetOnlineFeaturesRequest , GetOnlineDocumentsRequest ],
123
+ store : "feast.FeatureStore" ,
124
+ ):
114
125
if request .feature_service :
115
126
feature_service = store .get_feature_service (
116
127
request .feature_service , allow_cache = True
@@ -246,24 +257,26 @@ async def get_online_features(request: GetOnlineFeaturesRequest) -> Dict[str, An
246
257
dependencies = [Depends (inject_user_details )],
247
258
)
248
259
async def retrieve_online_documents (
249
- request : GetOnlineFeaturesRequest ,
260
+ request : GetOnlineDocumentsRequest ,
250
261
) -> Dict [str , Any ]:
251
262
logger .warning (
252
263
"This endpoint is in alpha and will be moved to /get-online-features when stable."
253
264
)
254
265
# Initialize parameters for FeatureStore.retrieve_online_documents_v2(...) call
255
266
features = await run_in_threadpool (_get_features , request , store )
256
267
257
- read_params = dict (
258
- features = features ,
259
- full_feature_names = request .full_feature_names ,
260
- query = request .query_embedding ,
261
- query_string = request .query_string ,
262
- )
268
+ read_params = dict (features = features , query = request .query , top_k = request .top_k )
269
+ if request .api_version == 2 and request .query_string is not None :
270
+ read_params ["query_string" ] = request .query_string
263
271
264
- response = await run_in_threadpool (
265
- lambda : store .retrieve_online_documents_v2 (** read_params ) # type: ignore
266
- )
272
+ if request .api_version == 2 :
273
+ response = await run_in_threadpool (
274
+ lambda : store .retrieve_online_documents_v2 (** read_params ) # type: ignore
275
+ )
276
+ else :
277
+ response = await run_in_threadpool (
278
+ lambda : store .retrieve_online_documents (** read_params ) # type: ignore
279
+ )
267
280
268
281
# Convert the Protobuf object to JSON and return it
269
282
response_dict = await run_in_threadpool (
0 commit comments