Skip to content

Commit 8abc2ef

Browse files
feat: Add column reordering to write_to_offline_store (#2876)
* Add feature extraction logic to batch writer Signed-off-by: Felix Wang <[email protected]> * Enable StreamProcessor to write to both online and offline stores Signed-off-by: Felix Wang <[email protected]> * Fix incorrect columns error message Signed-off-by: Felix Wang <[email protected]> * Reorder columns in _write_to_offline_store Signed-off-by: Felix Wang <[email protected]> * Make _write_to_offline_store a public method Signed-off-by: Felix Wang <[email protected]> * Import FeatureStore correctly Signed-off-by: Felix Wang <[email protected]> * Remove defaults for `processing_time` and `query_timeout` Signed-off-by: Felix Wang <[email protected]> * Clean up `test_offline_write.py` Signed-off-by: Felix Wang <[email protected]> * Do not do any custom logic for double underscore columns Signed-off-by: Felix Wang <[email protected]> * Lint Signed-off-by: Felix Wang <[email protected]> * Switch entity values for all tests using push sources to not affect other tests Signed-off-by: Felix Wang <[email protected]>
1 parent 51df8be commit 8abc2ef

File tree

11 files changed

+153
-109
lines changed

11 files changed

+153
-109
lines changed

sdk/python/feast/feature_store.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,7 +1383,7 @@ def push(
13831383
fv.name, df, allow_registry_cache=allow_registry_cache
13841384
)
13851385
if to == PushMode.OFFLINE or to == PushMode.ONLINE_AND_OFFLINE:
1386-
self._write_to_offline_store(
1386+
self.write_to_offline_store(
13871387
fv.name, df, allow_registry_cache=allow_registry_cache
13881388
)
13891389

@@ -1415,14 +1415,18 @@ def write_to_online_store(
14151415
provider.ingest_df(feature_view, entities, df)
14161416

14171417
@log_exceptions_and_usage
1418-
def _write_to_offline_store(
1418+
def write_to_offline_store(
14191419
self,
14201420
feature_view_name: str,
14211421
df: pd.DataFrame,
14221422
allow_registry_cache: bool = True,
1423+
reorder_columns: bool = True,
14231424
):
14241425
"""
1425-
ingests data directly into the Online store
1426+
Persists the dataframe directly into the batch data source for the given feature view.
1427+
1428+
Fails if the dataframe columns do not match the columns of the batch data source. Optionally
1429+
reorders the columns of the dataframe to match.
14261430
"""
14271431
# TODO: restrict this to work with online StreamFeatureViews and validate the FeatureView type
14281432
try:
@@ -1433,7 +1437,21 @@ def _write_to_offline_store(
14331437
feature_view = self.get_feature_view(
14341438
feature_view_name, allow_registry_cache=allow_registry_cache
14351439
)
1436-
df.reset_index(drop=True)
1440+
1441+
# Get columns of the batch source and the input dataframe.
1442+
column_names_and_types = feature_view.batch_source.get_table_column_names_and_types(
1443+
self.config
1444+
)
1445+
source_columns = [column for column, _ in column_names_and_types]
1446+
input_columns = df.columns.values.tolist()
1447+
1448+
if set(input_columns) != set(source_columns):
1449+
raise ValueError(
1450+
f"The input dataframe has columns {set(input_columns)} but the batch source has columns {set(source_columns)}."
1451+
)
1452+
1453+
if reorder_columns:
1454+
df = df.reindex(columns=source_columns)
14371455

14381456
table = pa.Table.from_pandas(df)
14391457
provider = self._get_provider()

sdk/python/feast/infra/contrib/spark_kafka_processor.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from types import MethodType
2-
from typing import List
2+
from typing import List, Optional
33

4+
import pandas as pd
45
from pyspark.sql import DataFrame, SparkSession
56
from pyspark.sql.avro.functions import from_avro
67
from pyspark.sql.functions import col, from_json
78

89
from feast.data_format import AvroFormat, JsonFormat
9-
from feast.data_source import KafkaSource
10+
from feast.data_source import KafkaSource, PushMode
11+
from feast.feature_store import FeatureStore
1012
from feast.infra.contrib.stream_processor import (
1113
ProcessorConfig,
1214
StreamProcessor,
@@ -24,16 +26,16 @@ class SparkProcessorConfig(ProcessorConfig):
2426
class SparkKafkaProcessor(StreamProcessor):
2527
spark: SparkSession
2628
format: str
27-
write_function: MethodType
29+
preprocess_fn: Optional[MethodType]
2830
join_keys: List[str]
2931

3032
def __init__(
3133
self,
34+
*,
35+
fs: FeatureStore,
3236
sfv: StreamFeatureView,
3337
config: ProcessorConfig,
34-
write_function: MethodType,
35-
processing_time: str = "30 seconds",
36-
query_timeout: int = 15,
38+
preprocess_fn: Optional[MethodType] = None,
3739
):
3840
if not isinstance(sfv.stream_source, KafkaSource):
3941
raise ValueError("data source is not kafka source")
@@ -55,15 +57,16 @@ def __init__(
5557
if not isinstance(config, SparkProcessorConfig):
5658
raise ValueError("config is not spark processor config")
5759
self.spark = config.spark_session
58-
self.write_function = write_function
59-
self.processing_time = processing_time
60-
self.query_timeout = query_timeout
61-
super().__init__(sfv=sfv, data_source=sfv.stream_source)
60+
self.preprocess_fn = preprocess_fn
61+
self.processing_time = config.processing_time
62+
self.query_timeout = config.query_timeout
63+
self.join_keys = [fs.get_entity(entity).join_key for entity in sfv.entities]
64+
super().__init__(fs=fs, sfv=sfv, data_source=sfv.stream_source)
6265

63-
def ingest_stream_feature_view(self) -> None:
66+
def ingest_stream_feature_view(self, to: PushMode = PushMode.ONLINE) -> None:
6467
ingested_stream_df = self._ingest_stream_data()
6568
transformed_df = self._construct_transformation_plan(ingested_stream_df)
66-
online_store_query = self._write_to_online_store(transformed_df)
69+
online_store_query = self._write_stream_data(transformed_df, to)
6770
return online_store_query
6871

6972
def _ingest_stream_data(self) -> StreamTable:
@@ -119,13 +122,35 @@ def _ingest_stream_data(self) -> StreamTable:
119122
def _construct_transformation_plan(self, df: StreamTable) -> StreamTable:
120123
return self.sfv.udf.__call__(df) if self.sfv.udf else df
121124

122-
def _write_to_online_store(self, df: StreamTable):
125+
def _write_stream_data(self, df: StreamTable, to: PushMode):
123126
# Validation occurs at the fs.write_to_online_store() phase against the stream feature view schema.
124127
def batch_write(row: DataFrame, batch_id: int):
125-
pd_row = row.toPandas()
126-
self.write_function(
127-
pd_row, input_timestamp="event_timestamp", output_timestamp=""
128+
rows: pd.DataFrame = row.toPandas()
129+
130+
# Extract the latest feature values for each unique entity row (i.e. the join keys).
131+
# Also add a 'created' column.
132+
rows = (
133+
rows.sort_values(
134+
by=self.join_keys + [self.sfv.timestamp_field], ascending=True
135+
)
136+
.groupby(self.join_keys)
137+
.nth(0)
128138
)
139+
rows["created"] = pd.to_datetime("now", utc=True)
140+
141+
# Reset indices to ensure the dataframe has all the required columns.
142+
rows = rows.reset_index()
143+
144+
# Optionally execute preprocessor before writing to the online store.
145+
if self.preprocess_fn:
146+
rows = self.preprocess_fn(rows)
147+
148+
# Finally persist the data to the online store and/or offline store.
149+
if rows.size > 0:
150+
if to == PushMode.ONLINE or to == PushMode.ONLINE_AND_OFFLINE:
151+
self.fs.write_to_online_store(self.sfv.name, rows)
152+
if to == PushMode.OFFLINE or to == PushMode.ONLINE_AND_OFFLINE:
153+
self.fs.write_to_offline_store(self.sfv.name, rows)
129154

130155
query = (
131156
df.writeStream.outputMode("update")

sdk/python/feast/infra/contrib/stream_processor.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
from abc import ABC
2-
from typing import Callable
2+
from types import MethodType
3+
from typing import TYPE_CHECKING, Optional
34

4-
import pandas as pd
55
from pyspark.sql import DataFrame
66

7-
from feast.data_source import DataSource
7+
from feast.data_source import DataSource, PushMode
88
from feast.importer import import_class
99
from feast.repo_config import FeastConfigBaseModel
1010
from feast.stream_feature_view import StreamFeatureView
1111

12+
if TYPE_CHECKING:
13+
from feast.feature_store import FeatureStore
14+
1215
STREAM_PROCESSOR_CLASS_FOR_TYPE = {
1316
("spark", "kafka"): "feast.infra.contrib.spark_kafka_processor.SparkKafkaProcessor",
1417
}
@@ -30,21 +33,26 @@ class StreamProcessor(ABC):
3033
and persist that data to the online store.
3134
3235
Attributes:
36+
fs: The feature store where data should be persisted.
3337
sfv: The stream feature view on which the stream processor operates.
3438
data_source: The stream data source from which data will be ingested.
3539
"""
3640

41+
fs: "FeatureStore"
3742
sfv: StreamFeatureView
3843
data_source: DataSource
3944

40-
def __init__(self, sfv: StreamFeatureView, data_source: DataSource):
45+
def __init__(
46+
self, fs: "FeatureStore", sfv: StreamFeatureView, data_source: DataSource
47+
):
48+
self.fs = fs
4149
self.sfv = sfv
4250
self.data_source = data_source
4351

44-
def ingest_stream_feature_view(self) -> None:
52+
def ingest_stream_feature_view(self, to: PushMode = PushMode.ONLINE) -> None:
4553
"""
4654
Ingests data from the stream source attached to the stream feature view; transforms the data
47-
and then persists it to the online store.
55+
and then persists it to the online store and/or offline store, depending on the 'to' parameter.
4856
"""
4957
pass
5058

@@ -62,26 +70,32 @@ def _construct_transformation_plan(self, table: StreamTable) -> StreamTable:
6270
"""
6371
pass
6472

65-
def _write_to_online_store(self, table: StreamTable) -> None:
73+
def _write_stream_data(self, table: StreamTable, to: PushMode) -> None:
6674
"""
67-
Returns query for persisting data to the online store.
75+
Launches a job to persist stream data to the online store and/or offline store, depending
76+
on the 'to' parameter, and returns a handle for the job.
6877
"""
6978
pass
7079

7180

7281
def get_stream_processor_object(
7382
config: ProcessorConfig,
83+
fs: "FeatureStore",
7484
sfv: StreamFeatureView,
75-
write_function: Callable[[pd.DataFrame, str, str], None],
85+
preprocess_fn: Optional[MethodType] = None,
7686
):
7787
"""
78-
Returns a stream processor object based on the config mode and stream source type. The write function is a
79-
function that wraps the feature store "write_to_online_store" capability.
88+
Returns a stream processor object based on the config.
89+
90+
The returned object will be capable of launching an ingestion job that reads data from the
91+
given stream feature view's stream source, transforms it if the stream feature view has a
92+
transformation, and then writes it to the online store. It will also preprocess the data
93+
if a preprocessor method is defined.
8094
"""
8195
if config.mode == "spark" and config.source == "kafka":
8296
stream_processor = STREAM_PROCESSOR_CLASS_FOR_TYPE[("spark", "kafka")]
8397
module_name, class_name = stream_processor.rsplit(".", 1)
8498
cls = import_class(module_name, class_name, "StreamProcessor")
85-
return cls(sfv=sfv, config=config, write_function=write_function,)
99+
return cls(fs=fs, sfv=sfv, config=config, preprocess_fn=preprocess_fn)
86100
else:
87101
raise ValueError("other processors besides spark-kafka not supported")

sdk/python/feast/infra/offline_stores/bigquery.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,8 @@ def offline_write_batch(
329329
)
330330
if column_names != table.column_names:
331331
raise ValueError(
332-
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
333-
f"The columns are expected to be (in this order): {column_names}."
332+
f"The input pyarrow table has schema {table.schema} with the incorrect columns {table.column_names}. "
333+
f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}."
334334
)
335335

336336
if table.schema != pa_schema:

sdk/python/feast/infra/offline_stores/file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,8 +430,8 @@ def offline_write_batch(
430430
)
431431
if column_names != table.column_names:
432432
raise ValueError(
433-
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
434-
f"The columns are expected to be (in this order): {column_names}."
433+
f"The input pyarrow table has schema {table.schema} with the incorrect columns {table.column_names}. "
434+
f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}."
435435
)
436436

437437
file_options = feature_view.batch_source.file_options

sdk/python/feast/infra/offline_stores/redshift.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,8 @@ def offline_write_batch(
323323
)
324324
if column_names != table.column_names:
325325
raise ValueError(
326-
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
327-
f"The columns are expected to be (in this order): {column_names}."
326+
f"The input pyarrow table has schema {table.schema} with the incorrect columns {table.column_names}. "
327+
f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}."
328328
)
329329

330330
if table.schema != pa_schema:

sdk/python/feast/infra/offline_stores/snowflake.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ def offline_write_batch(
332332
)
333333
if column_names != table.column_names:
334334
raise ValueError(
335-
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
336-
f"The columns are expected to be (in this order): {column_names}."
335+
f"The input pyarrow table has schema {table.schema} with the incorrect columns {table.column_names}. "
336+
f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}."
337337
)
338338

339339
if table.schema != pa_schema:

sdk/python/tests/integration/e2e/test_python_feature_server.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,16 @@ def test_get_online_features(python_fs_client):
6363
@pytest.mark.integration
6464
@pytest.mark.universal_online_stores
6565
def test_push(python_fs_client):
66-
initial_temp = get_temperatures(python_fs_client, location_ids=[1])[0]
66+
# TODO(felixwang9817): Note that we choose an entity value of 102 here since it is not included
67+
# in the existing range of entity values (1-49). This allows us to push data for this test
68+
# without affecting other tests. This decision is tech debt, and should be resolved by finding a
69+
# better way to isolate data sources across tests.
6770
json_data = json.dumps(
6871
{
6972
"push_source_name": "location_stats_push_source",
7073
"df": {
71-
"location_id": [1],
72-
"temperature": [initial_temp * 100],
74+
"location_id": [102],
75+
"temperature": [4],
7376
"event_timestamp": [str(datetime.utcnow())],
7477
"created": [str(datetime.utcnow())],
7578
},
@@ -79,7 +82,7 @@ def test_push(python_fs_client):
7982

8083
# Check new pushed temperature is fetched
8184
assert response.status_code == 200
82-
assert get_temperatures(python_fs_client, location_ids=[1]) == [initial_temp * 100]
85+
assert get_temperatures(python_fs_client, location_ids=[102]) == [4]
8386

8487

8588
def get_temperatures(client, location_ids: List[int]):

0 commit comments

Comments
 (0)