Skip to content

Commit 8dd39fe

Browse files
committed
feat: run any skops model from hf hub
1 parent 05419c1 commit 8dd39fe

File tree

3 files changed

+48
-5
lines changed

3 files changed

+48
-5
lines changed

detectree/classifier.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,10 @@ def __init__(
264264
*,
265265
clf=None,
266266
clf_dict=None,
267+
hf_hub_repo_id=None,
268+
hf_hub_clf_filename=None,
269+
hf_hub_download_kwargs=None,
270+
skops_trusted=None,
267271
tree_val=None,
268272
nontree_val=None,
269273
refine=None,
@@ -285,6 +289,20 @@ def __init__(
285289
clf_dict : dictionary, optional
286290
Dictionary mapping a trained scikit-learn-like classifier to each
287291
first-level cluster label.
292+
hf_hub_repo_id, hf_hub_clf_filename : str, optional
293+
HuggingFace Hub repository id (string with the user or organization and
294+
repository name separated by a `/`) and file name of the skops classifier
295+
respectively. If no value is provided, the values set in
296+
`settings.HF_HUB_REPO_ID` and `settings.HF_HUB_CLF_FILENAME` Ignored if
297+
`clf` or `clf_dict` are provided.
298+
hf_hub_download_kwargs : dict, optional
299+
Additional keyword arguments (besides "repo_id", "filename", "library_name"
300+
and "library_version") to pass to `huggingface_hub.hf_hub_download`.
301+
skosp_trusted : list, optional
302+
List of trusted object types to load the classifier from HuggingFace Hub,
303+
passed to `skops.io.load`. If no value is provided, the value from
304+
`settings.SKOPS_TRUSTED` is used. Ignored if `clf` or `clf_dict` are
305+
provided.
288306
tree_val : int, optional
289307
Label used to denote tree pixels in the predicted images. If no value is
290308
provided, the value set in `settings.CLF_TREE_VAL` is used.
@@ -315,14 +333,32 @@ def __init__(
315333
elif clf is not None:
316334
self.clf = clf
317335
else:
336+
if hf_hub_repo_id is None:
337+
hf_hub_repo_id = settings.HF_HUB_REPO_ID
338+
if hf_hub_clf_filename is None:
339+
hf_hub_clf_filename = settings.HF_HUB_CLF_FILENAME
340+
if hf_hub_download_kwargs is None:
341+
_hf_hub_download_kwargs = {}
342+
else:
343+
_hf_hub_download_kwargs = hf_hub_download_kwargs.copy()
344+
for key in [
345+
"repo_id",
346+
"filename",
347+
"library_name",
348+
"library_version",
349+
]:
350+
_ = _hf_hub_download_kwargs.pop(key, None)
351+
if skops_trusted is None:
352+
skops_trusted = settings.SKOPS_TRUSTED
318353
self.clf = io.load(
319354
hf_hub.hf_hub_download(
320-
repo_id=settings.HF_HUB_REPO_ID,
321-
filename=settings.HF_HUB_FILENAME,
355+
repo_id=hf_hub_repo_id,
356+
filename=hf_hub_clf_filename,
322357
library_name="skops",
323358
library_version=skops.__version__,
359+
**_hf_hub_download_kwargs,
324360
),
325-
trusted=settings.SKOPS_TRUSTED,
361+
trusted=skops_trusted,
326362
)
327363

328364
if tree_val is None:

detectree/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"lightgbm.sklearn.LGBMClassifier",
3737
]
3838
HF_HUB_REPO_ID = "martibosch/detectree"
39-
HF_HUB_FILENAME = "clf.skops"
39+
HF_HUB_CLF_FILENAME = "clf.skops"
4040

4141
# LIDAR
4242
LIDAR_TREE_THRESHOLD = 15

tests/test_detectree.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,14 @@ def test_classifier(self):
610610
# test that for the pre-trained classifier (no init `clf`/`clf_dict` arg) and
611611
# for the classifier initialized with the `clf` arg, the `clf` attribute is set
612612
# but not `clf_dict`
613-
for c in [dtr.Classifier(), dtr.Classifier(clf=self.clf)]:
613+
for c in [
614+
dtr.Classifier(),
615+
dtr.Classifier(clf=self.clf),
616+
dtr.Classifier(
617+
hf_hub_repo_id=settings.HF_HUB_REPO_ID,
618+
hf_hub_clf_filename=settings.HF_CLF_FILENAME,
619+
),
620+
]:
614621
self.assertTrue(hasattr(c, "clf"))
615622
self.assertFalse(hasattr(c, "clf_dict"))
616623
# test that when initializing `clf_dict`, the `clf_dict` attribute is set but

0 commit comments

Comments
 (0)