Skip to content

Commit 26f6b69

Browse files
fix: Fix file offline store logic for feature views without ttl (#2971)
* Add new test for historical retrieval with feature views with no ttl Signed-off-by: Felix Wang <[email protected]> * Fix no ttl logic Signed-off-by: Felix Wang <[email protected]>
1 parent 3ce5139 commit 26f6b69

File tree

2 files changed

+104
-7
lines changed

2 files changed

+104
-7
lines changed

sdk/python/feast/infra/offline_stores/file.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,14 @@ def _filter_ttl(
635635
)
636636
]
637637

638+
df_to_join = df_to_join.persist()
639+
else:
640+
df_to_join = df_to_join[
641+
# do not drop entity rows if one of the sources returns NaNs
642+
df_to_join[timestamp_field].isna()
643+
| (df_to_join[timestamp_field] <= df_to_join[entity_df_event_timestamp_col])
644+
]
645+
638646
df_to_join = df_to_join.persist()
639647

640648
return df_to_join

sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py

Lines changed: 96 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,60 +115,70 @@ def get_expected_training_df(
115115
entity_df.to_dict("records"), event_timestamp
116116
)
117117

118+
# Set sufficiently large ttl that it effectively functions as infinite for the calculations below.
119+
default_ttl = timedelta(weeks=52)
120+
118121
# Manually do point-in-time join of driver, customer, and order records against
119122
# the entity df
120123
for entity_row in entity_rows:
121124
customer_record = find_asof_record(
122125
customer_records,
123126
ts_key=customer_fv.batch_source.timestamp_field,
124-
ts_start=entity_row[event_timestamp] - customer_fv.ttl,
127+
ts_start=entity_row[event_timestamp]
128+
- get_feature_view_ttl(customer_fv, default_ttl),
125129
ts_end=entity_row[event_timestamp],
126130
filter_keys=["customer_id"],
127131
filter_values=[entity_row["customer_id"]],
128132
)
129133
driver_record = find_asof_record(
130134
driver_records,
131135
ts_key=driver_fv.batch_source.timestamp_field,
132-
ts_start=entity_row[event_timestamp] - driver_fv.ttl,
136+
ts_start=entity_row[event_timestamp]
137+
- get_feature_view_ttl(driver_fv, default_ttl),
133138
ts_end=entity_row[event_timestamp],
134139
filter_keys=["driver_id"],
135140
filter_values=[entity_row["driver_id"]],
136141
)
137142
order_record = find_asof_record(
138143
order_records,
139144
ts_key=customer_fv.batch_source.timestamp_field,
140-
ts_start=entity_row[event_timestamp] - order_fv.ttl,
145+
ts_start=entity_row[event_timestamp]
146+
- get_feature_view_ttl(order_fv, default_ttl),
141147
ts_end=entity_row[event_timestamp],
142148
filter_keys=["customer_id", "driver_id"],
143149
filter_values=[entity_row["customer_id"], entity_row["driver_id"]],
144150
)
145151
origin_record = find_asof_record(
146152
location_records,
147153
ts_key=location_fv.batch_source.timestamp_field,
148-
ts_start=order_record[event_timestamp] - location_fv.ttl,
154+
ts_start=order_record[event_timestamp]
155+
- get_feature_view_ttl(location_fv, default_ttl),
149156
ts_end=order_record[event_timestamp],
150157
filter_keys=["location_id"],
151158
filter_values=[order_record["origin_id"]],
152159
)
153160
destination_record = find_asof_record(
154161
location_records,
155162
ts_key=location_fv.batch_source.timestamp_field,
156-
ts_start=order_record[event_timestamp] - location_fv.ttl,
163+
ts_start=order_record[event_timestamp]
164+
- get_feature_view_ttl(location_fv, default_ttl),
157165
ts_end=order_record[event_timestamp],
158166
filter_keys=["location_id"],
159167
filter_values=[order_record["destination_id"]],
160168
)
161169
global_record = find_asof_record(
162170
global_records,
163171
ts_key=global_fv.batch_source.timestamp_field,
164-
ts_start=order_record[event_timestamp] - global_fv.ttl,
172+
ts_start=order_record[event_timestamp]
173+
- get_feature_view_ttl(global_fv, default_ttl),
165174
ts_end=order_record[event_timestamp],
166175
)
167176

