Skip to content

Commit dab0154

Browse files
committed
pass datasets as shared memory instead of via default joblib method
1 parent 8ecad3c commit dab0154

File tree

8 files changed

+916
-856
lines changed

8 files changed

+916
-856
lines changed

python/interpret-core/interpret-core.pyproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
<Compile Include="interpret\utils\_preprocessor.py" />
9595
<Compile Include="interpret\utils\_privacy.py" />
9696
<Compile Include="interpret\utils\_seed.py" />
97+
<Compile Include="interpret\utils\_shared_dataset.py" />
9798
<Compile Include="interpret\utils\_synthetic.py" />
9899
<Compile Include="interpret\utils\_unify_predict.py" />
99100
<Compile Include="interpret\utils\_unify_data.py" />

python/interpret-core/interpret/glassbox/_ebm/_boost.py

Lines changed: 272 additions & 266 deletions
Large diffs are not rendered by default.

python/interpret-core/interpret/glassbox/_ebm/_ebm.py

Lines changed: 475 additions & 472 deletions
Large diffs are not rendered by default.

python/interpret-core/interpret/utils/_compressed_dataset.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from ._clean_x import categorical_encode, unify_columns
99
from ._native import Native
10+
from multiprocessing import shared_memory
1011

1112
_log = logging.getLogger(__name__)
1213

@@ -20,6 +21,7 @@ def bin_native(
2021
sample_weight,
2122
feature_names_in,
2223
feature_types_in,
24+
shared,
2325
):
2426
# called under: fit
2527

