Skip to content

Commit ad9978f

Browse files
Balandatfacebook-github-bot
authored andcommitted
Robustify prune_inferior_points tests against sorting order (#2548)
Summary: Our nightly CI started failing, likely due to a sorting order change introduced in pytorch/pytorch#127936 This change robustifies the tests against the point order (and also fixes a torch deprecation warning). NOTE: Even though pytorch/pytorch#127936 was unlanded, getting these changes is will help robustify the tests going forward. NOTE: As this makes `torch.sort` use `stable=True`, this will come at a slight performance hit. However, the tensor sizes typically involved in `prune_inferior_points` are quite small (order of a few hundred items maybe), so this should be negligible. Pull Request resolved: #2548 Test Plan: unit tests Reviewed By: sdaulton, saitcakmak Differential Revision: D63260870 Pulled By: Balandat fbshipit-source-id: c5a2676c96581fe74c01208db4aa3ba3fd9ff4be
1 parent 9fc39fa commit ad9978f

File tree

4 files changed

+20
-29
lines changed

4 files changed

+20
-29
lines changed

botorch/acquisition/multi_objective/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def prune_inferior_points_multi_objective(
154154
probs = pareto_mask.to(dtype=X.dtype).mean(dim=0)
155155
idcs = probs.nonzero().view(-1)
156156
if idcs.shape[0] > max_points:
157-
counts, order_idcs = torch.sort(probs, descending=True)
157+
counts, order_idcs = torch.sort(probs, stable=True, descending=True)
158158
idcs = order_idcs[:max_points]
159159
effective_n_w = obj_vals.shape[-2] // X.shape[-2]
160160
idcs = (idcs / effective_n_w).long().unique()

botorch/acquisition/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,15 +335,16 @@ def prune_inferior_points(
335335
marginalize_dim=marginalize_dim,
336336
)
337337
if infeas.any():
338-
# set infeasible points to worse than worst objective
339-
# across all samples
338+
# set infeasible points to worse than worst objective across all samples
339+
# Use clone() here to avoid deprecated `index_put_` on an expanded tensor
340+
obj_vals = obj_vals.clone()
340341
obj_vals[infeas] = obj_vals.min() - 1
341342

342343
is_best = torch.argmax(obj_vals, dim=-1)
343344
idcs, counts = torch.unique(is_best, return_counts=True)
344345

345346
if len(idcs) > max_points:
346-
counts, order_idcs = torch.sort(counts, descending=True)
347+
counts, order_idcs = torch.sort(counts, stable=True, descending=True)
347348
idcs = order_idcs[:max_points]
348349

349350
return X[idcs]

test/acquisition/multi_objective/test_utils.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_get_default_partitioning_alpha(self):
4646

4747

4848
class DummyMCMultiOutputObjective(MCMultiOutputObjective):
49-
def forward(self, samples: Tensor) -> Tensor:
49+
def forward(self, samples: Tensor, X: Tensor | None) -> Tensor:
5050
return samples
5151

5252

@@ -130,13 +130,12 @@ def test_prune_inferior_points_multi_objective(self):
130130
X_pruned = prune_inferior_points_multi_objective(
131131
model=mm, X=X, ref_point=ref_point, max_frac=2 / 3
132132
)
133-
if self.device.type == "cuda":
134-
# sorting has different order on cuda
135-
self.assertTrue(
136-
torch.equal(X_pruned, X[[2, 1]]) or torch.equal(X_pruned, X[[1, 2]])
133+
self.assertTrue(
134+
torch.equal(
135+
torch.sort(X_pruned, stable=True).values,
136+
torch.sort(X[:2], stable=True).values,
137137
)
138-
else:
139-
self.assertTrue(torch.equal(X_pruned, X[:2]))
138+
)
140139
# test that zero-probability is in fact pruned
141140
samples[2, 0, 0] = 10
142141
with mock.patch.object(MockPosterior, "rsample", return_value=samples):
@@ -276,10 +275,7 @@ def test_random_search_optimizer(self):
276275
input_dim = 3
277276
num_initial = 5
278277
tkwargs = {"device": self.device}
279-
optimizer_kwargs = {
280-
"pop_size": 1000,
281-
"max_tries": 5,
282-
}
278+
optimizer_kwargs = {"pop_size": 1000, "max_tries": 5}
283279

284280
for (
285281
dtype,
@@ -350,10 +346,7 @@ def test_sample_optimal_points(self):
350346
input_dim = 3
351347
num_initial = 5
352348
tkwargs = {"device": self.device}
353-
optimizer_kwargs = {
354-
"pop_size": 100,
355-
"max_tries": 1,
356-
}
349+
optimizer_kwargs = {"pop_size": 100, "max_tries": 1}
357350
num_samples = 2
358351
num_points = 1
359352

test/acquisition/test_utils.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -270,11 +270,12 @@ def test_prune_inferior_points(self):
270270
with mock.patch.object(MockPosterior, "rsample", return_value=samples):
271271
mm = MockModel(MockPosterior(samples=samples))
272272
X_pruned = prune_inferior_points(model=mm, X=X, max_frac=2 / 3)
273-
if self.device.type == "cuda":
274-
# sorting has different order on cuda
275-
self.assertTrue(torch.equal(X_pruned, torch.stack([X[2], X[1]], dim=0)))
276-
else:
277-
self.assertTrue(torch.equal(X_pruned, X[:2]))
273+
self.assertTrue(
274+
torch.equal(
275+
torch.sort(X_pruned, stable=True).values,
276+
torch.sort(X[:2], stable=True).values,
277+
)
278+
)
278279
# test that zero-probability is in fact pruned
279280
samples[2, 0, 0] = 10
280281
with mock.patch.object(MockPosterior, "rsample", return_value=samples):
@@ -289,11 +290,7 @@ def test_prune_inferior_points(self):
289290
device=self.device,
290291
dtype=dtype,
291292
)
292-
mm = MockModel(
293-
MockPosterior(
294-
samples=samples,
295-
)
296-
)
293+
mm = MockModel(MockPosterior(samples=samples))
297294
X_pruned = prune_inferior_points(
298295
model=mm,
299296
X=X,

0 commit comments

Comments
 (0)