Skip to content

Commit 67cd98e

Browse files
committed
Fixing problem with torch.stack non-empty elemnt
1 parent d560489 commit 67cd98e

File tree

6 files changed

+39
-37
lines changed

6 files changed

+39
-37
lines changed

cellvit/inference/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def _import_postprocessing(self) -> Tuple[Callable, Callable]:
475475
Tuple[Callable, Callable]: Postprocessing module
476476
"""
477477
if self.system_configuration["cupy"]:
478-
from cellvit.inference.postprocessing_numpy import (
478+
from cellvit.inference.postprocessing_cupy import (
479479
DetectionCellPostProcessor,
480480
create_batch_pooling_actor,
481481
)

cellvit/inference/postprocessing_cupy.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -653,22 +653,23 @@ def convert_batch_to_graph_nodes(
653653
batch_cell_positions = batch_cell_positions + patch_cell_positions
654654

655655
if self.detection_cell_postprocessor.classifier is not None:
656-
batch_cell_tokens_pt = torch.stack(batch_cell_tokens)
657-
updated_preds = self.detection_cell_postprocessor.classifier(
658-
batch_cell_tokens_pt
659-
)
660-
updated_preds = F.softmax(updated_preds, dim=1)
661-
updated_classes = torch.argmax(updated_preds, dim=1)
662-
updated_class_preds = updated_preds[
663-
torch.arange(updated_classes.shape[0]), updated_classes
664-
]
665-
666-
for f, z in zip(batch_complete, updated_classes):
667-
f["type"] = int(z)
668-
for f, z in zip(batch_complete, updated_class_preds):
669-
f["type_prob"] = int(z)
670-
for f, z in zip(batch_detection, updated_classes):
671-
f["type"] = int(z)
656+
if len(batch_cell_tokens) > 0:
657+
batch_cell_tokens_pt = torch.stack(batch_cell_tokens)
658+
updated_preds = self.detection_cell_postprocessor.classifier(
659+
batch_cell_tokens_pt
660+
)
661+
updated_preds = F.softmax(updated_preds, dim=1)
662+
updated_classes = torch.argmax(updated_preds, dim=1)
663+
updated_class_preds = updated_preds[
664+
torch.arange(updated_classes.shape[0]), updated_classes
665+
]
666+
667+
for f, z in zip(batch_complete, updated_classes):
668+
f["type"] = int(z)
669+
for f, z in zip(batch_complete, updated_class_preds):
670+
f["type_prob"] = int(z)
671+
for f, z in zip(batch_detection, updated_classes):
672+
f["type"] = int(z)
672673
if self.detection_cell_postprocessor.binary:
673674
for f in batch_complete:
674675
f["type"] = 1

cellvit/inference/postprocessing_numpy.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -653,22 +653,23 @@ def convert_batch_to_graph_nodes(
653653
batch_cell_positions = batch_cell_positions + patch_cell_positions
654654

655655
if self.detection_cell_postprocessor.classifier is not None:
656-
batch_cell_tokens_pt = torch.stack(batch_cell_tokens)
657-
updated_preds = self.detection_cell_postprocessor.classifier(
658-
batch_cell_tokens_pt
659-
)
660-
updated_preds = F.softmax(updated_preds, dim=1)
661-
updated_classes = torch.argmax(updated_preds, dim=1)
662-
updated_class_preds = updated_preds[
663-
torch.arange(updated_classes.shape[0]), updated_classes
664-
]
665-
666-
for f, z in zip(batch_complete, updated_classes):
667-
f["type"] = int(z)
668-
for f, z in zip(batch_complete, updated_class_preds):
669-
f["type_prob"] = int(z)
670-
for f, z in zip(batch_detection, updated_classes):
671-
f["type"] = int(z)
656+
if len(batch_cell_tokens) > 0:
657+
batch_cell_tokens_pt = torch.stack(batch_cell_tokens)
658+
updated_preds = self.detection_cell_postprocessor.classifier(
659+
batch_cell_tokens_pt
660+
)
661+
updated_preds = F.softmax(updated_preds, dim=1)
662+
updated_classes = torch.argmax(updated_preds, dim=1)
663+
updated_class_preds = updated_preds[
664+
torch.arange(updated_classes.shape[0]), updated_classes
665+
]
666+
667+
for f, z in zip(batch_complete, updated_classes):
668+
f["type"] = int(z)
669+
for f, z in zip(batch_complete, updated_class_preds):
670+
f["type_prob"] = int(z)
671+
for f, z in zip(batch_detection, updated_classes):
672+
f["type"] = int(z)
672673
if self.detection_cell_postprocessor.binary:
673674
for f in batch_complete:
674675
f["type"] = 1

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "cellvit"
7-
version = "1.0.1b"
7+
version = "1.0.2b"
88
description = "CellViT Inference Pipeline for Whole Slide Images (WSI) in Memory"
99
authors = [
1010
{ name = "Fabian Hörst", email = "[email protected]" }

tests/test_cli/test_cli_yaml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ def test_invalid_batch_size(self, mock_device_count):
125125
InferenceConfiguration(invalid_config)
126126
self.assertEqual(
127127
str(context.exception),
128-
"Batch size must be between 2 and 32",
129-
"Batch size must be between 2 and 32",
128+
"Batch size must be between 2 and 48",
129+
"Batch size must be between 2 and 48",
130130
)
131131

132132
@patch("torch.cuda.device_count")

tests/test_dataclass/test_wsi_meta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_load_wsi_meta_with_svs_file_1(self):
2727
self.assertEqual(
2828
slide_properties["magnification"], 20.0
2929
) # Fill in expected value
30-
self.assertEqual(target_mpp, 0.25) # Fill in expected value
30+
self.assertEqual(target_mpp, 0.2495) # Fill in expected value
3131

3232
def test_load_wsi_meta_with_svs_file_2(self):
3333
"""Test loading WSI metadata from SVS file 2."""

0 commit comments

Comments
 (0)