Skip to content

Commit a59c33a

Browse files
authored
feat: Implement spark materialization engine (#3184)
* implement spark materialization engine Signed-off-by: niklasvm <[email protected]> * remove redundant code Signed-off-by: niklasvm <[email protected]> * make function private Signed-off-by: niklasvm <[email protected]> * refactor serializing into a class Signed-off-by: niklasvm <[email protected]> * switch to using `foreachPartition` Signed-off-by: niklasvm <[email protected]> * remove batch_size parameter Signed-off-by: niklasvm <[email protected]> * add partitions parameter Signed-off-by: niklasvm <[email protected]> * linting Signed-off-by: niklasvm <[email protected]> * rename spark to spark.offline and spark.engine Signed-off-by: niklasvm <[email protected]> * fix to test Signed-off-by: niklasvm <[email protected]> * forgot to stage Signed-off-by: niklasvm <[email protected]> * revert spark.offline to spark to ensure backward compatibility Signed-off-by: niklasvm <[email protected]> * fix import Signed-off-by: niklasvm <[email protected]> * remove code from testing a large data set Signed-off-by: niklasvm <[email protected]> * linting Signed-off-by: niklasvm <[email protected]> * test without repartition Signed-off-by: niklasvm <[email protected]> * test alternate connection string Signed-off-by: niklasvm <[email protected]> * use redis online creator Signed-off-by: niklasvm <[email protected]> Signed-off-by: niklasvm <[email protected]>
1 parent 7bc1dff commit a59c33a

File tree

3 files changed

+343
-0
lines changed

3 files changed

+343
-0
lines changed
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
import tempfile
2+
from dataclasses import dataclass
3+
from datetime import datetime
4+
from typing import Callable, List, Literal, Optional, Sequence, Union
5+
6+
import dill
7+
import pandas as pd
8+
import pyarrow
9+
from tqdm import tqdm
10+
11+
from feast.batch_feature_view import BatchFeatureView
12+
from feast.entity import Entity
13+
from feast.feature_view import FeatureView
14+
from feast.infra.materialization.batch_materialization_engine import (
15+
BatchMaterializationEngine,
16+
MaterializationJob,
17+
MaterializationJobStatus,
18+
MaterializationTask,
19+
)
20+
from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
21+
SparkOfflineStore,
22+
SparkRetrievalJob,
23+
)
24+
from feast.infra.online_stores.online_store import OnlineStore
25+
from feast.infra.passthrough_provider import PassthroughProvider
26+
from feast.infra.registry.base_registry import BaseRegistry
27+
from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto
28+
from feast.repo_config import FeastConfigBaseModel, RepoConfig
29+
from feast.stream_feature_view import StreamFeatureView
30+
from feast.utils import (
31+
_convert_arrow_to_proto,
32+
_get_column_names,
33+
_run_pyarrow_field_mapping,
34+
)
35+
36+
37+
class SparkMaterializationEngineConfig(FeastConfigBaseModel):
38+
"""Batch Materialization Engine config for spark engine"""
39+
40+
type: Literal["spark.engine"] = "spark.engine"
41+
""" Type selector"""
42+
43+
partitions: int = 0
44+
"""Number of partitions to use when writing data to online store. If 0, no repartitioning is done"""
45+
46+
47+
@dataclass
48+
class SparkMaterializationJob(MaterializationJob):
49+
def __init__(
50+
self,
51+
job_id: str,
52+
status: MaterializationJobStatus,
53+
error: Optional[BaseException] = None,
54+
) -> None:
55+
super().__init__()
56+
self._job_id: str = job_id
57+
self._status: MaterializationJobStatus = status
58+
self._error: Optional[BaseException] = error
59+
60+
def status(self) -> MaterializationJobStatus:
61+
return self._status
62+
63+
def error(self) -> Optional[BaseException]:
64+
return self._error
65+
66+
def should_be_retried(self) -> bool:
67+
return False
68+
69+
def job_id(self) -> str:
70+
return self._job_id
71+
72+
def url(self) -> Optional[str]:
73+
return None
74+
75+
76+
class SparkMaterializationEngine(BatchMaterializationEngine):
77+
def update(
78+
self,
79+
project: str,
80+
views_to_delete: Sequence[
81+
Union[BatchFeatureView, StreamFeatureView, FeatureView]
82+
],
83+
views_to_keep: Sequence[
84+
Union[BatchFeatureView, StreamFeatureView, FeatureView]
85+
],
86+
entities_to_delete: Sequence[Entity],
87+
entities_to_keep: Sequence[Entity],
88+
):
89+
# Nothing to set up.
90+
pass
91+
92+
def teardown_infra(
93+
self,
94+
project: str,
95+
fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]],
96+
entities: Sequence[Entity],
97+
):
98+
# Nothing to tear down.
99+
pass
100+
101+
def __init__(
102+
self,
103+
*,
104+
repo_config: RepoConfig,
105+
offline_store: SparkOfflineStore,
106+
online_store: OnlineStore,
107+
**kwargs,
108+
):
109+
if not isinstance(offline_store, SparkOfflineStore):
110+
raise TypeError(
111+
"SparkMaterializationEngine is only compatible with the SparkOfflineStore"
112+
)
113+
super().__init__(
114+
repo_config=repo_config,
115+
offline_store=offline_store,
116+
online_store=online_store,
117+
**kwargs,
118+
)
119+
120+
def materialize(
121+
self, registry, tasks: List[MaterializationTask]
122+
) -> List[MaterializationJob]:
123+
return [
124+
self._materialize_one(
125+
registry,
126+
task.feature_view,
127+
task.start_time,
128+
task.end_time,
129+
task.project,
130+
task.tqdm_builder,
131+
)
132+
for task in tasks
133+
]
134+
135+
def _materialize_one(
136+
self,
137+
registry: BaseRegistry,
138+
feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView],
139+
start_date: datetime,
140+
end_date: datetime,
141+
project: str,
142+
tqdm_builder: Callable[[int], tqdm],
143+
):
144+
entities = []
145+
for entity_name in feature_view.entities:
146+
entities.append(registry.get_entity(entity_name, project))
147+
148+
(
149+
join_key_columns,
150+
feature_name_columns,
151+
timestamp_field,
152+
created_timestamp_column,
153+
) = _get_column_names(feature_view, entities)
154+
155+
job_id = f"{feature_view.name}-{start_date}-{end_date}"
156+
157+
try:
158+
offline_job: SparkRetrievalJob = (
159+
self.offline_store.pull_latest_from_table_or_query(
160+
config=self.repo_config,
161+
data_source=feature_view.batch_source,
162+
join_key_columns=join_key_columns,
163+
feature_name_columns=feature_name_columns,
164+
timestamp_field=timestamp_field,
165+
created_timestamp_column=created_timestamp_column,
166+
start_date=start_date,
167+
end_date=end_date,
168+
)
169+
)
170+
171+
spark_serialized_artifacts = _SparkSerializedArtifacts.serialize(
172+
feature_view=feature_view, repo_config=self.repo_config
173+
)
174+
175+
spark_df = offline_job.to_spark_df()
176+
if self.repo_config.batch_engine.partitions != 0:
177+
spark_df = spark_df.repartition(
178+
self.repo_config.batch_engine.partitions
179+
)
180+
181+
spark_df.foreachPartition(
182+
lambda x: _process_by_partition(x, spark_serialized_artifacts)
183+
)
184+
185+
return SparkMaterializationJob(
186+
job_id=job_id, status=MaterializationJobStatus.SUCCEEDED
187+
)
188+
except BaseException as e:
189+
return SparkMaterializationJob(
190+
job_id=job_id, status=MaterializationJobStatus.ERROR, error=e
191+
)
192+
193+
194+
@dataclass
195+
class _SparkSerializedArtifacts:
196+
"""Class to assist with serializing unpicklable artifacts to the spark workers"""
197+
198+
feature_view_proto: str
199+
repo_config_file: str
200+
201+
@classmethod
202+
def serialize(cls, feature_view, repo_config):
203+
204+
# serialize to proto
205+
feature_view_proto = feature_view.to_proto().SerializeToString()
206+
207+
# serialize repo_config to disk. Will be used to instantiate the online store
208+
repo_config_file = tempfile.NamedTemporaryFile(delete=False).name
209+
with open(repo_config_file, "wb") as f:
210+
dill.dump(repo_config, f)
211+
212+
return _SparkSerializedArtifacts(
213+
feature_view_proto=feature_view_proto, repo_config_file=repo_config_file
214+
)
215+
216+
def unserialize(self):
217+
# unserialize
218+
proto = FeatureViewProto()
219+
proto.ParseFromString(self.feature_view_proto)
220+
feature_view = FeatureView.from_proto(proto)
221+
222+
# load
223+
with open(self.repo_config_file, "rb") as f:
224+
repo_config = dill.load(f)
225+
226+
provider = PassthroughProvider(repo_config)
227+
online_store = provider.online_store
228+
return feature_view, online_store, repo_config
229+
230+
231+
def _process_by_partition(rows, spark_serialized_artifacts: _SparkSerializedArtifacts):
232+
"""Load pandas df to online store"""
233+
234+
# convert to pyarrow table
235+
dicts = []
236+
for row in rows:
237+
dicts.append(row.asDict())
238+
239+
df = pd.DataFrame.from_records(dicts)
240+
if df.shape[0] == 0:
241+
print("Skipping")
242+
return
243+
244+
table = pyarrow.Table.from_pandas(df)
245+
246+
# unserialize artifacts
247+
feature_view, online_store, repo_config = spark_serialized_artifacts.unserialize()
248+
249+
if feature_view.batch_source.field_mapping is not None:
250+
table = _run_pyarrow_field_mapping(
251+
table, feature_view.batch_source.field_mapping
252+
)
253+
254+
join_key_to_value_type = {
255+
entity.name: entity.dtype.to_value_type()
256+
for entity in feature_view.entity_columns
257+
}
258+
259+
rows_to_write = _convert_arrow_to_proto(table, feature_view, join_key_to_value_type)
260+
online_store.online_write_batch(
261+
repo_config,
262+
feature_view,
263+
rows_to_write,
264+
lambda x: None,
265+
)

