Skip to content

Commit 964d491

Browse files
committed
add loss_cfc
1 parent 7e29630 commit 964d491

File tree

4 files changed

+20
-6
lines changed

4 files changed

+20
-6
lines changed

.DS_Store

6 KB
Binary file not shown.

classification/.DS_Store

6 KB
Binary file not shown.

classification/run_networks.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def batch_forward(self, inputs, labels=None, feature_ext=False, phase='train'):
231231
var_dic['feat'], var_dic['gt'] = self.features, labels
232232
var_dic['op_opt'] = self.op_opt
233233
kappa = var_dic['kappa'].detach()
234-
self.loss_icd, bias, op, self.sim_c = multi_get_loss(**var_dic)
234+
self.loss_icd, self.loss_cfc, bias, op, self.sim_c = multi_get_loss(**var_dic)
235235
self.networks['classifier'].module.bias = bias
236236
self.networks['classifier'].module.op = op
237237
self.logits = self.networks['classifier'](self.features, bias.repeat(self.num_gpus))
@@ -263,9 +263,11 @@ def batch_loss(self, labels):
263263
self.loss_perf = self.criterions['PerformanceLoss'](self.logits, labels)
264264
self.loss_perf *= self.criterion_weights['PerformanceLoss']
265265
self.loss += self.loss_perf
266+
weight = self.op_opt['auxlossweight'] * (1 - self.current_epoch / self.total_epoch) ** 0.9
266267
if 'icd' in self.op_opt['auxloss']:
267-
weight = self.op_opt['auxlossweight'] * (1 - self.current_epoch / self.total_epoch) ** 0.9
268268
self.loss += weight * self.loss_icd
269+
if 'cfc' in self.op_opt['auxloss']:
270+
self.loss += weight * self.loss_cfc
269271

270272

271273
def shuffle_batch(self, x, y):
@@ -361,6 +363,9 @@ def train(self):
361363

362364
if 'icd' in self.op_opt['auxloss']:
363365
minibatch_loss_icd = self.loss_icd.item()
366+
if 'cfc' in self.op_opt['auxloss']:
367+
minibatch_loss_cfc = self.loss_cfc.item()
368+
364369

365370
minibatch_acc = mic_acc_cal(preds, labels)
366371

@@ -377,6 +382,8 @@ def train(self):
377382
if 'vmf' in self.config['networks']['classifier']['def_file']:
378383
if 'icd' in self.op_opt['auxloss']:
379384
print_str += ['Licd: %.3f'% (minibatch_loss_icd)]
385+
if 'cfc' in self.op_opt['auxloss']:
386+
print_str += ['Lcfc: %.3f'% (minibatch_loss_cfc)]
380387
print_write(print_str, self.log_file)
381388

382389
loss_info = {

classification/utils_multi_K_vMF.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,26 @@ def get_all_overlap(op_opt, apk, bias, kappa, feat, gt, num_classes, mu, ni):
9696
loss_icd = ( ni * valid_op ).mean().mean() #.sum().sum() / ni.sum().sum()
9797
else:
9898
loss_icd = 0
99+
if 'cfc' in op_opt['auxloss']:
100+
sim_feat_class = ( mu[mask] * mu_feat ).sum(1)
101+
kl = apk[mask] * kappa[mask] * ( 1 - sim_feat_class )
102+
op = 1 / ( 1 + kl )
103+
loss_cfc = ( 1 - op ).mean()
104+
else:
105+
loss_cfc = 0
99106
with torch.no_grad():
100107
sim_c = (ni * sim_class).mean(1).mean(0)
101108
op = valid_op.mean(1)
102-
return loss_icd, op, sim_c
109+
return loss_icd, loss_cfc, op, sim_c
103110

104111
ni = read_ni()
105112
def multi_get_loss(op_opt, p, kappa, weight, feat, gt, num_classes):
106113
mu = get_mu(weight)
107114
apk = get_apk(p, kappa)
108115
bias = get_bias(p, kappa, apk)
109116
if op_opt['auxloss'] != []:
110-
loss_icd, op, sim_c = get_all_overlap(op_opt, apk, bias, kappa, feat, gt, num_classes, mu, ni)
117+
loss_icd, loss_cfc, op, sim_c = get_all_overlap(op_opt, apk, bias, kappa, feat, gt, num_classes, mu, ni)
111118
else:
112-
loss_icd, sim_c, op = 0.0, 0.0, bias.detach()
113-
return loss_icd, bias, op, sim_c
119+
loss_icd, loss_cfc, sim_c, op = 0.0, 0.0, 0.0, bias.detach() # hard code to avoid happening errors when running codes, I will rewrite it.
120+
return loss_icd, loss_cfc, bias, op, sim_c
114121

0 commit comments

Comments
 (0)