-
Notifications
You must be signed in to change notification settings - Fork 49
Download weight file from URL #349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
6431cbf
9e359bf
1341ddc
b4a568f
74b54ba
610841c
2e3b756
7e7f64d
70a3ecb
0acf8d4
72cc151
0340998
49f78dc
794bdeb
dc6a31b
c03e45e
c2bc971
98ca779
c3a0f77
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,17 @@ | ||
"""The command line entry point for Casanovo.""" | ||
|
||
import datetime | ||
import email.utils | ||
import functools | ||
import hashlib | ||
import logging | ||
import urllib | ||
import os | ||
import re | ||
import shutil | ||
import sys | ||
import time | ||
import urllib.parse | ||
import warnings | ||
from pathlib import Path | ||
from typing import Optional, Tuple | ||
|
@@ -60,10 +64,9 @@ def __init__(self, *args, **kwargs) -> None: | |
click.Option( | ||
("-m", "--model"), | ||
help=""" | ||
The model weights (.ckpt file). If not provided, Casanovo | ||
will try to download the latest release. | ||
Either the model weights (.ckpt file) or a URL pointing to the model weights | ||
file. If not provided, Casanovo will try to download the latest release automatically. | ||
""", | ||
type=click.Path(exists=True, dir_okay=False), | ||
), | ||
click.Option( | ||
("-o", "--output"), | ||
|
@@ -363,22 +366,32 @@ def setup_model( | |
seed_everything(seed=config["random_seed"], workers=True) | ||
|
||
# Download model weights if these were not specified (except when training). | ||
if model is None and not is_train: | ||
try: | ||
model = _get_model_weights() | ||
except github.RateLimitExceededException: | ||
logger.error( | ||
"GitHub API rate limit exceeded while trying to download the " | ||
"model weights. Please download compatible model weights " | ||
"manually from the official Casanovo code website " | ||
"(https://github.com/Noble-Lab/casanovo) and specify these " | ||
"explicitly using the `--model` parameter when running " | ||
"Casanovo." | ||
cache_dir = Path(appdirs.user_cache_dir("casanovo", False, opinion=False)) | ||
if model is None: | ||
if not is_train: | ||
try: | ||
bittremieux marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model = _get_model_weights(cache_dir) | ||
except github.RateLimitExceededException: | ||
logger.error( | ||
"GitHub API rate limit exceeded while trying to download the " | ||
"model weights. Please download compatible model weights " | ||
"manually from the official Casanovo code website " | ||
"(https://github.com/Noble-Lab/casanovo) and specify these " | ||
"explicitly using the `--model` parameter when running " | ||
"Casanovo." | ||
) | ||
raise PermissionError( | ||
"GitHub API rate limit exceeded while trying to download the " | ||
"model weights" | ||
) from None | ||
else: | ||
if _is_valid_url(model): | ||
model = _get_weights_from_url(model, cache_dir) | ||
elif not Path(model).is_file(): | ||
raise ValueError( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. todo: Add a similar logging statement. |
||
f"{model} is not a valid URL or checkpoint file path, " | ||
bittremieux marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"--model argument must be a URL or checkpoint file path" | ||
) | ||
raise PermissionError( | ||
"GitHub API rate limit exceeded while trying to download the " | ||
"model weights" | ||
) from None | ||
|
||
# Log the active configuration. | ||
logger.info("Casanovo version %s", str(__version__)) | ||
|
@@ -391,7 +404,7 @@ def setup_model( | |
return config, model | ||
|
||
|
||
def _get_model_weights() -> str: | ||
def _get_model_weights(cache_dir: Path) -> str: | ||
""" | ||
Use cached model weights or download them from GitHub. | ||
|
||
|
@@ -405,12 +418,16 @@ def _get_model_weights() -> str: | |
Note that the GitHub API is limited to 60 requests from the same IP per | ||
hour. | ||
|
||
Parameters | ||
---------- | ||
cache_dir : Path | ||
model weights cache directory path | ||
|
||
Returns | ||
------- | ||
str | ||
The name of the model weights file. | ||
""" | ||
cache_dir = appdirs.user_cache_dir("casanovo", False, opinion=False) | ||
os.makedirs(cache_dir, exist_ok=True) | ||
version = utils.split_version(__version__) | ||
version_match: Tuple[Optional[str], Optional[str], int] = None, None, 0 | ||
|
@@ -434,7 +451,7 @@ def _get_model_weights() -> str: | |
"Model weights file %s retrieved from local cache", | ||
version_match[0], | ||
) | ||
return version_match[0] | ||
return Path(version_match[0]) | ||
# Otherwise try to find compatible model weights on GitHub. | ||
else: | ||
repo = github.Github().get_repo("Noble-Lab/casanovo") | ||
|
@@ -467,19 +484,9 @@ def _get_model_weights() -> str: | |
# Download the model weights if a matching release was found. | ||
if version_match[2] > 0: | ||
filename, url, _ = version_match | ||
logger.info( | ||
"Downloading model weights file %s from %s", filename, url | ||
) | ||
r = requests.get(url, stream=True, allow_redirects=True) | ||
r.raise_for_status() | ||
file_size = int(r.headers.get("Content-Length", 0)) | ||
desc = "(Unknown total file size)" if file_size == 0 else "" | ||
r.raw.read = functools.partial(r.raw.read, decode_content=True) | ||
with tqdm.tqdm.wrapattr( | ||
r.raw, "read", total=file_size, desc=desc | ||
) as r_raw, open(filename, "wb") as f: | ||
shutil.copyfileobj(r_raw, f) | ||
return filename | ||
cache_file_path = cache_dir / filename | ||
_download_weights(url, cache_file_path) | ||
return cache_file_path | ||
else: | ||
logger.error( | ||
"No matching model weights for release v%s found, please " | ||
|
@@ -494,5 +501,126 @@ def _get_model_weights() -> str: | |
) | ||
|
||
|
||
def _get_weights_from_url( | ||
file_url: str, | ||
cache_dir: Path, | ||
force_download: Optional[bool] = False, | ||
) -> Path: | ||
""" | ||
Resolve weight file from URL | ||
|
||
Attempt to download weight file from URL if weights are not already | ||
cached - otherwise use cached weights. Downloaded weight files will be | ||
cached. | ||
|
||
Parameters | ||
---------- | ||
file_url : str | ||
URL pointing to model weights file. | ||
cache_dir : Path | ||
Model weights cache directory path. | ||
force_download : Optional[bool], default=False | ||
If True, forces a new download of the weight file even if it exists in | ||
the cache. | ||
|
||
Returns | ||
------- | ||
Path | ||
Path to the cached weights file. | ||
""" | ||
os.makedirs(cache_dir, exist_ok=True) | ||
cache_file_name = Path(urllib.parse.urlparse(file_url).path).name | ||
url_hash = hashlib.shake_256(file_url.encode("utf-8")).hexdigest(5) | ||
cache_file_dir = cache_dir / url_hash | ||
cache_file_path = cache_file_dir / cache_file_name | ||
|
||
if cache_file_path.is_file() and not force_download: | ||
cache_time = cache_file_path.stat() | ||
url_last_modified = 0 | ||
|
||
try: | ||
file_response = requests.head(file_url) | ||
if file_response.ok: | ||
if "Last-Modified" in file_response.headers: | ||
url_last_modified = email.utils.parsedate_to_datetime( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. todo: I'm reluctant to include a library with rather different functionality for just a utility function (even though it's part of the standard libraries). Do we need extra functionality that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good call, I changed it to use strptime instead. |
||
file_response.headers["Last-Modified"] | ||
).timestamp() | ||
else: | ||
logger.warning( | ||
"Attempted HEAD request to %s yielded non-ok status code - using cached file", | ||
file_url, | ||
) | ||
except ( | ||
requests.ConnectionError, | ||
requests.Timeout, | ||
requests.TooManyRedirects, | ||
): | ||
logger.warning( | ||
"Failed to reach %s to get remote last modified time - using cached file", | ||
file_url, | ||
) | ||
|
||
if cache_time.st_mtime > url_last_modified: | ||
Lilferrit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger.info( | ||
"Model weights %s retrieved from local cache", file_url | ||
) | ||
return cache_file_path | ||
|
||
_download_weights(file_url, cache_file_path) | ||
return cache_file_path | ||
|
||
|
||
def _download_weights(file_url: str, download_path: Path) -> None: | ||
""" | ||
Download weights file from URL | ||
|
||
Download the model weights file from the specified URL and save it to the | ||
given path. Ensures the download directory exists, and uses a progress | ||
bar to indicate download status. | ||
|
||
Parameters | ||
---------- | ||
file_url : str | ||
URL pointing to the model weights file. | ||
download_path : Path | ||
Path where the downloaded weights file will be saved. | ||
""" | ||
download_file_dir = download_path.parent | ||
os.makedirs(download_file_dir, exist_ok=True) | ||
response = requests.get(file_url, stream=True, allow_redirects=True) | ||
response.raise_for_status() | ||
file_size = int(response.headers.get("Content-Length", 0)) | ||
desc = "(Unknown total file size)" if file_size == 0 else "" | ||
response.raw.read = functools.partial( | ||
response.raw.read, decode_content=True | ||
) | ||
|
||
with tqdm.tqdm.wrapattr( | ||
response.raw, "read", total=file_size, desc=desc | ||
) as r_raw, open(download_path, "wb") as file: | ||
shutil.copyfileobj(r_raw, file) | ||
|
||
|
||
def _is_valid_url(file_url: str) -> bool: | ||
""" | ||
Determine whether file URL is a valid URL | ||
|
||
Parameters | ||
---------- | ||
file_url : str | ||
url to verify | ||
|
||
Return | ||
------ | ||
is_url : bool | ||
whether file_url is a valid url | ||
""" | ||
try: | ||
result = urllib.parse.urlparse(file_url) | ||
return all([result.scheme, result.netloc]) | ||
except ValueError: | ||
return False | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
todo: I don't think this is necessary (rather the relevant import is
urllib.parse
a few lines further down).