-
Notifications
You must be signed in to change notification settings - Fork 60
ENH add implementation for init and push #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
869a145
b0c826e
5cb789b
0fae6ef
22e2d12
a250737
72a0480
b65e189
1682146
84c8903
2a89f5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,16 +3,26 @@ | |
hub. | ||
""" | ||
|
||
import collections | ||
import json | ||
import shutil | ||
from pathlib import Path | ||
from typing import List, Union | ||
|
||
from huggingface_hub import HfApi | ||
from requests import HTTPError | ||
|
||
|
||
def _validate_folder(path: Union[str, Path]): | ||
"""Validate the contents of a folder. | ||
|
||
This function checks if the contents of a folder make a valid repo for a | ||
scikit-learn based repo on the HuggingFace Hub. | ||
|
||
A valid repository is one which is understood by the Hub as well as this | ||
library to run and use the model. Otherwise anything can be put as a model | ||
repository on the Hub and use it as a `git` and `git lfs` server. | ||
|
||
Raises a ``TypeError`` if invalid. | ||
|
||
Parameters | ||
|
@@ -24,12 +34,62 @@ def _validate_folder(path: Union[str, Path]): | |
------- | ||
None | ||
""" | ||
pass | ||
path = Path(path) | ||
if not path.is_dir(): | ||
raise TypeError("The given path is not a directory.") | ||
|
||
config_path = path / "config.json" | ||
if not config_path.exists(): | ||
raise TypeError("Configuration file `config.json` missing.") | ||
|
||
def init( | ||
*, model: Union[str, Path], requirements: List[str], destination: Union[str, Path] | ||
): | ||
with open(config_path, "r") as f: | ||
config = json.load(f) | ||
|
||
model_path = config.get("sklearn", {}).get("model", {}).get("file", None) | ||
if not model_path: | ||
raise TypeError( | ||
"Model file not configured in the configuration file. It should be stored" | ||
" in the hf_hub.sklearn.model key." | ||
) | ||
|
||
if not (path / model_path).exists(): | ||
raise TypeError(f"Model file {model_path} does not exist.") | ||
|
||
|
||
def _create_config(*, model_path: str, requirements: List[str], dst: str): | ||
"""Write the configuration into a `config.json` file. | ||
|
||
Parameters | ||
---------- | ||
model_path : str | ||
The relative path (from the repo root) to the model file. | ||
|
||
requirements : list of str | ||
A list of required packages. The versions are then extracted from the | ||
current environment. | ||
|
||
dst : str, or Path | ||
The path to an existing folder where the config file should be created. | ||
|
||
Returns | ||
------- | ||
None | ||
""" | ||
# so that we don't have to explicitly add keys and they're added as a | ||
# dictionary if they are not found | ||
# see: https://stackoverflow.com/a/13151294/2536294 | ||
def recursively_default_dict(): | ||
return collections.defaultdict(recursively_default_dict) | ||
|
||
config = recursively_default_dict() | ||
config["sklearn"]["model"]["file"] = model_path | ||
config["sklearn"]["environment"] = requirements | ||
|
||
with open(Path(dst) / "config.json", mode="w") as f: | ||
json.dump(config, f, sort_keys=True, indent=4) | ||
|
||
|
||
def init(*, model: Union[str, Path], requirements: List[str], dst: Union[str, Path]): | ||
"""Initialize a scikit-learn based HuggingFace repo. | ||
|
||
Given a model pickle and a set of required packages, this function | ||
|
@@ -44,14 +104,22 @@ def init( | |
A list of required packages. The versions are then extracted from the | ||
current environment. | ||
|
||
destination: str, or Path | ||
The path to a non-existing folder which is to be initializes. | ||
dst: str, or Path | ||
The path to a non-existing or empty folder which is to be initialized. | ||
|
||
Returns | ||
------- | ||
None | ||
""" | ||
pass | ||
dst = Path(dst) | ||
if dst.exists() and next(dst.iterdir(), None): | ||
raise OSError("None-empty dst path already exists!") | ||
dst.mkdir(parents=True, exist_ok=True) | ||
|
||
shutil.copy2(src=model, dst=dst) | ||
|
||
model_name = Path(model).name | ||
_create_config(model_path=model_name, requirements=requirements, dst=dst) | ||
|
||
|
||
def update_env(*, path: Union[str, Path], requirements: List[str] = None): | ||
|
@@ -76,7 +144,14 @@ def update_env(*, path: Union[str, Path], requirements: List[str] = None): | |
pass | ||
|
||
|
||
def push(*, repo_id: str, source: Union[str, Path], token: str = None): | ||
def push( | ||
*, | ||
repo_id: str, | ||
source: Union[str, Path], | ||
token: str = None, | ||
commit_message: str = None, | ||
create_remote: bool = False, | ||
): | ||
"""Pushes the contents of a model repo to HuggingFace Hub. | ||
|
||
This function validates the contents of the folder before pushing it to the | ||
|
@@ -94,6 +169,15 @@ def push(*, repo_id: str, source: Union[str, Path], token: str = None): | |
A token to push to the hub. If not provided, the user should be already | ||
logged in using ``huggingface-cli login``. | ||
|
||
commit_message: str, optional | ||
The commit message to be used when pushing to the repo. | ||
|
||
create_remote: bool, optional | ||
Whether to create the remote repository if it doesn't exist. If the | ||
remote repository doesn't exist and this parameter is ``False``, it | ||
raises an error. Otherwise it checks if the remote repository exists, | ||
and would create it if it doesn't. | ||
|
||
Returns | ||
------- | ||
None | ||
|
@@ -103,4 +187,23 @@ def push(*, repo_id: str, source: Union[str, Path], token: str = None): | |
This function raises a ``TypeError`` if the contents of the source folder | ||
do not make a valid HuggingFace Hub scikit-learn based repo. | ||
""" | ||
pass | ||
_validate_folder(path=source) | ||
client = HfApi() | ||
|
||
if create_remote: | ||
try: | ||
client.model_info(repo_id=repo_id, token=token) | ||
except HTTPError: | ||
client.create_repo(repo_id=repo_id, token=token, repo_type="model") | ||
|
||
client.upload_folder( | ||
repo_id=repo_id, | ||
path_in_repo=".", | ||
folder_path=source, | ||
commit_message=commit_message, | ||
commit_description=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this needed? 😅 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now I'm not exposing this to the end user, but passing it here. Passing it explicitly means if |
||
token=token, | ||
repo_type=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is by default None and None refers to model BTW There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, same as above, I'm just explicitly setting it in case in the future the default value changes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah gotcha! |
||
revision=None, | ||
create_pr=False, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
{ | ||
"sklearn": { | ||
"environment": [ | ||
"scikit-learn=\"1.1.1\"" | ||
], | ||
"model": { | ||
"file": "model.pkl" | ||
} | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this exists, why do we create_repo in examples and tests? Wouldn't it cause confusion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed it from the example, but the test needs to test different scenarios, and the user might create the repo before calling this method, so the test makes sure that case is tested as well.