Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 162 additions & 34 deletions casanovo/casanovo.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""The command line entry point for Casanovo."""

import datetime
import email.utils
import functools
import hashlib
import logging
import urllib
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo: I don't think this is necessary (rather the relevant import is urllib.parse a few lines further down).

import os
import re
import shutil
import sys
import time
import urllib.parse
import warnings
from pathlib import Path
from typing import Optional, Tuple
Expand Down Expand Up @@ -60,10 +64,9 @@ def __init__(self, *args, **kwargs) -> None:
click.Option(
("-m", "--model"),
help="""
The model weights (.ckpt file). If not provided, Casanovo
will try to download the latest release.
Either the model weights (.ckpt file) or a URL pointing to the model weights
file. If not provided, Casanovo will try to download the latest release automatically.
""",
type=click.Path(exists=True, dir_okay=False),
),
click.Option(
("-o", "--output"),
Expand Down Expand Up @@ -363,22 +366,32 @@ def setup_model(
seed_everything(seed=config["random_seed"], workers=True)

# Download model weights if these were not specified (except when training).
if model is None and not is_train:
try:
model = _get_model_weights()
except github.RateLimitExceededException:
logger.error(
"GitHub API rate limit exceeded while trying to download the "
"model weights. Please download compatible model weights "
"manually from the official Casanovo code website "
"(https://github.com/Noble-Lab/casanovo) and specify these "
"explicitly using the `--model` parameter when running "
"Casanovo."
cache_dir = Path(appdirs.user_cache_dir("casanovo", False, opinion=False))
if model is None:
if not is_train:
try:
model = _get_model_weights(cache_dir)
except github.RateLimitExceededException:
logger.error(
"GitHub API rate limit exceeded while trying to download the "
"model weights. Please download compatible model weights "
"manually from the official Casanovo code website "
"(https://github.com/Noble-Lab/casanovo) and specify these "
"explicitly using the `--model` parameter when running "
"Casanovo."
)
raise PermissionError(
"GitHub API rate limit exceeded while trying to download the "
"model weights"
) from None
else:
if _is_valid_url(model):
model = _get_weights_from_url(model, cache_dir)
elif not Path(model).is_file():
raise ValueError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo: Add a similar logging statement.

f"{model} is not a valid URL or checkpoint file path, "
"--model argument must be a URL or checkpoint file path"
)
raise PermissionError(
"GitHub API rate limit exceeded while trying to download the "
"model weights"
) from None

# Log the active configuration.
logger.info("Casanovo version %s", str(__version__))
Expand All @@ -391,7 +404,7 @@ def setup_model(
return config, model


def _get_model_weights() -> str:
def _get_model_weights(cache_dir: Path) -> str:
"""
Use cached model weights or download them from GitHub.

Expand All @@ -405,12 +418,16 @@ def _get_model_weights() -> str:
Note that the GitHub API is limited to 60 requests from the same IP per
hour.

Parameters
----------
cache_dir : Path
model weights cache directory path

Returns
-------
str
The name of the model weights file.
"""
cache_dir = appdirs.user_cache_dir("casanovo", False, opinion=False)
os.makedirs(cache_dir, exist_ok=True)
version = utils.split_version(__version__)
version_match: Tuple[Optional[str], Optional[str], int] = None, None, 0
Expand All @@ -434,7 +451,7 @@ def _get_model_weights() -> str:
"Model weights file %s retrieved from local cache",
version_match[0],
)
return version_match[0]
return Path(version_match[0])
# Otherwise try to find compatible model weights on GitHub.
else:
repo = github.Github().get_repo("Noble-Lab/casanovo")
Expand Down Expand Up @@ -467,19 +484,9 @@ def _get_model_weights() -> str:
# Download the model weights if a matching release was found.
if version_match[2] > 0:
filename, url, _ = version_match
logger.info(
"Downloading model weights file %s from %s", filename, url
)
r = requests.get(url, stream=True, allow_redirects=True)
r.raise_for_status()
file_size = int(r.headers.get("Content-Length", 0))
desc = "(Unknown total file size)" if file_size == 0 else ""
r.raw.read = functools.partial(r.raw.read, decode_content=True)
with tqdm.tqdm.wrapattr(
r.raw, "read", total=file_size, desc=desc
) as r_raw, open(filename, "wb") as f:
shutil.copyfileobj(r_raw, f)
return filename
cache_file_path = cache_dir / filename
_download_weights(url, cache_file_path)
return cache_file_path
else:
logger.error(
"No matching model weights for release v%s found, please "
Expand All @@ -494,5 +501,126 @@ def _get_model_weights() -> str:
)


def _get_weights_from_url(
file_url: str,
cache_dir: Path,
force_download: Optional[bool] = False,
) -> Path:
"""
Resolve weight file from URL

