Skip to content

Commit f0b29f6

Browse files
committed
fix: Count operator should return N/A intervals
1 parent 1587c96 commit f0b29f6

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

execution_engine/task/process/rectangle.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ def filter_count_intervals(
479479
min_count: int | None,
480480
max_count: int | None,
481481
keep_no_data: bool = True,
482+
keep_not_applicable: bool = True,
482483
) -> PersonIntervals:
483484
"""
484485
Filters the intervals per dict key in the list by count.
@@ -487,14 +488,18 @@ def filter_count_intervals(
487488
:param min_count: The minimum count of the intervals.
488489
:param max_count: The maximum count of the intervals.
489490
:param keep_no_data: Whether to keep NO_DATA intervals (irrespective of the count).
491+
:param keep_not_applicable: Whether to keep NOT_APPLICABLE intervals (irrespective of the count).
490492
:return: A dict with the unioned intervals.
491493
"""
492494

493495
result: PersonIntervals = {}
494496

495497
interval_filter = []
498+
496499
if keep_no_data:
497500
interval_filter.append(IntervalType.NO_DATA)
501+
if keep_not_applicable:
502+
interval_filter.append(IntervalType.NOT_APPLICABLE)
498503

499504
if min_count is None and max_count is None:
500505
raise ValueError("min_count and max_count cannot both be None")
@@ -668,6 +673,7 @@ def create_time_intervals(
668673
# Prepare to collect intervals
669674
intervals = []
670675
previous_end = None
676+
671677
def add_interval(interval_start, interval_end, interval_type):
672678
nonlocal previous_end
673679
effective_start = max(interval_start, start_datetime)
@@ -681,11 +687,13 @@ def add_interval(interval_start, interval_end, interval_type):
681687
# touching intervals.
682688
if previous_end is not None:
683689
assert previous_end < effective_start
684-
intervals.append(Interval(
685-
lower=effective_start.timestamp(),
686-
upper=effective_end.timestamp(),
687-
type=interval_type,
688-
))
690+
intervals.append(
691+
Interval(
692+
lower=effective_start.timestamp(),
693+
upper=effective_end.timestamp(),
694+
type=interval_type,
695+
)
696+
)
689697
previous_end = effective_end
690698

691699
# Current date to process
@@ -714,25 +722,22 @@ def add_interval(interval_start, interval_end, interval_type):
714722
# overlaps the main datetime range, otherwise fill the day
715723
# with an interval of type "not applicable".
716724
# TODO: what about intervals "before" the main datetime range?
717-
if end_interval < start_datetime: # completely before datetime range
725+
if end_interval < start_datetime: # completely before datetime range
718726
day_start = timezone.localize(
719-
datetime.datetime.combine(
720-
current_date, datetime.time(0, 0, 0)
721-
))
727+
datetime.datetime.combine(current_date, datetime.time(0, 0, 0))
728+
)
722729
day_end = timezone.localize(
723-
datetime.datetime.combine(
724-
current_date, datetime.time(23, 59, 59)
725-
))
730+
datetime.datetime.combine(current_date, datetime.time(23, 59, 59))
731+
)
726732
if (previous_end is not None) and day_start <= previous_end:
727733
start = previous_end + datetime.timedelta(seconds=1)
728734
else:
729735
start = day_start
730736
add_interval(start, day_end, IntervalType.NOT_APPLICABLE)
731-
elif end_datetime < start_interval: # completely after datetime range
737+
elif end_datetime < start_interval: # completely after datetime range
732738
day_start = timezone.localize(
733-
datetime.datetime.combine(
734-
current_date, datetime.time(0, 0, 0)
735-
))
739+
datetime.datetime.combine(current_date, datetime.time(0, 0, 0))
740+
)
736741
if (previous_end is not None) and day_start <= previous_end:
737742
start = previous_end + datetime.timedelta(seconds=1)
738743
else:
@@ -799,10 +804,15 @@ def find_overlapping_personal_windows(
799804

800805
return result
801806

807+
802808
def find_rectangles_with_count(data: list[PersonIntervals]) -> PersonIntervals:
803809
if len(data) == 0:
804810
return {}
805811
else:
806812
keys = data[0].keys()
807-
return {key: _impl.find_rectangles_with_count([ intervals[key] for intervals in data ])
808-
for key in keys}
813+
return {
814+
key: _impl.find_rectangles_with_count(
815+
[intervals[key] for intervals in data]
816+
)
817+
for key in keys
818+
}

tests/recommendation/test_recommendation_base_v2.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,6 @@ def setup_testdata(self, db_session, run_slow_tests):
553553
for item in generate_combinations(c, self.invalid_combinations)
554554
]
555555

556-
# combinations = [combinations[0]]
557-
558556
self.insert_criteria_into_database(db_session, combinations)
559557

560558
df_criterion_entries = self.generate_criterion_entries(combinations)

0 commit comments

Comments
 (0)