2
2
import importlib
3
3
import logging
4
4
import os
5
+ from collections import defaultdict
5
6
from typing import Callable , Dict , List , Set , cast
6
7
7
8
import numpy as np
10
11
from sqlalchemy import CursorResult
11
12
12
13
from execution_engine .util .interval import IntervalType , interval_datetime
13
- from execution_engine .util .types import TimeRange
14
14
15
+ from ...util .types .timerange import TimeRange
15
16
from . import (
16
17
GeneralizedInterval ,
17
18
Interval ,
18
19
IntervalWithCount ,
19
20
interval_like ,
21
+ timerange_to_interval ,
20
22
)
21
23
24
+ IntervalConstructor = Callable [
25
+ [int , int , List [GeneralizedInterval ]], GeneralizedInterval
26
+ ]
27
+ SameResult = Callable [[List [GeneralizedInterval ], List [GeneralizedInterval ]], bool ]
28
+
22
29
PROCESS_RECTANGLE_VERSION = os .getenv ("PROCESS_RECTANGLE_VERSION" , "auto" )
23
30
24
31
@@ -69,7 +76,7 @@ def result_to_intervals(result: CursorResult) -> PersonIntervals:
69
76
"""
70
77
Converts the result of the interval operations to a list of intervals.
71
78
"""
72
- person_interval = {}
79
+ person_interval = defaultdict ( list )
73
80
74
81
for row in result :
75
82
if row .interval_end < row .interval_start :
@@ -81,15 +88,12 @@ def result_to_intervals(result: CursorResult) -> PersonIntervals:
81
88
raise ValueError ("Interval end is None" )
82
89
83
90
interval = Interval (
84
- row .interval_start . timestamp () ,
85
- row .interval_end . timestamp () ,
91
+ row .interval_start ,
92
+ row .interval_end ,
86
93
row .interval_type ,
87
94
)
88
95
89
- if row .person_id not in person_interval :
90
- person_interval [row .person_id ] = [interval ]
91
- else :
92
- person_interval [row .person_id ].append (interval )
96
+ person_interval [row .person_id ].append (interval )
93
97
94
98
for person_id in person_interval :
95
99
person_interval [person_id ] = _impl .union_rects (person_interval [person_id ])
@@ -219,10 +223,10 @@ def forward_fill(
219
223
220
224
if observation_window is not None :
221
225
last_interval = result [person_id ][- 1 ]
222
- if last_interval .upper < observation_window .end .timestamp ():
226
+ if last_interval .upper < int ( observation_window .end .timestamp () ):
223
227
result [person_id ][- 1 ] = Interval (
224
228
last_interval .lower ,
225
- observation_window .end .timestamp (),
229
+ int ( observation_window .end .timestamp () ),
226
230
last_interval .type ,
227
231
)
228
232
@@ -307,15 +311,11 @@ def complementary_intervals(
307
311
"""
308
312
309
313
interval_type_missing_persons = interval_type
310
- baseline_interval = Interval (
311
- observation_window .start .timestamp (),
312
- observation_window .end .timestamp (),
313
- interval_type_missing_persons ,
314
+ baseline_interval = timerange_to_interval (
315
+ observation_window , type_ = interval_type_missing_persons
314
316
)
315
- observation_window_mask = Interval (
316
- observation_window .start .timestamp (),
317
- observation_window .end .timestamp (),
318
- IntervalType .least_intersection_priority (),
317
+ observation_window_mask = timerange_to_interval (
318
+ observation_window , type_ = IntervalType .least_intersection_priority ()
319
319
)
320
320
321
321
result = {}
@@ -595,8 +595,8 @@ def add_interval(
595
595
assert previous_end < effective_start # type: ignore[unreachable]
596
596
intervals .append (
597
597
Interval (
598
- lower = effective_start .timestamp (),
599
- upper = effective_end .timestamp (),
598
+ lower = int ( effective_start .timestamp () ),
599
+ upper = int ( effective_end .timestamp () ),
600
600
type = interval_type ,
601
601
)
602
602
)
@@ -695,12 +695,20 @@ def find_overlapping_personal_windows(
695
695
696
696
def find_rectangles (
697
697
data : list [PersonIntervals ],
698
- interval_constructor : Callable ,
699
- is_same_result : Callable | None = None ,
698
+ interval_constructor : IntervalConstructor ,
699
+ is_same_result : SameResult | None = None ,
700
700
) -> Dict [int , List [GeneralizedInterval ]]:
701
701
"""
702
- Iterates over intervals for each person across all items in `data` and constructs new intervals
703
- ("rectangles") by applying `interval_constructor` to the overlapping intervals in each time range.
702
+ Constructs new intervals ("time slices") by combining multiple parallel tracks of intervals.
703
+
704
+ This iterates over all intervals for each person across the given `data` list. Whenever an
705
+ interval starts or ends on any track, that boundary can produce a new interval. For each
706
+ interval, we invoke `interval_constructor(start, end, active_intervals)` to decide how
707
+ to label it (e.g., POSITIVE, NEGATIVE).
708
+
709
+ If `is_same_result` is provided, it’s used to decide whether two adjacent slices have the
710
+ same "type" (so we can merge them). If not provided, a default routine is used that merges
711
+ slices if they have the same object identity and the same result.
704
712
705
713
:param data: A list of dictionaries, each mapping a person ID to a list of intervals.
706
714
:param interval_constructor: A callable that takes the time boundaries and the corresponding intervals
@@ -712,20 +720,21 @@ def find_rectangles(
712
720
# TODO(jmoringe): can this use _process_interval?
713
721
if len (data ) == 0 :
714
722
return {}
715
- else :
716
- keys : Set [int ] = set ()
717
- result : Dict [int , List [GeneralizedInterval ]] = dict ()
718
-
719
- for track in data :
720
- keys |= track .keys ()
721
-
722
- for key in keys :
723
- key_result = _impl .find_rectangles (
724
- [intervals .get (key , []) for intervals in data ],
725
- interval_constructor ,
726
- is_same_result = is_same_result ,
727
- )
728
- if len (key_result ) > 0 :
729
- result [key ] = key_result
730
723
731
- return result
724
+ # Collect all person IDs across all tracks
725
+ keys : Set [int ] = set ()
726
+ result : Dict [int , List [GeneralizedInterval ]] = dict ()
727
+
728
+ for track in data :
729
+ keys |= track .keys ()
730
+
731
+ for key in keys :
732
+ key_result = _impl .find_rectangles (
733
+ [intervals .get (key , []) for intervals in data ],
734
+ interval_constructor ,
735
+ is_same_result = is_same_result ,
736
+ )
737
+ if len (key_result ) > 0 :
738
+ result [key ] = key_result
739
+
740
+ return result
0 commit comments