168177
field_mapping_record = find_asof_record(
169178
field_mapping_records,
170179
ts_key=field_mapping_fv.batch_source.timestamp_field,
171-
ts_start=order_record[event_timestamp] - field_mapping_fv.ttl,
180+
ts_start=order_record[event_timestamp]
181+
- get_feature_view_ttl(field_mapping_fv, default_ttl),
172182
ts_end=order_record[event_timestamp],
173183
)
174184

@@ -666,6 +676,78 @@ def test_historical_features_persisting(
666676
)
667677

668678

679+
@pytest.mark.integration
680+
@pytest.mark.universal_offline_stores
681+
@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v))
682+
def test_historical_features_with_no_ttl(
683+
environment, universal_data_sources, full_feature_names
684+
):
685+
store = environment.feature_store
686+
687+
(entities, datasets, data_sources) = universal_data_sources
688+
feature_views = construct_universal_feature_views(data_sources)
689+
690+
# Remove ttls.
691+
feature_views.customer.ttl = timedelta(seconds=0)
692+
feature_views.order.ttl = timedelta(seconds=0)
693+
feature_views.global_fv.ttl = timedelta(seconds=0)
694+
feature_views.field_mapping.ttl = timedelta(seconds=0)
695+
696+
store.apply([driver(), customer(), location(), *feature_views.values()])
697+
698+
entity_df = datasets.entity_df.drop(
699+
columns=["order_id", "origin_id", "destination_id"]
700+
)
701+
702+
job = store.get_historical_features(
703+
entity_df=entity_df,
704+
features=[
705+
"customer_profile:current_balance",
706+
"customer_profile:avg_passenger_count",
707+
"customer_profile:lifetime_trip_count",
708+
"order:order_is_success",
709+
"global_stats:num_rides",
710+
"global_stats:avg_ride_length",
711+
"field_mapping:feature_name",
712+
],
713+
full_feature_names=full_feature_names,
714+
)
715+
716+
event_timestamp = DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
717+
expected_df = get_expected_training_df(
718+
datasets.customer_df,
719+
feature_views.customer,
720+
datasets.driver_df,
721+
feature_views.driver,
722+
datasets.orders_df,
723+
feature_views.order,
724+
datasets.location_df,
725+
feature_views.location,
726+
datasets.global_df,
727+
feature_views.global_fv,
728+
datasets.field_mapping_df,
729+
feature_views.field_mapping,
730+
entity_df,
731+
event_timestamp,
732+
full_feature_names,
733+
).drop(
734+
columns=[
735+
response_feature_name("conv_rate_plus_100", full_feature_names),
736+
response_feature_name("conv_rate_plus_100_rounded", full_feature_names),
737+
response_feature_name("avg_daily_trips", full_feature_names),
738+
response_feature_name("conv_rate", full_feature_names),
739+
"origin__temperature",
740+
"destination__temperature",
741+
]
742+
)
743+
744+
assert_frame_equal(
745+
expected_df,
746+
job.to_df(),
747+
keys=[event_timestamp, "driver_id", "customer_id"],
748+
)
749+
750+
669751
@pytest.mark.integration
670752
@pytest.mark.universal_offline_stores
671753
def test_historical_features_from_bigquery_sources_containing_backfills(environment):
@@ -781,6 +863,13 @@ def response_feature_name(feature: str, full_feature_names: bool) -> str:
781863
return feature
782864

783865

866+
def get_feature_view_ttl(
867+
feature_view: FeatureView, default_ttl: timedelta
868+
) -> timedelta:
869+
"""Returns the ttl of a feature view if it is non-zero. Otherwise returns the specified default."""
870+
return feature_view.ttl if feature_view.ttl else default_ttl
871+
872+
784873
def assert_feature_service_correctness(
785874
store, feature_service, full_feature_names, entity_df, expected_df, event_timestamp
786875
):

0 commit comments

Comments
 (0)