Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 92 additions & 6 deletions casanovo/casanovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import torch
import tqdm
from lightning.pytorch import seed_everything
from hashlib import shake_256
from urllib.parse import urlparse

from . import __version__
from . import utils
Expand All @@ -59,10 +61,9 @@ def __init__(self, *args, **kwargs) -> None:
click.Option(
("-m", "--model"),
help="""
The model weights (.ckpt file). If not provided, Casanovo
will try to download the latest release.
Either the model weights (.ckpt file) or a URL pointing to the model weights
file. If not provided, Casanovo will try to download the latest release.
""",
type=click.Path(exists=True, dir_okay=False),
),
click.Option(
("-o", "--output"),
Expand Down Expand Up @@ -354,9 +355,10 @@ def setup_model(
seed_everything(seed=config["random_seed"], workers=True)

# Download model weights if these were not specified (except when training).
cache_dir = appdirs.user_cache_dir("casanovo", False, opinion=False)
if model is None and not is_train:
try:
model = _get_model_weights()
model = _get_model_weights(cache_dir)
except github.RateLimitExceededException:
logger.error(
"GitHub API rate limit exceeded while trying to download the "
Expand All @@ -371,6 +373,17 @@ def setup_model(
"model weights"
) from None

# Download model from URL if model is a valid url
is_url = _is_valid_url(model)
if (model is not None) and is_url:
model = _get_weights_from_url(model, Path(cache_dir))

if (model is not None) and (not is_url) and (not Path(model).is_file()):
raise ValueError(
f"{model} is not a valid URL or checkpoint file path, "
"--model argument must be a URL or checkpoint file path"
)

# Log the active configuration.
logger.info("Casanovo version %s", str(__version__))
logger.debug("model = %s", model)
Expand All @@ -382,7 +395,76 @@ def setup_model(
return config, model


def _get_model_weights() -> str:
def _get_weights_from_url(
file_url: Optional[str],
cache_dir: Path,
) -> str:
"""
Attempt to download weight file from URL if weights are not already
cached. Otherwise use cased weights. Downloaded weight files will be
cached.

Parameters
----------
file_url : str
url pointing to model weights file
cache_dir : Path
model weights cache directory path

Returns
-------
str
path to cached weights file
"""
os.makedirs(cache_dir, exist_ok=True)
url_hash = shake_256(file_url.encode("utf-8")).hexdigest(10)
cache_file_name = url_hash + ".ckpt"
cache_file_path = cache_dir / cache_file_name

if cache_file_path.is_file():
logger.info(f"Model weights {file_url} retrieved from local cache")
return str(cache_file_path)

logger.info(f"Model weights {file_url} not in local cache, downloading")
file_response = requests.get(file_url)

if not file_response.ok:
logger.error(f"Failed to download weights from {file_url}")
logger.error(
f"Server Response: {file_response.status_code}: {file_response.reason}"
)
raise ConnectionError(f"Failed to download weights file: {file_url}")

logger.info("Model weights downloaded, writing to cache")
with open(cache_file_path, "wb") as cache_file:
cache_file.write(file_response.content)

logger.info("Model weights cached")
return str(cache_file_path)


def _is_valid_url(file_url: str) -> bool:
"""
Determine whether file URL is a valid URL

Parameters
----------
file_url : str
url to verify

Return
------
is_url : bool
whether file_url is a valid url
"""
try:
result = urlparse(file_url)
return all([result.scheme, result.netloc])
except ValueError:
return False


def _get_model_weights(cache_dir: str) -> str:
"""
Use cached model weights or download them from GitHub.

Expand All @@ -396,12 +478,16 @@ def _get_model_weights() -> str:
Note that the GitHub API is limited to 60 requests from the same IP per
hour.

Parameters
----------
cache_dir : str
model weights cache directory path

Returns
-------
str
The name of the model weights file.
"""
cache_dir = appdirs.user_cache_dir("casanovo", False, opinion=False)
os.makedirs(cache_dir, exist_ok=True)
version = utils.split_version(__version__)
version_match: Tuple[Optional[str], Optional[str], int] = None, None, 0
Expand Down
64 changes: 35 additions & 29 deletions docs/images/configure-help.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading