Skip to content

Commit 297bbcb

Browse files
committed
feat: compute "interval ratio" in logical count operators
Also implement the operators via find_rectangles and remove from the process module the functions count_intervals and filter_count_intervals.
1 parent a6d8796 commit 297bbcb

File tree

3 files changed

+44
-194
lines changed

3 files changed

+44
-194
lines changed

execution_engine/task/process/rectangle.py

Lines changed: 0 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -428,110 +428,6 @@ def union_intervals(data: list[PersonIntervals]) -> PersonIntervals:
428428
return _process_intervals(data, _impl.union_interval_lists)
429429

430430

431-
def interval_to_interval_with_count(interval: Interval) -> IntervalWithCount:
432-
"""
433-
Converts an Interval to an IntervalWithCount.
434-
"""
435-
return IntervalWithCount(interval.lower, interval.upper, interval.type, 1)
436-
437-
438-
def intervals_to_intervals_with_count(
439-
intervals: list[Interval],
440-
) -> list[IntervalWithCount]:
441-
"""
442-
Converts a list of Intervals to a list of IntervalWithCount.
443-
"""
444-
return [interval_to_interval_with_count(interval) for interval in intervals]
445-
446-
447-
def count_intervals(data: list[PersonIntervals]) -> PersonIntervalsWithCount:
448-
"""
449-
Counts the intervals per dict key in the list.
450-
451-
:param data: A list of dict of intervals.
452-
:return: A dict with the unioned intervals.
453-
"""
454-
if not len(data):
455-
return dict()
456-
457-
# assert dfs is a list of dataframes
458-
assert isinstance(data, list) and all(
459-
isinstance(arr, dict) for arr in data
460-
), "data must be a list of dicts"
461-
462-
result = {}
463-
464-
for arr in data:
465-
if not len(arr):
466-
# if the operation is union, an empty dataframe can be ignored
467-
continue
468-
469-
for group_keys, intervals in arr.items():
470-
intervals_with_count = intervals_to_intervals_with_count(intervals)
471-
intervals_with_count = _impl.union_rects_with_count(intervals_with_count)
472-
if group_keys not in result:
473-
result[group_keys] = intervals_with_count
474-
else:
475-
result[group_keys] = _impl.union_rects_with_count(
476-
result[group_keys] + intervals_with_count
477-
)
478-
479-
return result
480-
481-
482-
def filter_count_intervals(
483-
data: PersonIntervalsWithCount,
484-
min_count: int | None,
485-
max_count: int | None,
486-
keep_no_data: bool = True,
487-
keep_not_applicable: bool = True,
488-
) -> PersonIntervals:
489-
"""
490-
Filters the intervals per dict key in the list by count.
491-
492-
:param data: A list of dict of intervals.
493-
:param min_count: The minimum count of the intervals.
494-
:param max_count: The maximum count of the intervals.
495-
:param keep_no_data: Whether to keep NO_DATA intervals (irrespective of the count).
496-
:param keep_not_applicable: Whether to keep NOT_APPLICABLE intervals (irrespective of the count).
497-
:return: A dict with the unioned intervals.
498-
"""
499-
500-
result: PersonIntervals = {}
501-
502-
interval_filter = []
503-
504-
if keep_no_data:
505-
interval_filter.append(IntervalType.NO_DATA)
506-
if keep_not_applicable:
507-
interval_filter.append(IntervalType.NOT_APPLICABLE)
508-
509-
if min_count is None and max_count is None:
510-
raise ValueError("min_count and max_count cannot both be None")
511-
elif min_count is not None and max_count is not None:
512-
for person_id in data:
513-
result[person_id] = [
514-
Interval(interval.lower, interval.upper, interval.type)
515-
for interval in data[person_id]
516-
if min_count <= interval.count <= max_count
517-
or interval.type in interval_filter
518-
]
519-
elif min_count is not None:
520-
for person_id in data:
521-
result[person_id] = [
522-
Interval(interval.lower, interval.upper, interval.type)
523-
for interval in data[person_id]
524-
if min_count <= interval.count or interval.type in interval_filter
525-
]
526-
elif max_count is not None:
527-
for person_id in data:
528-
result[person_id] = [
529-
Interval(interval.lower, interval.upper, interval.type)
530-
for interval in data[person_id]
531-
if interval.count <= max_count or interval.type in interval_filter
532-
]
533-
534-
return result
535431

536432

537433
def intersect_intervals(data: list[PersonIntervals]) -> PersonIntervals:

execution_engine/task/process/rectangle_python.py

Lines changed: 0 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -105,90 +105,6 @@ def union_rects(intervals: list[Interval]) -> list[Interval]:
105105
return union
106106

107107

