Skip to content

Commit 71c7978

Browse files
committed
refactor: serializable handling
1 parent 9b47d5d commit 71c7978

40 files changed

+1459
-1253
lines changed

apps/graph/static/script.js

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,15 @@ async function loadGraph(recommendationId) {
129129
return nodeColors[ele.data('category')] || '#666'; // Assign color based on 'type', with a default
130130
},
131131
'shape': function(ele) {
132-
return nodeShapes[ele.data('type')] || 'star'; // Assign color based on 'type', with a default
132+
return nodeShapes[ele.data("is_atom") ? "Symbol" : ele.data('type')] || 'star'; // Assign color based on 'type', with a default
133133
},
134134
'text-valign': 'center',
135135
'color': '#000000',
136136
'width': function(ele) {
137-
return ele.data('type') === 'Symbol' ? '120px': '40px';
137+
return ele.data('is_atom') ? '120px': '40px';
138138
},
139139
'height': function(ele) {
140-
return ele.data('type') === 'Symbol' ? '80px': '40px';
140+
return ele.data('is_atom') ? '80px': '40px';
141141
},
142142
'font-size': '10px',
143143
'text-wrap': 'wrap',

apps/rest_api/app/routers/recommendation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ async def recommendation_criteria(
4646

4747
data = []
4848

49-
for c in recommendation.flatten():
49+
for c in recommendation.atoms():
5050
data.append(
5151
{
5252
"description": c.description(),

apps/viz-backend/app/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ def get_execution_graph(recommendation_id: int, db: Session = Depends(get_db)) -
8989
if not result:
9090
raise HTTPException(status_code=404, detail="Recommendation not found")
9191

92+
print(result)
93+
9294
# Decode the bytes to a string and parse it as JSON
9395
execution_graph = json.loads(result.recommendation_execution_graph.decode("utf-8"))
9496

execution_engine/converter/action/procedure.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,15 @@ def _to_expression(self) -> logic.Symbol:
7171
# as Observation and Measurement normally expect a value.
7272
criterion = Measurement(
7373
concept=self._code,
74-
override_value_required=False,
74+
value_required=False,
7575
timing=self._timing,
7676
)
7777
case "Observation":
7878
# we need to explicitly set the VALUE_REQUIRED flag to false, otherwise creating the query will raise an error
7979
# as Observation and Measurement normally expect a value.
8080
criterion = Observation(
8181
concept=self._code,
82-
override_value_required=False,
82+
value_required=False,
8383
timing=self._timing,
8484
)
8585
case _:

execution_engine/converter/recommendation_factory.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def parse_recommendation_from_url(
6363
pi_pairs: list[PopulationInterventionPairExpr] = []
6464

6565
base_criterion = PatientsActiveDuringPeriod()
66-
66+
base_criterion.dict()
6767
for rec_plan in rec.plans():
6868

6969
# parse population and create criteria
@@ -72,6 +72,9 @@ def parse_recommendation_from_url(
7272
# parse intervention and create criteria
7373
actions = parser.parse_actions(rec_plan.actions, rec_plan)
7474

75+
# population_expr is assigned a NoDataPreservingAnd to ensure creation of negative intervals
76+
# todo: not sure we really need this - we can just always create negative intervals when store_results=True
77+
# in the graph
7578
pi_pair = PopulationInterventionPairExpr(
7679
population_expr=population_criteria,
7780
intervention_expr=actions,

execution_engine/converter/time_from_event/abstract.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ def valid(cls, fhir: Element) -> bool:
5555
raise NotImplementedError("must be implemented by class")
5656

5757
@abstractmethod
58-
def to_temporal_combination(self, combo: logic.BaseExpr) -> logic.TemporalCount:
58+
def to_temporal_combination(self, expr: logic.BaseExpr) -> logic.Expr:
5959
"""
60-
Wraps Criterion/CriterionCombinaion with a TemporalIndicatorCombination
60+
Wraps Criterion/CriterionCombination with a TemporalIndicatorCombination
6161
"""
6262
raise NotImplementedError("must be implemented by class")
6363

@@ -122,8 +122,8 @@ def valid(cls, fhir: Element) -> bool:
122122
return cls._event_vocabulary.is_system(cc.system) and cc.code == cls._event_code
123123

124124
@abstractmethod
125-
def to_temporal_combination(self, combo: logic.BaseExpr) -> logic.TemporalCount:
125+
def to_temporal_combination(self, expr: logic.BaseExpr) -> logic.Expr:
126126
"""
127-
Wraps Criterion/CriterionCombinaion with a TemporalIndicatorCombination
127+
Wraps expression with a TemporalIndicatorCombination
128128
"""
129129
raise NotImplementedError("must be implemented by class")

execution_engine/execution_engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from execution_engine.omop.cohort import PopulationInterventionPairExpr
1818
from execution_engine.omop.criterion.abstract import Criterion
1919
from execution_engine.omop.db.celida import tables as result_db
20-
from execution_engine.omop.serializable import Serializable
2120
from execution_engine.task import runner
21+
from execution_engine.util.serializable import Serializable
2222

2323

2424
class ExecutionEngine:
@@ -208,7 +208,7 @@ def load_recommendation_from_database(
208208
pi_pair, rec_db.recommendation_id
209209
)
210210

211-
for criterion in recommendation.flatten():
211+
for criterion in recommendation.atoms():
212212
self.register_criterion(criterion)
213213

214214
# All objects in the deserialized object graph must have an id.
@@ -219,7 +219,7 @@ def load_recommendation_from_database(
219219
for pi_pair in recommendation.population_intervention_pairs():
220220
assert pi_pair.id is not None
221221

222-
for criterion in recommendation.flatten():
222+
for criterion in recommendation.atoms():
223223
assert criterion.id is not None
224224

225225
return recommendation
@@ -294,7 +294,7 @@ def register_recommendation(self, recommendation: cohort.Recommendation) -> None
294294
pi_pair, recommendation_id=recommendation.id
295295
)
296296

297-
for criterion in recommendation.flatten():
297+
for criterion in recommendation.atoms():
298298
self.register_criterion(criterion)
299299

300300
assert recommendation.id is not None
@@ -304,7 +304,7 @@ def register_recommendation(self, recommendation: cohort.Recommendation) -> None
304304
for pi_pair in recommendation.population_intervention_pairs():
305305
assert pi_pair.id is not None
306306

307-
for criterion in recommendation.flatten():
307+
for criterion in recommendation.atoms():
308308
assert criterion.id is not None
309309

310310
# Update the recommendation in the database with the final

execution_engine/execution_graph/graph.py

Lines changed: 23 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import copy
2-
from typing import Any
1+
from typing import Any, cast
32

43
import networkx as nx
54

@@ -40,27 +39,6 @@ def is_sink(self, expr: logic.Expr) -> bool:
4039
"""
4140
return self.out_degree(expr) == 0
4241

43-
@classmethod
44-
def filter_symbols(cls, node: logic.Expr, filter_: logic.Expr) -> logic.Expr:
45-
"""
46-
Filter (=AND-combine) all symbols by the applied filter function
47-
48-
Used to filter all intervention criteria (symbols) by the population output in order to exclude
49-
all intervention events outside the population intervals, which may otherwise interfere with corrected
50-
determination of temporal combination, i.e. the presence of an intervention event during some time window.
51-
"""
52-
53-
if isinstance(node, logic.Symbol):
54-
return logic.LeftDependentToggle(left=filter_, right=node)
55-
56-
if hasattr(node, "args") and isinstance(node.args, tuple):
57-
converted_args = [cls.filter_symbols(a, filter_) for a in node.args]
58-
59-
if any(a is not b for a, b in zip(node.args, converted_args)):
60-
node.args = tuple(converted_args)
61-
62-
return node
63-
6442
@classmethod
6543
def from_expression(
6644
cls, expr: logic.Expr, base_criterion: Criterion, category: CohortCategory
@@ -71,9 +49,7 @@ def from_expression(
7149

7250
from execution_engine.omop.cohort import PopulationInterventionPairExpr
7351

74-
# we might make changes to the expression (e.g. filtering), so we must preserve
75-
# the original expression from the caller
76-
expr = copy.deepcopy(expr)
52+
expr_hash = hash(expr)
7753

7854
graph = cls()
7955
base_node = base_criterion
@@ -89,42 +65,33 @@ def traverse(
8965
parent: logic.Expr | None = None,
9066
category: CohortCategory = category,
9167
) -> None:
68+
69+
graph.add_node(expr, category=category, store_result=False)
70+
71+
if parent is not None:
72+
assert expr in graph.nodes
73+
assert parent in graph.nodes
74+
graph.add_edge(expr, parent)
75+
9276
if isinstance(expr, PopulationInterventionPairExpr):
9377
# special case for PopulationInterventionPairExpr:
9478
# we need explicitly set the category of the population and intervention nodes
9579

9680
p, i = expr.left, expr.right
9781

98-
# filter all intervention criteria by the output of the population - this is performed to filter out
99-
# intervention events that outside of the population intervals (i.e. the time windows during which
100-
# patients are part of the population) as otherwise events outside of the population time may be picked up
101-
# by Temporal criteria that determine the presence of some event or condition during a specific time window.
102-
103-
# the following command changes expr, i.e. we must not add expr before this command to the graph
104-
105-
i = cls.filter_symbols(i, filter_=p)
106-
10782
traverse(i, parent=expr, category=CohortCategory.INTERVENTION)
10883
traverse(p, parent=expr, category=CohortCategory.POPULATION)
10984

110-
graph.add_node(expr, category=category, store_result=False)
111-
112-
if parent is not None:
113-
graph.add_edge(expr, parent)
114-
115-
subgraph = graph.subgraph(nx.ancestors(graph, expr) | {expr})
116-
117-
subgraph.set_sink_nodes_store(bind_params=dict(pi_pair_id=expr._id))
118-
119-
# children are already traversed
120-
return
121-
122-
graph.add_node(expr, category=category, store_result=False)
85+
# create a subgraph for the pair in order to determine the sink nodes (i.e. the nodes that have no
86+
# outgoing edges) for POPULATION and POPULATION_INTERVENTION and mark them for storing their result
87+
# in the database
88+
subgraph = cast(
89+
ExecutionGraph, graph.subgraph(nx.ancestors(graph, expr) | {expr})
90+
)
91+
subgraph.set_sink_nodes_store(bind_params=dict(pi_pair_id=expr.id))
12392

124-
if parent is not None:
125-
graph.add_edge(expr, parent)
126-
127-
if expr.is_Atom:
93+
elif expr.is_Atom:
94+
assert expr in graph.nodes
12895
graph.nodes[expr]["store_result"] = True
12996
graph.add_edge(base_node, expr)
13097
else:
@@ -133,6 +100,9 @@ def traverse(
133100

134101
traverse(expr, category=category)
135102

103+
if hash(expr) != expr_hash:
104+
raise ValueError("Expression has been modified during traversal")
105+
136106
return graph
137107

138108
def add_node(
@@ -219,6 +189,7 @@ def to_cytoscape_dict(self) -> dict:
219189
self.nodes[node]["store_result"]
220190
), # Convert to string if necessary
221191
"is_sink": self.is_sink(node),
192+
"is_atom": node.is_Atom,
222193
"bind_params": self.nodes[node]["bind_params"],
223194
}
224195
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import copy
2+
3+
import execution_engine.util.logic as logic
4+
from execution_engine.constants import CohortCategory
5+
from execution_engine.execution_graph import ExecutionGraph
6+
from execution_engine.omop.cohort.population_intervention_pair import (
7+
PopulationInterventionPairExpr,
8+
)
9+
from execution_engine.omop.criterion.abstract import Criterion
10+
11+
12+
class RecommendationGraphBuilder:
13+
"""
14+
A builder class for constructing ExecutionGraph objects based on
15+
population/intervention expressions. It provides utility methods to filter
16+
intervention criteria by population constraints, and then converts the
17+
filtered expression into an ExecutionGraph ready for execution and storage.
18+
"""
19+
20+
@classmethod
21+
def filter_symbols(cls, node: logic.Expr, filter_: logic.Expr) -> logic.Expr:
22+
"""
23+
Filter (=AND-combine) all symbols by the applied filter function
24+
25+
Used to filter all intervention criteria (symbols) by the population output in order to exclude
26+
all intervention events outside the population intervals, which may otherwise interfere with corrected
27+
determination of temporal combination, i.e. the presence of an intervention event during some time window.
28+
29+
:param node: The expression node to be filtered.
30+
:type node: logic.Expr
31+
:param filter_: The filter expression to AND-combine with symbols in the node.
32+
:type filter_: logic.Expr
33+
:return: A new expression in which all symbols are constrained by the filter expression.
34+
:rtype: logic.Expr
35+
"""
36+
37+
if isinstance(node, logic.Symbol):
38+
return logic.LeftDependentToggle(left=filter_, right=node)
39+
elif isinstance(node, logic.Expr):
40+
converted_args = [cls.filter_symbols(a, filter_) for a in node.args]
41+
42+
if any(a is not b for a, b in zip(node.args, converted_args)):
43+
node.update_args(*converted_args)
44+
45+
return node
46+
47+
@classmethod
48+
def filter_intervention_criteria_by_population(cls, expr: logic.Expr) -> logic.Expr:
49+
"""
50+
Filter all intervention criteria in a given expression by the population part of the expression.
51+
52+
:param expr: The expression that may contain population and intervention parts.
53+
:type expr: logic.Expr
54+
:return: A new expression where all intervention symbols are constrained by the population intervals.
55+
:rtype: logic.Expr
56+
"""
57+
58+
from execution_engine.omop.cohort import PopulationInterventionPairExpr
59+
60+
# we might make changes to the expression (e.g. filtering), so we must preserve
61+
# the original expression from the caller
62+
expr = copy.deepcopy(expr)
63+
64+
def traverse(
65+
expr: logic.Expr,
66+
) -> None:
67+
if isinstance(expr, PopulationInterventionPairExpr):
68+
p, i = expr.left, expr.right
69+
70+
# filter all intervention criteria by the output of the population - this is performed to filter out
71+
# intervention events that outside of the population intervals (i.e. the time windows during which
72+
# patients are part of the population) as otherwise events outside of the population time may be picked up
73+
# by Temporal criteria that determine the presence of some event or condition during a specific time window.
74+
i = cls.filter_symbols(i, filter_=p)
75+
76+
expr.update_args(p, i)
77+
78+
traverse(i)
79+
traverse(p)
80+
81+
elif not expr.is_Atom:
82+
for child in expr.args:
83+
traverse(child)
84+
85+
traverse(expr)
86+
87+
return expr
88+
89+
@classmethod
90+
def build(cls, expr: logic.Expr, base_criterion: Criterion) -> ExecutionGraph:
91+
"""
92+
Build an ExecutionGraph for a population/intervention expression.
93+
94+
If the expression is a PopulationInterventionPairExpr, it is wrapped in a
95+
NonSimplifiableAnd to ensure a top-level result entry is generated in the database.
96+
Then the expression is filtered and converted into an ExecutionGraph with the
97+
appropriate sink nodes and edges.
98+
99+
:param expr: The population/intervention expression to build the graph from.
100+
:type expr: logic.Expr
101+
:param base_criterion: The base criterion used to label the execution graph.
102+
:type base_criterion: Criterion
103+
:return: The constructed ExecutionGraph for the given expression.
104+
:rtype: ExecutionGraph
105+
"""
106+
if isinstance(expr, PopulationInterventionPairExpr):
107+
expr = logic.NonSimplifiableAnd(expr)
108+
109+
# Make sure the expr is filtered
110+
expr_filtered = cls.filter_intervention_criteria_by_population(expr)
111+
112+
graph = ExecutionGraph.from_expression(
113+
expr_filtered,
114+
base_criterion=base_criterion,
115+
category=CohortCategory.POPULATION_INTERVENTION,
116+
)
117+
118+
p_sink_nodes = graph.sink_nodes(CohortCategory.POPULATION)
119+
graph.set_sink_nodes_store(
120+
bind_params={}, desired_category=CohortCategory.POPULATION_INTERVENTION
121+
)
122+
123+
p_combination_node = logic.NoDataPreservingOr(*p_sink_nodes)
124+
graph.add_node(
125+
p_combination_node, store_result=True, category=CohortCategory.POPULATION
126+
)
127+
graph.add_edges_from((src, p_combination_node) for src in p_sink_nodes)
128+
129+
return graph

0 commit comments

Comments
 (0)