@@ -99,7 +101,12 @@ def bin_native(
99101
_log.error(msg)
100102
raise ValueError(msg)
101103

102-
dataset = np.empty(n_bytes, np.ubyte) # joblib loky doesn't support RawArray
104+
shared_mem = shared_memory.SharedMemory(create=True, size=n_bytes, name=None)
105+
shared.shared_memory = shared_mem
106+
shared.name = shared_mem.name
107+
108+
dataset = np.ndarray(n_bytes, dtype=np.ubyte, buffer=shared_mem.buf)
109+
shared.dataset = dataset
103110

104111
native.fill_dataset_header(len(feature_idxs), n_weights, 1, dataset)
105112

@@ -153,8 +160,6 @@ def bin_native(
153160
_log.error(msg)
154161
raise ValueError(msg)
155162

156-
return dataset
157-
158163

159164
def bin_native_by_dimension(
160165
n_classes,
@@ -165,6 +170,7 @@ def bin_native_by_dimension(
165170
sample_weight,
166171
feature_names_in,
167172
feature_types_in,
173+
shared,
168174
):
169175
# called under: fit
170176

@@ -175,7 +181,7 @@ def bin_native_by_dimension(
175181
feature_bins = bin_levels[min(len(bin_levels), n_dimensions) - 1]
176182
bins_iter.append(feature_bins)
177183

178-
return bin_native(
184+
bin_native(
179185
n_classes,
180186
feature_idxs,
181187
bins_iter,
@@ -184,4 +190,5 @@ def bin_native_by_dimension(
184190
sample_weight,
185191
feature_names_in,
186192
feature_types_in,
193+
shared,
187194
)

python/interpret-core/interpret/utils/_measure_interactions.py

Lines changed: 56 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ._native import Native
2727
from ._preprocessor import construct_bins
2828
from ._rank_interactions import rank_interactions
29+
from ._shared_dataset import SharedDataset
2930

3031
_log = logging.getLogger(__name__)
3132

@@ -239,59 +240,61 @@ def measure_interactions(
239240
bins = binning_result[2]
240241
n_features_in = len(bins)
241242

242-
dataset = bin_native_by_dimension(
243-
n_classes=n_classes,
244-
n_dimensions=2,
245-
bins=bins,
246-
X=X,
247-
y=y,
248-
sample_weight=sample_weight,
249-
feature_names_in=feature_names_in,
250-
feature_types_in=feature_types_in,
251-
)
252-
253-
interaction_flags = Native.CalcInteractionFlags_Default
254-
if develop.get_option("full_interaction"):
255-
interaction_flags |= Native.CalcInteractionFlags_Full
256-
257-
if isinstance(interactions, int):
258-
n_output_interactions = interactions
259-
iter_term_features = combinations(range(n_features_in), 2)
260-
elif interactions is None:
261-
n_output_interactions = 0
262-
iter_term_features = combinations(range(n_features_in), 2)
263-
else:
264-
n_output_interactions = 0
265-
iter_term_features = interactions
266-
267-
ranked_interactions = rank_interactions(
268-
None,
269-
0,
270-
dataset=dataset,
271-
intercept=None,
272-
bag=None,
273-
init_scores=init_score,
274-
iter_term_features=iter_term_features,
275-
exclude=set(),
276-
exclude_features=set(),
277-
calc_interaction_flags=interaction_flags,
278-
max_cardinality=max_cardinality,
279-
min_samples_leaf=min_samples_leaf,
280-
min_hessian=min_hessian,
281-
reg_alpha=reg_alpha,
282-
reg_lambda=reg_lambda,
283-
max_delta_step=max_delta_step,
284-
create_interaction_flags=(
285-
Native.CreateInteractionFlags_DifferentialPrivacy
286-
if is_differential_privacy
287-
else Native.CreateInteractionFlags_Default
288-
),
289-
objective=objective,
290-
acceleration=develop.get_option("acceleration"),
291-
experimental_params=None,
292-
n_output_interactions=n_output_interactions,
293-
develop_options=develop._develop_options,
294-
)
243+
with SharedDataset() as shared:
244+
bin_native_by_dimension(
245+
n_classes=n_classes,
246+
n_dimensions=2,
247+
bins=bins,
248+
X=X,
249+
y=y,
250+
sample_weight=sample_weight,
251+
feature_names_in=feature_names_in,
252+
feature_types_in=feature_types_in,
253+
shared=shared,
254+
)
255+
256+
interaction_flags = Native.CalcInteractionFlags_Default
257+
if develop.get_option("full_interaction"):
258+
interaction_flags |= Native.CalcInteractionFlags_Full
259+
260+
if isinstance(interactions, int):
261+
n_output_interactions = interactions
262+
iter_term_features = combinations(range(n_features_in), 2)
263+
elif interactions is None:
264+
n_output_interactions = 0
265+
iter_term_features = combinations(range(n_features_in), 2)
266+
else:
267+
n_output_interactions = 0
268+
iter_term_features = interactions
269+
270+
ranked_interactions = rank_interactions(
271+
None,
272+
0,
273+
dataset_name=shared.name,
274+
intercept=None,
275+
bag=None,
276+
init_scores=init_score,
277+
iter_term_features=iter_term_features,
278+
exclude=set(),
279+
exclude_features=set(),
280+
calc_interaction_flags=interaction_flags,
281+
max_cardinality=max_cardinality,
282+
min_samples_leaf=min_samples_leaf,
283+
min_hessian=min_hessian,
284+
reg_alpha=reg_alpha,
285+
reg_lambda=reg_lambda,
286+
max_delta_step=max_delta_step,
287+
create_interaction_flags=(
288+
Native.CreateInteractionFlags_DifferentialPrivacy
289+
if is_differential_privacy
290+
else Native.CreateInteractionFlags_Default
291+
),
292+
objective=objective,
293+
acceleration=develop.get_option("acceleration"),
294+
experimental_params=None,
295+
n_output_interactions=n_output_interactions,
296+
develop_options=develop._develop_options,
297+
)
295298

296299
if isinstance(ranked_interactions, Exception):
297300
raise ranked_interactions

python/interpret-core/interpret/utils/_rank_interactions.py

Lines changed: 52 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
def rank_interactions(
2121
shm_name,
2222
bag_idx,
23-
dataset,
23+
dataset_name,
2424
intercept,
2525
bag,
2626
init_scores,
@@ -44,55 +44,62 @@ def rank_interactions(
4444
try:
4545
develop._develop_options = develop_options # restore these in this process
4646

47-
shm = None
4847
try:
49-
stop_flag = None
50-
if shm_name is not None:
51-
shm = shared_memory.SharedMemory(name=shm_name)
52-
stop_flag = np.ndarray((1,), dtype=np.bool_, buffer=shm.buf)
48+
shared_dataset = shared_memory.SharedMemory(name=dataset_name)
49+
# we do not know the length of the dataset, so we create a 1-element array
50+
dataset = np.ndarray(1, dtype=np.ubyte, buffer=shared_dataset.buf)
5351

54-
interaction_strengths = []
55-
with InteractionDetector(
56-
dataset,
57-
intercept,
58-
bag,
59-
init_scores,
60-
create_interaction_flags,
61-
objective,
62-
acceleration,
63-
experimental_params,
64-
) as interaction_detector:
65-
for feature_idxs in iter_term_features:
66-
if tuple(sorted(feature_idxs)) in exclude:
67-
continue
68-
if any(i in exclude_features for i in feature_idxs):
69-
continue
52+
shm = None
53+
try:
54+
stop_flag = None
55+
if shm_name is not None:
56+
shm = shared_memory.SharedMemory(name=shm_name)
57+
stop_flag = np.ndarray(1, dtype=np.bool_, buffer=shm.buf)
7058

71-
strength = interaction_detector.calc_interaction_strength(
72-
feature_idxs,
73-
calc_interaction_flags,
74-
max_cardinality,
75-
min_samples_leaf,
76-
min_hessian,
77-
reg_alpha,
78-
reg_lambda,
79-
max_delta_step,
80-
)
81-
item = (strength, feature_idxs)
82-
if n_output_interactions <= 0:
83-
interaction_strengths.append(item)
84-
elif len(interaction_strengths) == n_output_interactions:
85-
heapq.heappushpop(interaction_strengths, item)
86-
else:
87-
heapq.heappush(interaction_strengths, item)
59+
interaction_strengths = []
60+
with InteractionDetector(
61+
dataset,
62+
intercept,
63+
bag,
64+
init_scores,
65+
create_interaction_flags,
66+
objective,
67+
acceleration,
68+
experimental_params,
69+
) as interaction_detector:
70+
for feature_idxs in iter_term_features:
71+
if tuple(sorted(feature_idxs)) in exclude:
72+
continue
73+
if any(i in exclude_features for i in feature_idxs):
74+
continue
8875

89-
if stop_flag is not None and stop_flag[0]:
90-
break
76+
strength = interaction_detector.calc_interaction_strength(
77+
feature_idxs,
78+
calc_interaction_flags,
79+
max_cardinality,
80+
min_samples_leaf,
81+
min_hessian,
82+
reg_alpha,
83+
reg_lambda,
84+
max_delta_step,
85+
)
86+
item = (strength, feature_idxs)
87+
if n_output_interactions <= 0:
88+
interaction_strengths.append(item)
89+
elif len(interaction_strengths) == n_output_interactions:
90+
heapq.heappushpop(interaction_strengths, item)
91+
else:
92+
heapq.heappush(interaction_strengths, item)
9193

92-
interaction_strengths.sort(reverse=True)
93-
return interaction_strengths
94+
if stop_flag is not None and stop_flag[0]:
95+
break
96+
97+
interaction_strengths.sort(reverse=True)
98+
return interaction_strengths
99+
finally:
100+
if shm is not None:
101+
shm.close()
94102
finally:
95-
if shm is not None:
96-
shm.close()
103+
shared_dataset.close()
97104
except Exception as e:
98105
return e
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (c) 2023 The InterpretML Contributors
2+
# Distributed under the MIT software license
3+
4+
from contextlib import AbstractContextManager
5+
6+
class SharedDataset(AbstractContextManager):
7+
def __init__(self):
8+
self.shared_memory = None
9+
self.dataset = None
10+
self.name = None
11+
12+
def __enter__(self):
13+
return self
14+
15+
def __exit__(self, exc_type, exc_value, traceback):
16+
self.reset()
17+
18+
def reset(self):
19+
shared_memory = self.shared_memory
20+
self.name = None
21+
self.dataset = None
22+
self.shared_memory = None
23+
if shared_memory is not None:
24+
shared_memory.close()
25+
shared_memory.unlink()

python/interpret-core/tests/utils/test_compressed_dataset.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from interpret.utils._clean_x import preclean_X
1010
from interpret.utils._compressed_dataset import bin_native, bin_native_by_dimension
1111
from interpret.utils._preprocessor import construct_bins
12-
12+
from interpret.utils._shared_dataset import SharedDataset
1313

1414
@pytest.mark.skip(reason="skip this until we have support for missing values")
1515
def test_bin_native():
@@ -79,19 +79,27 @@ def test_bin_native():
7979
feature_idxs.append(feature_idx)
8080
bins_iter.append(feature_bins)
8181

82-
shared_dataset = bin_native(
83-
n_classes,
84-
feature_idxs,
85-
bins_iter,
86-
X,
87-
y,
88-
sample_weight,
89-
feature_names_in,
90-
feature_types_in,
91-
)
92-
assert shared_dataset is not None
82+
with SharedDataset() as shared:
83+
bin_native(
84+
n_classes,
85+
feature_idxs,
86+
bins_iter,
87+
X,
88+
y,
89+
sample_weight,
90+
feature_names_in,
91+
feature_types_in,
92+
shared,
93+
)
94+
assert shared.shared_memory is not None
95+
assert shared.dataset is not None
96+
assert shared.name is not None
9397

94-
shared_dataset = bin_native_by_dimension(
95-
n_classes, 1, bins, X, y, sample_weight, feature_names_in, feature_types_in
96-
)
97-
assert shared_dataset is not None
98+
with SharedDataset() as shared:
99+
bin_native_by_dimension(
100+
n_classes, 1, bins, X, y, sample_weight, feature_names_in, feature_types_in,
101+
shared,
102+
)
103+
assert shared.shared_memory is not None
104+
assert shared.dataset is not None
105+
assert shared.name is not None

0 commit comments

Comments
 (0)