Skip to content

Dimension mismatch in multiclass_recall with average="macro" when some classes have zero support #216

@AllisonOge

Description

@AllisonOge

🐛 Describe the bug

I'm experiencing a runtime error when computing multiclass recall using torcheval's multiclass_recall function with average="macro". The error is:

RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0

Steps to Reproduce:

Using the minimal example below (see the attached code), I have logits for 3 classes where only classes 0 and 2 are predicted (the argmax of each row is 0 or 2), and the ground-truth labels are all class 2. In this scenario, the true positive counts (num_tp) and the total label counts (num_labels) are computed internally. However, during the averaging step, the code seems to filter out classes with zero support from the true positives (reducing the tensor’s length, e.g., from 3 to 2) while leaving the denominator unchanged. This leads to a dimension mismatch when performing an element-wise division.

MWE

import torch
from torcheval.metrics.functional import (
    multiclass_accuracy,
    multiclass_precision,
    multiclass_recall
)
# Example predictions and labels
logits = torch.tensor([
    [2.0, 0.5, 1.0],
    [0.2, 1.5, 2.1],
    [1.0, 2.0, 3.5],
    [0.1, 1.0, 1.2],
    [1.0, 0.2, 2.1],
])  # Shape: (5, 3)

labels = torch.tensor([2, 2, 2, 2, 2])  # True class indices

# Define metrics
accuracy = multiclass_accuracy(logits, labels)
precision = multiclass_precision(
    logits, labels, num_classes=3, average="macro")  # Average over classes
recall = multiclass_recall(logits, labels, num_classes=3, average="macro")

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)

Expected Behavior:

I would expect that either:

  • Both num_tp and num_labels are filtered consistently

Versions

  • torcheval version: 0.0.7
  • PyTorch version: 2.3.1
  • OS: Ubuntu 22.04.4 LTS

Any guidance or fix would be greatly appreciated!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions