Skip to content

Commit 36037f7

Browse files
committed
Minor fix to torch.load() (#4392)
* fix weights only * change otx version * change Changelog * fix linter
1 parent 26f55c2 commit 36037f7

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## \[2.4.4\]
6+
7+
### Bug fixes
8+
9+
- Fix torch.load() to be able to load all OTX custom snapshots
10+
(<https://github.com/open-edge-platform/training_extensions/pull/4392>)
11+
512
## \[2.4.3\]
613

714
### Enhancements

src/otx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Copyright (C) 2024-2025 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
55

6-
__version__ = "2.4.3"
6+
__version__ = "2.4.4"
77

88
import os
99
from pathlib import Path

src/otx/engine/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def train(
266266
# load the model state from the checkpoint incrementally.
267267
# This means only the model weights are loaded. If there is a mismatch in label_info,
268268
# perform incremental weight loading for the model's classification layer.
269-
ckpt = torch.load(checkpoint)
269+
ckpt = torch.load(checkpoint, weights_only=False)
270270
self.model.load_state_dict_incrementally(ckpt)
271271

272272
with override_metric_callable(model=self.model, new_metric_callable=metric) as model:

0 commit comments

Comments
 (0)