Skip to content

Commit b8e39ea

Browse files
authored
fix: Register BatchFeatureView in feature repos correctly (#3092)
* fix: Registry BatchFeatureView in feature repos correctly Signed-off-by: Achal Shah <[email protected]> * tests Signed-off-by: Achal Shah <[email protected]> Signed-off-by: Achal Shah <[email protected]>
1 parent c93b4cc commit b8e39ea

File tree

4 files changed

+77
-2
lines changed

4 files changed

+77
-2
lines changed

sdk/python/feast/repo_operations.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,15 @@ def parse_repo(repo_root: Path) -> RepoContents:
172172
assert stream_source
173173
if not any((stream_source is ds) for ds in res.data_sources):
174174
res.data_sources.append(stream_source)
175+
elif isinstance(obj, BatchFeatureView) and not any(
176+
(obj is bfv) for bfv in res.feature_views
177+
):
178+
res.feature_views.append(obj)
179+
180+
# Handle batch sources defined with feature views.
181+
batch_source = obj.batch_source
182+
if not any((batch_source is ds) for ds in res.data_sources):
183+
res.data_sources.append(batch_source)
175184
elif isinstance(obj, Entity) and not any(
176185
(obj is entity) for entity in res.entities
177186
):
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from datetime import timedelta
2+
3+
from feast import BatchFeatureView, Entity, Field, FileSource
4+
from feast.types import Float32, Int32, Int64
5+
6+
driver_hourly_stats = FileSource(
7+
path="%PARQUET_PATH%", # placeholder to be replaced by the test
8+
timestamp_field="event_timestamp",
9+
created_timestamp_column="created",
10+
)
11+
12+
driver = Entity(
13+
name="driver_id",
14+
description="driver id",
15+
)
16+
17+
18+
driver_hourly_stats_view = BatchFeatureView(
19+
name="driver_hourly_stats",
20+
entities=[driver],
21+
ttl=timedelta(days=1),
22+
schema=[
23+
Field(name="conv_rate", dtype=Float32),
24+
Field(name="acc_rate", dtype=Float32),
25+
Field(name="avg_daily_trips", dtype=Int64),
26+
Field(name="driver_id", dtype=Int32),
27+
],
28+
online=True,
29+
source=driver_hourly_stats,
30+
tags={},
31+
)
32+
33+
34+
global_daily_stats = FileSource(
35+
path="%PARQUET_PATH_GLOBAL%", # placeholder to be replaced by the test
36+
timestamp_field="event_timestamp",
37+
created_timestamp_column="created",
38+
)
39+
40+
41+
global_stats_feature_view = BatchFeatureView(
42+
name="global_daily_stats",
43+
entities=None,
44+
ttl=timedelta(days=1),
45+
schema=[
46+
Field(name="num_rides", dtype=Int32),
47+
Field(name="avg_ride_length", dtype=Float32),
48+
],
49+
online=True,
50+
source=global_daily_stats,
51+
tags={},
52+
)

sdk/python/tests/unit/local_feast_tests/test_e2e_local.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,16 @@ def test_e2e_local() -> None:
5151
runner, store, start_date, end_date, driver_df
5252
)
5353

54+
with runner.local_repo(
55+
get_example_repo("example_feature_repo_with_bfvs.py")
56+
.replace("%PARQUET_PATH%", driver_stats_path)
57+
.replace("%PARQUET_PATH_GLOBAL%", global_stats_path),
58+
"file",
59+
) as store:
60+
_test_materialize_and_online_retrieval(
61+
runner, store, start_date, end_date, driver_df
62+
)
63+
5464
with runner.local_repo(
5565
get_example_repo("example_feature_repo_with_ttl_0.py")
5666
.replace("%PARQUET_PATH%", driver_stats_path)

sdk/python/tests/utils/cli_repo_creator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ def local_repo(self, example_repo_py: str, offline_store: str):
8888
stderr = result.stderr.decode("utf-8")
8989
print(f"Apply stdout:\n{stdout}")
9090
print(f"Apply stderr:\n{stderr}")
91-
assert result.returncode == 0
91+
assert (
92+
result.returncode == 0
93+
), f"stdout: {result.stdout}\nstderr: {result.stderr}"
9294

9395
yield FeatureStore(repo_path=str(repo_path), config=None)
9496

@@ -97,4 +99,6 @@ def local_repo(self, example_repo_py: str, offline_store: str):
9799
stderr = result.stderr.decode("utf-8")
98100
print(f"Apply stdout:\n{stdout}")
99101
print(f"Apply stderr:\n{stderr}")
100-
assert result.returncode == 0
102+
assert (
103+
result.returncode == 0
104+
), f"stdout: {result.stdout}\nstderr: {result.stderr}"

0 commit comments

Comments
 (0)