Skip to content

Commit 00ed65a

Browse files
fix: Fix broken proto conversion methods for data sources (#2603)
* Fix Snowflake proto conversion and add test Signed-off-by: Felix Wang <[email protected]> * Add proto conversion test for FileSource Signed-off-by: Felix Wang <[email protected]> * Fix Redshift proto conversion and add test Signed-off-by: Felix Wang <[email protected]> * Add proto conversion test for BigQuerySource Signed-off-by: Felix Wang <[email protected]> * Fix tests to use DataSource.from_proto Signed-off-by: Felix Wang <[email protected]> * Add proto conversion test for KafkaSource Signed-off-by: Felix Wang <[email protected]> * Add proto conversion test for KinesisSource Signed-off-by: Felix Wang <[email protected]> * Add proto conversion test for PushSource Signed-off-by: Felix Wang <[email protected]> * Add proto conversion test for PushSource Signed-off-by: Felix Wang <[email protected]> * Add name and other fixes Signed-off-by: Felix Wang <[email protected]> * Fix proto conversion tests Signed-off-by: Felix Wang <[email protected]> * Add tags to test Signed-off-by: Felix Wang <[email protected]> * Fix BigQuerySource bug Signed-off-by: Felix Wang <[email protected]> * Fix bug in RedshiftSource and TrinoSource Signed-off-by: Felix Wang <[email protected]> * Remove references to event_timestamp_column Signed-off-by: Felix Wang <[email protected]>
1 parent c94a69c commit 00ed65a

File tree

28 files changed

+313
-412
lines changed

28 files changed

+313
-412
lines changed

go/cmd/server/logging/feature_repo/example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# for more info.
1010
driver_hourly_stats = FileSource(
1111
path="driver_stats.parquet",
12-
event_timestamp_column="event_timestamp",
12+
timestamp_field="event_timestamp",
1313
created_timestamp_column="created",
1414
)
1515

sdk/python/feast/data_source.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def __init__(
409409

410410
if _message_format is None:
411411
raise ValueError("Message format must be specified for Kafka source")
412-
print("Asdfasdf")
412+
413413
super().__init__(
414414
event_timestamp_column=_event_timestamp_column,
415415
created_timestamp_column=created_timestamp_column,
@@ -467,7 +467,9 @@ def from_proto(data_source: DataSourceProto):
467467
description=data_source.description,
468468
tags=dict(data_source.tags),
469469
owner=data_source.owner,
470-
batch_source=DataSource.from_proto(data_source.batch_source),
470+
batch_source=DataSource.from_proto(data_source.batch_source)
471+
if data_source.batch_source
472+
else None,
471473
)
472474

473475
def to_proto(self) -> DataSourceProto:
@@ -500,17 +502,20 @@ class RequestSource(DataSource):
500502
"""
501503
RequestSource that can be used to provide input features for on demand transforms
502504
503-
Args:
505+
Attributes:
504506
name: Name of the request data source
505-
schema Union[Dict[str, ValueType], List[Field]]: Schema mapping from the input feature name to a ValueType
506-
description (optional): A human-readable description.
507-
tags (optional): A dictionary of key-value pairs to store arbitrary metadata.
508-
owner (optional): The owner of the request data source, typically the email of the primary
507+
schema: Schema mapping from the input feature name to a ValueType
508+
description: A human-readable description.
509+
tags: A dictionary of key-value pairs to store arbitrary metadata.
510+
owner: The owner of the request data source, typically the email of the primary
509511
maintainer.
510512
"""
511513

512514
name: str
513515
schema: List[Field]
516+
description: str
517+
tags: Dict[str, str]
518+
owner: str
514519

515520
def __init__(
516521
self,
@@ -697,7 +702,9 @@ def from_proto(data_source: DataSourceProto):
697702
description=data_source.description,
698703
tags=dict(data_source.tags),
699704
owner=data_source.owner,
700-
batch_source=DataSource.from_proto(data_source.batch_source),
705+
batch_source=DataSource.from_proto(data_source.batch_source)
706+
if data_source.batch_source
707+
else None,
701708
)
702709

703710
@staticmethod

sdk/python/feast/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def update_entities_with_inferred_types_from_feature_views(
7171
def update_data_sources_with_inferred_event_timestamp_col(
7272
data_sources: List[DataSource], config: RepoConfig
7373
) -> None:
74-
ERROR_MSG_PREFIX = "Unable to infer DataSource event_timestamp_column"
74+
ERROR_MSG_PREFIX = "Unable to infer DataSource timestamp_field"
7575

7676
for data_source in data_sources:
7777
if isinstance(data_source, RequestSource):

sdk/python/feast/infra/offline_stores/bigquery.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def pull_latest_from_table_or_query(
8383
data_source: DataSource,
8484
join_key_columns: List[str],
8585
feature_name_columns: List[str],
86-
event_timestamp_column: str,
86+
timestamp_field: str,
8787
created_timestamp_column: Optional[str],
8888
start_date: datetime,
8989
end_date: datetime,
@@ -96,7 +96,7 @@ def pull_latest_from_table_or_query(
9696
partition_by_join_key_string = (
9797
"PARTITION BY " + partition_by_join_key_string
9898
)
99-
timestamps = [event_timestamp_column]
99+
timestamps = [timestamp_field]
100100
if created_timestamp_column:
101101
timestamps.append(created_timestamp_column)
102102
timestamp_desc_string = " DESC, ".join(timestamps) + " DESC"
@@ -114,7 +114,7 @@ def pull_latest_from_table_or_query(
114114
SELECT {field_string},
115115
ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS _feast_row
116116
FROM {from_expression}
117-
WHERE {event_timestamp_column} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}')
117+
WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}')
118118
)
119119
WHERE _feast_row = 1
120120
"""
@@ -131,7 +131,7 @@ def pull_all_from_table_or_query(
131131
data_source: DataSource,
132132
join_key_columns: List[str],
133133
feature_name_columns: List[str],
134-
event_timestamp_column: str,
134+
timestamp_field: str,
135135
start_date: datetime,
136136
end_date: datetime,
137137
) -> RetrievalJob:
@@ -143,12 +143,12 @@ def pull_all_from_table_or_query(
143143
location=config.offline_store.location,
144144
)
145145
field_string = ", ".join(
146-
join_key_columns + feature_name_columns + [event_timestamp_column]
146+
join_key_columns + feature_name_columns + [timestamp_field]
147147
)
148148
query = f"""
149149
SELECT {field_string}
150150
FROM {from_expression}
151-
WHERE {event_timestamp_column} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}')
151+
WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}')
152152
"""
153153
return BigQueryRetrievalJob(
154154
query=query, client=client, config=config, full_feature_names=False,
@@ -583,9 +583,9 @@ def _get_bigquery_client(project: Optional[str] = None, location: Optional[str]
583583
584584
1. We first join the current feature_view to the entity dataframe that has been passed.
585585
This JOIN has the following logic:
586-
- For each row of the entity dataframe, only keep the rows where the `event_timestamp_column`
586+
- For each row of the entity dataframe, only keep the rows where the `timestamp_field`
587587
is less than the one provided in the entity dataframe
588-
- If there a TTL for the current feature_view, also keep the rows where the `event_timestamp_column`
588+
- If there a TTL for the current feature_view, also keep the rows where the `timestamp_field`
589589
is higher the the one provided minus the TTL
590590
- For each row, Join on the entity key and retrieve the `entity_row_unique_id` that has been
591591
computed previously
@@ -596,16 +596,16 @@ def _get_bigquery_client(project: Optional[str] = None, location: Optional[str]
596596
597597
{{ featureview.name }}__subquery AS (
598598
SELECT
599-
{{ featureview.event_timestamp_column }} as event_timestamp,
599+
{{ featureview.timestamp_field }} as event_timestamp,
600600
{{ featureview.created_timestamp_column ~ ' as created_timestamp,' if featureview.created_timestamp_column else '' }}
601601
{{ featureview.entity_selections | join(', ')}}{% if featureview.entity_selections %},{% else %}{% endif %}
602602
{% for feature in featureview.features %}
603603
{{ feature }} as {% if full_feature_names %}{{ featureview.name }}__{{featureview.field_mapping.get(feature, feature)}}{% else %}{{ featureview.field_mapping.get(feature, feature) }}{% endif %}{% if loop.last %}{% else %}, {% endif %}
604604
{% endfor %}
605605
FROM {{ featureview.table_subquery }}
606-
WHERE {{ featureview.event_timestamp_column }} <= '{{ featureview.max_event_timestamp }}'
606+
WHERE {{ featureview.timestamp_field }} <= '{{ featureview.max_event_timestamp }}'
607607
{% if featureview.ttl == 0 %}{% else %}
608-
AND {{ featureview.event_timestamp_column }} >= '{{ featureview.min_event_timestamp }}'
608+
AND {{ featureview.timestamp_field }} >= '{{ featureview.min_event_timestamp }}'
609609
{% endif %}
610610
),
611611

sdk/python/feast/infra/offline_stores/bigquery_source.py

Lines changed: 9 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,9 @@ def __eq__(self, other):
9999
)
100100

101101
return (
102-
self.name == other.name
103-
and self.bigquery_options.table == other.bigquery_options.table
104-
and self.bigquery_options.query == other.bigquery_options.query
105-
and self.timestamp_field == other.timestamp_field
106-
and self.created_timestamp_column == other.created_timestamp_column
107-
and self.field_mapping == other.field_mapping
108-
and self.description == other.description
109-
and self.tags == other.tags
110-
and self.owner == other.owner
102+
super().__eq__(other)
103+
and self.table == other.table
104+
and self.query == other.query
111105
)
112106

113107
@property
@@ -120,7 +114,6 @@ def query(self):
120114

121115
@staticmethod
122116
def from_proto(data_source: DataSourceProto):
123-
124117
assert data_source.HasField("bigquery_options")
125118

126119
return BigQuerySource(
@@ -144,11 +137,10 @@ def to_proto(self) -> DataSourceProto:
144137
description=self.description,
145138
tags=self.tags,
146139
owner=self.owner,
140+
timestamp_field=self.timestamp_field,
141+
created_timestamp_column=self.created_timestamp_column,
147142
)
148143

149-
data_source_proto.timestamp_field = self.timestamp_field
150-
data_source_proto.created_timestamp_column = self.created_timestamp_column
151-
152144
return data_source_proto
153145

154146
def validate(self, config: RepoConfig):
@@ -179,7 +171,7 @@ def get_table_column_names_and_types(
179171
from google.cloud import bigquery
180172

181173
client = bigquery.Client()
182-
if self.table is not None:
174+
if self.table:
183175
schema = client.get_table(self.table).schema
184176
if not isinstance(schema[0], bigquery.schema.SchemaField):
185177
raise TypeError("Could not parse BigQuery table schema.")
@@ -200,42 +192,14 @@ def get_table_column_names_and_types(
200192

201193
class BigQueryOptions:
202194
"""
203-
DataSource BigQuery options used to source features from BigQuery query
195+
Configuration options for a BigQuery data source.
204196
"""
205197

206198
def __init__(
207199
self, table: Optional[str], query: Optional[str],
208200
):
209-
self._table = table
210-
self._query = query
211-
212-
@property
213-
def query(self):
214-
"""
215-
Returns the BigQuery SQL query referenced by this source
216-
"""
217-
return self._query
218-
219-
@query.setter
220-
def query(self, query):
221-
"""
222-
Sets the BigQuery SQL query referenced by this source
223-
"""
224-
self._query = query
225-
226-
@property
227-
def table(self):
228-
"""
229-
Returns the table ref of this BQ table
230-
"""
231-
return self._table
232-
233-
@table.setter
234-
def table(self, table):
235-
"""
236-
Sets the table ref of this BQ table
237-
"""
238-
self._table = table
201+
self.table = table or ""
202+
self.query = query or ""
239203

240204
@classmethod
241205
def from_proto(cls, bigquery_options_proto: DataSourceProto.BigQueryOptions):
@@ -248,7 +212,6 @@ def from_proto(cls, bigquery_options_proto: DataSourceProto.BigQueryOptions):
248212
Returns:
249213
Returns a BigQueryOptions object based on the bigquery_options protobuf
250214
"""
251-
252215
bigquery_options = cls(
253216
table=bigquery_options_proto.table, query=bigquery_options_proto.query,
254217
)
@@ -262,7 +225,6 @@ def to_proto(self) -> DataSourceProto.BigQueryOptions:
262225
Returns:
263226
BigQueryOptionsProto protobuf
264227
"""
265-
266228
bigquery_options_proto = DataSourceProto.BigQueryOptions(
267229
table=self.table, query=self.query,
268230
)

sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def pull_latest_from_table_or_query(
5555
data_source: DataSource,
5656
join_key_columns: List[str],
5757
feature_name_columns: List[str],
58-
event_timestamp_column: str,
58+
timestamp_field: str,
5959
created_timestamp_column: Optional[str],
6060
start_date: datetime,
6161
end_date: datetime,
@@ -68,7 +68,7 @@ def pull_latest_from_table_or_query(
6868
partition_by_join_key_string = (
6969
"PARTITION BY " + partition_by_join_key_string
7070
)
71-
timestamps = [event_timestamp_column]
71+
timestamps = [timestamp_field]
7272
if created_timestamp_column:
7373
timestamps.append(created_timestamp_column)
7474
timestamp_desc_string = " DESC, ".join(_append_alias(timestamps, "a")) + " DESC"
@@ -87,7 +87,7 @@ def pull_latest_from_table_or_query(
8787
SELECT {a_field_string},
8888
ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS _feast_row
8989
FROM ({from_expression}) a
90-
WHERE a."{event_timestamp_column}" BETWEEN '{start_date}'::timestamptz AND '{end_date}'::timestamptz
90+
WHERE a."{timestamp_field}" BETWEEN '{start_date}'::timestamptz AND '{end_date}'::timestamptz
9191
) b
9292
WHERE _feast_row = 1
9393
"""
@@ -191,15 +191,15 @@ def pull_all_from_table_or_query(
191191
data_source: DataSource,
192192
join_key_columns: List[str],
193193
feature_name_columns: List[str],
194-
event_timestamp_column: str,
194+
timestamp_field: str,
195195
start_date: datetime,
196196
end_date: datetime,
197197
) -> RetrievalJob:
198198
assert isinstance(data_source, PostgreSQLSource)
199199
from_expression = data_source.get_table_query_string()
200200

201201
field_string = ", ".join(
202-
join_key_columns + feature_name_columns + [event_timestamp_column]
202+
join_key_columns + feature_name_columns + [timestamp_field]
203203
)
204204

205205
start_date = start_date.astimezone(tz=utc)
@@ -208,7 +208,7 @@ def pull_all_from_table_or_query(
208208
query = f"""
209209
SELECT {field_string}
210210
FROM {from_expression}
211-
WHERE "{event_timestamp_column}" BETWEEN '{start_date}'::timestamptz AND '{end_date}'::timestamptz
211+
WHERE "{timestamp_field}" BETWEEN '{start_date}'::timestamptz AND '{end_date}'::timestamptz
212212
"""
213213

214214
return PostgreSQLRetrievalJob(
@@ -415,9 +415,9 @@ def build_point_in_time_query(
415415
416416
1. We first join the current feature_view to the entity dataframe that has been passed.
417417
This JOIN has the following logic:
418-
- For each row of the entity dataframe, only keep the rows where the `event_timestamp_column`
418+
- For each row of the entity dataframe, only keep the rows where the `timestamp_field`
419419
is less than the one provided in the entity dataframe
420-
- If there a TTL for the current feature_view, also keep the rows where the `event_timestamp_column`
420+
- If there a TTL for the current feature_view, also keep the rows where the `timestamp_field`
421421
is higher the the one provided minus the TTL
422422
- For each row, Join on the entity key and retrieve the `entity_row_unique_id` that has been
423423
computed previously
@@ -428,16 +428,16 @@ def build_point_in_time_query(
428428
429429
"{{ featureview.name }}__subquery" AS (
430430
SELECT
431-
"{{ featureview.event_timestamp_column }}" as event_timestamp,
431+
"{{ featureview.timestamp_field }}" as event_timestamp,
432432
{{ '"' ~ featureview.created_timestamp_column ~ '" as created_timestamp,' if featureview.created_timestamp_column else '' }}
433433
{{ featureview.entity_selections | join(', ')}}{% if featureview.entity_selections %},{% else %}{% endif %}
434434
{% for feature in featureview.features %}
435435
"{{ feature }}" as {% if full_feature_names %}"{{ featureview.name }}__{{featureview.field_mapping.get(feature, feature)}}"{% else %}"{{ featureview.field_mapping.get(feature, feature) }}"{% endif %}{% if loop.last %}{% else %}, {% endif %}
436436
{% endfor %}
437437
FROM {{ featureview.table_subquery }} AS sub
438-
WHERE "{{ featureview.event_timestamp_column }}" <= (SELECT MAX(entity_timestamp) FROM entity_dataframe)
438+
WHERE "{{ featureview.timestamp_field }}" <= (SELECT MAX(entity_timestamp) FROM entity_dataframe)
439439
{% if featureview.ttl == 0 %}{% else %}
440-
AND "{{ featureview.event_timestamp_column }}" >= (SELECT MIN(entity_timestamp) FROM entity_dataframe) - {{ featureview.ttl }} * interval '1' second
440+
AND "{{ featureview.timestamp_field }}" >= (SELECT MIN(entity_timestamp) FROM entity_dataframe) - {{ featureview.ttl }} * interval '1' second
441441
{% endif %}
442442
),
443443

0 commit comments

Comments
 (0)