Skip to content

Commit 102cd54

Browse files
committed
refactor: optimize implementation of TemporalCount
1 parent fcf98d0 commit 102cd54

File tree

1 file changed

+53
-39
lines changed

1 file changed

+53
-39
lines changed

execution_engine/task/task.py

Lines changed: 53 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import logging
55
from enum import Enum, auto
6-
from typing import List
6+
from typing import List, Any
77

88
from sqlalchemy.exc import DBAPIError, IntegrityError, ProgrammingError, SQLAlchemyError
99

@@ -551,49 +551,63 @@ def get_start_end_from_interval_type(
551551
timezone=get_config().timezone,
552552
)
553553

554-
# Create a "temporary window interval" for each window
555-
# interval. Associate with each temporary window interval
556-
# all data intervals that overlap it. The association
557-
# works by assigning a unique id to each temporary window
554+
# Incrementally compute the interval type for each window
558555
# interval.
559-
ids = dict() # window_interval -> unique id
560-
infos = dict() # unique id -> list of overlapping data intervals
561-
def temporary_window_interval(start: int, end: int, intervals: List[AnyInterval]):
556+
window_types: dict[AnyInterval, Any] = dict() # window interval -> interval type
557+
def update_window_type(window_interval, data_interval):
558+
window_type = window_types.get(window_interval.lower, None)
559+
if data_interval is None or data_interval.type is IntervalType.NEGATIVE:
560+
if window_type is not IntervalType.POSITIVE:
561+
window_type = IntervalType.NEGATIVE
562+
elif data_interval.type is IntervalType.POSITIVE:
563+
window_type = IntervalType.POSITIVE
564+
elif data_interval.type is IntervalType.NOT_APPLICABLE:
565+
if window_type is None:
566+
window_type = IntervalType.NOT_APPLICABLE
567+
else:
568+
assert data_interval.type is IntervalType.NO_DATA
569+
if window_type is None:
570+
window_type = IntervalType.NO_DATA
571+
window_types[window_interval.lower] = window_type
572+
return window_type
573+
# The boundaries of the result intervals are identical to
574+
# those of the window intervals. In addition, update the
575+
# result interval window types based on the data
576+
# intervals.
577+
def is_same_interval(left_intervals, right_intervals):
578+
left_window_interval, left_data_interval = left_intervals
579+
right_window_interval, right_data_interval = right_intervals
580+
if right_window_interval is None:
581+
if left_window_interval is None:
582+
return True
583+
else:
584+
update_window_type(left_window_interval, left_data_interval)
585+
return False
586+
else:
587+
update_window_type(right_window_interval, right_data_interval)
588+
if left_window_interval is None:
589+
return False
590+
else:
591+
if left_window_interval is right_window_interval:
592+
return True
593+
else:
594+
update_window_type(left_window_interval, left_data_interval)
595+
return False
596+
# Create result intervals based on the computed interval
597+
# types.
598+
def result_interval(start: int, end: int, intervals: List[AnyInterval]):
562599
window_interval, data_interval = intervals
563-
if window_interval is None or window_interval.type == IntervalType.NOT_APPLICABLE:
600+
if window_interval is None or window_interval.type is IntervalType.NOT_APPLICABLE:
564601
return Interval(start, end, IntervalType.NOT_APPLICABLE)
565602
else:
566-
window_id = ids.get(window_interval, len(ids))
567-
ids[window_interval] = window_id
568-
info = infos.get(window_id, set())
569-
infos[window_id] = info
570-
data_interval_type = data_interval.type if data_interval is not None else IntervalType.NEGATIVE
571-
info.add(data_interval_type)
572-
return IntervalWithCount(start, end, IntervalType.POSITIVE, window_id)
603+
window_type = window_types.get(window_interval.lower, None)
604+
if window_type is None:
605+
window_type = update_window_type(window_interval, data_interval)
606+
return Interval(start, end, window_type)
573607
person_indicator_windows = { key: indicator_windows for key in data_p.keys() }
574-
result = process.find_rectangles([ person_indicator_windows, data_p], temporary_window_interval)
575-
# Turn the temporary window intervals into the final
576-
# intervals by computing the interval types based on the
577-
# respective overlapping data intervals.
578-
def finalize_interval(interval):
579-
if isinstance(interval, IntervalWithCount):
580-
window_id = interval.count
581-
data_intervals = infos[window_id]
582-
# TODO(jmoringe): there should be a way to implement this with max(data_intervals)
583-
if IntervalType.POSITIVE in data_intervals:
584-
interval_type = IntervalType.POSITIVE
585-
elif IntervalType.NEGATIVE in data_intervals:
586-
interval_type = IntervalType.NEGATIVE
587-
elif IntervalType.NOT_APPLICABLE in data_intervals:
588-
interval_type = IntervalType.NOT_APPLICABLE
589-
else:
590-
assert IntervalType.NO_DATA in data_intervals
591-
interval_type = IntervalType.NO_DATA
592-
return Interval(interval.lower, interval.upper, interval_type)
593-
else:
594-
return interval
595-
result = { key: [ finalize_interval(i) for i in intervals ]
596-
for key, intervals in result.items() }
608+
result = process.find_rectangles([ person_indicator_windows, data_p],
609+
result_interval,
610+
is_same_result=is_same_interval)
597611

598612
return result
599613

0 commit comments

Comments
 (0)