Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ All notable changes to this project will be documented in this file.

- Fix DataInputParams Serialization
(<https://github.com/openvinotoolkit/training_extensions/pull/4293>)
- Align KP detection validation with ModelAPI post processing
(<https://github.com/openvinotoolkit/training_extensions/pull/4300>)

### Removed

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ docs = [
base = [
"torch==2.5.1",
"lightning==2.4.0",
"torchmetrics==1.6.0",
"pytorchcv==0.0.67",
"timm==1.0.3",
"openvino==2025.0",
Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/data/dataset/keypoint_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _get_item_impl(self, index: int) -> TorchDataItem | None:
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=img_shape,
),
label=torch.as_tensor([ann.label for ann in bbox_anns]),
label=torch.as_tensor([ann.label for ann in bbox_anns], dtype=torch.long),
keypoints=torch.as_tensor(keypoints, dtype=torch.float32),
)

Expand Down
26 changes: 22 additions & 4 deletions src/otx/core/model/keypoint_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,28 @@ def _customize_outputs(
scores = []
# default visibility threshold
visibility_threshold = 0.5
for output in outputs:
if inputs.imgs_info is None:
msg = "The input image information is not provided."
raise ValueError(msg)
for i, output in enumerate(outputs):
if not isinstance(output, tuple):
raise TypeError(output)
kps = torch.as_tensor(output[0], device=self.device)
if inputs.imgs_info[i] is None:
msg = f"The image information for the image {i} is not provided."
raise ValueError(msg)
# scale to the original image size
orig_h, orig_w = inputs.imgs_info[i].ori_shape # type: ignore[union-attr]
kp_scale_h, kp_scale_w = (
orig_h / self.data_input_params.input_size[0],
orig_w / self.data_input_params.input_size[1],
)
inverted_scale = max(kp_scale_h, kp_scale_w)
kp_scale_h = kp_scale_w = inverted_scale
# decode kps
kps = torch.as_tensor(output[0], device=self.device) * torch.tensor(
[kp_scale_w, kp_scale_h],
device=self.device,
)
score = torch.as_tensor(output[1], device=self.device)
visible_keypoints = torch.cat([kps, score.unsqueeze(1) > visibility_threshold], dim=1)
keypoints.append(visible_keypoints)
Expand Down Expand Up @@ -164,8 +182,8 @@ def get_dummy_input(self, batch_size: int = 1) -> TorchDataBatch: # type: ignor
infos.append(
ImageInfo(
img_idx=i,
img_shape=img.shape,
ori_shape=img.shape,
img_shape=img.shape[:2],
ori_shape=img.shape[:2],
),
)

Expand Down
11 changes: 9 additions & 2 deletions src/otx/data/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,16 @@ def collate_fn(items: list[TorchDataItem]) -> TorchDataBatch:
Returns:
Batched TorchDataItems with stacked tensors
"""
# Check if all images have the same size. TODO(kprokofi): remove this check once OV IR models are moved.
if all(item.image.shape == items[0].image.shape for item in items):
images = torch.stack([item.image for item in items])
else:
# we need this only in case of OV inference, where no resize
images = [item.image for item in items]

return TorchDataBatch(
batch_size=len(items),
images=torch.stack([item.image for item in items]),
images=images,
labels=[item.label for item in items],
bboxes=[item.bboxes for item in items],
keypoints=[item.keypoints for item in items],
Expand All @@ -82,7 +89,7 @@ class TorchDataBatch(ValidateBatchMixin):
"""Torch data item batch implementation."""

batch_size: int # TODO(ashwinvaidya17): Remove this
images: torch.Tensor
images: torch.Tensor | list[torch.Tensor]
labels: list[torch.Tensor] | None = None
masks: list[Mask] | None = None
bboxes: list[BoundingBoxes] | None = None
Expand Down
36 changes: 25 additions & 11 deletions src/otx/data/validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,18 +191,32 @@ def __post_init__(self) -> None:
@staticmethod
def _images_validator(image_batch: torch.Tensor) -> torch.Tensor:
"""Validate the image batch."""
if not isinstance(image_batch, torch.Tensor):
msg = f"Image batch must be a torch tensor. Got {type(image_batch)}"
if not isinstance(image_batch, list) and not isinstance(image_batch, torch.Tensor):
msg = f"Image batch must be a torch tensor or list of tensors. Got {type(image_batch)}"
raise TypeError(msg)
if image_batch.dtype != torch.float32:
msg = "Image batch must have dtype float32"
raise ValueError(msg)
if image_batch.ndim != 4:
msg = "Image batch must have 4 dimensions"
raise ValueError(msg)
if image_batch.shape[1] not in [1, 3]:
msg = "Image batch must have 1 or 3 channels"
raise ValueError(msg)
if isinstance(image_batch, torch.Tensor):
if image_batch.dtype != torch.float32:
msg = "Image batch must have dtype float32"
raise ValueError(msg)
if image_batch.ndim != 4:
msg = "Image batch must have 4 dimensions"
raise ValueError(msg)
if image_batch.shape[1] not in [1, 3]:
msg = "Image batch must have 1 or 3 channels"
raise ValueError(msg)
else:
if not all(isinstance(image, torch.Tensor) for image in image_batch):
msg = "Image batch must be a list of torch tensors"
raise TypeError(msg)
if not all(image.dtype == torch.float32 for image in image_batch):
msg = "Image batch must have dtype float32"
raise ValueError(msg)
if not all(image.ndim == 3 for image in image_batch):
msg = "Image batch must have 3 dimensions"
raise ValueError(msg)
if not all(image.shape[0] in [1, 3] for image in image_batch):
msg = "Image batch must have 1 or 3 channels"
raise ValueError(msg)
return image_batch

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions src/otx/recipe/_base_/data/keypoint_detection.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ val_subset:
init_args:
scale: $(input_size)
keep_ratio: true
transform_keypoints: true
transform_keypoints: false
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size: $(input_size)
Expand All @@ -55,7 +55,7 @@ test_subset:
init_args:
scale: $(input_size)
keep_ratio: true
transform_keypoints: true
transform_keypoints: false
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size: $(input_size)
Expand Down
Loading