@@ -231,7 +231,7 @@ def batch_forward(self, inputs, labels=None, feature_ext=False, phase='train'):
231
231
var_dic ['feat' ], var_dic ['gt' ] = self .features , labels
232
232
var_dic ['op_opt' ] = self .op_opt
233
233
kappa = var_dic ['kappa' ].detach ()
234
- self .loss_icd , bias , op , self .sim_c = multi_get_loss (** var_dic )
234
+ self .loss_icd , self . loss_cfc , bias , op , self .sim_c = multi_get_loss (** var_dic )
235
235
self .networks ['classifier' ].module .bias = bias
236
236
self .networks ['classifier' ].module .op = op
237
237
self .logits = self .networks ['classifier' ](self .features , bias .repeat (self .num_gpus ))
@@ -263,9 +263,11 @@ def batch_loss(self, labels):
263
263
self .loss_perf = self .criterions ['PerformanceLoss' ](self .logits , labels )
264
264
self .loss_perf *= self .criterion_weights ['PerformanceLoss' ]
265
265
self .loss += self .loss_perf
266
+ weight = self .op_opt ['auxlossweight' ] * (1 - self .current_epoch / self .total_epoch ) ** 0.9
266
267
if 'icd' in self .op_opt ['auxloss' ]:
267
- weight = self .op_opt ['auxlossweight' ] * (1 - self .current_epoch / self .total_epoch ) ** 0.9
268
268
self .loss += weight * self .loss_icd
269
+ if 'cfc' in self .op_opt ['auxloss' ]:
270
+ self .loss += weight * self .loss_cfc
269
271
270
272
271
273
def shuffle_batch (self , x , y ):
@@ -361,6 +363,9 @@ def train(self):
361
363
362
364
if 'icd' in self .op_opt ['auxloss' ]:
363
365
minibatch_loss_icd = self .loss_icd .item ()
366
+ if 'cfc' in self .op_opt ['auxloss' ]:
367
+ minibatch_loss_cfc = self .loss_cfc .item ()
368
+
364
369
365
370
minibatch_acc = mic_acc_cal (preds , labels )
366
371
@@ -377,6 +382,8 @@ def train(self):
377
382
if 'vmf' in self .config ['networks' ]['classifier' ]['def_file' ]:
378
383
if 'icd' in self .op_opt ['auxloss' ]:
379
384
print_str += ['Licd: %.3f' % (minibatch_loss_icd )]
385
+ if 'cfc' in self .op_opt ['auxloss' ]:
386
+ print_str += ['Lcfc: %.3f' % (minibatch_loss_cfc )]
380
387
print_write (print_str , self .log_file )
381
388
382
389
loss_info = {
0 commit comments