Skip to content

Commit eba99ee

Browse files
committed
Encapsulate common GCS operations in Storage class
1 parent 0de9e3f commit eba99ee

File tree

3 files changed

+164
-23
lines changed

3 files changed

+164
-23
lines changed

nmdc_server/cli.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
import click
1111
import requests
1212

13-
from nmdc_server import jobs, storage
13+
from nmdc_server import jobs
1414
from nmdc_server.config import settings
1515
from nmdc_server.database import SessionLocalIngest
1616
from nmdc_server.ingest import errors
1717
from nmdc_server.static_files import generate_submission_schema_files, initialize_static_directory
18+
from nmdc_server.storage import BucketName, storage
1819

1920

2021
def send_slack_message(text: str) -> bool:
@@ -350,17 +351,19 @@ def generate_static_files(remove_existing):
350351
@cli.command()
351352
def ensure_storage_buckets():
352353
"""Ensure that the storage buckets exist."""
353-
for bucket_name in storage.Bucket:
354+
for bucket_name in BucketName:
354355
click.echo(f"Ensuring bucket '{bucket_name}' exists")
355-
try:
356-
storage.client.get_bucket(bucket_name)
356+
bucket = storage.get_bucket(bucket_name)
357+
if bucket.exists():
357358
click.echo(f"Bucket '{bucket_name}' already exists")
358-
except Exception as e:
359-
if settings.use_fake_gcs_server:
360-
click.echo(f"Creating bucket '{bucket_name}'")
361-
storage.client.create_bucket(bucket_name)
362-
else:
363-
raise RuntimeError(f"Failed to ensure bucket '{bucket_name}' exists: {e}")
359+
elif settings.use_fake_gcs_server:
360+
click.echo(f"Creating bucket '{bucket_name}'")
361+
bucket.create()
362+
else:
363+
raise RuntimeError(
364+
f"Failed to ensure bucket '{bucket_name}' exists. "
365+
f"This bucket may need to be created manually in the cloud storage provider."
366+
)
364367

365368

366369
if __name__ == "__main__":

nmdc_server/schemas.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,3 +692,34 @@ class VersionInfo(BaseModel):
692692
nmdc_schema: str = version("nmdc-schema")
693693
nmdc_submission_schema: str = version("nmdc-submission-schema")
694694
model_config = ConfigDict(frozen=False)
695+
696+
697+
class SignedUploadUrlRequest(BaseModel):
698+
"""Request to generate a signed upload URL for a file.
699+
700+
This model is used to generate a signed URL for uploading files to the object store.
701+
"""
702+
703+
file_name: str
704+
file_size: int
705+
content_type: str
706+
707+
708+
class SignedUrl(BaseModel):
709+
"""Response containing the signed URL and other metadata."""
710+
711+
url: str
712+
object_name: str
713+
expiration: datetime
714+
715+
716+
class UploadCompleteRequest(BaseModel):
717+
"""Request to mark an upload as complete.
718+
719+
This model is used to mark an upload as complete after the file has been uploaded
720+
to the object store.
721+
"""
722+
723+
object_name: str
724+
file_size: int
725+
content_type: str

nmdc_server/storage.py

Lines changed: 120 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,136 @@
1+
from datetime import datetime, timedelta, timezone
12
from enum import StrEnum
3+
from functools import cached_property
24
from typing import Any
35

4-
from google.auth.credentials import AnonymousCredentials
5-
from google.cloud import storage
6+
from google.cloud import storage as gcs, exceptions as gce
67

78
from nmdc_server.config import settings
9+
from nmdc_server.schemas import SignedUrl
810

911

10-
class Bucket(StrEnum):
12+
class BucketName(StrEnum):
1113
"""Enum for GCS bucket names"""
1214

1315
SUBMISSION_IMAGES = "nmdc-submission-images"
1416

1517

16-
def _initialize_client() -> storage.Client:
17-
client_args: dict[str, Any] = {
18-
"project": settings.gcs_project_id,
19-
}
18+
class Storage:
19+
"""A class to manage Google Cloud Storage interactions."""
2020

21-
if settings.use_fake_gcs_server:
22-
# https://github.com/fsouza/fake-gcs-server/blob/cd43b03fcfb8149c6f57c1a92e19d1a07e291a3c/examples/python/python.py
23-
client_args["credentials"] = AnonymousCredentials()
24-
client_args["client_options"] = {"api_endpoint": "http://storage:4443"}
21+
def __init__(self, project_id: str, use_fake_gcs_server: bool):
22+
self.project_id = project_id
23+
self.use_fake_gcs_server = use_fake_gcs_server
2524

