Skip to content

Commit 9fecc67

Browse files
committed
feat: allow saving additional interval attributes in result_interval table
1 parent 84d0f61 commit 9fecc67

File tree

6 files changed

+45
-19
lines changed

6 files changed

+45
-19
lines changed

execution_engine/omop/db/celida/tables.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ class ResultInterval(Base): # noqa: D101
169169
interval_start: Mapped[datetime]
170170
interval_end: Mapped[datetime]
171171
interval_type = mapped_column(IntervalTypeEnum)
172-
172+
interval_ratio: Mapped[float] = mapped_column(
173+
nullable=True
174+
)
173175
execution_run: Mapped["ExecutionRun"] = relationship(
174176
primaryjoin="ResultInterval.run_id == ExecutionRun.run_id",
175177
)

execution_engine/omop/db/celida/views.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def interval_result_view() -> Select:
199199
rri.c.interval_type,
200200
rri.c.interval_start,
201201
rri.c.interval_end,
202+
rri.c.interval_ratio,
202203
)
203204
.select_from(rri)
204205
.outerjoin(pip, (rri.c.pi_pair_id == pip.c.pi_pair_id))

execution_engine/task/process/__init__.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import importlib
23
import os
34
import sys
@@ -44,4 +45,21 @@ def get_processing_module(
4445
AnyInterval = Interval | IntervalWithCount | IntervalWithTypeCounts
4546
GeneralizedInterval = None | AnyInterval
4647

47-
TInterval = TypeVar('TInterval', bound = AnyInterval)
48+
TInterval = TypeVar("TInterval", bound=AnyInterval)
49+
50+
def interval_like(interval: TInterval, start: int, end: int) -> TInterval:
51+
"""
52+
Return a copy of the given interval with its lower and upper bounds replaced.
53+
54+
Args:
55+
interval (I): The interval to copy. Must be one of Interval, IntervalWithCount, or IntervalWithTypeCounts.
56+
start (datetime): The new lower bound.
57+
end (datetime): The new upper bound.
58+
59+
Returns:
60+
I: A copy of the interval with updated lower and upper bounds.
61+
"""
62+
63+
return copy.copy(interval)._replace(
64+
lower=start, upper=end
65+
) # type: ignore[return-value]m

execution_engine/task/process/rectangle.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import importlib
33
import logging
44
import os
5-
from typing import Callable, cast
5+
from typing import Callable, cast, List
66

77
import numpy as np
88
import pendulum
@@ -12,7 +12,7 @@
1212
from execution_engine.util.interval import IntervalType, interval_datetime
1313
from execution_engine.util.types import TimeRange
1414

15-
from . import Interval, IntervalWithCount
15+
from . import Interval, IntervalWithCount, AnyInterval, GeneralizedInterval, interval_like
1616

1717
PROCESS_RECTANGLE_VERSION = os.getenv("PROCESS_RECTANGLE_VERSION", "auto")
1818

@@ -566,17 +566,11 @@ def mask_intervals(
566566
for person_id, intervals in mask.items()
567567
}
568568

569-
result = {}
570-
for person_id in data:
571-
# intersect every interval in data with every interval in mask
572-
person_result = _impl.intersect_interval_lists(
573-
data[person_id], person_mask[person_id]
574-
)
575-
if not person_result:
576-
continue
577-
578-
result[person_id] = person_result
579-
569+
def intersection_interval(start: int, end: int, intervals: List[GeneralizedInterval]) -> GeneralizedInterval:
570+
left_interval, right_interval = intervals
571+
if left_interval is not None and right_interval is not None:
572+
return interval_like(right_interval, start, end)
573+
result = find_rectangles([person_mask, data], intersection_interval)
580574
return result
581575

582576

execution_engine/task/task.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,18 @@ def store_result_in_db(
578578
run_id=bind_params["run_id"],
579579
cohort_category=self.category,
580580
)
581+
def interval_data(interval):
582+
data = dict(
583+
interval_start=interval.lower,
584+
interval_end=interval.upper,
585+
interval_type=interval.type,
586+
)
587+
if isinstance(interval, Interval):
588+
data["interval_ratio"] = None
589+
else:
590+
assert isinstance(interval, IntervalWithCount)
591+
data["interval_ratio"] = interval.count
592+
return data
581593

582594
try:
583595
with get_engine().begin() as conn:
@@ -586,9 +598,7 @@ def store_result_in_db(
586598
[
587599
{
588600
"person_id": person_id,
589-
"interval_start": normalized_interval.lower,
590-
"interval_end": normalized_interval.upper,
591-
"interval_type": normalized_interval.type,
601+
**interval_data(normalized_interval),
592602
**params,
593603
}
594604
for person_id, intervals in result.items()

execution_engine/util/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
)
1414
from execution_engine.util.value import ValueNumber, ValueNumeric
1515
from execution_engine.util.value.time import ValueCount, ValueDuration, ValuePeriod
16+
from execution_engine.task.process import AnyInterval
1617

17-
PersonIntervals = dict[int, Any]
18+
PersonIntervals = dict[int, AnyInterval]
1819

1920

2021
class TimeRange(BaseModel):

0 commit comments

Comments
 (0)