Skip to content

Commit 5748a8b

Browse files
authored
feat: Push to Redshift batch source offline store directly (#2819)
* Skaffolding for offline store push Signed-off-by: Kevin Zhang <[email protected]> * LInt Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]> * File source offline push Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]> * Address review comments Signed-off-by: Kevin Zhang <[email protected]> * Add redshift function Signed-off-by: Kevin Zhang <[email protected]> * Add redshift Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]> * Lint Signed-off-by: Kevin Zhang <[email protected]> * fix Signed-off-by: Kevin Zhang <[email protected]> * fix Signed-off-by: Kevin Zhang <[email protected]> * Fix errors Signed-off-by: Kevin Zhang <[email protected]> * Fix test Signed-off-by: Kevin Zhang <[email protected]> * Fix test Signed-off-by: Kevin Zhang <[email protected]> * Fix test Signed-off-by: Kevin Zhang <[email protected]> * add back in Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]> * lint Signed-off-by: Kevin Zhang <[email protected]> * Address review comments Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]>
1 parent 17303d3 commit 5748a8b

File tree

7 files changed

+189
-43
lines changed

7 files changed

+189
-43
lines changed

sdk/python/feast/feature_store.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,6 +1423,8 @@ def _write_to_offline_store(
14231423
feature_view = self.get_feature_view(
14241424
feature_view_name, allow_registry_cache=allow_registry_cache
14251425
)
1426+
df.reset_index(drop=True)
1427+
14261428
table = pa.Table.from_pandas(df)
14271429
provider = self._get_provider()
14281430
provider.ingest_df_to_offline_store(feature_view, table)

sdk/python/feast/infra/offline_stores/redshift.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from datetime import datetime
44
from pathlib import Path
55
from typing import (
6+
Any,
67
Callable,
78
ContextManager,
89
Dict,
@@ -41,6 +42,7 @@
4142
from feast.registry import BaseRegistry
4243
from feast.repo_config import FeastConfigBaseModel, RepoConfig
4344
from feast.saved_dataset import SavedDatasetStorage
45+
from feast.type_map import feast_value_type_to_pa, redshift_to_feast_value_type
4446
from feast.usage import log_exceptions_and_usage
4547

4648

@@ -297,6 +299,69 @@ def write_logged_features(
297299
fail_if_exists=False,
298300
)
299301

302+
@staticmethod
303+
def offline_write_batch(
304+
config: RepoConfig,
305+
feature_view: FeatureView,
306+
table: pyarrow.Table,
307+
progress: Optional[Callable[[int], Any]],
308+
):
309+
if not feature_view.batch_source:
310+
raise ValueError(
311+
"feature view does not have a batch source to persist offline data"
312+
)
313+
if not isinstance(config.offline_store, RedshiftOfflineStoreConfig):
314+
raise ValueError(
315+
f"offline store config is of type {type(config.offline_store)} when redshift type required"
316+
)
317+
if not isinstance(feature_view.batch_source, RedshiftSource):
318+
raise ValueError(
319+
f"feature view batch source is {type(feature_view.batch_source)} not redshift source"
320+
)
321+
redshift_options = feature_view.batch_source.redshift_options
322+
redshift_client = aws_utils.get_redshift_data_client(
323+
config.offline_store.region
324+
)
325+
326+
column_name_to_type = feature_view.batch_source.get_table_column_names_and_types(
327+
config
328+
)
329+
pa_schema_list = []
330+
column_names = []
331+
for column_name, redshift_type in column_name_to_type:
332+
pa_schema_list.append(
333+
(
334+
column_name,
335+
feast_value_type_to_pa(redshift_to_feast_value_type(redshift_type)),
336+
)
337+
)
338+
column_names.append(column_name)
339+
pa_schema = pa.schema(pa_schema_list)
340+
if column_names != table.column_names:
341+
raise ValueError(
342+
f"Input dataframe has incorrect schema or wrong order, expected columns are: {column_names}"
343+
)
344+
345+
if table.schema != pa_schema:
346+
table = table.cast(pa_schema)
347+
348+
s3_resource = aws_utils.get_s3_resource(config.offline_store.region)
349+
350+
aws_utils.upload_arrow_table_to_redshift(
351+
table=table,
352+
redshift_data_client=redshift_client,
353+
cluster_id=config.offline_store.cluster_id,
354+
database=redshift_options.database
355+
or config.offline_store.database, # Users can define database in the source if needed but it's not required.
356+
user=config.offline_store.user,
357+
s3_resource=s3_resource,
358+
s3_path=f"{config.offline_store.s3_staging_location}/push/{uuid.uuid4()}.parquet",
359+
iam_role=config.offline_store.iam_role,
360+
table_name=redshift_options.table,
361+
schema=pa_schema,
362+
fail_if_exists=False,
363+
)
364+
300365

301366
class RedshiftRetrievalJob(RetrievalJob):
302367
def __init__(

sdk/python/feast/infra/passthrough_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,14 @@ def online_write_batch(
103103
def offline_write_batch(
104104
self,
105105
config: RepoConfig,
106-
table: FeatureView,
106+
feature_view: FeatureView,
107107
data: pa.Table,
108108
progress: Optional[Callable[[int], Any]],
109109
) -> None:
110110
set_usage_attribute("provider", self.__class__.__name__)
111111

112112
if self.offline_store:
113-
self.offline_store.offline_write_batch(config, table, data, progress)
113+
self.offline_store.offline_write_batch(config, feature_view, data, progress)
114114

115115
@log_exceptions_and_usage(sampler=RatioSampler(ratio=0.001))
116116
def online_read(

sdk/python/feast/infra/utils/aws_utils.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,15 @@ def upload_df_to_redshift(
235235
)
236236

237237

238+
def delete_redshift_table(
239+
redshift_data_client, cluster_id: str, database: str, user: str, table_name: str,
240+
):
241+
drop_query = f"DROP {table_name} IF EXISTS"
242+
execute_redshift_statement(
243+
redshift_data_client, cluster_id, database, user, drop_query,
244+
)
245+
246+
238247
def upload_arrow_table_to_redshift(
239248
table: Union[pyarrow.Table, Path],
240249
redshift_data_client,
@@ -320,7 +329,7 @@ def upload_arrow_table_to_redshift(
320329
cluster_id,
321330
database,
322331
user,
323-
f"{create_query}; {copy_query}",
332+
f"{create_query}; {copy_query};",
324333
)
325334
finally:
326335
# Clean up S3 temporary data
@@ -371,6 +380,53 @@ def temporarily_upload_df_to_redshift(
371380
)
372381

373382

383+
@contextlib.contextmanager
384+
def temporarily_upload_arrow_table_to_redshift(
385+
table: Union[pyarrow.Table, Path],
386+
redshift_data_client,
387+
cluster_id: str,
388+
database: str,
389+
user: str,
390+
s3_resource,
391+
iam_role: str,
392+
s3_path: str,
393+
table_name: str,
394+
schema: Optional[pyarrow.Schema] = None,
395+
fail_if_exists: bool = True,
396+
) -> Iterator[None]:
397+
"""Uploads a Arrow Table to Redshift as a new table with cleanup logic.
398+
399+
This is essentially the same as upload_arrow_table_to_redshift (check out its docstring for full details),
400+
but unlike it this method is a generator and should be used with `with` block. For example:
401+
402+
>>> with temporarily_upload_arrow_table_to_redshift(...): # doctest: +SKIP
403+
>>> # Use `table_name` table in Redshift here
404+
>>> # `table_name` will not exist at this point, since it's cleaned up by the `with` block
405+
406+
"""
407+
# Upload the dataframe to Redshift
408+
upload_arrow_table_to_redshift(
409+
table,
410+
redshift_data_client,
411+
cluster_id,
412+
database,
413+
user,
414+
s3_resource,
415+
s3_path,
416+
iam_role,
417+
table_name,
418+
schema,
419+
fail_if_exists,
420+
)
421+
422+
yield
423+
424+
# Clean up the uploaded Redshift table
425+
execute_redshift_statement(
426+
redshift_data_client, cluster_id, database, user, f"DROP TABLE {table_name}",
427+
)
428+
429+
374430
def download_s3_directory(s3_resource, bucket: str, key: str, local_dir: str):
375431
"""Download the S3 directory to a local disk"""
376432
bucket_obj = s3_resource.Bucket(bucket)

sdk/python/tests/conftest.py

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from tests.integration.feature_repos.repo_configuration import (
3434
AVAILABLE_OFFLINE_STORES,
3535
AVAILABLE_ONLINE_STORES,
36+
OFFLINE_STORE_TO_PROVIDER_CONFIG,
3637
Environment,
3738
TestData,
3839
construct_test_environment,
@@ -196,16 +197,24 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):
196197
"""
197198
if "environment" in metafunc.fixturenames:
198199
markers = {m.name: m for m in metafunc.definition.own_markers}
199-
200+
offline_stores = None
200201
if "universal_offline_stores" in markers:
201-
offline_stores = AVAILABLE_OFFLINE_STORES
202+
# Offline stores can be explicitly requested
203+
if "only" in markers["universal_offline_stores"].kwargs:
204+
offline_stores = [
205+
OFFLINE_STORE_TO_PROVIDER_CONFIG.get(store_name)
206+
for store_name in markers["universal_offline_stores"].kwargs["only"]
207+
if store_name in OFFLINE_STORE_TO_PROVIDER_CONFIG
208+
]
209+
else:
210+
offline_stores = AVAILABLE_OFFLINE_STORES
202211
else:
203212
# default offline store for testing online store dimension
204213
offline_stores = [("local", FileDataSourceCreator)]
205214

206215
online_stores = None
207216
if "universal_online_stores" in markers:
208-
# Online stores are explicitly requested
217+
# Online stores can be explicitly requested
209218
if "only" in markers["universal_online_stores"].kwargs:
210219
online_stores = [
211220
AVAILABLE_ONLINE_STORES.get(store_name)
@@ -240,40 +249,44 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):
240249
extra_dimensions.append({"go_feature_retrieval": True})
241250

242251
configs = []
243-
for provider, offline_store_creator in offline_stores:
244-
for online_store, online_store_creator in online_stores:
245-
for dim in extra_dimensions:
246-
config = {
247-
"provider": provider,
248-
"offline_store_creator": offline_store_creator,
249-
"online_store": online_store,
250-
"online_store_creator": online_store_creator,
251-
**dim,
252-
}
253-
# temporary Go works only with redis
254-
if config.get("go_feature_retrieval") and (
255-
not isinstance(online_store, dict)
256-
or online_store["type"] != "redis"
257-
):
258-
continue
259-
260-
# aws lambda works only with dynamo
261-
if (
262-
config.get("python_feature_server")
263-
and config.get("provider") == "aws"
264-
and (
252+
if offline_stores:
253+
for provider, offline_store_creator in offline_stores:
254+
for online_store, online_store_creator in online_stores:
255+
for dim in extra_dimensions:
256+
config = {
257+
"provider": provider,
258+
"offline_store_creator": offline_store_creator,
259+
"online_store": online_store,
260+
"online_store_creator": online_store_creator,
261+
**dim,
262+
}
263+
# temporary Go works only with redis
264+
if config.get("go_feature_retrieval") and (
265265
not isinstance(online_store, dict)
266-
or online_store["type"] != "dynamodb"
267-
)
268-
):
269-
continue
270-
271-
c = IntegrationTestRepoConfig(**config)
272-
273-
if c not in _config_cache:
274-
_config_cache[c] = c
275-
276-
configs.append(_config_cache[c])
266+
or online_store["type"] != "redis"
267+
):
268+
continue
269+
270+
# aws lambda works only with dynamo
271+
if (
272+
config.get("python_feature_server")
273+
and config.get("provider") == "aws"
274+
and (
275+
not isinstance(online_store, dict)
276+
or online_store["type"] != "dynamodb"
277+
)
278+
):
279+
continue
280+
281+
c = IntegrationTestRepoConfig(**config)
282+
283+
if c not in _config_cache:
284+
_config_cache[c] = c
285+
286+
configs.append(_config_cache[c])
287+
else:
288+
# No offline stores requested -> setting the default or first available
289+
offline_stores = [("local", FileDataSourceCreator)]
277290

278291
metafunc.parametrize(
279292
"environment", configs, indirect=True, ids=[str(c) for c in configs]

sdk/python/tests/integration/feature_repos/repo_configuration.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,13 @@
7474
"connection_string": "127.0.0.1:6001,127.0.0.1:6002,127.0.0.1:6003",
7575
}
7676

77+
OFFLINE_STORE_TO_PROVIDER_CONFIG: Dict[str, DataSourceCreator] = {
78+
"file": ("local", FileDataSourceCreator),
79+
"gcp": ("gcp", BigQueryDataSourceCreator),
80+
"redshift": ("aws", RedshiftDataSourceCreator),
81+
"snowflake": ("aws", RedshiftDataSourceCreator),
82+
}
83+
7784
AVAILABLE_OFFLINE_STORES: List[Tuple[str, Type[DataSourceCreator]]] = [
7885
("local", FileDataSourceCreator),
7986
]

sdk/python/tests/integration/offline_store/test_offline_write.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111

1212

1313
@pytest.mark.integration
14-
@pytest.mark.universal_online_stores
15-
def test_writing_incorrect_order_fails(environment, universal_data_sources):
14+
@pytest.mark.universal_offline_stores(only=["file", "redshift"])
15+
@pytest.mark.universal_online_stores(only=["sqlite"])
16+
def test_writing_columns_in_incorrect_order_fails(environment, universal_data_sources):
1617
# TODO(kevjumba) handle incorrect order later, for now schema must be in the order that the filesource is in
1718
store = environment.feature_store
1819
_, _, data_sources = universal_data_sources
@@ -59,7 +60,8 @@ def test_writing_incorrect_order_fails(environment, universal_data_sources):
5960

6061

6162
@pytest.mark.integration
62-
@pytest.mark.universal_online_stores
63+
@pytest.mark.universal_offline_stores(only=["file", "redshift"])
64+
@pytest.mark.universal_online_stores(only=["sqlite"])
6365
def test_writing_incorrect_schema_fails(environment, universal_data_sources):
6466
# TODO(kevjumba) handle incorrect order later, for now schema must be in the order that the filesource is in
6567
store = environment.feature_store
@@ -107,7 +109,8 @@ def test_writing_incorrect_schema_fails(environment, universal_data_sources):
107109

108110

109111
@pytest.mark.integration
110-
@pytest.mark.universal_online_stores
112+
@pytest.mark.universal_offline_stores(only=["file", "redshift"])
113+
@pytest.mark.universal_online_stores(only=["sqlite"])
111114
def test_writing_consecutively_to_offline_store(environment, universal_data_sources):
112115
store = environment.feature_store
113116
_, _, data_sources = universal_data_sources

0 commit comments

Comments
 (0)