Skip to content

Commit b05950d

Browse files
committed
fix: deserialization, base criterion in graph
1 parent b10c405 commit b05950d

File tree

6 files changed

+197
-43
lines changed

6 files changed

+197
-43
lines changed

execution_engine/execution_graph/graph.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,12 @@ def traverse(
9090
)
9191
subgraph.set_sink_nodes_store(bind_params=dict(pi_pair_id=expr.id))
9292

93+
elif expr == base_node:
94+
# don't need to do anything - only non-base criteria are connected to the base criterion,
95+
# otherwise we get a cyclic graph
96+
pass
9397
elif expr.is_Atom:
94-
assert expr in graph.nodes
98+
assert expr in graph.nodes, "Node not found in graph"
9599
graph.nodes[expr]["store_result"] = True
96100
graph.add_edge(base_node, expr)
97101
else:

execution_engine/omop/cohort/graph_builder.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def filter_symbols(cls, node: logic.Expr, filter_: logic.Expr) -> logic.Expr:
3636
:rtype: logic.Expr
3737
"""
3838

39+
node = copy.copy(node)
40+
3941
if isinstance(node, logic.Symbol):
4042
return logic.LeftDependentToggle(left=filter_, right=node)
4143
elif isinstance(node, logic.Expr):
@@ -140,7 +142,10 @@ def build(cls, expr: logic.Expr, base_criterion: Criterion) -> ExecutionGraph:
140142
)
141143
graph.add_edges_from((src, p_combination_node) for src in p_sink_nodes)
142144

145+
if graph.in_degree(base_criterion) != 0:
146+
raise AssertionError("Base criterion must not have incoming edges")
147+
143148
if not nx.is_directed_acyclic_graph(graph):
144-
raise ValueError("Graph is not acyclic")
149+
raise AssertionError("Graph is not acyclic")
145150

146151
return graph

execution_engine/task/creator.py

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,96 @@
1+
import pickle # nosec
2+
13
import networkx as nx
4+
from typing_extensions import Any
25

36
import execution_engine.util.logic as logic
47
from execution_engine.task.task import Task
58

69

10+
def assert_pickle_roundtrip(obj: logic.BaseExpr) -> None:
11+
"""
12+
Serializes 'obj' via pickle (the same method multiprocessing would use),
13+
then deserializes it, and finally compares the original object to the result.
14+
15+
:param obj: The object to serialize/deserialize.
16+
:raises AssertionError: If the object does not match its clone.
17+
:return: The deserialized clone (for further inspection if needed).
18+
"""
19+
pickled = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) # nosec
20+
clone = pickle.loads(pickled) # nosec
21+
22+
if isinstance(obj, logic.CountOperator):
23+
if obj.dict()["data"]["threshold"] is None:
24+
raise AssertionError("Threshold must be set")
25+
26+
if obj == clone:
27+
return # They are considered equal, so nothing to do.
28+
29+
# If they're unequal, compare their dict() representations to find differences.
30+
d1 = obj.dict()
31+
d2 = clone.dict()
32+
33+
if d1 == d2:
34+
# If they're unequal but dicts are the same, there's some internal difference
35+
# not visible via .dict(). Just alert that we can't show details.
36+
raise AssertionError(
37+
f"Objects differ in __eq__, but their dict() representations are identical.\n"
38+
f"obj: {obj}\n"
39+
f"clone: {clone}"
40+
)
41+
42+
# Otherwise, gather all leaf-level differences in d1 vs d2.
43+
diffs = _compare_dicts_leaf_level(d1, d2)
44+
diff_msg = "\n".join(diffs)
45+
46+
raise AssertionError(
47+
f"Object does not match its clone after round-trip!\n\n"
48+
f"Differences at leaf level in .dict() representations:\n{diff_msg}"
49+
)
50+
51+
52+
def _compare_dicts_leaf_level(d1: Any, d2: Any, path: str = "") -> list[str]:
53+
"""
54+
Recursively compare two dict/list/tuple/scalar structures and return
55+
a list of strings describing differences at the leaf level.
56+
57+
:param d1, d2: Potentially nested structures (dict, list, tuple, scalar).
58+
:param path: Path string to locate the current point in the structure.
59+
:return: List of difference descriptions.
60+
"""
61+
differences = []
62+
63+
# If both are dicts, recurse into matching keys
64+
if isinstance(d1, dict) and isinstance(d2, dict):
65+
all_keys = set(d1.keys()) | set(d2.keys())
66+
for key in sorted(all_keys):
67+
sub_path = f"{path}.{key}" if path else str(key)
68+
if key not in d1:
69+
differences.append(f"[MISSING IN ORIGINAL] {sub_path} => {d2[key]!r}")
70+
elif key not in d2:
71+
differences.append(f"[MISSING IN CLONE] {sub_path} => {d1[key]!r}")
72+
else:
73+
differences.extend(
74+
_compare_dicts_leaf_level(d1[key], d2[key], sub_path)
75+
)
76+
77+
# If both are lists/tuples, compare element by element
78+
elif isinstance(d1, (list, tuple)) and isinstance(d2, (list, tuple)):
79+
if len(d1) != len(d2):
80+
differences.append(f"[LEN MISMATCH] {path} => {len(d1)} vs {len(d2)}")
81+
else:
82+
for i, (item1, item2) in enumerate(zip(d1, d2)):
83+
sub_path = f"{path}[{i}]"
84+
differences.extend(_compare_dicts_leaf_level(item1, item2, sub_path))
85+
86+
# Otherwise, treat them as leaf values and compare directly
87+
else:
88+
if d1 != d2:
89+
differences.append(f"[VALUE MISMATCH] {path} => {d1!r} vs {d2!r}")
90+
91+
return differences
92+
93+
794
class TaskCreator:
895
"""
996
A TaskCreator object creates a Task tree for an expression and its dependencies.
@@ -43,12 +130,11 @@ def node_to_task(expr: logic.Expr, attr: dict) -> Task:
43130

