@@ -264,6 +264,10 @@ def __init__(
264
264
* ,
265
265
clf = None ,
266
266
clf_dict = None ,
267
+ hf_hub_repo_id = None ,
268
+ hf_hub_clf_filename = None ,
269
+ hf_hub_download_kwargs = None ,
270
+ skops_trusted = None ,
267
271
tree_val = None ,
268
272
nontree_val = None ,
269
273
refine = None ,
@@ -285,6 +289,20 @@ def __init__(
285
289
clf_dict : dictionary, optional
286
290
Dictionary mapping a trained scikit-learn-like classifier to each
287
291
first-level cluster label.
292
+ hf_hub_repo_id, hf_hub_clf_filename : str, optional
293
+ HuggingFace Hub repository id (string with the user or organization and
294
+ repository name separated by a `/`) and file name of the skops classifier
295
+ respectively. If no value is provided, the values set in
296
+ `settings.HF_HUB_REPO_ID` and `settings.HF_HUB_CLF_FILENAME` Ignored if
297
+ `clf` or `clf_dict` are provided.
298
+ hf_hub_download_kwargs : dict, optional
299
+ Additional keyword arguments (besides "repo_id", "filename", "library_name"
300
+ and "library_version") to pass to `huggingface_hub.hf_hub_download`.
301
+ skosp_trusted : list, optional
302
+ List of trusted object types to load the classifier from HuggingFace Hub,
303
+ passed to `skops.io.load`. If no value is provided, the value from
304
+ `settings.SKOPS_TRUSTED` is used. Ignored if `clf` or `clf_dict` are
305
+ provided.
288
306
tree_val : int, optional
289
307
Label used to denote tree pixels in the predicted images. If no value is
290
308
provided, the value set in `settings.CLF_TREE_VAL` is used.
@@ -315,14 +333,32 @@ def __init__(
315
333
elif clf is not None :
316
334
self .clf = clf
317
335
else :
336
+ if hf_hub_repo_id is None :
337
+ hf_hub_repo_id = settings .HF_HUB_REPO_ID
338
+ if hf_hub_clf_filename is None :
339
+ hf_hub_clf_filename = settings .HF_HUB_CLF_FILENAME
340
+ if hf_hub_download_kwargs is None :
341
+ _hf_hub_download_kwargs = {}
342
+ else :
343
+ _hf_hub_download_kwargs = hf_hub_download_kwargs .copy ()
344
+ for key in [
345
+ "repo_id" ,
346
+ "filename" ,
347
+ "library_name" ,
348
+ "library_version" ,
349
+ ]:
350
+ _ = _hf_hub_download_kwargs .pop (key , None )
351
+ if skops_trusted is None :
352
+ skops_trusted = settings .SKOPS_TRUSTED
318
353
self .clf = io .load (
319
354
hf_hub .hf_hub_download (
320
- repo_id = settings . HF_HUB_REPO_ID ,
321
- filename = settings . HF_HUB_FILENAME ,
355
+ repo_id = hf_hub_repo_id ,
356
+ filename = hf_hub_clf_filename ,
322
357
library_name = "skops" ,
323
358
library_version = skops .__version__ ,
359
+ ** _hf_hub_download_kwargs ,
324
360
),
325
- trusted = settings . SKOPS_TRUSTED ,
361
+ trusted = skops_trusted ,
326
362
)
327
363
328
364
if tree_val is None :
0 commit comments