@@ -58,25 +58,31 @@ def _recreate(cls, args: Any, kwargs: dict) -> "Expr":
58
58
self .set_id (_id )
59
59
return self
60
60
61
- def _reduce_helper (self , ivars_map : dict | None = None ) -> tuple [Callable , tuple ]:
61
+ def _reduce_helper (
62
+ self , ivars_map : dict | None = None , args : tuple | None = None
63
+ ) -> tuple [Callable , tuple ]:
62
64
"""
63
65
Return a picklable tuple that calls self._recreate and passes in
64
66
(self.args, combined_ivars).
65
67
66
68
:param ivars_map: A dictionary for renaming keys in the instance variables.
67
69
For each old_key -> new_key in ivars_map, if old_key exists
68
70
in data, it will be removed and stored under new_key instead.
71
+ :param args: Optionally, the args parameter can be specified. If it is None, self.args is used.
69
72
"""
70
- data = dict ( self .get_instance_variables () )
71
- data ["_id" ] = self ._id
73
+ data = self .get_instance_variables ()
74
+ data ["_id" ] = self ._id # type: ignore[index]
72
75
73
76
# Apply key renaming if ivars_map is provided
74
77
if ivars_map :
75
78
for old_key , new_key in ivars_map .items ():
76
79
if old_key in data :
77
- data [new_key ] = data .pop (old_key )
80
+ data [new_key ] = data .pop (old_key ) # type: ignore[index,union-attr]
78
81
79
- return (self ._recreate , (self .args , data ))
82
+ if args is None :
83
+ args = self .args
84
+
85
+ return (self ._recreate , (args , data ))
80
86
81
87
def __reduce__ (self ) -> tuple [Callable , tuple ]:
82
88
"""
@@ -404,15 +410,30 @@ def __new__(
404
410
405
411
return self
406
412
413
+ def _replace_map (self ) -> dict :
414
+ replace = {}
415
+
416
+ if self .count_min and self .count_max :
417
+ assert self .count_min == self .count_max
418
+ replace ["count_min" ] = "threshold"
419
+ elif self .count_min :
420
+ replace ["count_min" ] = "threshold"
421
+ elif self .count_max :
422
+ replace ["count_max" ] = "threshold"
423
+ else :
424
+ raise AttributeError ("At least one of count_min or count_max must be set" )
425
+
426
+ return replace
427
+
407
428
def __reduce__ (self ) -> tuple [Callable , tuple ]:
408
429
"""
409
430
Reduce the expression to its arguments and category.
410
431
411
432
Required for pickling (e.g. when using multiprocessing).
412
433
413
- :return: Tuple of the class, arguments, and category .
434
+ :return: Tuple of the class, argument .
414
435
"""
415
- return self ._reduce_helper ({ "count_min" : "threshold" , "count_max" : "threshold" } )
436
+ return self ._reduce_helper (self . _replace_map () )
416
437
417
438
418
439
class Count (CountOperator ):
@@ -683,6 +704,31 @@ def _validate_time_inputs(
683
704
f"Invalid criterion - expected Criterion or CriterionCombination, got { type (interval_criterion )} "
684
705
)
685
706
707
+ def __reduce__ (self ) -> tuple [Callable , tuple ]:
708
+ """
709
+ Reduce the expression to its arguments and category.
710
+
711
+ Required for pickling (e.g. when using multiprocessing).
712
+
713
+ :return: Tuple of the class, argument.
714
+ """
715
+ if self .interval_criterion :
716
+
717
+ if len (self .args ) <= 1 :
718
+ raise ValueError (
719
+ "More than one argument required if interval_criterion is set"
720
+ )
721
+
722
+ args , pop = self .args [:- 1 ], self .args [- 1 ]
723
+
724
+ if pop != self .interval_criterion :
725
+ raise ValueError (
726
+ f"Expected last argument to be the interval_criterion, got { str (pop )} "
727
+ )
728
+ return self ._reduce_helper (self ._replace_map (), args = args )
729
+
730
+ return super ().__reduce__ ()
731
+
686
732
def dict (self , include_id : bool = False ) -> dict :
687
733
"""
688
734
Get a dictionary representation of the object.
@@ -746,11 +792,11 @@ class TemporalMinCount(TemporalCount):
746
792
def __new__ (
747
793
cls ,
748
794
* args : Any ,
749
- threshold : int | None ,
750
- start_time : time | None ,
751
- end_time : time | None ,
752
- interval_type : TimeIntervalType | None ,
753
- interval_criterion : BaseExpr | None ,
795
+ threshold : int ,
796
+ start_time : time | None = None ,
797
+ end_time : time | None = None ,
798
+ interval_type : TimeIntervalType | None = None ,
799
+ interval_criterion : BaseExpr | None = None ,
754
800
** kwargs : Any ,
755
801
) -> "TemporalMinCount" :
756
802
"""
@@ -776,7 +822,7 @@ def dict(self, include_id: bool = False) -> dict:
776
822
Get a dictionary representation of the object.
777
823
"""
778
824
data = super ().dict (include_id = include_id )
779
- data .update ({"threshold" : self .count_min })
825
+ data [ "data" ] .update ({"threshold" : self .count_min })
780
826
return data
781
827
782
828
@@ -788,11 +834,11 @@ class TemporalMaxCount(TemporalCount):
788
834
def __new__ (
789
835
cls ,
790
836
* args : Any ,
791
- threshold : int | None ,
792
- start_time : time | None ,
793
- end_time : time | None ,
794
- interval_type : TimeIntervalType | None ,
795
- interval_criterion : BaseExpr | None ,
837
+ threshold : int ,
838
+ start_time : time | None = None ,
839
+ end_time : time | None = None ,
840
+ interval_type : TimeIntervalType | None = None ,
841
+ interval_criterion : BaseExpr | None = None ,
796
842
** kwargs : Any ,
797
843
) -> "TemporalMaxCount" :
798
844
"""
@@ -818,7 +864,7 @@ def dict(self, include_id: bool = False) -> dict:
818
864
Get a dictionary representation of the object.
819
865
"""
820
866
data = super ().dict (include_id = include_id )
821
- data .update ({"threshold" : self .count_max })
867
+ data [ "data" ] .update ({"threshold" : self .count_max })
822
868
return data
823
869
824
870
@@ -830,11 +876,11 @@ class TemporalExactCount(TemporalCount):
830
876
def __new__ (
831
877
cls ,
832
878
* args : Any ,
833
- threshold : int | None ,
834
- start_time : time | None ,
835
- end_time : time | None ,
836
- interval_type : TimeIntervalType | None ,
837
- interval_criterion : BaseExpr | None ,
879
+ threshold : int ,
880
+ start_time : time | None = None ,
881
+ end_time : time | None = None ,
882
+ interval_type : TimeIntervalType | None = None ,
883
+ interval_criterion : BaseExpr | None = None ,
838
884
** kwargs : Any ,
839
885
) -> "TemporalExactCount" :
840
886
"""
@@ -860,7 +906,7 @@ def dict(self, include_id: bool = False) -> dict:
860
906
Get a dictionary representation of the object.
861
907
"""
862
908
data = super ().dict (include_id = include_id )
863
- data .update ({"threshold" : self .count_min })
909
+ data [ "data" ] .update ({"threshold" : self .count_min })
864
910
return data
865
911
866
912
0 commit comments