-
Notifications
You must be signed in to change notification settings - Fork 56
Description
🐛 Describe the bug
I'm experiencing a runtime error when computing multiclass recall using torcheval's multiclass_recall function with average="macro". The error is:
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0Steps to Reproduce:
Using the minimal example below (see the attached code), I have logits for 3 classes where only classes 0 and 2 are predicted (the argmax of each row is 0 or 2), and the ground-truth labels are all class 2. In this scenario, the true positive counts (num_tp) and the total label counts (num_labels) are computed internally. However, during the averaging step, the code seems to filter out classes with zero support from the true positives (reducing the tensor’s length, e.g., from 3 to 2) while leaving the denominator unchanged. This leads to a dimension mismatch when performing an element-wise division.
MWE
import torch
from torcheval.metrics.functional import (
multiclass_accuracy,
multiclass_precision,
multiclass_recall
)
# Example predictions and labels
logits = torch.tensor([
[2.0, 0.5, 1.0],
[0.2, 1.5, 2.1],
[1.0, 2.0, 3.5],
[0.1, 1.0, 1.2],
[1.0, 0.2, 2.1],
]) # Shape: (5, 3)
labels = torch.tensor([2, 2, 2, 2, 2]) # True class indices
# Define metrics
accuracy = multiclass_accuracy(logits, labels)
precision = multiclass_precision(
logits, labels, num_classes=3, average="macro") # Average over classes
recall = multiclass_recall(logits, labels, num_classes=3, average="macro")
print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)Expected Behavior:
I would expect that either:
- Both
num_tpandnum_labelsare filtered consistently
Versions
- torcheval version: 0.0.7
- PyTorch version: 2.3.1
- OS: Ubuntu 22.04.4 LTS
Any guidance or fix would be greatly appreciated!