Skip to content

Commit 78750c1

Browse files
committed
improve memory estimation for interactions
1 parent ae7cf99 commit 78750c1

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

docs/benchmarks/ebm-benchmark.ipynb

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,12 @@
832832
" \n",
833833
" # Train\n",
834834
" print(f\"FIT: {global_counter}, {trial.task.origin}, {trial.task.name}, {trial.method}, {trial.meta}, classes:{trial.task.n_classes}, features:{fit_params['X'].shape[1]}, train_samples:{fit_params['X'].shape[0]}, orig_samples:{trial.task.n_samples}\")\n",
835+
"\n",
836+
" if isinstance(est, (ExplainableBoostingClassifier, ExplainableBoostingRegressor)):\n",
837+
" n_bytes = est.estimate_mem(fit_params[\"X\"])\n",
838+
" print(f\"EBM Memory Required: {n_bytes}\")\n",
839+
" trial.log(\"mem\", n_bytes)\n",
840+
" \n",
835841
" with warnings.catch_warnings():\n",
836842
" warnings.filterwarnings(\"ignore\")\n",
837843
" gc.collect() # clean out garbage to have as much memory available as possible\n",

python/interpret-core/interpret/glassbox/_ebm/_ebm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1701,6 +1701,7 @@ def estimate_mem(self, X):
17011701
Estimated memory usage in bytes.
17021702
The estimate does not include the memory from the
17031703
caller's copy of X, nor the process's code or other data.
1704+
The estimate will be more accurate for larger datasets.
17041705
"""
17051706

17061707
# number of classes does not affect memory much, so choose a sensible default
@@ -1771,6 +1772,11 @@ def estimate_mem(self, X):
17711772
max_bytes = max(max_bytes, interaction_detection_bytes)
17721773

17731774
interaction_multiple = float(interactions) / float(n_features_in)
1775+
# We merge the interactions together to make a combined interaction
1776+
# dataset, so if feature1 takes 4 bits and feature2 takes 10 bits
1777+
# then the resulting data storage should take approx 14 bits in total,
1778+
# so as a loose approximation we can add the bits in a pair.
1779+
interaction_multiple *= 2.0
17741780
interaction_boosting_bytes = n_bytes_pairs + int(
17751781
n_bytes_pairs * interaction_multiple * self.outer_bags
17761782
)

0 commit comments

Comments
 (0)