Skip to content

Commit 25a9d99

Browse files
authored
ENH remove trusted=True from skops.io.load(s) (#422)
* SEC remove trusted=True * FIX make the rest of the fixes and make sure tests pass * TST add test for trusted=True * DOC add changelog * DOC update PR number * ENH make the error message more helpful * Address Benjamin's comments * CLN fix a typecheck to what we really want
1 parent f69b928 commit 25a9d99

21 files changed

+157
-156
lines changed

docs/changes.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ v0.10
1414
- Removes Pythn 3.8 support and adds Python 3.12 Support :pr:`418` by :user:`Thomas Lazarus <lazarust>`.
1515
- Removes a shortcut to add `sklearn-intelex` as a not dependency.
1616
:pr:`420` by :user:`Thomas Lazarus < lazarust > `.
17+
- ``trusted=True`` is now removed from ``skops.io.load`` and ``skops.io.loads``.
18+
This is to further encourage users to inspect the input data before loading
19+
it. :func:`skops.io.get_untrusted_types` can be used to get the untrusted types
20+
present in the input.
21+
:pr:`422` by `Adrin Jalali`_.
1722

1823
v0.9
1924
----

docs/persistence.rst

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ The code snippet below illustrates how to use :func:`skops.io.dump` and
5454
from xgboost.sklearn import XGBClassifier
5555
from sklearn.model_selection import GridSearchCV, train_test_split
5656
from sklearn.datasets import load_iris
57-
from skops.io import dump, load
57+
from skops.io import dump, load, get_untrusted_types
5858
5959
X, y = load_iris(return_X_y=True)
6060
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4)
@@ -64,26 +64,24 @@ The code snippet below illustrates how to use :func:`skops.io.dump` and
6464
0.9666666666666667
6565
dump(clf, "my-model.skops")
6666
# ...
67-
loaded = load("my-model.skops", trusted=True)
67+
unknown_types = get_untrusted_types(file="my-model.skops")
68+
print(unknown_types)
69+
['sklearn.metrics._scorer._passthrough_scorer',
70+
'xgboost.core.Booster', 'xgboost.sklearn.XGBClassifier']
71+
loaded = load("my-model.skops", trusted=unknown_types)
6872
print(loaded.score(X_test, y_test))
6973
0.9666666666666667
7074
7175
# in memory
7276
from skops.io import dumps, loads
7377
serialized = dumps(clf)
74-
loaded = loads(serialized, trusted=True)
75-
76-
Note that you should only load files with ``trusted=True`` if you trust the
77-
source. Otherwise you can get a list of untrusted types present in the dump
78-
using :func:`skops.io.get_untrusted_types`:
79-
80-
.. code:: python
78+
loaded = loads(serialized, trusted=unknown_types)
8179
82-
from skops.io import get_untrusted_types
83-
unknown_types = get_untrusted_types(file="my-model.skops")
84-
print(unknown_types)
85-
['sklearn.metrics._scorer._passthrough_scorer',
86-
'xgboost.core.Booster', 'xgboost.sklearn.XGBClassifier']
80+
Note that the ``get_untrusted_types`` function is used to check which types are
81+
not trusted by default. The user can then decide whether to trust them or not.
82+
In previous before version 0.10, users could pass ``trusted=True`` to skip the
83+
audit phase, which is now removed to encourage users to validate the input
84+
before loading.
8785

8886
Note that everything in the above list is safe to load. We already have many
8987
types included as trusted by default, and some of the above values might be
@@ -92,10 +90,6 @@ added to that list in the future.
9290
Once you check the list and you validate that everything in the list is safe,
9391
you can load the file with ``trusted=unknown_types``:
9492

