Skip to content

Commit 97444e4

Browse files
feat: Implement offline_write_batch for BigQuery and Snowflake (#2840)
* Factor out Redshift pyarrow schema inference logic into helper method Signed-off-by: Felix Wang <[email protected]> * Switch file offline store to use offline_utils for offline_write_batch Signed-off-by: Felix Wang <[email protected]> * Implement offline_write_batch for bigquery Signed-off-by: Felix Wang <[email protected]> * Implement offline_write_batch for snowflake Signed-off-by: Felix Wang <[email protected]> * Enable bigquery and snowflake for test_push_features_and_read_from_offline_store test Signed-off-by: Felix Wang <[email protected]> * Rename get_pyarrow_schema Signed-off-by: Felix Wang <[email protected]>
1 parent a88cd30 commit 97444e4

File tree

9 files changed

+155
-35
lines changed

9 files changed

+155
-35
lines changed

sdk/python/feast/infra/offline_stores/bigquery.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from datetime import date, datetime, timedelta
55
from pathlib import Path
66
from typing import (
7+
Any,
78
Callable,
89
ContextManager,
910
Dict,
@@ -303,6 +304,60 @@ def write_logged_features(
303304
job_config=job_config,
304305
)
305306

307+
@staticmethod
308+
def offline_write_batch(
309+
config: RepoConfig,
310+
feature_view: FeatureView,
311+
table: pyarrow.Table,
312+
progress: Optional[Callable[[int], Any]],
313+
):
314+
if not feature_view.batch_source:
315+
raise ValueError(
316+
"feature view does not have a batch source to persist offline data"
317+
)
318+
if not isinstance(config.offline_store, BigQueryOfflineStoreConfig):
319+
raise ValueError(
320+
f"offline store config is of type {type(config.offline_store)} when bigquery type required"
321+
)
322+
if not isinstance(feature_view.batch_source, BigQuerySource):
323+
raise ValueError(
324+
f"feature view batch source is {type(feature_view.batch_source)} not bigquery source"
325+
)
326+
327+
pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
328+
config, feature_view.batch_source
329+
)
330+
if column_names != table.column_names:
331+
raise ValueError(
332+
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
333+
f"The columns are expected to be (in this order): {column_names}."
334+
)
335+
336+
if table.schema != pa_schema:
337+
table = table.cast(pa_schema)
338+
339+
client = _get_bigquery_client(
340+
project=config.offline_store.project_id,
341+
location=config.offline_store.location,
342+
)
343+
344+
job_config = bigquery.LoadJobConfig(
345+
source_format=bigquery.SourceFormat.PARQUET,
346+
schema=arrow_schema_to_bq_schema(pa_schema),
347+
write_disposition="WRITE_APPEND", # Default but included for clarity
348+
)
349+
350+
with tempfile.TemporaryFile() as parquet_temp_file:
351+
pyarrow.parquet.write_table(table=table, where=parquet_temp_file)
352+
353+
parquet_temp_file.seek(0)
354+
355+
client.load_table_from_file(
356+
file_obj=parquet_temp_file,
357+
destination=feature_view.batch_source.table,
358+
job_config=job_config,
359+
)
360+
306361

307362
class BigQueryRetrievalJob(RetrievalJob):
308363
def __init__(

sdk/python/feast/infra/offline_stores/file.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from feast.infra.offline_stores.offline_utils import (
2929
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL,
30+
get_pyarrow_schema_from_batch_source,
3031
)
3132
from feast.infra.provider import (
3233
_get_requested_feature_views_to_features_dict,
@@ -408,7 +409,7 @@ def write_logged_features(
408409
def offline_write_batch(
409410
config: RepoConfig,
410411
feature_view: FeatureView,
411-
data: pyarrow.Table,
412+
table: pyarrow.Table,
412413
progress: Optional[Callable[[int], Any]],
413414
):
414415
if not feature_view.batch_source:
@@ -423,20 +424,27 @@ def offline_write_batch(
423424
raise ValueError(
424425
f"feature view batch source is {type(feature_view.batch_source)} not file source"
425426
)
427+
428+
pa_schema, column_names = get_pyarrow_schema_from_batch_source(
429+
config, feature_view.batch_source
430+
)
431+
if column_names != table.column_names:
432+
raise ValueError(
433+
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
434+
f"The columns are expected to be (in this order): {column_names}."
435+
)
436+
426437
file_options = feature_view.batch_source.file_options
427438
filesystem, path = FileSource.create_filesystem_and_path(
428439
file_options.uri, file_options.s3_endpoint_override
429440
)
430-
431441
prev_table = pyarrow.parquet.read_table(path, memory_map=True)
432-
if prev_table.column_names != data.column_names:
433-
raise ValueError(
434-
f"Input dataframe has incorrect schema or wrong order, expected columns are: {prev_table.column_names}"
435-
)
436-
if data.schema != prev_table.schema:
437-
data = data.cast(prev_table.schema)
438-
new_table = pyarrow.concat_tables([data, prev_table])
439-
writer = pyarrow.parquet.ParquetWriter(path, data.schema, filesystem=filesystem)
442+
if table.schema != prev_table.schema:
443+
table = table.cast(prev_table.schema)
444+
new_table = pyarrow.concat_tables([table, prev_table])
445+
writer = pyarrow.parquet.ParquetWriter(
446+
path, table.schema, filesystem=filesystem
447+
)
440448
writer.write_table(new_table)
441449
writer.close()
442450

sdk/python/feast/infra/offline_stores/offline_store.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def write_logged_features(
275275
def offline_write_batch(
276276
config: RepoConfig,
277277
feature_view: FeatureView,
278-
data: pyarrow.Table,
278+
table: pyarrow.Table,
279279
progress: Optional[Callable[[int], Any]],
280280
):
281281
"""
@@ -286,8 +286,8 @@ def offline_write_batch(
286286
287287
Args:
288288
config: Repo configuration object
289-
table: FeatureView to write the data to.
290-
data: pyarrow table containing feature data and timestamp column for historical feature retrieval
289+
feature_view: FeatureView to write the data to.
290+
table: pyarrow table containing feature data and timestamp column for historical feature retrieval
291291
progress: Optional function to be called once every mini-batch of rows is written to
292292
the online store. Can be used to display progress.
293293
"""

sdk/python/feast/infra/offline_stores/offline_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55

66
import numpy as np
77
import pandas as pd
8+
import pyarrow as pa
89
from jinja2 import BaseLoader, Environment
910
from pandas import Timestamp
1011

12+
from feast.data_source import DataSource
1113
from feast.errors import (
1214
EntityTimestampInferenceException,
1315
FeastEntityDFMissingColumnsError,
@@ -17,6 +19,8 @@
1719
from feast.infra.offline_stores.offline_store import OfflineStore
1820
from feast.infra.provider import _get_requested_feature_views_to_features_dict
1921
from feast.registry import BaseRegistry
22+
from feast.repo_config import RepoConfig
23+
from feast.type_map import feast_value_type_to_pa
2024
from feast.utils import to_naive_utc
2125

2226
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL = "event_timestamp"
@@ -217,3 +221,25 @@ def get_offline_store_from_config(offline_store_config: Any) -> OfflineStore:
217221
class_name = qualified_name.replace("Config", "")
218222
offline_store_class = import_class(module_name, class_name, "OfflineStore")
219223
return offline_store_class()
224+
225+
226+
def get_pyarrow_schema_from_batch_source(
227+
config: RepoConfig, batch_source: DataSource
228+
) -> Tuple[pa.Schema, List[str]]:
229+
"""Returns the pyarrow schema and column names for the given batch source."""
230+
column_names_and_types = batch_source.get_table_column_names_and_types(config)
231+
232+
pa_schema = []
233+
column_names = []
234+
for column_name, column_type in column_names_and_types:
235+
pa_schema.append(
236+
(
237+
column_name,
238+
feast_value_type_to_pa(
239+
batch_source.source_datatype_to_feast_value_type()(column_type)
240+
),
241+
)
242+
)
243+
column_names.append(column_name)
244+
245+
return pa.schema(pa_schema), column_names

sdk/python/feast/infra/offline_stores/redshift.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
from feast.registry import BaseRegistry
4343
from feast.repo_config import FeastConfigBaseModel, RepoConfig
4444
from feast.saved_dataset import SavedDatasetStorage
45-
from feast.type_map import feast_value_type_to_pa, redshift_to_feast_value_type
4645
from feast.usage import log_exceptions_and_usage
4746

4847

@@ -318,33 +317,23 @@ def offline_write_batch(
318317
raise ValueError(
319318
f"feature view batch source is {type(feature_view.batch_source)} not redshift source"
320319
)
321-
redshift_options = feature_view.batch_source.redshift_options
322-
redshift_client = aws_utils.get_redshift_data_client(
323-
config.offline_store.region
324-
)
325320

326-
column_name_to_type = feature_view.batch_source.get_table_column_names_and_types(
327-
config
321+
pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
322+
config, feature_view.batch_source
328323
)
329-
pa_schema_list = []
330-
column_names = []
331-
for column_name, redshift_type in column_name_to_type:
332-
pa_schema_list.append(
333-
(
334-
column_name,
335-
feast_value_type_to_pa(redshift_to_feast_value_type(redshift_type)),
336-
)
337-
)
338-
column_names.append(column_name)
339-
pa_schema = pa.schema(pa_schema_list)
340324
if column_names != table.column_names:
341325
raise ValueError(
342-
f"Input dataframe has incorrect schema or wrong order, expected columns are: {column_names}"
326+
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
327+
f"The columns are expected to be (in this order): {column_names}."
343328
)
344329

345330
if table.schema != pa_schema:
346331
table = table.cast(pa_schema)
347332

333+
redshift_options = feature_view.batch_source.redshift_options
334+
redshift_client = aws_utils.get_redshift_data_client(
335+
config.offline_store.region
336+
)
348337
s3_resource = aws_utils.get_s3_resource(config.offline_store.region)
349338

350339
aws_utils.upload_arrow_table_to_redshift(

sdk/python/feast/infra/offline_stores/snowflake.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from datetime import datetime
44
from pathlib import Path
55
from typing import (
6+
Any,
67
Callable,
78
ContextManager,
89
Dict,
@@ -306,6 +307,47 @@ def write_logged_features(
306307
auto_create_table=True,
307308
)
308309

310+
@staticmethod
311+
def offline_write_batch(
312+
config: RepoConfig,
313+
feature_view: FeatureView,
314+
table: pyarrow.Table,
315+
progress: Optional[Callable[[int], Any]],
316+
):
317+
if not feature_view.batch_source:
318+
raise ValueError(
319+
"feature view does not have a batch source to persist offline data"
320+
)
321+
if not isinstance(config.offline_store, SnowflakeOfflineStoreConfig):
322+
raise ValueError(
323+
f"offline store config is of type {type(config.offline_store)} when snowflake type required"
324+
)
325+
if not isinstance(feature_view.batch_source, SnowflakeSource):
326+
raise ValueError(
327+
f"feature view batch source is {type(feature_view.batch_source)} not snowflake source"
328+
)
329+
330+
pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
331+
config, feature_view.batch_source
332+
)
333+
if column_names != table.column_names:
334+
raise ValueError(
335+
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
336+
f"The columns are expected to be (in this order): {column_names}."
337+
)
338+
339+
if table.schema != pa_schema:
340+
table = table.cast(pa_schema)
341+
342+
snowflake_conn = get_snowflake_conn(config.offline_store)
343+
344+
write_pandas(
345+
snowflake_conn,
346+
table.to_pandas(),
347+
table_name=feature_view.batch_source.table,
348+
auto_create_table=True,
349+
)
350+
309351

310352
class SnowflakeRetrievalJob(RetrievalJob):
311353
def __init__(

sdk/python/tests/integration/feature_repos/repo_configuration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676

7777
OFFLINE_STORE_TO_PROVIDER_CONFIG: Dict[str, DataSourceCreator] = {
7878
"file": ("local", FileDataSourceCreator),
79-
"gcp": ("gcp", BigQueryDataSourceCreator),
79+
"bigquery": ("gcp", BigQueryDataSourceCreator),
8080
"redshift": ("aws", RedshiftDataSourceCreator),
8181
"snowflake": ("aws", RedshiftDataSourceCreator),
8282
}

sdk/python/tests/integration/offline_store/test_offline_write.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_writing_incorrect_schema_fails(environment, universal_data_sources):
109109

110110

111111
@pytest.mark.integration
112-
@pytest.mark.universal_offline_stores(only=["file", "redshift"])
112+
@pytest.mark.universal_offline_stores
113113
@pytest.mark.universal_online_stores(only=["sqlite"])
114114
def test_writing_consecutively_to_offline_store(environment, universal_data_sources):
115115
store = environment.feature_store

sdk/python/tests/integration/offline_store/test_push_offline_retrieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
@pytest.mark.integration
19-
@pytest.mark.universal_offline_stores(only=["file", "redshift"])
19+
@pytest.mark.universal_offline_stores
2020
@pytest.mark.universal_online_stores(only=["sqlite"])
2121
def test_push_features_and_read_from_offline_store(environment, universal_data_sources):
2222
store = environment.feature_store

0 commit comments

Comments
 (0)