Skip to content

Commit 331a214

Browse files
authored
fix: Update udf tests and add base functions to streaming fcos and fix some nonetype errors (#2776)
* Fix lint and add comments Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]> * Fix lint Signed-off-by: Kevin Zhang <[email protected]>
1 parent 83ab682 commit 331a214

File tree

4 files changed

+220
-23
lines changed

4 files changed

+220
-23
lines changed

sdk/python/feast/data_source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def __hash__(self):
503503
@staticmethod
504504
def from_proto(data_source: DataSourceProto):
505505
watermark = None
506-
if data_source.kafka_options.HasField("watermark"):
506+
if data_source.kafka_options.watermark:
507507
watermark = (
508508
timedelta(days=0)
509509
if data_source.kafka_options.watermark.ToNanoseconds() == 0

sdk/python/feast/stream_feature_view.py

Lines changed: 78 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import functools
23
import warnings
34
from datetime import timedelta
@@ -9,7 +10,7 @@
910

1011
from feast import utils
1112
from feast.aggregation import Aggregation
12-
from feast.data_source import DataSource, KafkaSource
13+
from feast.data_source import DataSource, KafkaSource, PushSource
1314
from feast.entity import Entity
1415
from feast.feature_view import FeatureView
1516
from feast.field import Field
@@ -39,6 +40,26 @@ class StreamFeatureView(FeatureView):
3940
"""
4041
NOTE: Stream Feature Views are not yet fully implemented and exist to allow users to register their stream sources and
4142
schemas with Feast.
43+
44+
Attributes:
45+
name: str. The unique name of the stream feature view.
46+
entities: Union[List[Entity], List[str]]. List of entities or entity join keys.
47+
ttl: timedelta. The amount of time this group of features lives. A ttl of 0 indicates that
48+
this group of features lives forever. Note that large ttl's or a ttl of 0
49+
can result in extremely computationally intensive queries.
50+
tags: Dict[str, str]. A dictionary of key-value pairs to store arbitrary metadata.
51+
online: bool. Defines whether this stream feature view is used in online feature retrieval.
52+
description: str. A human-readable description.
53+
owner: The owner of the on demand feature view, typically the email of the primary
54+
maintainer.
55+
schema: List[Field] The schema of the feature view, including feature, timestamp, and entity
56+
columns. If not specified, can be inferred from the underlying data source.
57+
source: DataSource. The stream source of data where this group of features
58+
is stored.
59+
aggregations (optional): List[Aggregation]. List of aggregations registered with the stream feature view.
60+
mode(optional): str. The mode of execution.
61+
timestamp_field (optional): Must be specified if aggregations are specified. Defines the timestamp column on which to aggregate windows.
62+
udf (optional): MethodType The user defined transformation function. This transformation function should have all of the corresponding imports imported within the function.
4263
"""
4364

4465
def __init__(
@@ -54,18 +75,19 @@ def __init__(
5475
schema: Optional[List[Field]] = None,
5576
source: Optional[DataSource] = None,
5677
aggregations: Optional[List[Aggregation]] = None,
57-
mode: Optional[str] = "spark", # Mode of ingestion/transformation
58-
timestamp_field: Optional[str] = "", # Timestamp for aggregation
78+
mode: Optional[str] = "spark",
79+
timestamp_field: Optional[str] = "",
5980
udf: Optional[MethodType] = None,
6081
):
6182
warnings.warn(
6283
"Stream Feature Views are experimental features in alpha development. "
6384
"Some functionality may still be unstable so functionality can change in the future.",
6485
RuntimeWarning,
6586
)
87+
6688
if source is None:
67-
raise ValueError("Stream Feature views need a source specified")
68-
# source uses the batch_source of the kafkasource in feature_view
89+
raise ValueError("Stream Feature views need a source to be specified")
90+
6991
if (
7092
type(source).__name__ not in SUPPORTED_STREAM_SOURCES
7193
and source.to_proto().type != DataSourceProto.SourceType.CUSTOM_SOURCE
@@ -74,18 +96,26 @@ def __init__(
7496
f"Stream feature views need a stream source, expected one of {SUPPORTED_STREAM_SOURCES} "
7597
f"or CUSTOM_SOURCE, got {type(source).__name__}: {source.name} instead "
7698
)
99+
100+
if aggregations and not timestamp_field:
101+
raise ValueError(
102+
"aggregations must have a timestamp field associated with them to perform the aggregations"
103+
)
104+
77105
self.aggregations = aggregations or []
78-
self.mode = mode
79-
self.timestamp_field = timestamp_field
106+
self.mode = mode or ""
107+
self.timestamp_field = timestamp_field or ""
80108
self.udf = udf
81109
_batch_source = None
82-
if isinstance(source, KafkaSource):
110+
if isinstance(source, KafkaSource) or isinstance(source, PushSource):
83111
_batch_source = source.batch_source if source.batch_source else None
84-
112+
_ttl = ttl
113+
if not _ttl:
114+
_ttl = timedelta(days=0)
85115
super().__init__(
86116
name=name,
87117
entities=entities,
88-
ttl=ttl,
118+
ttl=_ttl,
89119
batch_source=_batch_source,
90120
stream_source=source,
91121
tags=tags,
@@ -102,7 +132,10 @@ def __eq__(self, other):
102132

103133
if not super().__eq__(other):
104134
return False
105-
135+
if not self.udf:
136+
return not other.udf
137+
if not other.udf:
138+
return False
106139
if (
107140
self.mode != other.mode
108141
or self.timestamp_field != other.timestamp_field
@@ -113,13 +146,14 @@ def __eq__(self, other):
113146

114147
return True
115148

116-
def __hash__(self):
149+
def __hash__(self) -> int:
117150
return super().__hash__()
118151

119152
def to_proto(self):
120153
meta = StreamFeatureViewMetaProto(materialization_intervals=[])
121154
if self.created_timestamp:
122155
meta.created_timestamp.FromDatetime(self.created_timestamp)
156+
123157
if self.last_updated_timestamp:
124158
meta.last_updated_timestamp.FromDatetime(self.last_updated_timestamp)
125159

@@ -134,6 +168,7 @@ def to_proto(self):
134168
ttl_duration = Duration()
135169
ttl_duration.FromTimedelta(self.ttl)
136170

171+
batch_source_proto = None
137172
if self.batch_source:
138173
batch_source_proto = self.batch_source.to_proto()
139174
batch_source_proto.data_source_class_type = f"{self.batch_source.__class__.__module__}.{self.batch_source.__class__.__name__}"
@@ -143,23 +178,24 @@ def to_proto(self):
143178
stream_source_proto = self.stream_source.to_proto()
144179
stream_source_proto.data_source_class_type = f"{self.stream_source.__class__.__module__}.{self.stream_source.__class__.__name__}"
145180

181+
udf_proto = None
182+
if self.udf:
183+
udf_proto = UserDefinedFunctionProto(
184+
name=self.udf.__name__, body=dill.dumps(self.udf, recurse=True),
185+
)
146186
spec = StreamFeatureViewSpecProto(
147187
name=self.name,
148188
entities=self.entities,
149189
entity_columns=[field.to_proto() for field in self.entity_columns],
150190
features=[field.to_proto() for field in self.schema],
151-
user_defined_function=UserDefinedFunctionProto(
152-
name=self.udf.__name__, body=dill.dumps(self.udf, recurse=True),
153-
)
154-
if self.udf
155-
else None,
191+
user_defined_function=udf_proto,
156192
description=self.description,
157193
tags=self.tags,
158194
owner=self.owner,
159-
ttl=(ttl_duration if ttl_duration is not None else None),
195+
ttl=ttl_duration,
160196
online=self.online,
161197
batch_source=batch_source_proto or None,
162-
stream_source=stream_source_proto,
198+
stream_source=stream_source_proto or None,
163199
timestamp_field=self.timestamp_field,
164200
aggregations=[agg.to_proto() for agg in self.aggregations],
165201
mode=self.mode,
@@ -239,6 +275,25 @@ def from_proto(cls, sfv_proto):
239275

240276
return sfv_feature_view
241277

278+
def __copy__(self):
279+
fv = StreamFeatureView(
280+
name=self.name,
281+
schema=self.schema,
282+
entities=self.entities,
283+
ttl=self.ttl,
284+
tags=self.tags,
285+
online=self.online,
286+
description=self.description,
287+
owner=self.owner,
288+
aggregations=self.aggregations,
289+
mode=self.mode,
290+
timestamp_field=self.timestamp_field,
291+
sources=self.sources,
292+
udf=self.udf,
293+
)
294+
fv.projection = copy.copy(self.projection)
295+
return fv
296+
242297

243298
def stream_feature_view(
244299
*,
@@ -251,11 +306,13 @@ def stream_feature_view(
251306
schema: Optional[List[Field]] = None,
252307
source: Optional[DataSource] = None,
253308
aggregations: Optional[List[Aggregation]] = None,
254-
mode: Optional[str] = "spark", # Mode of ingestion/transformation
255-
timestamp_field: Optional[str] = "", # Timestamp for aggregation
309+
mode: Optional[str] = "spark",
310+
timestamp_field: Optional[str] = "",
256311
):
257312
"""
258313
Creates an StreamFeatureView object with the given user function as udf.
314+
Please make sure that the udf contains all non-built in imports within the function to ensure that the execution
315+
of a deserialized function does not miss imports.
259316
"""
260317

261318
def mainify(obj):

sdk/python/tests/integration/registration/test_stream_feature_view_apply.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,71 @@ def simple_sfv(df):
7070
assert features["test_key"] == [1001]
7171
assert "dummy_field" in features
7272
assert features["dummy_field"] == [None]
73+
74+
75+
@pytest.mark.integration
76+
def test_stream_feature_view_udf(environment) -> None:
77+
"""
78+
Test apply of StreamFeatureView udfs are serialized correctly and usable.
79+
"""
80+
fs = environment.feature_store
81+
82+
# Create Feature Views
83+
entity = Entity(name="driver_entity", join_keys=["test_key"])
84+
85+
stream_source = KafkaSource(
86+
name="kafka",
87+
timestamp_field="event_timestamp",
88+
bootstrap_servers="",
89+
message_format=AvroFormat(""),
90+
topic="topic",
91+
batch_source=FileSource(path="test_path", timestamp_field="event_timestamp"),
92+
watermark=timedelta(days=1),
93+
)
94+
95+
@stream_feature_view(
96+
entities=[entity],
97+
ttl=timedelta(days=30),
98+
99+
online=True,
100+
schema=[Field(name="dummy_field", dtype=Float32)],
101+
description="desc",
102+
aggregations=[
103+
Aggregation(
104+
column="dummy_field", function="max", time_window=timedelta(days=1),
105+
),
106+
Aggregation(
107+
column="dummy_field2", function="count", time_window=timedelta(days=24),
108+
),
109+
],
110+
timestamp_field="event_timestamp",
111+
mode="spark",
112+
source=stream_source,
113+
tags={},
114+
)
115+
def pandas_view(pandas_df):
116+
import pandas as pd
117+
118+
assert type(pandas_df) == pd.DataFrame
119+
df = pandas_df.transform(lambda x: x + 10, axis=1)
120+
df.insert(2, "C", [20.2, 230.0, 34.0], True)
121+
return df
122+
123+
import pandas as pd
124+
125+
df = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
126+
127+
fs.apply([entity, pandas_view])
128+
stream_feature_views = fs.list_stream_feature_views()
129+
assert len(stream_feature_views) == 1
130+
assert stream_feature_views[0].name == "pandas_view"
131+
assert stream_feature_views[0] == pandas_view
132+
133+
sfv = stream_feature_views[0]
134+
135+
new_df = sfv.udf(df)
136+
137+
expected_df = pd.DataFrame(
138+
{"A": [11, 12, 13], "B": [20, 30, 40], "C": [20.2, 230.0, 34.0]}
139+
)
140+
assert new_df.equals(expected_df)

sdk/python/tests/unit/test_feature_views.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from feast.entity import Entity
1010
from feast.field import Field
1111
from feast.infra.offline_stores.file_source import FileSource
12-
from feast.stream_feature_view import StreamFeatureView
12+
from feast.stream_feature_view import StreamFeatureView, stream_feature_view
1313
from feast.types import Float32
1414

1515

@@ -129,3 +129,75 @@ def test_stream_feature_view_serialization():
129129

130130
new_sfv = StreamFeatureView.from_proto(sfv_proto=sfv_proto)
131131
assert new_sfv == sfv
132+
133+
134+
def test_stream_feature_view_udfs():
135+
entity = Entity(name="driver_entity", join_keys=["test_key"])
136+
stream_source = KafkaSource(
137+
name="kafka",
138+
timestamp_field="event_timestamp",
139+
bootstrap_servers="",
140+
message_format=AvroFormat(""),
141+
topic="topic",
142+
batch_source=FileSource(path="some path"),
143+
)
144+
145+
@stream_feature_view(
146+
entities=[entity],
147+
ttl=timedelta(days=30),
148+
149+
online=True,
150+
schema=[Field(name="dummy_field", dtype=Float32)],
151+
description="desc",
152+
aggregations=[
153+
Aggregation(
154+
column="dummy_field", function="max", time_window=timedelta(days=1),
155+
)
156+
],
157+
timestamp_field="event_timestamp",
158+
source=stream_source,
159+
)
160+
def pandas_udf(pandas_df):
161+
import pandas as pd
162+
163+
assert type(pandas_df) == pd.DataFrame
164+
df = pandas_df.transform(lambda x: x + 10, axis=1)
165+
return df
166+
167+
import pandas as pd
168+
169+
df = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
170+
sfv = pandas_udf
171+
sfv_proto = sfv.to_proto()
172+
new_sfv = StreamFeatureView.from_proto(sfv_proto)
173+
new_df = new_sfv.udf(df)
174+
175+
expected_df = pd.DataFrame({"A": [11, 12, 13], "B": [20, 30, 40]})
176+
177+
assert new_df.equals(expected_df)
178+
179+
180+
def test_stream_feature_view_initialization_with_optional_fields_omitted():
181+
entity = Entity(name="driver_entity", join_keys=["test_key"])
182+
stream_source = KafkaSource(
183+
name="kafka",
184+
timestamp_field="event_timestamp",
185+
bootstrap_servers="",
186+
message_format=AvroFormat(""),
187+
topic="topic",
188+
batch_source=FileSource(path="some path"),
189+
)
190+
191+
sfv = StreamFeatureView(
192+
name="test kafka stream feature view",
193+
entities=[entity],
194+
schema=[],
195+
description="desc",
196+
timestamp_field="event_timestamp",
197+
source=stream_source,
198+
tags={},
199+
)
200+
sfv_proto = sfv.to_proto()
201+
202+
new_sfv = StreamFeatureView.from_proto(sfv_proto=sfv_proto)
203+
assert new_sfv == sfv

0 commit comments

Comments
 (0)