@@ -115,60 +115,70 @@ def get_expected_training_df(
115
115
entity_df .to_dict ("records" ), event_timestamp
116
116
)
117
117
118
+ # Set sufficiently large ttl that it effectively functions as infinite for the calculations below.
119
+ default_ttl = timedelta (weeks = 52 )
120
+
118
121
# Manually do point-in-time join of driver, customer, and order records against
119
122
# the entity df
120
123
for entity_row in entity_rows :
121
124
customer_record = find_asof_record (
122
125
customer_records ,
123
126
ts_key = customer_fv .batch_source .timestamp_field ,
124
- ts_start = entity_row [event_timestamp ] - customer_fv .ttl ,
127
+ ts_start = entity_row [event_timestamp ]
128
+ - get_feature_view_ttl (customer_fv , default_ttl ),
125
129
ts_end = entity_row [event_timestamp ],
126
130
filter_keys = ["customer_id" ],
127
131
filter_values = [entity_row ["customer_id" ]],
128
132
)
129
133
driver_record = find_asof_record (
130
134
driver_records ,
131
135
ts_key = driver_fv .batch_source .timestamp_field ,
132
- ts_start = entity_row [event_timestamp ] - driver_fv .ttl ,
136
+ ts_start = entity_row [event_timestamp ]
137
+ - get_feature_view_ttl (driver_fv , default_ttl ),
133
138
ts_end = entity_row [event_timestamp ],
134
139
filter_keys = ["driver_id" ],
135
140
filter_values = [entity_row ["driver_id" ]],
136
141
)
137
142
order_record = find_asof_record (
138
143
order_records ,
139
144
ts_key = customer_fv .batch_source .timestamp_field ,
140
- ts_start = entity_row [event_timestamp ] - order_fv .ttl ,
145
+ ts_start = entity_row [event_timestamp ]
146
+ - get_feature_view_ttl (order_fv , default_ttl ),
141
147
ts_end = entity_row [event_timestamp ],
142
148
filter_keys = ["customer_id" , "driver_id" ],
143
149
filter_values = [entity_row ["customer_id" ], entity_row ["driver_id" ]],
144
150
)
145
151
origin_record = find_asof_record (
146
152
location_records ,
147
153
ts_key = location_fv .batch_source .timestamp_field ,
148
- ts_start = order_record [event_timestamp ] - location_fv .ttl ,
154
+ ts_start = order_record [event_timestamp ]
155
+ - get_feature_view_ttl (location_fv , default_ttl ),
149
156
ts_end = order_record [event_timestamp ],
150
157
filter_keys = ["location_id" ],
151
158
filter_values = [order_record ["origin_id" ]],
152
159
)
153
160
destination_record = find_asof_record (
154
161
location_records ,
155
162
ts_key = location_fv .batch_source .timestamp_field ,
156
- ts_start = order_record [event_timestamp ] - location_fv .ttl ,
163
+ ts_start = order_record [event_timestamp ]
164
+ - get_feature_view_ttl (location_fv , default_ttl ),
157
165
ts_end = order_record [event_timestamp ],
158
166
filter_keys = ["location_id" ],
159
167
filter_values = [order_record ["destination_id" ]],
160
168
)
161
169
global_record = find_asof_record (
162
170
global_records ,
163
171
ts_key = global_fv .batch_source .timestamp_field ,
164
- ts_start = order_record [event_timestamp ] - global_fv .ttl ,
172
+ ts_start = order_record [event_timestamp ]
173
+ - get_feature_view_ttl (global_fv , default_ttl ),
165
174
ts_end = order_record [event_timestamp ],
166
175
)
167
176
168
177
field_mapping_record = find_asof_record (
169
178
field_mapping_records ,
170
179
ts_key = field_mapping_fv .batch_source .timestamp_field ,
171
- ts_start = order_record [event_timestamp ] - field_mapping_fv .ttl ,
180
+ ts_start = order_record [event_timestamp ]
181
+ - get_feature_view_ttl (field_mapping_fv , default_ttl ),
172
182
ts_end = order_record [event_timestamp ],
173
183
)
174
184
@@ -666,6 +676,78 @@ def test_historical_features_persisting(
666
676
)
667
677
668
678
679
+ @pytest .mark .integration
680
+ @pytest .mark .universal_offline_stores
681
+ @pytest .mark .parametrize ("full_feature_names" , [True , False ], ids = lambda v : str (v ))
682
+ def test_historical_features_with_no_ttl (
683
+ environment , universal_data_sources , full_feature_names
684
+ ):
685
+ store = environment .feature_store
686
+
687
+ (entities , datasets , data_sources ) = universal_data_sources
688
+ feature_views = construct_universal_feature_views (data_sources )
689
+
690
+ # Remove ttls.
691
+ feature_views .customer .ttl = timedelta (seconds = 0 )
692
+ feature_views .order .ttl = timedelta (seconds = 0 )
693
+ feature_views .global_fv .ttl = timedelta (seconds = 0 )
694
+ feature_views .field_mapping .ttl = timedelta (seconds = 0 )
695
+
696
+ store .apply ([driver (), customer (), location (), * feature_views .values ()])
697
+
698
+ entity_df = datasets .entity_df .drop (
699
+ columns = ["order_id" , "origin_id" , "destination_id" ]
700
+ )
701
+
702
+ job = store .get_historical_features (
703
+ entity_df = entity_df ,
704
+ features = [
705
+ "customer_profile:current_balance" ,
706
+ "customer_profile:avg_passenger_count" ,
707
+ "customer_profile:lifetime_trip_count" ,
708
+ "order:order_is_success" ,
709
+ "global_stats:num_rides" ,
710
+ "global_stats:avg_ride_length" ,
711
+ "field_mapping:feature_name" ,
712
+ ],
713
+ full_feature_names = full_feature_names ,
714
+ )
715
+
716
+ event_timestamp = DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
717
+ expected_df = get_expected_training_df (
718
+ datasets .customer_df ,
719
+ feature_views .customer ,
720
+ datasets .driver_df ,
721
+ feature_views .driver ,
722
+ datasets .orders_df ,
723
+ feature_views .order ,
724
+ datasets .location_df ,
725
+ feature_views .location ,
726
+ datasets .global_df ,
727
+ feature_views .global_fv ,
728
+ datasets .field_mapping_df ,
729
+ feature_views .field_mapping ,
730
+ entity_df ,
731
+ event_timestamp ,
732
+ full_feature_names ,
733
+ ).drop (
734
+ columns = [
735
+ response_feature_name ("conv_rate_plus_100" , full_feature_names ),
736
+ response_feature_name ("conv_rate_plus_100_rounded" , full_feature_names ),
737
+ response_feature_name ("avg_daily_trips" , full_feature_names ),
738
+ response_feature_name ("conv_rate" , full_feature_names ),
739
+ "origin__temperature" ,
740
+ "destination__temperature" ,
741
+ ]
742
+ )
743
+
744
+ assert_frame_equal (
745
+ expected_df ,
746
+ job .to_df (),
747
+ keys = [event_timestamp , "driver_id" , "customer_id" ],
748
+ )
749
+
750
+
669
751
@pytest .mark .integration
670
752
@pytest .mark .universal_offline_stores
671
753
def test_historical_features_from_bigquery_sources_containing_backfills (environment ):
@@ -781,6 +863,13 @@ def response_feature_name(feature: str, full_feature_names: bool) -> str:
781
863
return feature
782
864
783
865
866
+ def get_feature_view_ttl (
867
+ feature_view : FeatureView , default_ttl : timedelta
868
+ ) -> timedelta :
869
+ """Returns the ttl of a feature view if it is non-zero. Otherwise returns the specified default."""
870
+ return feature_view .ttl if feature_view .ttl else default_ttl
871
+
872
+
784
873
def assert_feature_service_correctness (
785
874
store , feature_service , full_feature_names , entity_df , expected_df , event_timestamp
786
875
):
0 commit comments