@@ -91,28 +91,37 @@ def test_script(self):
91
91
test_input = torch .ones (2 , 1 , 8 , 8 )
92
92
test_script_save (loss , test_input , test_input )
93
93
94
- def test_result_with_alpha (self ):
94
+ @parameterized .expand ([
95
+ ("sum_None_0.5_0.25" , "sum" , None , 0.5 , 0.25 ),
96
+ ("sum_weight_0.5_0.25" , "sum" , torch .tensor ([1.0 , 1.0 , 2.0 ]), 0.5 , 0.25 ),
97
+ ("sum_weight_tuple_0.5_0.25" , "sum" , (3 , 2.0 , 1 ), 0.5 , 0.25 ),
98
+ ("mean_None_0.5_0.25" , "mean" , None , 0.5 , 0.25 ),
99
+ ("mean_weight_0.5_0.25" , "mean" , torch .tensor ([1.0 , 1.0 , 2.0 ]), 0.5 , 0.25 ),
100
+ ("mean_weight_tuple_0.5_0.25" , "mean" , (3 , 2.0 , 1 ), 0.5 , 0.25 ),
101
+ ("none_None_0.5_0.25" , "none" , None , 0.5 , 0.25 ),
102
+ ("none_weight_0.5_0.25" , "none" , torch .tensor ([1.0 , 1.0 , 2.0 ]), 0.5 , 0.25 ),
103
+ ("none_weight_tuple_0.5_0.25" , "none" , (3 , 2.0 , 1 ), 0.5 , 0.25 ),
104
+ ])
105
+ def test_with_alpha (self , name , reduction , weight , lambda_focal , alpha ):
95
106
size = [3 , 3 , 5 , 5 ]
96
107
label = torch .randint (low = 0 , high = 2 , size = size )
97
108
pred = torch .randn (size )
98
- alpha_values = [0.25 , 0.5 , 0.75 ]
99
- for reduction in ["sum" , "mean" , "none" ]:
100
- for weight in [None , torch .tensor ([1.0 , 1.0 , 2.0 ]), (3 , 2.0 , 1 )]:
101
- common_params = {
102
- "include_background" : True ,
103
- "to_onehot_y" : False ,
104
- "reduction" : reduction ,
105
- "weight" : weight ,
106
- }
107
- for lambda_focal in [0.5 , 1.0 , 1.5 ]:
108
- for alpha in alpha_values :
109
- dice_focal = DiceFocalLoss (gamma = 1.0 , lambda_focal = lambda_focal , alpha = alpha , ** common_params )
110
- dice = DiceLoss (** common_params )
111
- focal = FocalLoss (gamma = 1.0 , alpha = alpha , ** common_params )
112
- result = dice_focal (pred , label )
113
- expected_val = dice (pred , label ) + lambda_focal * focal (pred , label )
114
- np .testing .assert_allclose (result , expected_val )
115
109
110
+ common_params = {
111
+ "include_background" : True ,
112
+ "to_onehot_y" : False ,
113
+ "reduction" : reduction ,
114
+ "weight" : weight ,
115
+ }
116
+
117
+ dice_focal = DiceFocalLoss (gamma = 1.0 , lambda_focal = lambda_focal , alpha = alpha , ** common_params )
118
+ dice = DiceLoss (** common_params )
119
+ focal = FocalLoss (gamma = 1.0 , alpha = alpha , ** common_params )
120
+
121
+ result = dice_focal (pred , label )
122
+ expected_val = dice (pred , label ) + lambda_focal * focal (pred , label )
123
+
124
+ np .testing .assert_allclose (result , expected_val , err_msg = f"Failed on case: { name } " )
116
125
117
126
if __name__ == "__main__" :
118
127
unittest .main ()
0 commit comments