Skip to content

Commit 76f2f73

Browse files
committed
remove the outer_bag count hack for release compatibility
1 parent 921e331 commit 76f2f73

File tree

1 file changed

+5
-11
lines changed
  • python/interpret-core/interpret/glassbox/_ebm

1 file changed

+5
-11
lines changed

python/interpret-core/interpret/glassbox/_ebm/_ebm.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -532,11 +532,10 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
532532
_log.error(msg)
533533
raise ValueError(msg)
534534

535-
# TODO: restore
536-
# if (bags is not None) and len(bags) != self.outer_bags:
537-
# msg = f"bags has {len(bags)} bags and self.outer_bags is {self.outer_bags} bags"
538-
# _log.error(msg)
539-
# raise ValueError(msg)
535+
if (bags is not None) and bags.shape[1] != self.outer_bags:
536+
msg = f"bags has {bags.shape[1]} bags and self.outer_bags is {self.outer_bags} bags"
537+
_log.error(msg)
538+
raise ValueError(msg)
540539

541540
if not isinstance(self.validation_size, int) and not isinstance(
542541
self.validation_size, float
@@ -915,12 +914,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
915914
and not is_differential_privacy,
916915
)
917916
else:
918-
if len(bags) == self.outer_bags:
919-
# TODO: hack to avoid breaking callers on the shape of the bags param
920-
warn("The bags param shape has been changed to (n_samples, n_bag).")
921-
bag = bags[idx]
922-
else:
923-
bag = bags[:, idx]
917+
bag = bags[:, idx]
924918
if not isinstance(bag, np.ndarray):
925919
bag = np.array(bag)
926920
if bag.ndim != 1:

0 commit comments

Comments
 (0)