44131
flattened_tasks = list(tasks.values())
45132

46-
# we will make sure all tasks are depickled correctly
47-
for i, node in enumerate(tasks):
48-
if logic.Expr.from_dict(node.dict(include_id=True)) != node:
49-
raise RuntimeError(
50-
"Expected depickled node to be the same as initial node."
51-
)
133+
# we will make sure all tasks are depickled correctly [commented out for performance reasons]
134+
# from tqdm import tqdm
135+
#
136+
# for node in tqdm(tasks):
137+
# assert_pickle_roundtrip(node)
52138

53139
assert (
54140
len(set(flattened_tasks))

execution_engine/util/logic.py

Lines changed: 71 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -58,25 +58,31 @@ def _recreate(cls, args: Any, kwargs: dict) -> "Expr":
5858
self.set_id(_id)
5959
return self
6060

61-
def _reduce_helper(self, ivars_map: dict | None = None) -> tuple[Callable, tuple]:
61+
def _reduce_helper(
62+
self, ivars_map: dict | None = None, args: tuple | None = None
63+
) -> tuple[Callable, tuple]:
6264
"""
6365
Return a picklable tuple that calls self._recreate and passes in
6466
(self.args, combined_ivars).
6567
6668
:param ivars_map: A dictionary for renaming keys in the instance variables.
6769
For each old_key -> new_key in ivars_map, if old_key exists
6870
in data, it will be removed and stored under new_key instead.
71+
:param args: Optionally, the args parameter can be specified. If it is None, self.args is used.
6972
"""
70-
data = dict(self.get_instance_variables())
71-
data["_id"] = self._id
73+
data = self.get_instance_variables()
74+
data["_id"] = self._id # type: ignore[index]
7275

7376
# Apply key renaming if ivars_map is provided
7477
if ivars_map:
7578
for old_key, new_key in ivars_map.items():
7679
if old_key in data:
77-
data[new_key] = data.pop(old_key)
80+
data[new_key] = data.pop(old_key) # type: ignore[index,union-attr]
7881

79-
return (self._recreate, (self.args, data))
82+
if args is None:
83+
args = self.args
84+
85+
return (self._recreate, (args, data))
8086

8187
def __reduce__(self) -> tuple[Callable, tuple]:
8288
"""
@@ -404,15 +410,30 @@ def __new__(
404410

405411
return self
406412

413+
def _replace_map(self) -> dict:
414+
replace = {}
415+
416+
if self.count_min and self.count_max:
417+
assert self.count_min == self.count_max
418+
replace["count_min"] = "threshold"
419+
elif self.count_min:
420+
replace["count_min"] = "threshold"
421+
elif self.count_max:
422+
replace["count_max"] = "threshold"
423+
else:
424+
raise AttributeError("At least one of count_min or count_max must be set")
425+
426+
return replace
427+
407428
def __reduce__(self) -> tuple[Callable, tuple]:
408429
"""
409430
Reduce the expression to its arguments and category.
410431
411432
Required for pickling (e.g. when using multiprocessing).
412433
413-
:return: Tuple of the class, arguments, and category.
434+
:return: Tuple of the class, argument.
414435
"""
415-
return self._reduce_helper({"count_min": "threshold", "count_max": "threshold"})
436+
return self._reduce_helper(self._replace_map())
416437

417438

418439
class Count(CountOperator):
@@ -683,6 +704,31 @@ def _validate_time_inputs(
683704
f"Invalid criterion - expected Criterion or CriterionCombination, got {type(interval_criterion)}"
684705
)
685706

707+
def __reduce__(self) -> tuple[Callable, tuple]:
708+
"""
709+
Reduce the expression to its arguments and category.
710+
711+
Required for pickling (e.g. when using multiprocessing).
712+
713+
:return: Tuple of the class, argument.
714+
"""
715+
if self.interval_criterion:
716+
717+
if len(self.args) <= 1:
718+
raise ValueError(
719+
"More than one argument required if interval_criterion is set"
720+
)
721+
722+
args, pop = self.args[:-1], self.args[-1]
723+
724+
if pop != self.interval_criterion:
725+
raise ValueError(
726+
f"Expected last argument to be the interval_criterion, got {str(pop)}"
727+
)
728+
return self._reduce_helper(self._replace_map(), args=args)
729+
730+
return super().__reduce__()
731+
686732
def dict(self, include_id: bool = False) -> dict:
687733
"""
688734
Get a dictionary representation of the object.
@@ -746,11 +792,11 @@ class TemporalMinCount(TemporalCount):
746792
def __new__(
747793
cls,
748794
*args: Any,
749-
threshold: int | None,
750-
start_time: time | None,
751-
end_time: time | None,
752-
interval_type: TimeIntervalType | None,
753-
interval_criterion: BaseExpr | None,
795+
threshold: int,
796+
start_time: time | None = None,
797+
end_time: time | None = None,
798+
interval_type: TimeIntervalType | None = None,
799+
interval_criterion: BaseExpr | None = None,
754800
**kwargs: Any,
755801
) -> "TemporalMinCount":
756802
"""
@@ -776,7 +822,7 @@ def dict(self, include_id: bool = False) -> dict:
776822
Get a dictionary representation of the object.
777823
"""
778824
data = super().dict(include_id=include_id)
779-
data.update({"threshold": self.count_min})
825+
data["data"].update({"threshold": self.count_min})
780826
return data
781827

782828

@@ -788,11 +834,11 @@ class TemporalMaxCount(TemporalCount):
788834
def __new__(
789835
cls,
790836
*args: Any,
791-
threshold: int | None,
792-
start_time: time | None,
793-
end_time: time | None,
794-
interval_type: TimeIntervalType | None,
795-
interval_criterion: BaseExpr | None,
837+
threshold: int,
838+
start_time: time | None = None,
839+
end_time: time | None = None,
840+
interval_type: TimeIntervalType | None = None,
841+
interval_criterion: BaseExpr | None = None,
796842
**kwargs: Any,
797843
) -> "TemporalMaxCount":
798844
"""
@@ -818,7 +864,7 @@ def dict(self, include_id: bool = False) -> dict:
818864
Get a dictionary representation of the object.
819865
"""
820866
data = super().dict(include_id=include_id)
821-
data.update({"threshold": self.count_max})
867+
data["data"].update({"threshold": self.count_max})
822868
return data
823869

824870

@@ -830,11 +876,11 @@ class TemporalExactCount(TemporalCount):
830876
def __new__(
831877
cls,
832878
*args: Any,
833-
threshold: int | None,
834-
start_time: time | None,
835-
end_time: time | None,
836-
interval_type: TimeIntervalType | None,
837-
interval_criterion: BaseExpr | None,
879+
threshold: int,
880+
start_time: time | None = None,
881+
end_time: time | None = None,
882+
interval_type: TimeIntervalType | None = None,
883+
interval_criterion: BaseExpr | None = None,
838884
**kwargs: Any,
839885
) -> "TemporalExactCount":
840886
"""
@@ -860,7 +906,7 @@ def dict(self, include_id: bool = False) -> dict:
860906
Get a dictionary representation of the object.
861907
"""
862908
data = super().dict(include_id=include_id)
863-
data.update({"threshold": self.count_min})
909+
data["data"].update({"threshold": self.count_min})
864910
return data
865911

866912

execution_engine/util/serializable.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,24 @@ def __new__(mcs, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> S
105105
return new_class
106106

107107

108+
def immutable_setattr(self: Self, key: str, value: Any) -> None:
109+
"""
110+
Prevent setting attributes on an immutable object.
111+
112+
This function is assigned to an instance's __setattr__ method in order to enforce
113+
immutability after object creation. Any attempt to set an attribute on the instance
114+
after initialization will raise an AttributeError.
115+
116+
:param self: The instance on which the attribute assignment was attempted.
117+
:param key: The name of the attribute being set.
118+
:param value: The value being assigned.
119+
:raises AttributeError: Always, to enforce immutability.
120+
"""
121+
raise AttributeError(
122+
f"Cannot set attribute {key} on immutable object {self.__class__.__name__}"
123+
)
124+
125+
108126
class Serializable(metaclass=RegisteredPostInitMeta):
109127
"""
110128
Base class for making objects serializable.
@@ -137,11 +155,6 @@ def __post_init__(self) -> None:
137155
"""
138156
self.rehash()
139157

140-
def immutable_setattr(self: Self, key: str, value: Any) -> None:
141-
raise AttributeError(
142-
f"Cannot set attribute {key} on immutable object {self.__class__.__name__}"
143-
)
144-
145158
self.__setattr__ = immutable_setattr # type: ignore[assignment]
146159

147160
def set_id(self, value: int, overwrite: bool = False) -> None:

0 commit comments

Comments
 (0)