Skip to content

Commit c0e2ad7

Browse files
authored
feat: Add file write_to_offline_store functionality (#2808)
* Skaffolding for offline store push Signed-off-by: Kevin Zhang <[email protected]> * LInt Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]> * File source offline push Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]> * Address review comments Signed-off-by: Kevin Zhang <[email protected]>
1 parent 4934d84 commit c0e2ad7

File tree

6 files changed

+280
-8
lines changed

6 files changed

+280
-8
lines changed

sdk/python/feast/feature_store.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,7 +1405,7 @@ def write_to_online_store(
14051405
provider.ingest_df(feature_view, entities, df)
14061406

14071407
@log_exceptions_and_usage
1408-
def write_to_offline_store(
1408+
def _write_to_offline_store(
14091409
self,
14101410
feature_view_name: str,
14111411
df: pd.DataFrame,
@@ -1423,8 +1423,9 @@ def write_to_offline_store(
14231423
feature_view = self.get_feature_view(
14241424
feature_view_name, allow_registry_cache=allow_registry_cache
14251425
)
1426+
table = pa.Table.from_pandas(df)
14261427
provider = self._get_provider()
1427-
provider.ingest_df_to_offline_store(feature_view, df)
1428+
provider.ingest_df_to_offline_store(feature_view, table)
14281429

14291430
@log_exceptions_and_usage
14301431
def get_online_features(

sdk/python/feast/infra/offline_stores/file.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import uuid
22
from datetime import datetime
33
from pathlib import Path
4-
from typing import Callable, List, Optional, Tuple, Union
4+
from typing import Any, Callable, List, Optional, Tuple, Union
55

66
import dask.dataframe as dd
77
import pandas as pd
@@ -404,6 +404,42 @@ def write_logged_features(
404404
existing_data_behavior="overwrite_or_ignore",
405405
)
406406

407+
@staticmethod
408+
def offline_write_batch(
409+
config: RepoConfig,
410+
feature_view: FeatureView,
411+
data: pyarrow.Table,
412+
progress: Optional[Callable[[int], Any]],
413+
):
414+
if not feature_view.batch_source:
415+
raise ValueError(
416+
"feature view does not have a batch source to persist offline data"
417+
)
418+
if not isinstance(config.offline_store, FileOfflineStoreConfig):
419+
raise ValueError(
420+
f"offline store config is of type {type(config.offline_store)} when file type required"
421+
)
422+
if not isinstance(feature_view.batch_source, FileSource):
423+
raise ValueError(
424+
f"feature view batch source is {type(feature_view.batch_source)} not file source"
425+
)
426+
file_options = feature_view.batch_source.file_options
427+
filesystem, path = FileSource.create_filesystem_and_path(
428+
file_options.uri, file_options.s3_endpoint_override
429+
)
430+
431+
prev_table = pyarrow.parquet.read_table(path, memory_map=True)
432+
if prev_table.column_names != data.column_names:
433+
raise ValueError(
434+
f"Input dataframe has incorrect schema or wrong order, expected columns are: {prev_table.column_names}"
435+
)
436+
if data.schema != prev_table.schema:
437+
data = data.cast(prev_table.schema)
438+
new_table = pyarrow.concat_tables([data, prev_table])
439+
writer = pyarrow.parquet.ParquetWriter(path, data.schema, filesystem=filesystem)
440+
writer.write_table(new_table)
441+
writer.close()
442+
407443

408444
def _get_entity_df_event_timestamp_range(
409445
entity_df: Union[pd.DataFrame, str], entity_df_event_timestamp_col: str,

sdk/python/feast/infra/offline_stores/offline_store.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,8 @@ def write_logged_features(
274274
@staticmethod
275275
def offline_write_batch(
276276
config: RepoConfig,
277-
table: FeatureView,
278-
data: pd.DataFrame,
277+
feature_view: FeatureView,
278+
data: pyarrow.Table,
279279
progress: Optional[Callable[[int], Any]],
280280
):
281281
"""
@@ -287,7 +287,7 @@ def offline_write_batch(
287287
Args:
288288
config: Repo configuration object
289289
table: FeatureView to write the data to.
290-
data: dataframe containing feature data and timestamp column for historical feature retrieval
290+
data: pyarrow table containing feature data and timestamp column for historical feature retrieval
291291
progress: Optional function to be called once every mini-batch of rows is written to
292292
the online store. Can be used to display progress.
293293
"""

sdk/python/feast/infra/passthrough_provider.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,11 @@ def offline_write_batch(
104104
self,
105105
config: RepoConfig,
106106
table: FeatureView,
107-
data: pd.DataFrame,
107+
data: pa.Table,
108108
progress: Optional[Callable[[int], Any]],
109109
) -> None:
110110
set_usage_attribute("provider", self.__class__.__name__)
111+
111112
if self.offline_store:
112113
self.offline_store.offline_write_batch(config, table, data, progress)
113114

@@ -143,6 +144,14 @@ def ingest_df(
143144
self.repo_config, feature_view, rows_to_write, progress=None
144145
)
145146

147+
def ingest_df_to_offline_store(self, feature_view: FeatureView, table: pa.Table):
148+
set_usage_attribute("provider", self.__class__.__name__)
149+
150+
if feature_view.batch_source.field_mapping is not None:
151+
table = _run_field_mapping(table, feature_view.batch_source.field_mapping)
152+
153+
self.offline_write_batch(self.repo_config, feature_view, table, None)
154+
146155
def materialize_single_feature_view(
147156
self,
148157
config: RepoConfig,

sdk/python/feast/infra/provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def ingest_df(
127127
pass
128128

129129
def ingest_df_to_offline_store(
130-
self, feature_view: FeatureView, df: pd.DataFrame,
130+
self, feature_view: FeatureView, df: pyarrow.Table,
131131
):
132132
"""
133133
Ingests a DataFrame directly into the offline store
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import random
2+
from datetime import datetime, timedelta
3+
4+
import numpy as np
5+
import pandas as pd
6+
import pytest
7+
8+
from feast import FeatureView, Field
9+
from feast.types import Float32, Int32
10+
from tests.integration.feature_repos.universal.entities import driver
11+
12+
13+
@pytest.mark.integration
14+
@pytest.mark.universal_online_stores
15+
def test_writing_incorrect_order_fails(environment, universal_data_sources):
16+
# TODO(kevjumba) handle incorrect order later, for now schema must be in the order that the filesource is in
17+
store = environment.feature_store
18+
_, _, data_sources = universal_data_sources
19+
driver_stats = FeatureView(
20+
name="driver_stats",
21+
entities=["driver"],
22+
schema=[
23+
Field(name="avg_daily_trips", dtype=Int32),
24+
Field(name="conv_rate", dtype=Float32),
25+
],
26+
source=data_sources.driver,
27+
)
28+
29+
now = datetime.utcnow()
30+
ts = pd.Timestamp(now).round("ms")
31+
32+
entity_df = pd.DataFrame.from_dict(
33+
{"driver_id": [1001, 1002], "event_timestamp": [ts - timedelta(hours=3), ts]}
34+
)
35+
36+
store.apply([driver(), driver_stats])
37+
df = store.get_historical_features(
38+
entity_df=entity_df,
39+
features=["driver_stats:conv_rate", "driver_stats:avg_daily_trips"],
40+
full_feature_names=False,
41+
).to_df()
42+
43+
assert df["conv_rate"].isnull().all()
44+
assert df["avg_daily_trips"].isnull().all()
45+
46+
expected_df = pd.DataFrame.from_dict(
47+
{
48+
"driver_id": [1001, 1002],
49+
"event_timestamp": [ts - timedelta(hours=3), ts],
50+
"conv_rate": [random.random(), random.random()],
51+
"avg_daily_trips": [random.randint(0, 10), random.randint(0, 10)],
52+
"created": [ts, ts],
53+
},
54+
)
55+
with pytest.raises(ValueError):
56+
store._write_to_offline_store(
57+
driver_stats.name, expected_df, allow_registry_cache=False
58+
)
59+
60+
61+
@pytest.mark.integration
62+
@pytest.mark.universal_online_stores
63+
def test_writing_incorrect_schema_fails(environment, universal_data_sources):
64+
# TODO(kevjumba) handle incorrect order later, for now schema must be in the order that the filesource is in
65+
store = environment.feature_store
66+
_, _, data_sources = universal_data_sources
67+
driver_stats = FeatureView(
68+
name="driver_stats",
69+
entities=["driver"],
70+
schema=[
71+
Field(name="avg_daily_trips", dtype=Int32),
72+
Field(name="conv_rate", dtype=Float32),
73+
],
74+
source=data_sources.driver,
75+
)
76+
77+
now = datetime.utcnow()
78+
ts = pd.Timestamp(now).round("ms")
79+
80+
entity_df = pd.DataFrame.from_dict(
81+
{"driver_id": [1001, 1002], "event_timestamp": [ts - timedelta(hours=3), ts]}
82+
)
83+
84+
store.apply([driver(), driver_stats])
85+
df = store.get_historical_features(
86+
entity_df=entity_df,
87+
features=["driver_stats:conv_rate", "driver_stats:avg_daily_trips"],
88+
full_feature_names=False,
89+
).to_df()
90+
91+
assert df["conv_rate"].isnull().all()
92+
assert df["avg_daily_trips"].isnull().all()
93+
94+
expected_df = pd.DataFrame.from_dict(
95+
{
96+
"event_timestamp": [ts - timedelta(hours=3), ts],
97+
"driver_id": [1001, 1002],
98+
"conv_rate": [random.random(), random.random()],
99+
"incorrect_schema": [random.randint(0, 10), random.randint(0, 10)],
100+
"created": [ts, ts],
101+
},
102+
)
103+
with pytest.raises(ValueError):
104+
store._write_to_offline_store(
105+
driver_stats.name, expected_df, allow_registry_cache=False
106+
)
107+
108+
109+
@pytest.mark.integration
110+
@pytest.mark.universal_online_stores
111+
def test_writing_consecutively_to_offline_store(environment, universal_data_sources):
112+
store = environment.feature_store
113+
_, _, data_sources = universal_data_sources
114+
driver_stats = FeatureView(
115+
name="driver_stats",
116+
entities=["driver"],
117+
schema=[
118+
Field(name="avg_daily_trips", dtype=Int32),
119+
Field(name="conv_rate", dtype=Float32),
120+
Field(name="acc_rate", dtype=Float32),
121+
],
122+
source=data_sources.driver,
123+
ttl=timedelta(minutes=10),
124+
)
125+
126+
now = datetime.utcnow()
127+
ts = pd.Timestamp(now, unit="ns")
128+
129+
entity_df = pd.DataFrame.from_dict(
130+
{
131+
"driver_id": [1001, 1001],
132+
"event_timestamp": [ts - timedelta(hours=4), ts - timedelta(hours=3)],
133+
}
134+
)
135+
136+
store.apply([driver(), driver_stats])
137+
df = store.get_historical_features(
138+
entity_df=entity_df,
139+
features=["driver_stats:conv_rate", "driver_stats:avg_daily_trips"],
140+
full_feature_names=False,
141+
).to_df()
142+
143+
assert df["conv_rate"].isnull().all()
144+
assert df["avg_daily_trips"].isnull().all()
145+
146+
first_df = pd.DataFrame.from_dict(
147+
{
148+
"event_timestamp": [ts - timedelta(hours=4), ts - timedelta(hours=3)],
149+
"driver_id": [1001, 1001],
150+
"conv_rate": [random.random(), random.random()],
151+
"acc_rate": [random.random(), random.random()],
152+
"avg_daily_trips": [random.randint(0, 10), random.randint(0, 10)],
153+
"created": [ts, ts],
154+
},
155+
)
156+
store._write_to_offline_store(
157+
driver_stats.name, first_df, allow_registry_cache=False
158+
)
159+
160+
after_write_df = store.get_historical_features(
161+
entity_df=entity_df,
162+
features=["driver_stats:conv_rate", "driver_stats:avg_daily_trips"],
163+
full_feature_names=False,
164+
).to_df()
165+
166+
assert len(after_write_df) == len(first_df)
167+
assert np.where(
168+
after_write_df["conv_rate"].reset_index(drop=True)
169+
== first_df["conv_rate"].reset_index(drop=True)
170+
)
171+
assert np.where(
172+
after_write_df["avg_daily_trips"].reset_index(drop=True)
173+
== first_df["avg_daily_trips"].reset_index(drop=True)
174+
)
175+
176+
second_df = pd.DataFrame.from_dict(
177+
{
178+
"event_timestamp": [ts - timedelta(hours=1), ts],
179+
"driver_id": [1001, 1001],
180+
"conv_rate": [random.random(), random.random()],
181+
"acc_rate": [random.random(), random.random()],
182+
"avg_daily_trips": [random.randint(0, 10), random.randint(0, 10)],
183+
"created": [ts, ts],
184+
},
185+
)
186+
187+
store._write_to_offline_store(
188+
driver_stats.name, second_df, allow_registry_cache=False
189+
)
190+
191+
entity_df = pd.DataFrame.from_dict(
192+
{
193+
"driver_id": [1001, 1001, 1001, 1001],
194+
"event_timestamp": [
195+
ts - timedelta(hours=4),
196+
ts - timedelta(hours=3),
197+
ts - timedelta(hours=1),
198+
ts,
199+
],
200+
}
201+
)
202+
203+
after_write_df = store.get_historical_features(
204+
entity_df=entity_df,
205+
features=[
206+
"driver_stats:conv_rate",
207+
"driver_stats:acc_rate",
208+
"driver_stats:avg_daily_trips",
209+
],
210+
full_feature_names=False,
211+
).to_df()
212+
213+
expected_df = pd.concat([first_df, second_df])
214+
assert len(after_write_df) == len(expected_df)
215+
assert np.where(
216+
after_write_df["conv_rate"].reset_index(drop=True)
217+
== expected_df["conv_rate"].reset_index(drop=True)
218+
)
219+
assert np.where(
220+
after_write_df["acc_rate"].reset_index(drop=True)
221+
== expected_df["acc_rate"].reset_index(drop=True)
222+
)
223+
assert np.where(
224+
after_write_df["avg_daily_trips"].reset_index(drop=True)
225+
== expected_df["avg_daily_trips"].reset_index(drop=True)
226+
)

0 commit comments

Comments
 (0)