Skip to content

Commit c5539fd

Browse files
authored
fix: Fix on demand feature view crash from inference when it uses df.apply (#2713)
* fix: Fix on demand feature view crash from inference when transformation uses df.apply Signed-off-by: Danny Chiao <[email protected]> * Fix inference Signed-off-by: Danny Chiao <[email protected]> * Fix test Signed-off-by: Danny Chiao <[email protected]>
1 parent cebf609 commit c5539fd

File tree

4 files changed

+168
-3
lines changed

4 files changed

+168
-3
lines changed

sdk/python/feast/on_demand_feature_view.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import copy
22
import functools
33
import warnings
4+
from datetime import datetime
45
from types import MethodType
5-
from typing import Dict, List, Optional, Type, Union
6+
from typing import Any, Dict, List, Optional, Type, Union
67

78
import dill
89
import pandas as pd
@@ -442,18 +443,29 @@ def infer_features(self):
442443
Raises:
443444
RegistryInferenceFailure: The set of features could not be inferred.
444445
"""
446+
rand_df_value: Dict[str, Any] = {
447+
"float": 1.0,
448+
"int": 1,
449+
"str": "hello world",
450+
"bytes": str.encode("hello world"),
451+
"bool": True,
452+
"datetime64[ns]": datetime.utcnow(),
453+
}
454+
445455
df = pd.DataFrame()
446456
for feature_view_projection in self.source_feature_view_projections.values():
447457
for feature in feature_view_projection.features:
448458
dtype = feast_value_type_to_pandas_type(feature.dtype.to_value_type())
449459
df[f"{feature_view_projection.name}__{feature.name}"] = pd.Series(
450460
dtype=dtype
451461
)
452-
df[f"{feature.name}"] = pd.Series(dtype=dtype)
462+
sample_val = rand_df_value[dtype] if dtype in rand_df_value else None
463+
df[f"{feature.name}"] = pd.Series(data=sample_val, dtype=dtype)
453464
for request_data in self.source_request_sources.values():
454465
for field in request_data.schema:
455466
dtype = feast_value_type_to_pandas_type(field.dtype.to_value_type())
456-
df[f"{field.name}"] = pd.Series(dtype=dtype)
467+
sample_val = rand_df_value[dtype] if dtype in rand_df_value else None
468+
df[f"{field.name}"] = pd.Series(sample_val, dtype=dtype)
457469
output_df: pd.DataFrame = self.udf.__call__(df)
458470
inferred_features = []
459471
for f, dt in zip(output_df.columns, output_df.dtypes):
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from datetime import timedelta
2+
3+
import pandas as pd
4+
5+
from feast import FeatureView, Field, FileSource
6+
from feast.on_demand_feature_view import on_demand_feature_view
7+
from feast.types import Float32, String
8+
9+
driver_stats = FileSource(
10+
name="driver_stats_source",
11+
path="data/driver_stats_lat_lon.parquet",
12+
timestamp_field="event_timestamp",
13+
created_timestamp_column="created",
14+
description="A table describing the stats of a driver based on hourly logs",
15+
16+
)
17+
18+
driver_daily_features_view = FeatureView(
19+
name="driver_daily_features",
20+
entities=["driver"],
21+
ttl=timedelta(seconds=8640000000),
22+
schema=[
23+
Field(name="daily_miles_driven", dtype=Float32),
24+
Field(name="lat", dtype=Float32),
25+
Field(name="lon", dtype=Float32),
26+
Field(name="string_feature", dtype=String),
27+
],
28+
online=True,
29+
source=driver_stats,
30+
tags={"production": "True"},
31+
32+
)
33+
34+
35+
@on_demand_feature_view(
36+
sources=[driver_daily_features_view],
37+
schema=[
38+
Field(name="first_char", dtype=String),
39+
Field(name="concat_string", dtype=String),
40+
],
41+
)
42+
def location_features_from_push(inputs: pd.DataFrame) -> pd.DataFrame:
43+
df = pd.DataFrame()
44+
df["concat_string"] = inputs.apply(
45+
lambda x: x.string_feature + "hello", axis=1
46+
).astype("string")
47+
df["first_char"] = inputs["string_feature"].str[:1].astype("string")
48+
return df

sdk/python/tests/integration/registration/test_cli.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,37 @@ def test_nullable_online_store(test_nullable_online_store) -> None:
201201
runner.run(["teardown"], cwd=repo_path)
202202

203203

204+
@pytest.mark.integration
205+
@pytest.mark.universal_offline_stores
206+
def test_odfv_apply(environment) -> None:
207+
project = f"test_odfv_apply{str(uuid.uuid4()).replace('-', '')[:8]}"
208+
runner = CliRunner()
209+
210+
with tempfile.TemporaryDirectory() as repo_dir_name:
211+
try:
212+
repo_path = Path(repo_dir_name)
213+
feature_store_yaml = make_feature_store_yaml(
214+
project, environment.test_repo_config, repo_path
215+
)
216+
217+
repo_config = repo_path / "feature_store.yaml"
218+
219+
repo_config.write_text(dedent(feature_store_yaml))
220+
221+
repo_example = repo_path / "example.py"
222+
repo_example.write_text(get_example_repo("on_demand_feature_view_repo.py"))
223+
result = runner.run(["apply"], cwd=repo_path)
224+
assertpy.assert_that(result.returncode).is_equal_to(0)
225+
226+
# entity & feature view list commands should succeed
227+
result = runner.run(["entities", "list"], cwd=repo_path)
228+
assertpy.assert_that(result.returncode).is_equal_to(0)
229+
result = runner.run(["on-demand-feature-views", "list"], cwd=repo_path)
230+
assertpy.assert_that(result.returncode).is_equal_to(0)
231+
finally:
232+
runner.run(["teardown"], cwd=repo_path)
233+
234+
204235
@contextmanager
205236
def setup_third_party_provider_repo(provider_name: str):
206237
with tempfile.TemporaryDirectory() as repo_dir_name:

sdk/python/tests/integration/registration/test_registry.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,80 @@ def test_apply_feature_view_success(test_registry):
234234
test_registry._get_registry_proto()
235235

236236

237+
@pytest.mark.parametrize(
238+
"test_registry", [lazy_fixture("local_registry")],
239+
)
240+
def test_apply_on_demand_feature_view_success(test_registry):
241+
# Create Feature Views
242+
driver_stats = FileSource(
243+
name="driver_stats_source",
244+
path="data/driver_stats_lat_lon.parquet",
245+
timestamp_field="event_timestamp",
246+
created_timestamp_column="created",
247+
description="A table describing the stats of a driver based on hourly logs",
248+
249+
)
250+
251+
driver_daily_features_view = FeatureView(
252+
name="driver_daily_features",
253+
entities=["driver"],
254+
ttl=timedelta(seconds=8640000000),
255+
schema=[
256+
Field(name="daily_miles_driven", dtype=Float32),
257+
Field(name="lat", dtype=Float32),
258+
Field(name="lon", dtype=Float32),
259+
Field(name="string_feature", dtype=String),
260+
],
261+
online=True,
262+
source=driver_stats,
263+
tags={"production": "True"},
264+
265+
)
266+
267+
@on_demand_feature_view(
268+
sources=[driver_daily_features_view],
269+
schema=[Field(name="first_char", dtype=String)],
270+
)
271+
def location_features_from_push(inputs: pd.DataFrame) -> pd.DataFrame:
272+
df = pd.DataFrame()
273+
df["first_char"] = inputs["string_feature"].str[:1].astype("string")
274+
return df
275+
276+
project = "project"
277+
278+
# Register Feature View
279+
test_registry.apply_feature_view(location_features_from_push, project)
280+
281+
feature_views = test_registry.list_on_demand_feature_views(project)
282+
283+
# List Feature Views
284+
assert (
285+
len(feature_views) == 1
286+
and feature_views[0].name == "location_features_from_push"
287+
and feature_views[0].features[0].name == "first_char"
288+
and feature_views[0].features[0].dtype == String
289+
)
290+
291+
feature_view = test_registry.get_on_demand_feature_view(
292+
"location_features_from_push", project
293+
)
294+
assert (
295+
feature_view.name == "location_features_from_push"
296+
and feature_view.features[0].name == "first_char"
297+
and feature_view.features[0].dtype == String
298+
)
299+
300+
test_registry.delete_feature_view("location_features_from_push", project)
301+
feature_views = test_registry.list_on_demand_feature_views(project)
302+
assert len(feature_views) == 0
303+
304+
test_registry.teardown()
305+
306+
# Will try to reload registry, which will fail because the file has been deleted
307+
with pytest.raises(FileNotFoundError):
308+
test_registry._get_registry_proto()
309+
310+
237311
@pytest.mark.parametrize(
238312
"test_registry", [lazy_fixture("local_registry")],
239313
)

0 commit comments

Comments
 (0)