Skip to content

Commit 0e83d96

Browse files
committed
add option to include X, y, and fixed memory in the memory estimation
1 parent 7241c8f commit 0e83d96

File tree

6 files changed

+151
-27
lines changed

6 files changed

+151
-27
lines changed

python/interpret-core/interpret-core.pyproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
<Compile Include="interpret\utils\_compressed_dataset.py" />
8585
<Compile Include="interpret\utils\_histogram.py" />
8686
<Compile Include="interpret\utils\_link.py" />
87+
<Compile Include="interpret\utils\_measure_mem.py" />
8788
<Compile Include="interpret\utils\_misc.py" />
8889
<Compile Include="interpret\utils\_purify.py" />
8990
<Compile Include="interpret\utils\_rank_interactions.py" />

python/interpret-core/interpret/glassbox/_ebm/_ebm.py

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
remove_extra_bins,
7373
)
7474
from ...utils._shared_dataset import SharedDataset
75+
from ...utils._measure_mem import total_bytes
7576

7677
_log = logging.getLogger(__name__)
7778

@@ -496,8 +497,8 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
496497
"""Fit model to provided samples.
497498
498499
Args:
499-
X: NumPy array for training samples.
500-
y: NumPy array as training labels.
500+
X: {array-like, sparse matrix} of shape (n_samples, n_features). Training data.
501+
y: array-like of shape (n_samples,). Target values.
501502
sample_weight: Optional array of weights per sample. Should be same length as X and y.
502503
bags: Optional bag definitions. The first dimension should have length equal to the number of samples.
503504
The second dimension should have length equal to the number of outer_bags. The contents should be
@@ -1695,18 +1696,33 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
16951696

16961697
return self
16971698

1698-
def estimate_mem(self, X, y=None):
1699+
def estimate_mem(self, X, y=None, data_multiplier=0.0):
16991700
"""Estimate memory usage of the model.
17001701
Args:
1701-
X: dataset
1702+
X: {array-like, sparse matrix} of shape (n_samples, n_features). Training data.
1703+
y: array-like of shape (n_samples,). Target values.
1704+
data_multiplier: The data in X needs to be allocated by the caller.
1705+
If data_multiplier is set to 0.0 then this function only estimates the additional
1706+
memory consumed by the fit function. If data_multiplier is set to 1.0 then
1707+
it will include the memory allocated to X by the caller. Often the caller will make
1708+
copies of X before calling fit, and in that case the data_multiplier could be set to a
1709+
value above 1.0 if the caller would like this function to include that in the memory estimate.
1710+
17021711
Returns:
17031712
Estimated memory usage in bytes.
17041713
The estimate does not include the memory from the
17051714
caller's copy of X, nor the process's code or other data.
17061715
The estimate will be more accurate for larger datasets.
17071716
"""
17081717

1718+
n_bytes = total_bytes(X)
1719+
if y is not None:
1720+
n_bytes += total_bytes(y)
1721+
1722+
n_bytes = int(n_bytes * data_multiplier)
1723+
17091724
if y is not None:
1725+
y_id = id(y)
17101726
n_classes = Native.Task_Unknown
17111727
y = clean_dimensions(y, "y")
17121728
if y.ndim != 1:
@@ -1757,10 +1773,18 @@ def estimate_mem(self, X, y=None):
17571773
_log.error(msg)
17581774
raise ValueError(msg)
17591775

1776+
if y_id != id(y):
1777+
# in fit we'll also make a copy of y that cannot be deleted until the end
1778+
n_bytes += total_bytes(y)
1779+
17601780
n_samples = None if y is None else len(y)
1781+
X_id = id(X)
17611782
X, n_samples = preclean_X(
17621783
X, self.feature_names, self.feature_types, n_samples, "y"
17631784
)
1785+
if X_id != id(X):
1786+
# a copy was made, and we'll need to also do this on fit, so add the new memory too
1787+
n_bytes += total_bytes(X)
17641788

17651789
if y is None:
17661790
n_classes = Native.Task_Regression
@@ -1794,11 +1818,19 @@ def estimate_mem(self, X, y=None):
17941818
feature_types_in,
17951819
None,
17961820
)
1797-
1798-
bin_lengths = [
1799-
len(x[0]) + 2 if isinstance(x[0], dict) else len(x[0]) + 3 for x in bins
1800-
]
1801-
n_tensor_bytes = sum(bin_lengths) * np.float64().nbytes * self.outer_bags * 2
1821+
# first calculate the number of cells in the mains for all features
1822+
n_tensor_bytes = sum(
1823+
2
1824+
if len(x[0]) == 0
1825+
else max(x[0].values()) + 2
1826+
if isinstance(x[0], dict)
1827+
else len(x[0]) + 3
1828+
for x in bins
1829+
if len(x) != 0
1830+
)
1831+
# We have 2 copies of the upate tensors in C++ (current and best) and we extract
1832+
# one more in python for the update before tearning down the C++ data.
1833+
n_tensor_bytes = n_tensor_bytes * np.float64().nbytes * self.outer_bags * 3
18021834

18031835
# One shared memory copy of the data mapped into all processes, plus a copy of
18041836
# the test and train data for each outer bag. Assume all processes are started
@@ -1831,6 +1863,19 @@ def estimate_mem(self, X, y=None):
18311863
None,
18321864
)
18331865

1866+
bin_lengths = [x[0] if len(x) == 1 else x[1] for x in bins if len(x) != 0]
1867+
bin_lengths = [
1868+
2
1869+
if len(x) == 0
1870+
else max(x.values()) + 2
1871+
if isinstance(x, dict)
1872+
else len(x) + 3
1873+
for x in bin_lengths
1874+
]
1875+
bin_lengths.sort()
1876+
# we use the 75th percentile bin length to estimate the number of bins
1877+
n_bad_case_bins = bin_lengths[len(bin_lengths) // 4 * 3]
1878+
18341879
# each outer bag makes a copy of the features. Only the training features
18351880
# are kept for interaction detection, but don't estimate that for now.
18361881
interaction_detection_bytes = (
@@ -1839,15 +1884,15 @@ def estimate_mem(self, X, y=None):
18391884

18401885
max_bytes = max(max_bytes, interaction_detection_bytes)
18411886

1842-
bin_lengths.sort()
1843-
n_bad_case_bins = bin_lengths[len(bin_lengths) // 4 * 3]
1887+
# We have 2 copies of the upate tensors in C++ (current and best) and we extract
1888+
# one more in python for the update before tearning down the C++ data.
18441889
interaction_boosting_bytes = (
18451890
n_bad_case_bins
18461891
* n_bad_case_bins
18471892
* np.float64().nbytes
18481893
* self.outer_bags
18491894
* interactions
1850-
* 2
1895+
* 3
18511896
)
18521897

18531898
# We merge the interactions together to make a combined interaction
@@ -1866,7 +1911,7 @@ def estimate_mem(self, X, y=None):
18661911

18671912
max_bytes = max(max_bytes, interaction_boosting_bytes)
18681913

1869-
return max_bytes
1914+
return int(n_bytes + max_bytes)
18701915

18711916
def to_jsonable(self, detail="all"):
18721917
"""Convert the model to a JSONable representation.

python/interpret-core/interpret/utils/_clean_simple.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def clean_dimensions(data, param_name):
8686
data = np.array(data, np.object_)
8787
elif callable(getattr(data, "__array__", None)):
8888
data = data.__array__()
89-
elif isinstance(data, str):
89+
elif isinstance(data, (str, bytes)):
9090
# we have just 1 item, so re-pack it and return
9191
ret = np.empty(1, np.object_)
9292
ret[0] = data
@@ -131,7 +131,7 @@ def clean_dimensions(data, param_name):
131131
while idx < n:
132132
item = data[idx]
133133

134-
if isinstance(item, str):
134+
if isinstance(item, (str, bytes)):
135135
if n_second_dim is not None and n_second_dim != 1:
136136
msg = (
137137
f"{param_name} is not consistent in length for the second dimension"
@@ -180,7 +180,7 @@ def clean_dimensions(data, param_name):
180180
while sub_idx < n_items:
181181
subitem = item[sub_idx]
182182

183-
if isinstance(subitem, str):
183+
if isinstance(subitem, (str, bytes)):
184184
sub_idx = sub_idx + 1
185185
continue
186186

python/interpret-core/interpret/utils/_clean_x.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ def _encode_pandas_categorical_initial(X_col, pd_categories, is_ordered, process
650650
_log.error(msg)
651651
raise ValueError(msg)
652652
else:
653-
if isinstance(processing, str):
653+
if isinstance(processing, (str, bytes)):
654654
# isinstance(, str) also works for np.str_
655655

656656
# don't allow strings to get to the for loop below
@@ -1133,7 +1133,7 @@ def _process_dict_column(X_col, is_initial, feature_type, min_unique_continuous)
11331133
raise ValueError(msg)
11341134
elif isinstance(X_col, _list_tuple_types):
11351135
X_col = np.array(X_col, np.object_)
1136-
elif isinstance(X_col, str):
1136+
elif isinstance(X_col, (str, bytes)):
11371137
# isinstance(, str) also works for np.str_
11381138

11391139
# don't allow strings to get to the np.array conversion below
@@ -1814,7 +1814,7 @@ def preclean_X(X, feature_names, feature_types, n_samples=None, sample_source="y
18141814
msg = "X cannot be None"
18151815
_log.error(msg)
18161816
raise TypeError(msg)
1817-
elif isinstance(X, str):
1817+
elif isinstance(X, (str, bytes)):
18181818
# str objects are iterable, so don't allow them to get to the list() conversion below
18191819
# isinstance(, str) also works for np.str_
18201820
msg = "X cannot be a str type"
@@ -1900,7 +1900,7 @@ def preclean_X(X, feature_names, feature_types, n_samples=None, sample_source="y
19001900
is_copied = True
19011901
X = list(X)
19021902
X[idx] = _reshape_1D_if_possible(sample)
1903-
elif isinstance(sample, str):
1903+
elif isinstance(sample, (str, bytes)):
19041904
# isinstance(, str) also works for np.str_
19051905
break # this only legal if we have one sample
19061906
else:
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) 2023 The InterpretML Contributors
2+
# Distributed under the MIT software license
3+
4+
from collections.abc import Iterable
5+
import sys
6+
from ._misc import safe_isinstance
7+
import numpy as np
8+
from typing import Any
9+
10+
11+
def total_bytes(obj: Any) -> int:
12+
n_bytes = 0
13+
items = [obj]
14+
15+
seen_ids = set()
16+
while items:
17+
item = items.pop()
18+
19+
obj_id = id(item)
20+
if obj_id in seen_ids:
21+
continue
22+
seen_ids.add(obj_id)
23+
24+
if safe_isinstance(item, "pandas.DataFrame"):
25+
n_bytes += item.memory_usage().sum()
26+
# pandas only includes the pointer to the object but not the object
27+
for col in item.select_dtypes(include=["object"]):
28+
for val in item[col]:
29+
try:
30+
n_bytes += sys.getsizeof(val)
31+
except Exception:
32+
pass
33+
if isinstance(val, Iterable) and not isinstance(val, (str, bytes)):
34+
try:
35+
items.extend(val)
36+
except Exception:
37+
pass
38+
elif safe_isinstance(item, "pandas.Series"):
39+
n_bytes += item.memory_usage()
40+
if item.dtype == "O":
41+
for val in item:
42+
try:
43+
n_bytes += sys.getsizeof(val)
44+
except Exception:
45+
pass
46+
if isinstance(val, Iterable) and not isinstance(val, (str, bytes)):
47+
try:
48+
items.extend(val)
49+
except Exception:
50+
pass
51+
elif isinstance(item, np.ndarray):
52+
n_bytes += item.nbytes
53+
if item.dtype == "O":
54+
items.extend(item.flat)
55+
elif safe_isinstance(item, "scipy.sparse.spmatrix") or safe_isinstance(
56+
item, "scipy.sparse.sparray"
57+
):
58+
n_bytes += item.data.nbytes + item.indptr.nbytes + item.indices.nbytes
59+
elif isinstance(item, dict):
60+
try:
61+
n_bytes += sys.getsizeof(item)
62+
except Exception:
63+
pass
64+
items.extend(item.values())
65+
items.extend(item.keys())
66+
else:
67+
try:
68+
n_bytes += sys.getsizeof(item)
69+
except Exception:
70+
pass
71+
if isinstance(item, Iterable) and not isinstance(item, (str, bytes)):
72+
try:
73+
items.extend(item)
74+
except Exception:
75+
pass
76+
77+
return int(n_bytes)

python/interpret-core/tests/glassbox/ebm/test_ebm.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,12 +1386,13 @@ def __call__(self, bag_index, step_index, progress, metric):
13861386

13871387

13881388
def test_estimate_mem():
1389-
X, y, names, types = make_synthetic(seed=42, output_type="float", n_samples=10000)
1389+
for t in ["object", "pandas", "str", "float", "csc_matrix", "csc_array"]:
1390+
X, y, names, types = make_synthetic(seed=42, output_type=t, n_samples=10000)
13901391

1391-
ebm = ExplainableBoostingClassifier(names, types, interactions=[])
1392-
n_bytes = ebm.estimate_mem(X, y)
1393-
# print(n_bytes)
1392+
ebm = ExplainableBoostingClassifier(names, types, interactions=[])
1393+
n_bytes_classifier = ebm.estimate_mem(X, y, 1.0)
13941394

1395-
ebm = ExplainableBoostingClassifier(names, types)
1396-
n_bytes = ebm.estimate_mem(X, y)
1397-
# print(n_bytes)
1395+
ebm = ExplainableBoostingClassifier(names, types)
1396+
n_bytes_regressor = ebm.estimate_mem(X, y, 1.0)
1397+
1398+
# print(f"datatype={t}, bytes_classifier[mains]={n_bytes_classifier}, bytes_regressor[+interactions]={n_bytes_regressor}")

0 commit comments

Comments
 (0)