95-
.. code:: python
96-
97-
loaded = load("my-model.skops", trusted=unknown_types)
98-
9993
At the moment, we support the vast majority of sklearn estimators. This
10094
includes complex use cases such as :class:`sklearn.pipeline.Pipeline`,
10195
:class:`sklearn.model_selection.GridSearchCV`, classes using objects defined in
@@ -226,10 +220,11 @@ green to cyan. The ``rich`` docs list the `supported standard colors
226220

227221
Note that the visualization feature is intended to help understand the structure
228222
of the object, e.g. what attributes are identified as untrusted. It is not a
229-
replacement for a proper security check. In particular, just because an object's
230-
visualization looks innocent does *not* mean you can just call `sio.load(<file>,
231-
trusted=True)` on this object -- only pass the types you really trust to the
232-
``trusted`` argument.
223+
replacement for a proper security check of the included types in the file. In
224+
particular, just because an object's visualization looks innocent does *not*
225+
mean you can just call `sio.load(<file>,
226+
trusted=get_untrusted_types(file=<file>))` on this object -- only pass the
227+
types you really trust to the ``trusted`` argument.
233228

234229
Supported libraries
235230
-------------------

skops/card/_model_card.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from hashlib import sha256
1313
from pathlib import Path
1414
from reprlib import Repr
15-
from typing import Any, Iterator, Literal, Sequence, Union
15+
from typing import Any, Iterator, List, Literal, Optional, Sequence, Union
1616

1717
import joblib
1818
from huggingface_hub import ModelCardData
@@ -488,7 +488,7 @@ def __init__(
488488
model_diagram: bool | Literal["auto"] | str = "auto",
489489
metadata: ModelCardData | None = None,
490490
template: Literal["skops"] | dict[str, str] | None = "skops",
491-
trusted: bool = False,
491+
trusted: Optional[List[str]] = None,
492492
) -> None:
493493
self.model = model
494494
self.metadata = metadata or ModelCardData()

skops/card/tests/test_card.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
TableSection,
2727
_load_model,
2828
)
29-
from skops.io import dump, load
29+
from skops.io import dump, get_untrusted_types, load
3030
from skops.utils.importutils import import_or_raise
3131

3232

@@ -51,10 +51,14 @@ def save_model_to_file(model_instance, suffix):
5151
def test_load_model(suffix):
5252
model0 = LinearRegression(n_jobs=123)
5353
_, save_file = save_model_to_file(model0, suffix)
54-
loaded_model_str = _load_model(save_file, trusted=True)
54+
if suffix == ".skops":
55+
untrusted_types = get_untrusted_types(file=save_file)
56+
else:
57+
untrusted_types = None
58+
loaded_model_str = _load_model(save_file, trusted=untrusted_types)
5559
save_file_path = Path(save_file)
56-
loaded_model_path = _load_model(save_file_path, trusted=True)
57-
loaded_model_instance = _load_model(model0, trusted=True)
60+
loaded_model_path = _load_model(save_file_path, trusted=untrusted_types)
61+
loaded_model_instance = _load_model(model0, trusted=untrusted_types)
5862

5963
assert loaded_model_str.n_jobs == 123
6064
assert loaded_model_path.n_jobs == 123
@@ -1383,8 +1387,11 @@ def test_with_metadata(self, card: Card, meth, expected_lines):
13831387

13841388

13851389
class TestCardModelAttributeIsPath:
1386-
def path_to_card(self, path):
1387-
card = Card(model=path, trusted=True)
1390+
def path_to_card(self, path, suffix):
1391+
if suffix == ".skops":
1392+
card = Card(model=path, trusted=get_untrusted_types(file=path))
1393+
else:
1394+
card = Card(model=path)
13881395
return card
13891396

13901397
@pytest.mark.parametrize("meth", [repr, str])
@@ -1397,7 +1404,7 @@ def test_model_card_repr(self, meth, suffix):
13971404
model = LinearRegression(fit_intercept=False)
13981405
file_handle, file_name = save_model_to_file(model, suffix)
13991406
os.close(file_handle)
1400-
card_from_path = self.path_to_card(file_name)
1407+
card_from_path = self.path_to_card(file_name, suffix=suffix)
14011408

14021409
result0 = meth(card_from_path)
14031410
expected = "Card(\n model=LinearRegression(fit_intercept=False),"

skops/cli/_update.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pathlib import Path
1010

1111
from skops.cli._utils import get_log_level
12-
from skops.io import dump, load
12+
from skops.io import dump, get_untrusted_types, load
1313
from skops.io._protocol import PROTOCOL
1414

1515

@@ -48,7 +48,7 @@ def _update_file(
4848
" file."
4949
)
5050

51-
input_model = load(input_file, trusted=True)
51+
input_model = load(input_file, trusted=get_untrusted_types(file=input_file))
5252
with zipfile.ZipFile(input_file, "r") as zip_file:
5353
input_file_schema = json.loads(zip_file.read("schema.json"))
5454

skops/cli/tests/test_convert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88

99
from skops.cli import _convert
10-
from skops.io import load
10+
from skops.io import get_untrusted_types, load
1111

1212

1313
class MockUnsafeType:
@@ -61,7 +61,7 @@ def test_unsafe_case_works_as_expected(
6161
):
6262
caplog.set_level(logging.WARNING)
6363
_convert._convert_file(pkl_path, skops_path)
64-
persisted_obj = load(skops_path, trusted=True)
64+
persisted_obj = load(skops_path, trusted=get_untrusted_types(file=skops_path))
6565

6666
assert isinstance(persisted_obj, MockUnsafeType)
6767

skops/io/_audit.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import io
44
from contextlib import contextmanager
5-
from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Type, Union
5+
from typing import Any, Dict, Generator, List, Optional, Sequence, Type, Union
66

77
from ._protocol import PROTOCOL
88
from ._utils import LoadContext, get_module, get_type_paths
@@ -14,9 +14,7 @@
1414
]
1515

1616

17-
def check_type(
18-
module_name: str, type_name: str, trusted: Literal[True] | Sequence[str]
19-
) -> bool:
17+
def check_type(module_name: str, type_name: str, trusted: Sequence[str]) -> bool:
2018
"""Check if a type is safe to load.
2119
2220
A type is safe to load only if it's present in the trusted list.
@@ -38,16 +36,13 @@ def check_type(
3836
is_safe : bool
3937
True if the type is safe, False otherwise.
4038
"""
41-
if trusted is True:
42-
return True
4339
return module_name + "." + type_name in trusted
4440

4541

4642
def audit_tree(tree: Node) -> None:
4743
"""Audit a tree of nodes.
4844
49-
A tree is safe if it only contains trusted types. Audit is skipped if
50-
trusted is ``True``.
45+
A tree is safe if it only contains trusted types.
5146
5247
Parameters
5348
----------
@@ -59,9 +54,6 @@ def audit_tree(tree: Node) -> None:
5954
UntrustedTypesFoundException
6055
If the tree contains an untrusted type.
6156
"""
62-
if tree.trusted is True:
63-
return
64-
6557
unsafe = tree.get_unsafe_set()
6658
if unsafe:
6759
raise UntrustedTypesFoundException(unsafe)
@@ -142,7 +134,7 @@ def __init__(
142134
self,
143135
state: dict[str, Any],
144136
load_context: LoadContext,
145-
trusted: bool | Sequence[str] = False,
137+
trusted: Optional[Sequence[str]] = None,
146138
memoize: bool = True,
147139
) -> None:
148140
self.class_name, self.module_name = state["__class__"], state["__module__"]
@@ -180,22 +172,19 @@ def _construct(self):
180172

181173
@staticmethod
182174
def _get_trusted(
183-
trusted: bool | Sequence[Union[str, Type]], default: Sequence[Union[str, Type]]
184-
) -> Literal[True] | list[str]:
175+
trusted: Optional[Sequence[Union[str, Type]]],
176+
default: Sequence[Union[str, Type]],
177+
) -> list[str]:
185178
"""Return a trusted list, or True.
186179
187-
If ``trusted`` is ``False``, we return the ``default``. If a list of
180+
If ``trusted`` is ``None``, we return the ``default``. If a list of
188181
types are being passed, those types, as well as default trusted types,
189182
are returned.
190183
191184
This is a convenience method called by child classes.
192185
193186
"""
194-
if trusted is True:
195-
# if trusted is True, we trust the node
196-
return True
197-
198-
if trusted is False:
187+
if trusted is None:
199188
# if trusted is False, we only trust the defaults
200189
return get_type_paths(default)
201190

@@ -289,12 +278,12 @@ def __init__(
289278
self,
290279
state: dict[str, Any],
291280
load_context: LoadContext,
292-
trusted: bool = False,
281+
trusted: Optional[List[str]] = None,
293282
):
294283
# we pass memoize as False because we don't want to memoize the cached
295284
# node.
296285
super().__init__(state, load_context, trusted, memoize=False)
297-
self.trusted = True
286+
self.trusted = self._get_trusted(trusted, default=[])
298287
# TODO: deal with case that __id__ is unknown or prevent it from
299288
# happening
300289
self.cached = load_context.get_object(state.get("__id__")) # type: ignore
@@ -313,7 +302,7 @@ def _construct(self):
313302
def get_tree(
314303
state: dict[str, Any],
315304
load_context: LoadContext,
316-
trusted: bool | Sequence[str],
305+
trusted: Optional[Sequence[str]],
317306
) -> Node:
318307
"""Get the tree of nodes.
319308

0 commit comments

Comments
 (0)