-
Notifications
You must be signed in to change notification settings - Fork 60
FEAT CLI conversion #249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT CLI conversion #249
Changes from 32 commits
22c4c71
eacedf8
453728d
da6da30
a5abe50
5a9c0fc
b3a561e
1d01ff5
91d1b0f
7af9c26
8ad10cd
864d8ca
7589a06
b39ec14
9e76ee1
aa24050
02b3cf8
245197f
2b681ce
25b539c
50bf189
6ecb8cf
227f23c
426d663
c3b001c
37d3a87
9b9f342
6f5b71e
aa7c3ba
1c0352d
8d9b3c5
452fbdb
8150f9a
88e7be0
c38cb73
5233230
95a5ff3
a17b14b
7e63178
6e35770
02c95ad
c3491ef
7e072da
799f91b
1864f1c
8ddc6b9
bdfbf58
0e6b4cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from __future__ import annotations | ||
|
||
import argparse | ||
import logging | ||
import os | ||
import pathlib | ||
import pickle | ||
from typing import Optional | ||
|
||
from skops.cli._utils import get_log_level | ||
from skops.io import dumps, get_untrusted_types | ||
|
||
|
||
def _convert_file(input_file: os.PathLike, output_file: os.PathLike): | ||
E-Aho marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Function that is called by ``skops convert`` entrypoint. | ||
E-Aho marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Loads a pickle model from the input path, converts to skops format, and saves to | ||
output file. | ||
|
||
Parameters | ||
---------- | ||
input_file : os.PathLike | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A more general question: I wonder if we should still document this, and similar cases, as "str or pathlib.Path", as users may not be familiar with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I honestly think I'm happy with either way There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hard to say if a user who doesn't know the type would guess that |
||
Path of input .pkl model to load. | ||
|
||
output_file : os.PathLike | ||
Path to save .skops model to. | ||
|
||
""" | ||
model_name = pathlib.Path(input_file).stem | ||
|
||
logging.debug(f"Converting {model_name}") | ||
|
||
with open(input_file, "rb") as f: | ||
obj = pickle.load(f) | ||
skops_dump = dumps(obj) | ||
|
||
untrusted_types = get_untrusted_types(data=skops_dump) | ||
|
||
if not untrusted_types: | ||
logging.info(f"No unknown types found in {model_name}.") | ||
else: | ||
untrusted_str = ", ".join(untrusted_types) | ||
|
||
logging.warning( | ||
"Unknown Types Detected! " | ||
E-Aho marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f"While converting {model_name}, " | ||
E-Aho marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"the following unknown types were found: " | ||
f"{untrusted_str}. " | ||
f"When loading {output_file}, add 'trusted=True' to the skops.load call. " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of asking users to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think from a users' perspective, it might be nicer to give them the easiest fix possible. When they're converting this file, and want to not deal, the log line above would've told them all of the untrusted types anyway, so if they wanted to go that route, they could just add those. We could add a line along the lines of: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this touches on a more general point of how we should deal with |
||
) | ||
|
||
with open(output_file, "wb") as out_file: | ||
logging.debug(f"Writing to {output_file}") | ||
out_file.write(skops_dump) | ||
|
||
|
||
def format_parser( | ||
parser: Optional[argparse.ArgumentParser] = None, | ||
) -> argparse.ArgumentParser: | ||
"""Adds arguments and help to parent CLI parser for the convert method.""" | ||
|
||
if not parser: # used in tests | ||
parser = argparse.ArgumentParser() | ||
|
||
parser_subgroup = parser.add_argument_group("convert") | ||
parser_subgroup.add_argument("input", help="Path to an input file to convert. ") | ||
|
||
parser_subgroup.add_argument( | ||
"-o", | ||
"--output-file", | ||
help=( | ||
"Specify the output file name for the converted skops file. " | ||
"If not provided, will default to using the same name as the input file, " | ||
"and saving to the current working directory with the suffix '.skops'." | ||
), | ||
default=None, | ||
) | ||
parser_subgroup.add_argument( | ||
"-v", | ||
"--verbose", | ||
help=( | ||
"Increases verbosity of logging. Can be used multiple times to increase " | ||
"verbosity further." | ||
), | ||
action="count", | ||
dest="loglevel", | ||
default=0, | ||
) | ||
return parser | ||
|
||
|
||
def main( | ||
parsed_args: argparse.Namespace, | ||
): | ||
E-Aho marked this conversation as resolved.
Show resolved
Hide resolved
|
||
output_file = parsed_args.output_file | ||
input_file = parsed_args.input | ||
|
||
logging.basicConfig(format="%(message)s", level=get_log_level(parsed_args.loglevel)) | ||
|
||
if not output_file: | ||
# No filename provided, defaulting to base file path | ||
file_name = pathlib.Path(input_file).stem | ||
output_file = pathlib.Path.cwd() / f"{file_name}.skops" | ||
|
||
_convert_file( | ||
input_file=input_file, | ||
output_file=output_file, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import logging | ||
|
||
|
||
def get_log_level(level: int = 0): | ||
E-Aho marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Takes in verbosity from a CLI entrypoint (number of times -v specified), | ||
and sets the logger to the required log level""" | ||
|
||
all_levels = [logging.WARNING, logging.INFO, logging.DEBUG] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC, warning level logs would always be shown, even if I don't add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From the python docs:
I followed this because it's the default for python packages. We can adjust the default level for our CLI, or we could also add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree this is true for logging in general, but for CLIs, there is the convention to not print anything if there is no issue. Ideally, I would like to avoid having a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, good point. I would raise that it could be considered an issue if untrusted types were present. If a user had a file with untrusted types in it, and we converted it silently, they wouldn't know they needed to specify anything as trusted when using it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO, this is more of an issue for the person loading the file, and they will know (unless they use |
||
|
||
if level > len(all_levels): | ||
level = len(all_levels) | ||
E-Aho marked this conversation as resolved.
Show resolved
Hide resolved
|
||
elif level < 0: | ||
E-Aho marked this conversation as resolved.
Show resolved
Hide resolved
|
||
level = 0 | ||
|
||
return all_levels[level] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import argparse | ||
|
||
import skops.cli._convert | ||
|
||
|
||
def main_cli(command_line_args=None): | ||
""" | ||
Main command line interface entrypoint for all command line Skops methods. | ||
E-Aho marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
To add a new entrypoint: | ||
1. Create a new method to call that accepts a namespace | ||
2. Create a new subparser formatter to define the expected CL arguments | ||
3. Add those to the function map. | ||
""" | ||
entry_parser = argparse.ArgumentParser( | ||
prog="Skops", | ||
description="Main entrypoint for all command line Skops methods.", | ||
add_help=True, | ||
) | ||
|
||
subparsers = entry_parser.add_subparsers( | ||
title="Commands", | ||
description="Skops command to call", | ||
dest="cmd", | ||
help="Sub-commands help", | ||
) | ||
|
||
# function_map should map a command to | ||
# method: the command to call | ||
# format_parser: the function used to create a subparser for that command | ||
function_map = { | ||
"convert": { | ||
"method": skops.cli._convert.main, | ||
"format_parser": skops.cli._convert.format_parser, | ||
}, | ||
} | ||
|
||
for func_name, values in function_map.items(): | ||
subparser = subparsers.add_parser(func_name) | ||
subparser.set_defaults(func=values["method"]) | ||
values["format_parser"](subparser) | ||
|
||
args = entry_parser.parse_args(command_line_args) | ||
args.func(args) | ||
E-Aho marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import logging | ||
import pathlib | ||
import pickle as pkl | ||
E-Aho marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from unittest import mock | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from skops.cli import _convert | ||
from skops.io import load | ||
|
||
|
||
class MockUnsafeType: | ||
def __init__(self): | ||
pass | ||
|
||
|
||
class TestConvert: | ||
model_name = "some_model_name" | ||
|
||
@pytest.fixture | ||
def safe_obj(self): | ||
return np.ndarray([1, 2, 3, 4]) | ||
|
||
@pytest.fixture | ||
def unsafe_obj(self): | ||
return MockUnsafeType() | ||
|
||
@pytest.fixture | ||
def pkl_path(self, tmp_path): | ||
return tmp_path / f"{self.model_name}.pkl" | ||
|
||
@pytest.fixture | ||
def skops_path(self, tmp_path): | ||
return tmp_path / f"{self.model_name}.skops" | ||
|
||
@pytest.fixture | ||
def write_safe_file(self, pkl_path, safe_obj): | ||
with open(pkl_path, "wb") as f: | ||
pkl.dump(safe_obj, f) | ||
|
||
@pytest.fixture | ||
def write_unsafe_file(self, pkl_path, unsafe_obj): | ||
with open(pkl_path, "wb") as f: | ||
pkl.dump(unsafe_obj, f) | ||
|
||
def test_base_case_works_as_expected( | ||
self, pkl_path, tmp_path, skops_path, write_safe_file, safe_obj, caplog | ||
): | ||
_convert._convert_file(pkl_path, skops_path) | ||
persisted_obj = load(skops_path) | ||
assert np.array_equal(persisted_obj, safe_obj) | ||
assert MockUnsafeType.__name__ not in caplog.text | ||
E-Aho marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def test_unsafe_case_works_as_expected( | ||
self, pkl_path, tmp_path, skops_path, write_unsafe_file, caplog | ||
): | ||
caplog.set_level(logging.WARNING) | ||
_convert._convert_file(pkl_path, skops_path) | ||
persisted_obj = load(skops_path, trusted=True) | ||
|
||
assert isinstance(persisted_obj, MockUnsafeType) | ||
|
||
# check logging has warned that an unsafe type was found | ||
assert MockUnsafeType.__name__ in caplog.text | ||
|
||
|
||
class TestMain: | ||
@staticmethod | ||
def assert_called_correctly( | ||
mock_convert: mock.MagicMock, | ||
path, | ||
output_file=None, | ||
): | ||
if not output_file: | ||
output_file = pathlib.Path.cwd() / f"{pathlib.Path(path).stem}.skops" | ||
mock_convert.assert_called_once_with(input_file=path, output_file=output_file) | ||
|
||
@mock.patch("skops.cli._convert._convert_file") | ||
def test_base_works_as_expected(self, mock_convert: mock.MagicMock): | ||
path = "123.pkl" | ||
namespace, _ = _convert.format_parser().parse_known_args([path]) | ||
|
||
_convert.main(namespace) | ||
self.assert_called_correctly(mock_convert, path) | ||
|
||
@mock.patch("skops.cli._convert._convert_file") | ||
@pytest.mark.parametrize( | ||
"input_path, output_file, expected_path", | ||
[ | ||
("abc.123", "a/b/c", "a/b/c"), | ||
E-Aho marked this conversation as resolved.
Show resolved
Hide resolved
|
||
("abc.123", None, pathlib.Path.cwd() / "abc.skops"), | ||
], | ||
ids=["Given an output path", "No output path"], | ||
) | ||
def test_with_output_dir_works_as_expected( | ||
self, mock_convert: mock.MagicMock, input_path, output_file, expected_path | ||
): | ||
if output_file is not None: | ||
args = [input_path, "--output", output_file] | ||
else: | ||
args = [input_path] | ||
|
||
namespace, _ = _convert.format_parser().parse_known_args(args) | ||
|
||
_convert.main(namespace) | ||
self.assert_called_correctly( | ||
mock_convert, path=input_path, output_file=expected_path | ||
) | ||
|
||
@mock.patch("skops.cli._convert._convert_file") | ||
@pytest.mark.parametrize( | ||
"verbosity, expected_level", | ||
[ | ||
("", logging.WARNING), | ||
("-v", logging.INFO), | ||
("--verbose", logging.INFO), | ||
("-vv", logging.DEBUG), | ||
("-v -v", logging.DEBUG), | ||
("-vvv", logging.DEBUG), | ||
("--verbose --verbose", logging.DEBUG), | ||
], | ||
) | ||
def test_given_log_levels_works_as_expected( | ||
self, mock_convert: mock.MagicMock, verbosity, expected_level, caplog | ||
): | ||
input_path = "abc.def" | ||
output_path = "bde.skops" | ||
args = [input_path, "--output", output_path, verbosity.split()] | ||
|
||
namespace, _ = _convert.format_parser().parse_known_args(args) | ||
|
||
_convert.main(namespace) | ||
self.assert_called_correctly( | ||
mock_convert, path=input_path, output_file=output_path | ||
) | ||
|
||
assert caplog.at_level(expected_level) | ||
E-Aho marked this conversation as resolved.
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.