Skip to content

Commit 25d8e17

Browse files
committed
add estimate_mem function
1 parent ebe6ed2 commit 25d8e17

File tree

3 files changed

+217
-112
lines changed

3 files changed

+217
-112
lines changed

python/interpret-core/interpret/glassbox/_ebm/_ebm.py

Lines changed: 146 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,60 @@ class EbmTags:
329329
input_tags: EbmInputTags = field(default_factory=EbmInputTags)
330330

331331

332+
def clean_interactions(interactions, n_features_in):
333+
if interactions is None:
334+
return 0
335+
336+
if isinstance(interactions, (float, int)):
337+
if interactions <= 0:
338+
if interactions == 0:
339+
return 0
340+
msg = "interactions cannot be negative"
341+
_log.error(msg)
342+
raise ValueError(msg)
343+
344+
if isinstance(interactions, float):
345+
if interactions < 1.0 or develop.get_option("allow_float_interactions"):
346+
interactions = int(ceil(n_features_in * interactions))
347+
else:
348+
msg = "interactions above 1 cannot be a float percentage and need to be an int instead"
349+
_log.error(msg)
350+
raise ValueError(msg)
351+
352+
# at this point interactions will be a positive, nonzero integer
353+
return interactions
354+
elif isinstance(interactions, str):
355+
interactions = interactions.strip()
356+
357+
if not interactions.lower().endswith("x"):
358+
raise ValueError(
359+
"If passing a string for interactions, it must end in an 'x' character."
360+
)
361+
362+
interactions = interactions[:-1]
363+
try:
364+
interactions = float(interactions)
365+
except ValueError:
366+
raise ValueError(f"'{interactions}' is not a valid floating-point number.")
367+
368+
if interactions <= 0.0:
369+
if interactions == 0.0:
370+
return 0
371+
msg = "interactions cannot be negative"
372+
_log.error(msg)
373+
raise ValueError(msg)
374+
375+
interactions = int(ceil(n_features_in * interactions))
376+
377+
# at this point interactions will be a positive, nonzero integer
378+
return interactions
379+
elif len(interactions) == 0:
380+
return 0
381+
else:
382+
# if it's a list then just return it
383+
return interactions
384+
385+
332386
class EBMModel(ExplainerMixin, BaseEstimator):
333387
"""Base class for all EBMs."""
334388

@@ -1183,66 +1237,14 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
11831237
if stop_flag is not None and stop_flag[0]:
11841238
break
11851239

1186-
if interactions is None:
1240+
interactions = clean_interactions(interactions, n_features_in)
1241+
if interactions == 0: # works if interactions is a list
11871242
break
11881243

1189-
if isinstance(interactions, (float, int)):
1190-
if interactions <= 0:
1191-
if interactions == 0:
1192-
break
1193-
msg = "interactions cannot be negative"
1194-
_log.error(msg)
1195-
raise ValueError(msg)
1196-
1197-
if isinstance(interactions, float):
1198-
if interactions < 1.0 or develop.get_option(
1199-
"allow_float_interactions"
1200-
):
1201-
interactions = int(ceil(n_features_in * interactions))
1202-
else:
1203-
msg = "interactions above 1 cannot be a float percentage and need to be an int instead"
1204-
_log.error(msg)
1205-
raise ValueError(msg)
1206-
1207-
if n_classes >= Native.Task_MulticlassPlus:
1208-
warn(
1209-
"For multiclass we cannot currently visualize pairs and they will be stripped from the global explanations. Set interactions=0 to generate a fully interpretable glassbox model."
1210-
)
1211-
1212-
# at this point interactions will be a positive, nonzero integer
1213-
elif isinstance(interactions, str):
1214-
interactions = interactions.strip()
1215-
1216-
if not interactions.lower().endswith("x"):
1217-
raise ValueError(
1218-
"If passing a string for interactions, it must end in an 'x' character."
1219-
)
1220-
1221-
interactions = interactions[:-1]
1222-
try:
1223-
interactions = float(interactions)
1224-
except ValueError:
1225-
raise ValueError(
1226-
f"'{interactions}' is not a valid floating-point number."
1227-
)
1228-
1229-
if interactions <= 0.0:
1230-
if interactions == 0.0:
1231-
break
1232-
msg = "interactions cannot be negative"
1233-
_log.error(msg)
1234-
raise ValueError(msg)
1235-
1236-
if n_classes >= Native.Task_MulticlassPlus:
1237-
warn(
1238-
"For multiclass we cannot currently visualize pairs and they will be stripped from the global explanations. Set interactions=0 to generate a fully interpretable glassbox model."
1239-
)
1240-
1241-
interactions = int(ceil(n_features_in * interactions))
1242-
1243-
# at this point interactions will be a positive, nonzero integer
1244-
elif len(interactions) == 0:
1245-
break
1244+
if n_classes >= Native.Task_MulticlassPlus:
1245+
warn(
1246+
"For multiclass we cannot currently visualize pairs and they will be stripped from the global explanations. Set interactions=0 to generate a fully interpretable glassbox model."
1247+
)
12461248

