Skip to content

Commit 5b0cc87

Browse files
authored
feat: Implement spark offline store offline_write_batch method (#3076)
* create integration spark data sets as files rather than a temp table Signed-off-by: niklasvm <[email protected]> * add offline_write_batch method to spark offline store Signed-off-by: niklasvm <[email protected]> * remove some comments Signed-off-by: niklasvm <[email protected]> * fix linting issue Signed-off-by: niklasvm <[email protected]> * fix more linting issues Signed-off-by: niklasvm <[email protected]> * fix flake8 errors Signed-off-by: niklasvm <[email protected]> Signed-off-by: niklasvm <[email protected]>
1 parent cdd1b07 commit 5b0cc87

File tree

2 files changed

+95
-4
lines changed

2 files changed

+95
-4
lines changed

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import tempfile
22
import warnings
33
from datetime import datetime
4-
from typing import Dict, List, Optional, Tuple, Union
4+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
55

66
import numpy as np
77
import pandas
@@ -191,6 +191,68 @@ def get_historical_features(
191191
),
192192
)
193193

194+
@staticmethod
195+
def offline_write_batch(
196+
config: RepoConfig,
197+
feature_view: FeatureView,
198+
table: pyarrow.Table,
199+
progress: Optional[Callable[[int], Any]],
200+
):
201+
if not feature_view.batch_source:
202+
raise ValueError(
203+
"feature view does not have a batch source to persist offline data"
204+
)
205+
if not isinstance(config.offline_store, SparkOfflineStoreConfig):
206+
raise ValueError(
207+
f"offline store config is of type {type(config.offline_store)} when spark type required"
208+
)
209+
if not isinstance(feature_view.batch_source, SparkSource):
210+
raise ValueError(
211+
f"feature view batch source is {type(feature_view.batch_source)} not spark source"
212+
)
213+
214+
pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
215+
config, feature_view.batch_source
216+
)
217+
if column_names != table.column_names:
218+
raise ValueError(
219+
f"The input pyarrow table has schema {table.schema} with the incorrect columns {table.column_names}. "
220+
f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}."
221+
)
222+
223+
spark_session = get_spark_session_or_start_new_with_repoconfig(
224+
store_config=config.offline_store
225+
)
226+
227+
if feature_view.batch_source.path:
228+
# write data to disk so that it can be loaded into spark (for preserving column types)
229+
with tempfile.NamedTemporaryFile(suffix=".parquet") as tmp_file:
230+
print(tmp_file.name)
231+
pq.write_table(table, tmp_file.name)
232+
233+
# load data
234+
df_batch = spark_session.read.parquet(tmp_file.name)
235+
236+
# load existing data to get spark table schema
237+
df_existing = spark_session.read.format(
238+
feature_view.batch_source.file_format
239+
).load(feature_view.batch_source.path)
240+
241+
# cast columns if applicable
242+
df_batch = _cast_data_frame(df_batch, df_existing)
243+
244+
df_batch.write.format(feature_view.batch_source.file_format).mode(
245+
"append"
246+
).save(feature_view.batch_source.path)
247+
elif feature_view.batch_source.query:
248+
raise NotImplementedError(
249+
"offline_write_batch not implemented for batch sources specified by query"
250+
)
251+
else:
252+
raise NotImplementedError(
253+
"offline_write_batch not implemented for batch sources specified by a table"
254+
)
255+
194256
@staticmethod
195257
@log_exceptions_and_usage(offline_store="spark")
196258
def pull_all_from_table_or_query(
@@ -388,6 +450,24 @@ def _format_datetime(t: datetime) -> str:
388450
return dt
389451

390452

453+
def _cast_data_frame(
454+
df_new: pyspark.sql.DataFrame, df_existing: pyspark.sql.DataFrame
455+
) -> pyspark.sql.DataFrame:
456+
"""Convert new dataframe's columns to the same types as existing dataframe while preserving the order of columns"""
457+
existing_dtypes = {k: v for k, v in df_existing.dtypes}
458+
new_dtypes = {k: v for k, v in df_new.dtypes}
459+
460+
select_expression = []
461+
for col, new_type in new_dtypes.items():
462+
existing_type = existing_dtypes[col]
463+
if new_type != existing_type:
464+
select_expression.append(f"cast({col} as {existing_type}) as {col}")
465+
else:
466+
select_expression.append(col)
467+
468+
return df_new.selectExpr(*select_expression)
469+
470+
391471
MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN = """
392472
/*
393473
Compute a deterministic hash for the `left_table_query_string` that will be used throughout

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import os
2+
import shutil
3+
import tempfile
14
import uuid
25
from typing import Dict, List
36

@@ -48,6 +51,8 @@ def __init__(self, project_name: str, *args, **kwargs):
4851

4952
def teardown(self):
5053
self.spark_session.stop()
54+
for table in self.tables:
55+
shutil.rmtree(table)
5156

5257
def create_offline_store_config(self):
5358
self.spark_offline_store_config = SparkOfflineStoreConfig()
@@ -86,11 +91,17 @@ def create_data_source(
8691
.appName("pytest-pyspark-local-testing")
8792
.getOrCreate()
8893
)
89-
self.spark_session.createDataFrame(df).createOrReplaceTempView(destination_name)
90-
self.tables.append(destination_name)
9194

95+
temp_dir = tempfile.mkdtemp(prefix="spark_offline_store_test_data")
96+
97+
path = os.path.join(temp_dir, destination_name)
98+
self.tables.append(path)
99+
100+
self.spark_session.createDataFrame(df).write.parquet(path)
92101
return SparkSource(
93-
table=destination_name,
102+
name=destination_name,
103+
file_format="parquet",
104+
path=path,
94105
timestamp_field=timestamp_field,
95106
created_timestamp_column=created_timestamp_column,
96107
field_mapping=field_mapping or {"ts_1": "ts"},

0 commit comments

Comments
 (0)