26-
return storage.Client(**client_args)
25+
@cached_property
26+
def _client(self):
27+
"""Google Cloud Storage client."""
28+
client_args: dict[str, Any] = {
29+
"project": self.project_id,
30+
}
2731

32+
if self.use_fake_gcs_server:
33+
client_args["client_options"] = {"api_endpoint": "http://storage:4443"}
2834

29-
client = _initialize_client()
35+
return gcs.Client(**client_args)
36+
37+
def get_bucket(self, bucket_name: BucketName) -> gcs.Bucket:
38+
"""Get a GCS bucket by name.
39+
40+
:param bucket_name: The name of the bucket to retrieve.
41+
"""
42+
return self._client.bucket(bucket_name)
43+
44+
def get_object(self, bucket_name: BucketName, object_name: str) -> gcs.Blob:
45+
"""Get an object from a GCS bucket.
46+
47+
:param bucket_name: The name of the bucket containing the object.
48+
:param object_name: The name of the object to retrieve.
49+
"""
50+
bucket = self.get_bucket(bucket_name)
51+
return bucket.blob(object_name)
52+
53+
def delete_object(
54+
self, bucket_name: BucketName, object_name: str, *, raise_if_not_found: bool = False
55+
) -> None:
56+
"""Delete an object from a GCS bucket.
57+
58+
:param bucket_name: The name of the bucket containing the object.
59+
:param object_name: The name of the object to delete.
60+
:param raise_if_not_found: If True, raise an exception if the object is not found. Default
61+
is False.
62+
"""
63+
bucket = self.get_bucket(bucket_name)
64+
try:
65+
bucket.delete_blob(object_name)
66+
except gce.NotFound as e:
67+
if raise_if_not_found:
68+
raise e
69+
70+
def get_signed_upload_url(
71+
self,
72+
bucket_name: BucketName,
73+
object_name: str,
74+
*,
75+
expiration: int = 15,
76+
content_type: str | None = None,
77+
) -> SignedUrl:
78+
"""Get a signed URL for uploading to an object to a GCS bucket.
79+
80+
:param bucket_name: The name of the bucket that will contain the object.
81+
:param object_name: The name of the object.
82+
:param expiration: The expiration time for the signed URL in minutes. Default is 15 minutes.
83+
:param content_type: The content type of the object being uploaded.
84+
"""
85+
blob = self.get_object(bucket_name, object_name)
86+
expiration_delta = timedelta(minutes=expiration)
87+
expiration_time = datetime.now(timezone.utc) + expiration_delta
88+
89+
url = blob.generate_signed_url(
90+
version="v4",
91+
expiration=expiration_delta,
92+
method="PUT",
93+
content_type=content_type,
94+
)
95+
96+
if self.use_fake_gcs_server:
97+
# If using a fake GCS server, we need to adjust the URL to point localhost instead of
98+
# the docker-compose service name.
99+
url = url.replace("//storage:", "//localhost:")
100+
101+
return SignedUrl(url=url, expiration=expiration_time, object_name=blob.name)
102+
103+
def get_signed_download_url(
104+
self, bucket_name: BucketName, object_name: str, *, expiration: int = 15
105+
) -> SignedUrl:
106+
"""Get a signed URL for downloading an object from a GCS bucket.
107+
108+
:param bucket_name: The name of the bucket containing the object.
109+
:param object_name: The name of the object to download.
110+
:param expiration: The expiration time for the signed URL in minutes. Default is 15 minutes.
111+
"""
112+
expiration_delta = timedelta(minutes=expiration)
113+
expiration_time = datetime.now(timezone.utc) + expiration_delta
114+
blob = self.get_object(bucket_name, object_name)
115+
url = blob.generate_signed_url(
116+
version="v4",
117+
expiration=expiration_delta,
118+
method="GET",
119+
)
120+
121+
if self.use_fake_gcs_server:
122+
# If using a fake GCS server, we need to adjust the URL to point localhost instead of
123+
# the docker-compose service name.
124+
url = url.replace("//storage:", "//localhost:")
125+
126+
return SignedUrl(
127+
url=url,
128+
object_name=blob.name,
129+
expiration=expiration_time,
130+
)
131+
132+
133+
storage = Storage(
134+
project_id=settings.gcs_project_id,
135+
use_fake_gcs_server=settings.use_fake_gcs_server,
136+
)

0 commit comments

Comments
 (0)