Skip to content

Commit ebe6ed2

Browse files
committed
reformat to ruff
1 parent dab0154 commit ebe6ed2

File tree

4 files changed

+44
-12
lines changed

4 files changed

+44
-12
lines changed

python/interpret-core/interpret/glassbox/_ebm/_boost.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ def boost(
128128
state_idx = 0
129129

130130
nominals = native.extract_nominals(dataset)
131-
random_cyclic_ordering = np.arange(len(term_features), dtype=np.int64)
131+
random_cyclic_ordering = np.arange(
132+
len(term_features), dtype=np.int64
133+
)
132134

133135
while step_idx < max_steps:
134136
if state_idx >= 0:
@@ -142,7 +144,9 @@ def boost(
142144
and develop.get_option(
143145
"randomize_initial_feature_order"
144146
)
145-
or develop.get_option("randomize_greedy_feature_order")
147+
or develop.get_option(
148+
"randomize_greedy_feature_order"
149+
)
146150
and greedy_steps > 0
147151
or develop.get_option("randomize_feature_order")
148152
):
@@ -169,7 +173,10 @@ def boost(
169173
if contains_nominals:
170174
reg_lambda_local += develop.get_option("cat_l2")
171175

172-
if develop.get_option("min_samples_leaf_nominal") is not None:
176+
if (
177+
develop.get_option("min_samples_leaf_nominal")
178+
is not None
179+
):
173180
min_samples_leaf_local = develop.get_option(
174181
"min_samples_leaf_nominal"
175182
)
@@ -179,7 +186,9 @@ def boost(
179186
elif missing == "high":
180187
term_boost_flags_local |= Native.TermBoostFlags_MissingHigh
181188
elif missing == "separate":
182-
term_boost_flags_local |= Native.TermBoostFlags_MissingSeparate
189+
term_boost_flags_local |= (
190+
Native.TermBoostFlags_MissingSeparate
191+
)
183192
elif missing != "gain":
184193
msg = f"Unrecognized missing option {missing}."
185194
raise Exception(msg)
@@ -219,7 +228,9 @@ def boost(
219228
max_delta_step=max_delta_step,
220229
min_cat_samples=min_cat_samples,
221230
cat_smooth=cat_smooth,
222-
max_cat_threshold=develop.get_option("max_cat_threshold"),
231+
max_cat_threshold=develop.get_option(
232+
"max_cat_threshold"
233+
),
223234
cat_include=develop.get_option("cat_include"),
224235
max_leaves=max_leaves,
225236
monotone_constraints=term_monotone,
@@ -260,7 +271,9 @@ def boost(
260271
for f, s, noise in zip(
261272
splits_iter[:-1], splits_iter[1:], noises
262273
):
263-
noisy_update_tensor[f:s] = term_update_tensor[f:s] + noise
274+
noisy_update_tensor[f:s] = (
275+
term_update_tensor[f:s] + noise
276+
)
264277

265278
# Native code will be returning sums of residuals in slices, not averages.
266279
# Compute noisy average by dividing noisy sum by noisy bin weights
@@ -292,7 +305,9 @@ def boost(
292305
min(abs(min_metric), abs(min_prev_metric))
293306
* early_stopping_tolerance
294307
)
295-
if np.isnan(modified_tolerance) or np.isinf(modified_tolerance):
308+
if np.isnan(modified_tolerance) or np.isinf(
309+
modified_tolerance
310+
):
296311
modified_tolerance = 0.0
297312

298313
if cur_metric <= min_metric - min(0.0, modified_tolerance):
@@ -319,14 +334,19 @@ def boost(
319334
circular_idx = (circular_idx + 1) % len(circular)
320335
min_prev_metric = min(toss, min_prev_metric)
321336

322-
if min_prev_metric - modified_tolerance <= circular.min():
337+
if (
338+
min_prev_metric - modified_tolerance
339+
<= circular.min()
340+
):
323341
break
324342

325343
if stop_flag is not None and stop_flag[0]:
326344
break
327345

328346
if callback is not None:
329-
is_done = callback(bag_idx, step_idx, make_progress, cur_metric)
347+
is_done = callback(
348+
bag_idx, step_idx, make_progress, cur_metric
349+
)
330350
if is_done:
331351
if stop_flag is not None:
332352
stop_flag[0] = True

python/interpret-core/interpret/glassbox/_ebm/_ebm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1419,7 +1419,10 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
14191419
parallel_args = []
14201420
for idx in range(self.outer_bags):
14211421
early_stopping_rounds_local = early_stopping_rounds
1422-
if internal_bags[idx] is None or (internal_bags[idx] >= 0).all():
1422+
if (
1423+
internal_bags[idx] is None
1424+
or (internal_bags[idx] >= 0).all()
1425+
):
14231426
# if there are no validation samples, turn off early stopping
14241427
# because the validation metric cannot improve each round
14251428
early_stopping_rounds_local = 0

python/interpret-core/interpret/utils/_shared_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from contextlib import AbstractContextManager
55

6+
67
class SharedDataset(AbstractContextManager):
78
def __init__(self):
89
self.shared_memory = None

python/interpret-core/tests/utils/test_compressed_dataset.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from interpret.utils._preprocessor import construct_bins
1212
from interpret.utils._shared_dataset import SharedDataset
1313

14+
1415
@pytest.mark.skip(reason="skip this until we have support for missing values")
1516
def test_bin_native():
1617
X = np.array(
@@ -97,8 +98,15 @@ def test_bin_native():
9798

9899
with SharedDataset() as shared:
99100
bin_native_by_dimension(
100-
n_classes, 1, bins, X, y, sample_weight, feature_names_in, feature_types_in,
101-
shared,
101+
n_classes,
102+
1,
103+
bins,
104+
X,
105+
y,
106+
sample_weight,
107+
feature_names_in,
108+
feature_types_in,
109+
shared,
102110
)
103111
assert shared.shared_memory is not None
104112
assert shared.dataset is not None

0 commit comments

Comments
 (0)