Skip to content

Commit 10c82c6

Browse files
committed
Fix argument ordering, and make alpha test parameterized
1 parent ec45af2 commit 10c82c6

File tree

2 files changed

+28
-19
lines changed

2 files changed

+28
-19
lines changed

monai/losses/dice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,9 +835,9 @@ def __init__(
835835
gamma: float = 2.0,
836836
focal_weight: Sequence[float] | float | int | torch.Tensor | None = None,
837837
weight: Sequence[float] | float | int | torch.Tensor | None = None,
838-
alpha: float | None = None,
839838
lambda_dice: float = 1.0,
840839
lambda_focal: float = 1.0,
840+
alpha: float | None = None,
841841
) -> None:
842842
"""
843843
Args:

tests/test_dice_focal_loss.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -91,28 +91,37 @@ def test_script(self):
9191
test_input = torch.ones(2, 1, 8, 8)
9292
test_script_save(loss, test_input, test_input)
9393

94-
def test_result_with_alpha(self):
94+
@parameterized.expand([
95+
("sum_None_0.5_0.25", "sum", None, 0.5, 0.25),
96+
("sum_weight_0.5_0.25", "sum", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
97+
("sum_weight_tuple_0.5_0.25", "sum", (3, 2.0, 1), 0.5, 0.25),
98+
("mean_None_0.5_0.25", "mean", None, 0.5, 0.25),
99+
("mean_weight_0.5_0.25", "mean", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
100+
("mean_weight_tuple_0.5_0.25", "mean", (3, 2.0, 1), 0.5, 0.25),
101+
("none_None_0.5_0.25", "none", None, 0.5, 0.25),
102+
("none_weight_0.5_0.25", "none", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
103+
("none_weight_tuple_0.5_0.25", "none", (3, 2.0, 1), 0.5, 0.25),
104+
])
105+
def test_with_alpha(self, name, reduction, weight, lambda_focal, alpha):
95106
size = [3, 3, 5, 5]
96107
label = torch.randint(low=0, high=2, size=size)
97108
pred = torch.randn(size)
98-
alpha_values = [0.25, 0.5, 0.75]
99-
for reduction in ["sum", "mean", "none"]:
100-
for weight in [None, torch.tensor([1.0, 1.0, 2.0]), (3, 2.0, 1)]:
101-
common_params = {
102-
"include_background": True,
103-
"to_onehot_y": False,
104-
"reduction": reduction,
105-
"weight": weight,
106-
}
107-
for lambda_focal in [0.5, 1.0, 1.5]:
108-
for alpha in alpha_values:
109-
dice_focal = DiceFocalLoss(gamma=1.0, lambda_focal=lambda_focal, alpha=alpha, **common_params)
110-
dice = DiceLoss(**common_params)
111-
focal = FocalLoss(gamma=1.0, alpha=alpha, **common_params)
112-
result = dice_focal(pred, label)
113-
expected_val = dice(pred, label) + lambda_focal * focal(pred, label)
114-
np.testing.assert_allclose(result, expected_val)
115109

110+
common_params = {
111+
"include_background": True,
112+
"to_onehot_y": False,
113+
"reduction": reduction,
114+
"weight": weight,
115+
}
116+
117+
dice_focal = DiceFocalLoss(gamma=1.0, lambda_focal=lambda_focal, alpha=alpha, **common_params)
118+
dice = DiceLoss(**common_params)
119+
focal = FocalLoss(gamma=1.0, alpha=alpha, **common_params)
120+
121+
result = dice_focal(pred, label)
122+
expected_val = dice(pred, label) + lambda_focal * focal(pred, label)
123+
124+
np.testing.assert_allclose(result, expected_val, err_msg=f"Failed on case: {name}")
116125

117126
if __name__ == "__main__":
118127
unittest.main()

0 commit comments

Comments
 (0)