108-
def union_rects_with_count(
109-
intervals: list[IntervalWithCount],
110-
) -> list[IntervalWithCount]:
111-
"""
112-
Unions the intervals while keeping track of the count of overlapping intervals of the same type.
113-
"""
114-
115-
if not len(intervals):
116-
return []
117-
118-
with IntervalType.union_order():
119-
events = intervals_to_events(intervals)
120-
121-
union = []
122-
123-
last_x_start = -np.inf # holds the x_min of the currently open output rectangle
124-
last_x_end = events[0][
125-
0
126-
] # x variable of the last closed interval (we start with the first x, so we
127-
# don't close the first rectangle at the first x)
128-
previous_x_visited = -np.inf
129-
open_y = SortedDict()
130-
131-
def get_y_max() -> IntervalType | None:
132-
max_key = None
133-
for key in reversed(open_y):
134-
if open_y[key] > 0:
135-
max_key = key
136-
break
137-
return max_key
138-
139-
for x, start_point, interval in events:
140-
y, count_event = interval.type, interval.count
141-
if start_point:
142-
y_max = get_y_max()
143-
144-
if x > previous_x_visited and y_max is None:
145-
# no currently open rectangles
146-
last_x_start = x # start new output rectangle
147-
elif y >= y_max:
148-
if x == last_x_end or x == last_x_start:
149-
# we already closed a rectangle at this x, so we don't need to start a new one
150-
open_y[y] = open_y.get(y, 0) + count_event
151-
continue
152-
153-
union.append(
154-
IntervalWithCount(
155-
lower=last_x_start,
156-
upper=x - 1,
157-
type=y_max,
158-
count=open_y[y_max],
159-
)
160-
)
161-
last_x_end = x
162-
last_x_start = x
163-
164-
open_y[y] = open_y.get(y, 0) + count_event
165-
166-
else:
167-
open_y[y] = max(open_y.get(y, 0) - count_event, 0)
168-
169-
y_max = get_y_max()
170-
171-
if (y_max is None or (open_y and y_max <= y)) and x > last_x_end:
172-
if y_max is None or y_max < y:
173-
# the closing rectangle has a higher y_max than the currently open ones
174-
count = count_event
175-
else:
176-
# the closing rectangle has the same y_max as the currently open ones
177-
count = open_y[y] + count_event
178-
179-
union.append(
180-
IntervalWithCount(
181-
lower=last_x_start, upper=x - 1, type=y, count=count
182-
)
183-
) # close the previous rectangle at y_max
184-
last_x_end = x
185-
last_x_start = x # start new output rectangle
186-
187-
previous_x_visited = x
188-
189-
return merge_adjacent_intervals(union)
190-
191-
192108
def merge_adjacent_intervals(
193109
intervals: list[IntervalWithCount],
194110
) -> list[IntervalWithCount]:

execution_engine/task/task.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,12 +306,50 @@ def handle_binary_logical_operator(
306306
elif isinstance(self.expr, (logic.Or, logic.NonSimplifiableOr)):
307307
result = process.union_intervals(data)
308308
elif isinstance(self.expr, logic.Count):
309-
result = process.count_intervals(data)
310-
result = process.filter_count_intervals(
311-
result,
312-
min_count=self.expr.count_min,
313-
max_count=self.expr.count_max,
314-
)
309+
# result = process.count_intervals(data)
310+
# result = process.filter_count_intervals(
311+
# result,
312+
# min_count=self.expr.count_min,
313+
# max_count=self.expr.count_max,
314+
# )
315+
count_min = self.expr.count_min
316+
count_max = self.expr.count_max
317+
if count_min is None and count_max is None:
318+
raise ValueError("count_min and count_max cannot both be None")
319+
def interval_counts(
320+
start: int, end: int, intervals: List[AnyInterval]
321+
) -> GeneralizedInterval:
322+
positive_count, negative_count, not_applicable_count, no_data_count = 0, 0, 0, 0
323+
for interval in intervals:
324+
if interval is None or interval.type is IntervalType.NEGATIVE:
325+
negative_count += 1
326+
elif interval.type is IntervalType.POSITIVE:
327+
positive_count += 1
328+
elif interval.type is IntervalType.NOT_APPLICABLE:
329+
not_applicable_count += 1
330+
elif interval.type is IntervalType.NO_DATA:
331+
no_data_count += 1
332+
#
333+
if positive_count > 0:
334+
if count_min is None:
335+
interval_type = IntervalType.POSITIVE if (
336+
positive_count <= count_max) else IntervalType.NEGATIVE
337+
return Interval(start, end, interval_type)
338+
else:
339+
min_good = count_min <= positive_count
340+
max_good = (count_max is None) or (positive_count <= count_max)
341+
interval_type = IntervalType.POSITIVE if (min_good and max_good) else IntervalType.NEGATIVE
342+
ratio = positive_count / count_min
343+
return IntervalWithCount(start, end, interval_type, ratio)
344+
if no_data_count > 0:
345+
return Interval(start, end, IntervalType.NO_DATA)
346+
if not_applicable_count > 0:
347+
return Interval(start, end, IntervalType.NOT_APPLICABLE)
348+
if negative_count > 0:
349+
return Interval(start, end, IntervalType.NEGATIVE)
350+
351+
result = process.find_rectangles(data, interval_counts)
352+
315353
elif isinstance(self.expr, logic.CappedCount):
316354

317355
def interval_counts(

0 commit comments

Comments
 (0)