Skip to content

Commit 86b872b

Browse files
committed
Merge branch 'cau/thread-safety-fixes' of github.com:DS4SD/docling-ibm-models into cau/batching-layout-model
2 parents aea6bc2 + 2edf60e commit 86b872b

File tree

5 files changed

+44
-6
lines changed

5 files changed

+44
-6
lines changed

docling_ibm_models/code_formula_model/code_formula_predictor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def __init__(
8989
self._tokenizer = AutoTokenizer.from_pretrained(
9090
artifacts_path, use_fast=True, padding_side="left"
9191
)
92-
self._model = SamOPTForCausalLM.from_pretrained(artifacts_path).to(
93-
self._device
92+
self._model = SamOPTForCausalLM.from_pretrained(
93+
artifacts_path, device_map=self._device
9494
)
9595
self._model.eval()
9696

docling_ibm_models/document_figure_classifier_model/document_figure_classifier_predictor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,9 @@ def __init__(
9090
torch.set_num_threads(self._num_threads)
9191

9292
with _model_init_lock:
93-
model = AutoModelForImageClassification.from_pretrained(artifacts_path)
94-
self._model = model.to(device)
93+
self._model = AutoModelForImageClassification.from_pretrained(
94+
artifacts_path, device_map=device
95+
)
9596
self._model.eval()
9697

9798
self._image_processor = transforms.Compose(

docling_ibm_models/layoutmodel/layout_predictor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def __init__(
8484
# Use lock to prevent threading issues during model initialization
8585
with _model_init_lock:
8686
self._model = AutoModelForObjectDetection.from_pretrained(
87-
artifact_path, config=self._model_config
88-
).to(self._device)
87+
artifact_path, config=self._model_config, device_map=self._device
88+
)
8989
self._model.eval()
9090

9191
# Set classes map

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ dependencies = [
4343
'transformers (>=4.42.0,<5.0.0)',
4444
'numpy (>=1.24.4,<3.0.0)',
4545
"rtree>=1.0.0",
46+
'accelerate (>=1.2.1,<2.0.0)',
4647
]
4748

4849
[project.urls]

uv.lock

Lines changed: 36 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)