Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ These utility functions perform various operations on bounding boxes.
:template: function.rst

box_area
box_area_center
box_convert
box_iou
box_iou_center
clip_boxes_to_image
complete_box_iou
distance_box_iou
Expand Down
102 changes: 102 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,6 +1451,41 @@ def test_box_area_jit(self):
torch.testing.assert_close(scripted_area, expected)


class TestBoxAreaCenter:
def area_check(self, box, expected, atol=1e-4):
out = ops.box_area_center(box)
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol)

@pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64])
def test_int_boxes(self, dtype):
box_tensor = ops.box_convert(torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype),
in_fmt="xyxy", out_fmt="cxcywh")
expected = torch.tensor([10000, 0], dtype=torch.int32)
self.area_check(box_tensor, expected)

@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_float_boxes(self, dtype):
box_tensor = ops.box_convert(torch.tensor(FLOAT_BOXES, dtype=dtype), in_fmt="xyxy", out_fmt="cxcywh")
expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype)
self.area_check(box_tensor, expected)

def test_float16_box(self):
box_tensor = ops.box_convert(torch.tensor(
[[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16
), in_fmt="xyxy", out_fmt="cxcywh")

expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16)
self.area_check(box_tensor, expected, atol=0.01)

def test_box_area_jit(self):
box_tensor = ops.box_convert(torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float),
in_fmt="xyxy", out_fmt="cxcywh")
expected = ops.box_area_center(box_tensor)
scripted_fn = torch.jit.script(ops.box_area_center)
scripted_area = scripted_fn(box_tensor)
torch.testing.assert_close(scripted_area, expected)


INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]]
INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
FLOAT_BOXES = [
Expand All @@ -1459,6 +1494,14 @@ def test_box_area_jit(self):
[279.2440, 197.9812, 1189.4746, 849.2019],
]

INT_BOXES_CXCYWH = [[50, 50, 100, 100], [25, 25, 50, 50], [250, 250, 100, 100], [10, 10, 20, 20]]
INT_BOXES2_CXCYWH = [[50, 50, 100, 100], [25, 25, 50, 50], [250, 250, 100, 100]]
FLOAT_BOXES_CXCYWH = [
[739.4324, 518.5154, 908.1572, 665.8793],
[738.8228, 519.9021, 907.3512, 662.3295],
[734.3593, 523.5916, 910.2306, 651.2207]
]


def gen_box(size, dtype=torch.float):
xy1 = torch.rand((size, 2), dtype=dtype)
Expand Down Expand Up @@ -1525,6 +1568,65 @@ def test_iou_cartesian(self):
self._run_cartesian_test(ops.box_iou)


class TestIouCenterBase:
@staticmethod
def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
for dtype in dtypes:
actual_box1 = torch.tensor(actual_box1, dtype=dtype)
actual_box2 = torch.tensor(actual_box2, dtype=dtype)
expected_box = torch.tensor(expected)
out = target_fn(actual_box1, actual_box2)
torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)

@staticmethod
def _run_jit_test(target_fn: Callable, actual_box: List):
box_tensor = torch.tensor(actual_box, dtype=torch.float)
expected = target_fn(box_tensor, box_tensor)
scripted_fn = torch.jit.script(target_fn)
scripted_out = scripted_fn(box_tensor, box_tensor)
torch.testing.assert_close(scripted_out, expected)

@staticmethod
def _cartesian_product(boxes1, boxes2, target_fn: Callable):
N = boxes1.size(0)
M = boxes2.size(0)
result = torch.zeros((N, M))
for i in range(N):
for j in range(M):
result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
return result

@staticmethod
def _run_cartesian_test(target_fn: Callable):
boxes1 = ops.box_convert(gen_box(5), in_fmt="xyxy", out_fmt="cxcywh")
boxes2 = ops.box_convert(gen_box(7), in_fmt="xyxy", out_fmt="cxcywh")
a = TestIouCenterBase._cartesian_product(boxes1, boxes2, target_fn)
b = target_fn(boxes1, boxes2)
torch.testing.assert_close(a, b)


class TestBoxIouCenter(TestIouBase):
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.04, 0.16, 0.0]]
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"actual_box1, actual_box2, dtypes, atol, expected",
[
pytest.param(INT_BOXES_CXCYWH, INT_BOXES2_CXCYWH, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
pytest.param(FLOAT_BOXES_CXCYWH, FLOAT_BOXES_CXCYWH, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES_CXCYWH, FLOAT_BOXES_CXCYWH, [torch.float32, torch.float64], 1e-3, float_expected),
],
)
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
self._run_test(ops.box_iou_center, actual_box1, actual_box2, dtypes, atol, expected)

def test_iou_jit(self):
self._run_jit_test(ops.box_iou_center, INT_BOXES_CXCYWH)

def test_iou_cartesian(self):
self._run_cartesian_test(ops.box_iou_center)


class TestGeneralizedBoxIou(TestIouBase):
int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
Expand Down
4 changes: 4 additions & 0 deletions torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from .boxes import (
batched_nms,
box_area,
box_area_center,
box_convert,
box_iou,
box_iou_center,
clip_boxes_to_image,
complete_box_iou,
distance_box_iou,
Expand Down Expand Up @@ -40,7 +42,9 @@
"clip_boxes_to_image",
"box_convert",
"box_area",
"box_area_center",
"box_iou",
"box_iou_center",
"generalized_box_iou",
"distance_box_iou",
"complete_box_iou",
Expand Down
55 changes: 55 additions & 0 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,25 @@ def box_area(boxes: Tensor) -> Tensor:
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])


def box_area_center(boxes: Tensor) -> Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR and for the benchmarks @alperenunlu !

I agree with @AntoineSimoulin 's suggestion in #8961 (comment) to not expose the new _center() functions, and instead just add a new fmt (or in_fmt) parameter to the existing box_area() and box_iou() functions.

This will reduce surface area of the API, so it'll be easier to find for users, and easier for us to maintain as well!

For the rest I will let @AntoineSimoulin take a look

"""
Computes the area of a set of bounding boxes, which are specified by their
(cx, cy, w, h) coordinates.

Args:
boxes (Tensor[N, 4]): boxes for which the area will be computed. They
are expected to be in (cx, cy, w, h) format with
``0 <= cx``, ``0 <= cy``, ``0 <= w`` and ``0 <= h``.

Returns:
Tensor[N]: the area for each box
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(box_area_center)
boxes = _upcast(boxes)
return boxes[:, 2] * boxes[:, 3]


# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
Expand Down Expand Up @@ -329,6 +348,42 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
return iou


def _box_inter_union_center(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
area1 = box_area_center(boxes1)
area2 = box_area_center(boxes2)

lt = torch.max(boxes1[:, None, :2] - boxes1[:, None, 2:] / 2, boxes2[:, :2] - boxes2[:, 2:] / 2) # [N,M,2]
rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2]

wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]

union = area1[:, None] + area2 - inter

return inter, union


def box_iou_center(boxes1: Tensor, boxes2: Tensor) -> Tensor:
"""
Return intersection-over-union (Jaccard index) between two sets of boxes.

Both sets of boxes are expected to be in ``(cx, cy, w, h)`` format with
``0 <= cx``, ``0 <= cy``, ``0 <= w`` and ``0 <= h``.

Args:
boxes1 (Tensor[N, 4]): first set of boxes
boxes2 (Tensor[M, 4]): second set of boxes

Returns:
Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(box_iou_center)
inter, union = _box_inter_union_center(boxes1, boxes2)
iou = inter / union
return iou


# Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
"""
Expand Down