Skip to content

Commit d2bcb67

Browse files
committed
improve memory estimates
1 parent b2906b2 commit d2bcb67

File tree

2 files changed

+89
-14
lines changed

2 files changed

+89
-14
lines changed

python/interpret-core/interpret/glassbox/_ebm/_ebm.py

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,7 +1695,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
16951695

16961696
return self
16971697

1698-
def estimate_mem(self, X):
1698+
def estimate_mem(self, X, y=None):
16991699
"""Estimate memory usage of the model.
17001700
Args:
17011701
X: dataset
@@ -1706,8 +1706,63 @@ def estimate_mem(self, X):
17061706
The estimate will be more accurate for larger datasets.
17071707
"""
17081708

1709-
# number of classes does not affect memory much, so choose a sensible default
1710-
n_classes = Native.Task_Regression
1709+
if y is not None:
1710+
n_classes = Native.Task_Unknown
1711+
y = clean_dimensions(y, "y")
1712+
if y.ndim != 1:
1713+
msg = "y must be 1 dimensional"
1714+
_log.error(msg)
1715+
raise ValueError(msg)
1716+
if len(y) == 0:
1717+
msg = "y cannot have 0 samples"
1718+
_log.error(msg)
1719+
raise ValueError(msg)
1720+
1721+
native = Native.get_native_singleton()
1722+
1723+
objective = self.objective
1724+
n_classes = Native.Task_Unknown
1725+
if objective is not None:
1726+
if len(objective.strip()) == 0:
1727+
objective = None
1728+
else:
1729+
n_classes = native.determine_task(objective)
1730+
1731+
if is_classifier(self):
1732+
if n_classes == Native.Task_Unknown:
1733+
n_classes = Native.Task_GeneralClassification
1734+
elif n_classes < Native.Task_GeneralClassification:
1735+
msg = f"classifier cannot have objective {self.objective}"
1736+
_log.error(msg)
1737+
raise ValueError(msg)
1738+
1739+
if is_regressor(self):
1740+
if n_classes == Native.Task_Unknown:
1741+
n_classes = Native.Task_Regression
1742+
elif n_classes != Native.Task_Regression:
1743+
msg = f"regressor cannot have objective {self.objective}"
1744+
_log.error(msg)
1745+
raise ValueError(msg)
1746+
1747+
if Native.Task_GeneralClassification <= n_classes:
1748+
y = typify_classification(y)
1749+
# use pure alphabetical ordering for the classes. It's tempting to sort by frequency first
1750+
# but that could lead to a lot of bugs if the # of categories is close and we flip the ordering
1751+
# in two separate runs, which would flip the ordering of the classes within our score tensors.
1752+
classes, y = np.unique(y, return_inverse=True)
1753+
n_classes = len(classes)
1754+
elif n_classes == Native.Task_Regression:
1755+
y = y.astype(np.float64, copy=False)
1756+
else:
1757+
msg = f"Unrecognized objective {self.objective}"
1758+
_log.error(msg)
1759+
raise ValueError(msg)
1760+
else:
1761+
n_classes = Native.Task_Regression
1762+
# create a dummy y array (simulate regression)
1763+
y = np.zeros(n_samples, dtype=np.float64)
1764+
1765+
n_scores = Native.get_count_scores_c(n_classes)
17111766

17121767
X, n_samples = preclean_X(X, self.feature_names, self.feature_types, None, None)
17131768

@@ -1725,9 +1780,6 @@ def estimate_mem(self, X):
17251780
feature_types_in = binning_result[1]
17261781
bins = binning_result[2]
17271782

1728-
# create a dummy y array (simulate regression)
1729-
y = np.zeros(n_samples, dtype=np.float64)
1730-
17311783
n_bytes_mains = bin_native_by_dimension(
17321784
n_classes,
17331785
1,
@@ -1740,12 +1792,22 @@ def estimate_mem(self, X):
17401792
None,
17411793
)
17421794

1795+
bin_lengths = [
1796+
len(x[0]) + 2 if isinstance(x[0], dict) else len(x[0]) + 3 for x in bins
1797+
]
1798+
n_tensor_bytes = sum(bin_lengths) * np.float64().nbytes * self.outer_bags * 2
1799+
17431800
# One shared memory copy of the data mapped into all processes, plus a copy of
17441801
# the test and train data for each outer bag. Assume all processes are started
17451802
# at some point and are eating up memory.
17461803
# When we cannot use shared memory the parent has a copy of the dataset and
17471804
# all the children share one copy.
1748-
max_bytes = n_bytes_mains + n_bytes_mains + n_bytes_mains * self.outer_bags
1805+
max_bytes = (
1806+
n_bytes_mains
1807+
+ n_bytes_mains
1808+
+ n_bytes_mains * self.outer_bags
1809+
+ n_tensor_bytes
1810+
)
17491811

17501812
n_features_in = len(bins)
17511813

@@ -1774,13 +1836,26 @@ def estimate_mem(self, X):
17741836

17751837
max_bytes = max(max_bytes, interaction_detection_bytes)
17761838

1777-
interaction_multiple = float(interactions) / float(n_features_in)
1839+
bin_lengths.sort()
1840+
n_bad_case_bins = bin_lengths[len(bin_lengths) // 4 * 3]
1841+
interaction_boosting_bytes = (
1842+
n_bad_case_bins
1843+
* n_bad_case_bins
1844+
* np.float64().nbytes
1845+
* self.outer_bags
1846+
* interactions
1847+
* 2
1848+
)
1849+
17781850
# We merge the interactions together to make a combined interaction
17791851
# dataset, so if feature1 takes 4 bits and feature2 takes 10 bits
17801852
# then the resulting data storage should take approx 14 bits in total,
1781-
# so as a loose approximation we can add the bits in a pair.
1782-
interaction_multiple *= 2.0
1783-
interaction_boosting_bytes = (
1853+
# so as a loose approximation we can add the bits in a pair, which means
1854+
# roughtly multiply by 2.0 for pairs. Multiply by another 2.0 just because
1855+
# we might get unlucky and the pairs used are biased towards the ones
1856+
# that have more bins.
1857+
interaction_multiple = 4.0 * float(interactions) / float(n_features_in)
1858+
interaction_boosting_bytes += (
17841859
n_bytes_pairs
17851860
+ n_bytes_pairs
17861861
+ int(n_bytes_pairs * interaction_multiple * self.outer_bags)

python/interpret-core/tests/glassbox/ebm/test_ebm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,12 +1386,12 @@ def __call__(self, bag_index, step_index, progress, metric):
13861386

13871387

13881388
def test_estimate_mem():
1389-
X, _, names, types = make_synthetic(seed=42, output_type="float", n_samples=10000)
1389+
X, y, names, types = make_synthetic(seed=42, output_type="float", n_samples=10000)
13901390

13911391
ebm = ExplainableBoostingClassifier(names, types, interactions=[])
1392-
n_bytes = ebm.estimate_mem(X)
1392+
n_bytes = ebm.estimate_mem(X, y)
13931393
# print(n_bytes)
13941394

13951395
ebm = ExplainableBoostingClassifier(names, types)
1396-
n_bytes = ebm.estimate_mem(X)
1396+
n_bytes = ebm.estimate_mem(X, y)
13971397
# print(n_bytes)

0 commit comments

Comments
 (0)