sdk/python/feast/repo_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"snowflake.engine": "feast.infra.materialization.snowflake_engine.SnowflakeMaterializationEngine",
4040
"lambda": "feast.infra.materialization.aws_lambda.lambda_engine.LambdaMaterializationEngine",
4141
"bytewax": "feast.infra.materialization.contrib.bytewax.bytewax_materialization_engine.BytewaxMaterializationEngine",
42+
"spark.engine": "feast.infra.materialization.contrib.spark.spark_materialization_engine.SparkMaterializationEngine",
4243
}
4344

4445
ONLINE_STORE_CLASS_FOR_TYPE = {
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from datetime import timedelta
2+
3+
import pytest
4+
5+
from feast.entity import Entity
6+
from feast.feature_view import FeatureView
7+
from feast.field import Field
8+
from feast.infra.offline_stores.contrib.spark_offline_store.tests.data_source import (
9+
SparkDataSourceCreator,
10+
)
11+
from feast.types import Float32
12+
from tests.data.data_creator import create_basic_driver_dataset
13+
from tests.integration.feature_repos.integration_test_repo_config import (
14+
IntegrationTestRepoConfig,
15+
)
16+
from tests.integration.feature_repos.repo_configuration import (
17+
construct_test_environment,
18+
)
19+
from tests.integration.feature_repos.universal.online_store.redis import (
20+
RedisOnlineStoreCreator,
21+
)
22+
from tests.utils.e2e_test_validation import validate_offline_online_store_consistency
23+
24+
25+
@pytest.mark.integration
26+
def test_spark_materialization_consistency():
27+
spark_config = IntegrationTestRepoConfig(
28+
provider="local",
29+
online_store_creator=RedisOnlineStoreCreator,
30+
offline_store_creator=SparkDataSourceCreator,
31+
batch_engine={"type": "spark.engine", "partitions": 10},
32+
)
33+
spark_environment = construct_test_environment(
34+
spark_config, None, entity_key_serialization_version=1
35+
)
36+
37+
df = create_basic_driver_dataset()
38+
39+
ds = spark_environment.data_source_creator.create_data_source(
40+
df,
41+
spark_environment.feature_store.project,
42+
field_mapping={"ts_1": "ts"},
43+
)
44+
45+
fs = spark_environment.feature_store
46+
driver = Entity(
47+
name="driver_id",
48+
join_keys=["driver_id"],
49+
)
50+
51+
driver_stats_fv = FeatureView(
52+
name="driver_hourly_stats",
53+
entities=[driver],
54+
ttl=timedelta(weeks=52),
55+
schema=[Field(name="value", dtype=Float32)],
56+
source=ds,
57+
)
58+
59+
try:
60+
61+
fs.apply([driver, driver_stats_fv])
62+
63+
print(df)
64+
65+
# materialization is run in two steps and
66+
# we use timestamp from generated dataframe as a split point
67+
split_dt = df["ts_1"][4].to_pydatetime() - timedelta(seconds=1)
68+
69+
print(f"Split datetime: {split_dt}")
70+
71+
validate_offline_online_store_consistency(fs, driver_stats_fv, split_dt)
72+
finally:
73+
fs.teardown()
74+
75+
76+
if __name__ == "__main__":
77+
test_spark_materialization_consistency()

0 commit comments

Comments
 (0)