12471249
# at this point we know we will be making a new one, so delete it now
12481250
shared.reset()
@@ -1691,6 +1693,92 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
16911693

16921694
return self
16931695

1696+
def estimate_mem(self, X):
1697+
"""Estimate memory usage of the model.
1698+
Args:
1699+
X: dataset
1700+
Returns:
1701+
Estimated memory usage in bytes.
1702+
The estimate does not include the memory from the
1703+
caller's copy of X, nor the process's code or other data.
1704+
"""
1705+
1706+
# number of classes does not affect memory much, so choose a sensible default
1707+
n_classes = Native.Task_Regression
1708+
1709+
X, n_samples = preclean_X(X, self.feature_names, self.feature_types, None, None)
1710+
1711+
bin_levels = [self.max_bins, self.max_interaction_bins]
1712+
1713+
binning_result = construct_bins(
1714+
X=X,
1715+
y=None,
1716+
sample_weight=None,
1717+
feature_names_given=self.feature_names,
1718+
feature_types_given=self.feature_types,
1719+
max_bins_leveled=bin_levels,
1720+
)
1721+
feature_names_in = binning_result[0]
1722+
feature_types_in = binning_result[1]
1723+
bins = binning_result[2]
1724+
1725+
# create a dummy y array (simulate regression)
1726+
y = np.zeros(n_samples, dtype=np.float64)
1727+
1728+
n_bytes_mains = bin_native_by_dimension(
1729+
n_classes,
1730+
1,
1731+
bins,
1732+
X,
1733+
y,
1734+
None,
1735+
feature_names_in,
1736+
feature_types_in,
1737+
None,
1738+
)
1739+
1740+
# One shared memory copy of the data mapped into all processes, plus a copy of
1741+
# the test and train data for each outer bag. Assume all processes are started
1742+
# at some point and are eating up memory.
1743+
1744+
max_bytes = n_bytes_mains + n_bytes_mains * self.outer_bags
1745+
1746+
n_features_in = len(bins)
1747+
1748+
interactions = clean_interactions(self.interactions, n_features_in)
1749+
if not isinstance(interactions, int):
1750+
interactions = len(interactions)
1751+
1752+
if interactions != 0:
1753+
n_bytes_pairs = bin_native_by_dimension(
1754+
n_classes,
1755+
2,
1756+
bins,
1757+
X,
1758+
y,
1759+
None,
1760+
feature_names_in,
1761+
feature_types_in,
1762+
None,
1763+
)
1764+
1765+
# each outer bag makes a copy of the features. Only the training features
1766+
# are kept for interaction detection, but don't estimate that for now.
1767+
interaction_detection_bytes = (
1768+
n_bytes_pairs + n_bytes_pairs * self.outer_bags
1769+
)
1770+
1771+
max_bytes = max(max_bytes, interaction_detection_bytes)
1772+
1773+
interaction_multiple = float(interactions) / float(n_features_in)
1774+
interaction_boosting_bytes = n_bytes_pairs + int(
1775+
n_bytes_pairs * interaction_multiple * self.outer_bags
1776+
)
1777+
1778+
max_bytes = max(max_bytes, interaction_boosting_bytes)
1779+
1780+
return max_bytes
1781+
16941782
def to_jsonable(self, detail="all"):
16951783
"""Convert the model to a JSONable representation.
16961784

python/interpret-core/interpret/utils/_compressed_dataset.py

Lines changed: 59 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -101,64 +101,69 @@ def bin_native(
101101
_log.error(msg)
102102
raise ValueError(msg)
103103

104-
shared_mem = shared_memory.SharedMemory(create=True, size=n_bytes, name=None)
105-
shared.shared_memory = shared_mem
106-
shared.name = shared_mem.name
104+
if shared is not None:
105+
shared_mem = shared_memory.SharedMemory(create=True, size=n_bytes, name=None)
106+
shared.shared_memory = shared_mem
107+
shared.name = shared_mem.name
107108

108-
dataset = np.ndarray(n_bytes, dtype=np.ubyte, buffer=shared_mem.buf)
109-
shared.dataset = dataset
109+
dataset = np.ndarray(n_bytes, dtype=np.ubyte, buffer=shared_mem.buf)
110+
shared.dataset = dataset
110111

111-
native.fill_dataset_header(len(feature_idxs), n_weights, 1, dataset)
112+
native.fill_dataset_header(len(feature_idxs), n_weights, 1, dataset)
112113

113-
get_col = unify_columns(
114-
X, n_samples, feature_names_in, feature_types_in, None, False, False
115-
)
116-
for feature_idx, feature_bins in zip(feature_idxs, bins_iter):
117-
feature_type = feature_types_in[feature_idx]
118-
if feature_type == "ignore":
119-
# TODO: exclude ignored features from the compressed dataset
120-
raise Exception("ignored features not supported yet")
121-
122-
_, nonmissings, uniques, X_col, bad = get_col(feature_idx)
123-
124-
if isinstance(feature_bins, dict):
125-
# categorical feature
126-
127-
X_col = categorical_encode(uniques, X_col, nonmissings, feature_bins)
128-
bad = X_col == -1
129-
if not bad.any():
130-
bad = None
131-
132-
n_bins = 2 if len(feature_bins) == 0 else (max(feature_bins.values()) + 2)
133-
else:
134-
# continuous feature
135-
136-
X_col = native.discretize(X_col, feature_bins)
137-
n_bins = len(feature_bins) + 3
138-
139-
if bad is not None:
140-
X_col[bad] = n_bins - 1
141-
142-
native.fill_feature(
143-
n_bins,
144-
np.count_nonzero(X_col) != len(X_col),
145-
bad is not None,
146-
feature_type == "nominal",
147-
X_col,
148-
dataset,
114+
get_col = unify_columns(
115+
X, n_samples, feature_names_in, feature_types_in, None, False, False
149116
)
117+
for feature_idx, feature_bins in zip(feature_idxs, bins_iter):
118+
feature_type = feature_types_in[feature_idx]
119+
if feature_type == "ignore":
120+
# TODO: exclude ignored features from the compressed dataset
121+
raise Exception("ignored features not supported yet")
122+
123+
_, nonmissings, uniques, X_col, bad = get_col(feature_idx)
124+
125+
if isinstance(feature_bins, dict):
126+
# categorical feature
127+
128+
X_col = categorical_encode(uniques, X_col, nonmissings, feature_bins)
129+
bad = X_col == -1
130+
if not bad.any():
131+
bad = None
132+
133+
n_bins = (
134+
2 if len(feature_bins) == 0 else (max(feature_bins.values()) + 2)
135+
)
136+
else:
137+
# continuous feature
138+
139+
X_col = native.discretize(X_col, feature_bins)
140+
n_bins = len(feature_bins) + 3
141+
142+
if bad is not None:
143+
X_col[bad] = n_bins - 1
144+
145+
native.fill_feature(
146+
n_bins,
147+
np.count_nonzero(X_col) != len(X_col),
148+
bad is not None,
149+
feature_type == "nominal",
150+
X_col,
151+
dataset,
152+
)
153+
154+
if sample_weight is not None:
155+
native.fill_weight(sample_weight, dataset)
156+
157+
if y.dtype == np.float64:
158+
native.fill_regression_target(y, dataset)
159+
elif y.dtype == np.int64:
160+
native.fill_classification_target(n_classes, y, dataset)
161+
else:
162+
msg = "y must be either float64 or int64"
163+
_log.error(msg)
164+
raise ValueError(msg)
150165

151-
if sample_weight is not None:
152-
native.fill_weight(sample_weight, dataset)
153-
154-
if y.dtype == np.float64:
155-
native.fill_regression_target(y, dataset)
156-
elif y.dtype == np.int64:
157-
native.fill_classification_target(n_classes, y, dataset)
158-
else:
159-
msg = "y must be either float64 or int64"
160-
_log.error(msg)
161-
raise ValueError(msg)
166+
return n_bytes
162167

163168

164169
def bin_native_by_dimension(
@@ -181,7 +186,7 @@ def bin_native_by_dimension(
181186
feature_bins = bin_levels[min(len(bin_levels), n_dimensions) - 1]
182187
bins_iter.append(feature_bins)
183188

184-
bin_native(
189+
return bin_native(
185190
n_classes,
186191
feature_idxs,
187192
bins_iter,

python/interpret-core/tests/glassbox/ebm/test_ebm.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,3 +1360,15 @@ def __call__(self, bag_index, step_index, progress, metric):
13601360
# print(ebm.best_iteration_)
13611361

13621362
pred = ebm.predict_proba(X)
1363+
1364+
1365+
def test_estimate_mem():
1366+
X, _, names, types = make_synthetic(seed=42, output_type="float", n_samples=10000)
1367+
1368+
ebm = ExplainableBoostingClassifier(names, types, interactions=[])
1369+
n_bytes = ebm.estimate_mem(X)
1370+
# print(n_bytes)
1371+
1372+
ebm = ExplainableBoostingClassifier(names, types)
1373+
n_bytes = ebm.estimate_mem(X)
1374+
# print(n_bytes)

0 commit comments

Comments
 (0)