Skip to content

Commit 89efdbd

Browse files
committed
refactor: fix window_types problem without reset callback
And add test.
1 parent a9d17b8 commit 89efdbd

File tree

3 files changed

+164
-33
lines changed

3 files changed

+164
-33
lines changed

execution_engine/task/process/rectangle.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ def create_time_intervals(
517517
end_time: datetime.time,
518518
interval_type: IntervalType,
519519
timezone: pytz.tzinfo.DstTzInfo | str,
520-
) -> tuple[Interval, ...]:
520+
) -> list[Interval]:
521521
"""
522522
Constructs a list of time intervals within a specified date range, each defined by daily start and end times.
523523
@@ -654,8 +654,7 @@ def add_interval(
654654
# Move to the next day
655655
current_date += datetime.timedelta(days=1)
656656

657-
# use a tuple for windows to make sure it is immutable (and can be shared by all persons)
658-
return tuple(intervals)
657+
return intervals
659658

660659

661660
def find_overlapping_personal_windows(
@@ -697,7 +696,6 @@ def find_rectangles(
697696
data: list[PersonIntervals],
698697
interval_constructor: Callable,
699698
is_same_result: Callable | None = None,
700-
reset: Callable | None = None,
701699
) -> PersonIntervals:
702700
"""
703701
Iterates over intervals for each person across all items in `data` and constructs new intervals
@@ -717,21 +715,11 @@ def find_rectangles(
717715
keys: Set[int] = set()
718716
for track in data:
719717
keys |= track.keys()
720-
result = {}
721-
722-
for key in keys:
723-
724-
if reset:
725-
reset()
726-
727-
intervals_for_person: list[list[Interval]] = [
728-
intervals.get(key, []) for intervals in data
729-
]
730-
intervals = _impl.find_rectangles(
731-
intervals_for_person,
718+
return {
719+
key: _impl.find_rectangles(
720+
[intervals.get(key, []) for intervals in data],
732721
interval_constructor,
733722
is_same_result=is_same_result,
734723
)
735-
result[key] = intervals
736-
737-
return result
724+
for key in keys
725+
}

execution_engine/task/task.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import base64
2+
import copy
23
import datetime
34
import json
45
import logging
@@ -655,19 +656,13 @@ def get_start_end_from_interval_type(
655656
)
656657

657658
# Incrementally compute the interval type for each window
658-
# interval.
659-
window_types: dict[AnyInterval, IntervalType] = (
660-
dict()
661-
) # window interval -> interval type
662-
663-
# todo: @moringenj - is this additional function really a good solution?
664-
def reset_window_types() -> None:
665-
window_types.clear()
659+
# interval. Maps id of window interval -> interval type
660+
window_types: dict[int, IntervalType] = dict()
666661

667662
def update_window_type(
668663
window_interval: AnyInterval, data_interval: AnyInterval
669664
) -> IntervalType:
670-
window_type = window_types.get(window_interval.lower, None)
665+
window_type = window_types.get(id(window_interval), None)
671666

672667
if data_interval is None or data_interval.type is IntervalType.NEGATIVE:
673668
if window_type is not IntervalType.POSITIVE:
@@ -681,7 +676,7 @@ def update_window_type(
681676
assert data_interval.type is IntervalType.NO_DATA
682677
if window_type is None:
683678
window_type = IntervalType.NO_DATA
684-
window_types[window_interval.lower] = window_type
679+
window_types[id(window_interval)] = window_type
685680

686681
return window_type
687682

@@ -690,7 +685,7 @@ def update_window_type(
690685
# result interval window types based on the data
691686
# intervals.
692687
def is_same_interval(
693-
left_intervals: tuple[AnyInterval], right_intervals: tuple[AnyInterval]
688+
left_intervals: List[AnyInterval], right_intervals: List[AnyInterval]
694689
) -> bool:
695690
left_window_interval, left_data_interval = left_intervals
696691
right_window_interval, right_data_interval = right_intervals
@@ -723,17 +718,22 @@ def result_interval(
723718
):
724719
return Interval(start, end, IntervalType.NOT_APPLICABLE)
725720
else:
726-
window_type = window_types.get(window_interval.lower, None)
721+
window_type = window_types.get(id(window_interval), None)
727722
if window_type is None:
728723
window_type = update_window_type(window_interval, data_interval)
729724
return Interval(start, end, window_type)
730725

731-
person_indicator_windows = {key: indicator_windows for key in data_p.keys()}
726+
# Make separate copies of the intervals for each person so
727+
# that the object identity of each interval is unique and
728+
# can be used as a dictionary key.
729+
person_indicator_windows = {
730+
key: [ copy.copy(window) for window in indicator_windows ]
731+
for key in data_p.keys()
732+
}
732733
result = process.find_rectangles(
733734
[person_indicator_windows, data_p],
734735
result_interval,
735736
is_same_result=is_same_interval,
736-
reset=reset_window_types,
737737
)
738738

739739
return result

tests/execution_engine/omop/criterion/combination/test_temporal_combination.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2038,3 +2038,146 @@ def test_interval_ratio_on_database(
20382038
result_tuples, expected[person.person_id]
20392039
):
20402040
assert result_tuple == expected_tuple
2041+
2042+
class TestIndicatorWindowsMulitplePatients(TestCriterionCombinationDatabase):
2043+
"""
2044+
This test ensures that the data TemporalCount operator works
2045+
independently between persons within a PersonIntervals data set.
2046+
2047+
This is mostly a regression test since at one point the exact
2048+
problem of cross-talk between the data structures for different
2049+
persons caused the operator to return incorrect results.
2050+
"""
2051+
2052+
@pytest.fixture
2053+
def observation_window(self) -> TimeRange:
2054+
return TimeRange(
2055+
name="observation", start="2025-02-18 14:55:00+01:00", end="2025-02-22 12:00:00+01:00"
2056+
)
2057+
2058+
def patient_events(self, db_session, visit_occurrence):
2059+
person_id = visit_occurrence.person_id
2060+
events = []
2061+
c1 = create_condition(
2062+
vo=visit_occurrence,
2063+
condition_concept_id=concept_covid19.concept_id,
2064+
condition_start_datetime=pendulum.parse("2025-02-19 08:00:00+01:00"),
2065+
condition_end_datetime=pendulum.parse("2025-02-21 02:00:00+01:00"),
2066+
)
2067+
events.append(c1)
2068+
if person_id == 1:
2069+
e1 = create_procedure(
2070+
vo=visit_occurrence,
2071+
procedure_concept_id=concept_delir_screening.concept_id,
2072+
start_datetime=pendulum.parse("2025-02-19 18:00:00+01:00"),
2073+
end_datetime=pendulum.parse("2025-02-19 18:01:00+01:00"),
2074+
)
2075+
events.append(e1)
2076+
db_session.add_all(events)
2077+
db_session.commit()
2078+
2079+
@pytest.mark.parametrize(
2080+
"population,intervention,expected",
2081+
[
2082+
(
2083+
logic.And(c2), # population
2084+
temporal_logic_util.Day(criterion=delir_screening),
2085+
{
2086+
1: [
2087+
(
2088+
IntervalType.NOT_APPLICABLE,
2089+
pendulum.parse("2025-02-18 17:55:00+01:00"),
2090+
pendulum.parse("2025-02-19 07:59:59+01:00"),
2091+
),
2092+
(
2093+
IntervalType.POSITIVE,
2094+
pendulum.parse("2025-02-19 08:00:00+01:00"),
2095+
pendulum.parse("2025-02-19 23:59:59+01:00"),
2096+
),
2097+
(
2098+
IntervalType.NEGATIVE,
2099+
pendulum.parse("2025-02-20 00:00:00+01:00"),
2100+
pendulum.parse("2025-02-21 02:00:00+01:00"),
2101+
),
2102+
(
2103+
IntervalType.NOT_APPLICABLE,
2104+
pendulum.parse("2025-02-21 02:00:01+01:00"),
2105+
pendulum.parse("2025-02-22 05:30:00+01:00"),
2106+
),
2107+
],
2108+
2: [
2109+
(
2110+
IntervalType.NOT_APPLICABLE,
2111+
pendulum.parse("2025-02-18 17:55:00+01:00"),
2112+
pendulum.parse("2025-02-19 07:59:59+01:00"),
2113+
),
2114+
# If cross-talk between the data structures
2115+
# for different persons occurs, parts of the
2116+
# following interval may turn positive because
2117+
# of the results for the first person.
2118+
(
2119+
IntervalType.NEGATIVE,
2120+
pendulum.parse("2025-02-19 08:00:00+01:00"),
2121+
pendulum.parse("2025-02-21 02:00:00+01:00"),
2122+
),
2123+
(
2124+
IntervalType.NOT_APPLICABLE,
2125+
pendulum.parse("2025-02-21 02:00:01+01:00"),
2126+
pendulum.parse("2025-02-22 05:30:00+01:00"),
2127+
),
2128+
],
2129+
},
2130+
),
2131+
],
2132+
)
2133+
def test_multiple_patients_on_database(
2134+
self,
2135+
person,
2136+
db_session,
2137+
population,
2138+
intervention,
2139+
base_criterion,
2140+
expected,
2141+
observation_window,
2142+
criteria,
2143+
):
2144+
persons = person[:2]
2145+
vos = []
2146+
for person in persons:
2147+
visit = create_visit(
2148+
person_id=person.person_id,
2149+
visit_start_datetime=observation_window.start
2150+
+ datetime.timedelta(hours=3),
2151+
visit_end_datetime=observation_window.end
2152+
- datetime.timedelta(hours=6.5),
2153+
visit_concept_id=concepts.INTENSIVE_CARE,
2154+
)
2155+
vos.append(visit)
2156+
self.patient_events(db_session, visit)
2157+
2158+
db_session.add_all(vos)
2159+
db_session.commit()
2160+
2161+
self.insert_expression(
2162+
db_session, population, intervention, base_criterion, observation_window
2163+
)
2164+
2165+
df = self.fetch_interval_result(
2166+
db_session,
2167+
pi_pair_id=self.pi_pair_id,
2168+
criterion_id=None,
2169+
category=CohortCategory.POPULATION_INTERVENTION,
2170+
)
2171+
2172+
for person in persons:
2173+
result = df.query(f"person_id=={person.person_id}")
2174+
result_tuples = list(
2175+
result[ [ "interval_type", "interval_start", "interval_end" ] ]
2176+
.fillna("nan")
2177+
.itertuples(index=False, name=None)
2178+
)
2179+
2180+
for result_tuple, expected_tuple in zip(
2181+
result_tuples, expected[person.person_id]
2182+
):
2183+
assert result_tuple == expected_tuple

0 commit comments

Comments
 (0)