Attempt to download weight file from URL if weights are not already
cached - otherwise use cached weights. Downloaded weight files will be
cached.

Parameters
----------
file_url : str
URL pointing to model weights file.
cache_dir : Path
Model weights cache directory path.
force_download : Optional[bool], default=False
If True, forces a new download of the weight file even if it exists in
the cache.

Returns
-------
Path
Path to the cached weights file.
"""
os.makedirs(cache_dir, exist_ok=True)
cache_file_name = Path(urllib.parse.urlparse(file_url).path).name
url_hash = hashlib.shake_256(file_url.encode("utf-8")).hexdigest(5)
cache_file_dir = cache_dir / url_hash
cache_file_path = cache_file_dir / cache_file_name

if cache_file_path.is_file() and not force_download:
cache_time = cache_file_path.stat()
url_last_modified = 0

try:
file_response = requests.head(file_url)
if file_response.ok:
if "Last-Modified" in file_response.headers:
url_last_modified = email.utils.parsedate_to_datetime(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo: I'm reluctant to include a library with rather different functionality for just a utility function (even though it's part of the standard libraries). Do we need extra functionality that strptime doesn't offer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, I changed it to use strptime instead.

file_response.headers["Last-Modified"]
).timestamp()
else:
logger.warning(
"Attempted HEAD request to %s yielded non-ok status code - using cached file",
file_url,
)
except (
requests.ConnectionError,
requests.Timeout,
requests.TooManyRedirects,
):
logger.warning(
"Failed to reach %s to get remote last modified time - using cached file",
file_url,
)

if cache_time.st_mtime > url_last_modified:
logger.info(
"Model weights %s retrieved from local cache", file_url
)
return cache_file_path

_download_weights(file_url, cache_file_path)
return cache_file_path


def _download_weights(file_url: str, download_path: Path) -> None:
"""
Download weights file from URL

Download the model weights file from the specified URL and save it to the
given path. Ensures the download directory exists, and uses a progress
bar to indicate download status.

Parameters
----------
file_url : str
URL pointing to the model weights file.
download_path : Path
Path where the downloaded weights file will be saved.
"""
download_file_dir = download_path.parent
os.makedirs(download_file_dir, exist_ok=True)
response = requests.get(file_url, stream=True, allow_redirects=True)
response.raise_for_status()
file_size = int(response.headers.get("Content-Length", 0))
desc = "(Unknown total file size)" if file_size == 0 else ""
response.raw.read = functools.partial(
response.raw.read, decode_content=True
)

with tqdm.tqdm.wrapattr(
response.raw, "read", total=file_size, desc=desc
) as r_raw, open(download_path, "wb") as file:
shutil.copyfileobj(r_raw, file)


def _is_valid_url(file_url: str) -> bool:
"""
Determine whether file URL is a valid URL

Parameters
----------
file_url : str
url to verify

Return
------
is_url : bool
whether file_url is a valid url
"""
try:
result = urllib.parse.urlparse(file_url)
return all([result.scheme, result.netloc])
except ValueError:
return False


if __name__ == "__main__":
main()
Loading