Skip to content

Commit aaa89c3

Browse files
committed
refactor: using int instead of float time in intervals
docs: add code comments
1 parent 63f8d5f commit aaa89c3

22 files changed

+756
-423
lines changed

execution_engine/omop/criterion/abstract.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from execution_engine.util.interval import IntervalType
2828
from execution_engine.util.serializable import SerializableDataClassABC
2929
from execution_engine.util.sql import SelectInto, select_into
30-
from execution_engine.util.types import PersonIntervals, TimeRange
30+
from execution_engine.util.types import PersonIntervals
31+
from execution_engine.util.types.timerange import TimeRange
3132

3233
__all__ = [
3334
"Criterion",

execution_engine/omop/criterion/point_in_time.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from execution_engine.omop.criterion.concept import ConceptCriterion
1111
from execution_engine.task.process import get_processing_module
1212
from execution_engine.util.interval import IntervalType
13-
from execution_engine.util.types import PersonIntervals, TimeRange, Timing
13+
from execution_engine.util.types import PersonIntervals, Timing
14+
from execution_engine.util.types.timerange import TimeRange
1415
from execution_engine.util.value import Value
1516

1617
process = get_processing_module()

execution_engine/omop/sqlclient.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,35 @@ def _enable_database_triggers(
5858
cursor.close()
5959

6060

61+
def datetime_cols_to_epoch(stmt: sqlalchemy.Select) -> sqlalchemy.Select:
62+
"""
63+
Given a SQLAlchemy 2.0 Select that has columns labeled 'interval_start'
64+
or 'interval_end', replace those column expressions with
65+
EXTRACT(EPOCH FROM <expr>)::BIGINT so they become integer timestamps.
66+
67+
Returns a new Select object with the replaced columns.
68+
"""
69+
new_columns = []
70+
71+
for col in stmt.selected_columns:
72+
label = getattr(col, "name")
73+
74+
if label in ("interval_start", "interval_end"):
75+
# We'll wrap col in EXTRACT(EPOCH FROM col)::BIGINT,
76+
new_col = (
77+
sqlalchemy.func.extract("epoch", col)
78+
.cast(sqlalchemy.BigInteger)
79+
.label(label)
80+
)
81+
new_columns.append(new_col)
82+
else:
83+
new_columns.append(col)
84+
85+
new_stmt = stmt.with_only_columns(*new_columns, maintain_column_froms=True)
86+
87+
return new_stmt
88+
89+
6190
class OMOPSQLClient:
6291
"""A client for the OMOP SQL database.
6392

execution_engine/task/process/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from collections import namedtuple
66
from typing import TypeVar
77

8+
from execution_engine.util.interval import IntervalType
9+
from execution_engine.util.types.timerange import TimeRange
10+
811

912
def get_processing_module(
1013
name: str = "rectangle", version: str = "auto"
@@ -61,3 +64,14 @@ def interval_like(interval: TInterval, start: int, end: int) -> TInterval:
6164
"""
6265

6366
return interval._replace(lower=start, upper=end) # type: ignore[return-value]
67+
68+
69+
def timerange_to_interval(tr: TimeRange, type_: IntervalType) -> Interval:
70+
"""
71+
Converts a timerange to an interval with the supplied type.
72+
"""
73+
return Interval(
74+
lower=int(tr.start.timestamp()),
75+
upper=int(tr.end.timestamp()),
76+
type=type_,
77+
)

execution_engine/task/process/rectangle.py

Lines changed: 49 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import importlib
33
import logging
44
import os
5+
from collections import defaultdict
56
from typing import Callable, Dict, List, Set, cast
67

78
import numpy as np
@@ -10,15 +11,21 @@
1011
from sqlalchemy import CursorResult
1112

1213
from execution_engine.util.interval import IntervalType, interval_datetime
13-
from execution_engine.util.types import TimeRange
1414

15+
from ...util.types.timerange import TimeRange
1516
from . import (
1617
GeneralizedInterval,
1718
Interval,
1819
IntervalWithCount,
1920
interval_like,
21+
timerange_to_interval,
2022
)
2123

24+
IntervalConstructor = Callable[
25+
[int, int, List[GeneralizedInterval]], GeneralizedInterval
26+
]
27+
SameResult = Callable[[List[GeneralizedInterval], List[GeneralizedInterval]], bool]
28+
2229
PROCESS_RECTANGLE_VERSION = os.getenv("PROCESS_RECTANGLE_VERSION", "auto")
2330

2431

@@ -69,7 +76,7 @@ def result_to_intervals(result: CursorResult) -> PersonIntervals:
6976
"""
7077
Converts the result of the interval operations to a list of intervals.
7178
"""
72-
person_interval = {}
79+
person_interval = defaultdict(list)
7380

7481
for row in result:
7582
if row.interval_end < row.interval_start:
@@ -81,15 +88,12 @@ def result_to_intervals(result: CursorResult) -> PersonIntervals:
8188
raise ValueError("Interval end is None")
8289

8390
interval = Interval(
84-
row.interval_start.timestamp(),
85-
row.interval_end.timestamp(),
91+
row.interval_start,
92+
row.interval_end,
8693
row.interval_type,
8794
)
8895

89-
if row.person_id not in person_interval:
90-
person_interval[row.person_id] = [interval]
91-
else:
92-
person_interval[row.person_id].append(interval)
96+
person_interval[row.person_id].append(interval)
9397

9498
for person_id in person_interval:
9599
person_interval[person_id] = _impl.union_rects(person_interval[person_id])
@@ -219,10 +223,10 @@ def forward_fill(
219223

220224
if observation_window is not None:
221225
last_interval = result[person_id][-1]
222-
if last_interval.upper < observation_window.end.timestamp():
226+
if last_interval.upper < int(observation_window.end.timestamp()):
223227
result[person_id][-1] = Interval(
224228
last_interval.lower,
225-
observation_window.end.timestamp(),
229+
int(observation_window.end.timestamp()),
226230
last_interval.type,
227231
)
228232

@@ -307,15 +311,11 @@ def complementary_intervals(
307311
"""
308312

309313
interval_type_missing_persons = interval_type
310-
baseline_interval = Interval(
311-
observation_window.start.timestamp(),
312-
observation_window.end.timestamp(),
313-
interval_type_missing_persons,
314+
baseline_interval = timerange_to_interval(
315+
observation_window, type_=interval_type_missing_persons
314316
)
315-
observation_window_mask = Interval(
316-
observation_window.start.timestamp(),
317-
observation_window.end.timestamp(),
318-
IntervalType.least_intersection_priority(),
317+
observation_window_mask = timerange_to_interval(
318+
observation_window, type_=IntervalType.least_intersection_priority()
319319
)
320320

321321
result = {}
@@ -595,8 +595,8 @@ def add_interval(
595595
assert previous_end < effective_start # type: ignore[unreachable]
596596
intervals.append(
597597
Interval(
598-
lower=effective_start.timestamp(),
599-
upper=effective_end.timestamp(),
598+
lower=int(effective_start.timestamp()),
599+
upper=int(effective_end.timestamp()),
600600
type=interval_type,
601601
)
602602
)
@@ -695,12 +695,20 @@ def find_overlapping_personal_windows(
695695

696696
def find_rectangles(
697697
data: list[PersonIntervals],
698-
interval_constructor: Callable,
699-
is_same_result: Callable | None = None,
698+
interval_constructor: IntervalConstructor,
699+
is_same_result: SameResult | None = None,
700700
) -> Dict[int, List[GeneralizedInterval]]:
701701
"""
702-
Iterates over intervals for each person across all items in `data` and constructs new intervals
703-
("rectangles") by applying `interval_constructor` to the overlapping intervals in each time range.
702+
Constructs new intervals ("time slices") by combining multiple parallel tracks of intervals.
703+
704+
This iterates over all intervals for each person across the given `data` list. Whenever an
705+
interval starts or ends on any track, that boundary can produce a new interval. For each
706+
interval, we invoke `interval_constructor(start, end, active_intervals)` to decide how
707+
to label it (e.g., POSITIVE, NEGATIVE).
708+
709+
If `is_same_result` is provided, it’s used to decide whether two adjacent slices have the
710+
same "type" (so we can merge them). If not provided, a default routine is used that merges
711+
slices if they have the same object identity and the same result.
704712
705713
:param data: A list of dictionaries, each mapping a person ID to a list of intervals.
706714
:param interval_constructor: A callable that takes the time boundaries and the corresponding intervals
@@ -712,20 +720,21 @@ def find_rectangles(
712720
# TODO(jmoringe): can this use _process_interval?
713721
if len(data) == 0:
714722
return {}
715-
else:
716-
keys: Set[int] = set()
717-
result: Dict[int, List[GeneralizedInterval]] = dict()
718-
719-
for track in data:
720-
keys |= track.keys()
721-
722-
for key in keys:
723-
key_result = _impl.find_rectangles(
724-
[intervals.get(key, []) for intervals in data],
725-
interval_constructor,
726-
is_same_result=is_same_result,
727-
)
728-
if len(key_result) > 0:
729-
result[key] = key_result
730723

731-
return result
724+
# Collect all person IDs across all tracks
725+
keys: Set[int] = set()
726+
result: Dict[int, List[GeneralizedInterval]] = dict()
727+
728+
for track in data:
729+
keys |= track.keys()
730+
731+
for key in keys:
732+
key_result = _impl.find_rectangles(
733+
[intervals.get(key, []) for intervals in data],
734+
interval_constructor,
735+
is_same_result=is_same_result,
736+
)
737+
if len(key_result) > 0:
738+
result[key] = key_result
739+
740+
return result

0 commit comments

Comments
 (0)