Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
f5e83a9
feat: generate README.md in hub_utils.init
jucamohedano Nov 1, 2022
28b4b0c
Merge branch 'skops-dev:main' into main
jucamohedano Nov 2, 2022
f0e9683
ref: replace _create_readme function with fewer lines
jucamohedano Nov 2, 2022
1aeb14c
test create model card in hub_utils.init
jucamohedano Nov 3, 2022
95c0e1b
test override model card after created by hub_utils.init
jucamohedano Nov 3, 2022
e0e6c7d
Merge branch 'skops-dev:main' into main
jucamohedano Nov 14, 2022
4b6cb73
ref: deduplicate test creation of README in init
jucamohedano Nov 14, 2022
870797f
fix: check that content of new model card is modified
jucamohedano Nov 14, 2022
d3a0eac
Merge branch 'skops-dev:main' into main
jucamohedano Nov 14, 2022
4b3fb8d
Merge branch 'main' of github.com:jucamohedano/skops into main
jucamohedano Nov 14, 2022
eaed93b
Merge branch 'skops-dev:main' into main
jucamohedano Nov 18, 2022
f182ee1
revert lines removed by mistake
jucamohedano Nov 18, 2022
9a41cf2
Merge branch 'main' into main
adrinjalali Nov 22, 2022
7f7d0c2
Merge branch 'skops-dev:main' into main
jucamohedano Nov 25, 2022
1c19795
Merge branch 'skops-dev:main' into main
jucamohedano Dec 4, 2022
56165e4
Merge branch 'skops-dev:main' into main
jucamohedano Dec 5, 2022
6f99565
Merge branch 'main' into main
jucamohedano Jan 19, 2023
c265f50
Merge branch 'main' into main
jucamohedano Jan 21, 2023
0b6d3e2
fix: check model format of model file
jucamohedano Jan 21, 2023
5e1494a
fix: run pre-commit on all files
jucamohedano Jan 23, 2023
5cfa962
Merge branch 'skops-dev:main' into main
jucamohedano Jan 23, 2023
35a30c2
Merge branch 'skops-dev:main' into main
jucamohedano Jan 31, 2023
0c4a66f
fix: check for file suffix to determine format
jucamohedano Jan 31, 2023
5436894
Merge branch 'skops-dev:main' into main
jucamohedano Feb 12, 2023
2171908
feat: implement model caching with sha256 hash
jucamohedano Feb 12, 2023
36b855b
feat: extend test to test cache model loading
jucamohedano Feb 12, 2023
81523ac
add rest of suffixes in test_hash_model
jucamohedano Feb 12, 2023
c8e9281
fix: cache model within the model card object
jucamohedano Feb 26, 2023
3c2151a
fix: run pre-comit on all files
jucamohedano Feb 28, 2023
5af9b4a
ref: remove additional unrelated code
jucamohedano Feb 28, 2023
e40a4d3
Merge remote-tracking branch 'skops-upstream/main' into cache-model-l…
jucamohedano Mar 13, 2023
1245ff3
run precommit on all files and apply fixes
jucamohedano Mar 13, 2023
ff6d3b9
fix test_model_caching with a higher level test
jucamohedano May 3, 2023
c1ba293
Merge remote-tracking branch 'skops-upstream/main' into cache-model-l…
jucamohedano May 3, 2023
24c014d
Merge branch 'cache-model-loading' of github.com:jucamohedano/skops i…
jucamohedano May 3, 2023
dd2eb68
revert changes to origin
jucamohedano May 3, 2023
7826c4b
revert changes to origin
jucamohedano May 3, 2023
1e1237c
apply and test suggestion
jucamohedano May 8, 2023
1879a3e
revert lines
jucamohedano May 8, 2023
78207d9
revert lines
jucamohedano May 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions skops/card/_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import zipfile
from collections.abc import Mapping
from dataclasses import dataclass, field
from functools import cached_property
from hashlib import sha256
from pathlib import Path
from reprlib import Repr
from typing import Any, Iterator, Literal, Sequence, Union
Expand Down Expand Up @@ -503,6 +505,7 @@ def __init__(

self._data: dict[str, Section] = {}
self._metrics: dict[str, str | float | int] = {}
self._model_hash = ""

self._populate_template(model_diagram=model_diagram)

Expand Down Expand Up @@ -564,9 +567,24 @@ def get_model(self) -> Any:
The model instance.

"""
if isinstance(self.model, (str, Path)) and hasattr(self, "_model"):
hash_obj = sha256()
buf_size = 2**20 # load in chunks to save memory
with open(self.model, "rb") as f:
for chunk in iter(lambda: f.read(buf_size), b""):
hash_obj.update(chunk)
model_hash = hash_obj.hexdigest()

# if hash changed, invalidate cache by deleting attribute
if model_hash != self._model_hash:
del self._model
self._model_hash = model_hash

return self._model

@cached_property
def _model(self):
model = _load_model(self.model, self.trusted)
# Ideally, we would only call the method below if we *know* that the
# model has changed, but at the moment we have no way of knowing that
return model

def add(self, **kwargs: str) -> Self:
Expand Down
31 changes: 30 additions & 1 deletion skops/card/tests/test_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile
import textwrap
from pathlib import Path
from unittest import mock

import numpy as np
import pytest
Expand All @@ -25,7 +26,7 @@
TableSection,
_load_model,
)
from skops.io import dump
from skops.io import dump, load
from skops.utils.importutils import import_or_raise


Expand Down Expand Up @@ -145,6 +146,34 @@ def test_save_model_card(destination_path, model_card):
assert (Path(destination_path) / "README.md").exists()


def test_model_caching(
skops_model_card_metadata_from_config, iris_skops_file, destination_path
):
"""Tests that the model card caches the model to avoid loading it multiple times"""

new_model = LogisticRegression(random_state=4321)
# mock _load_model, it still loads the model but we can track call count
mock_load_model = mock.Mock(side_effect=load)
card = Card(iris_skops_file, metadata=metadata_from_config(destination_path))
with mock.patch("skops.card._model_card._load_model", mock_load_model):
model1 = card.get_model()
model2 = card.get_model()
assert model1 is model2
# model is cached, hence _load_model is not called
mock_load_model.assert_not_called()

# override model with new model
dump(new_model, card.model)

model3 = card.get_model()
assert mock_load_model.call_count == 1
assert model3.random_state == 4321
model4 = card.get_model()

assert model3 is model4
assert mock_load_model.call_count == 1 # cached call


CUSTOM_TEMPLATES = [None, {}, {"A Title", "Another Title", "A Title/A Section"}] # type: ignore


Expand Down