Skip to content

✨ Ensure file chunks are uploaded concurrently and improve PaginationIterator #220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 71 additions & 48 deletions clients/python/src/osparc/_api_files_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import math
from pathlib import Path
from typing import Any, Iterator, List, Optional, Tuple, Union
from typing import Any, Iterator, List, Optional, Tuple, Union, Set, Final

import httpx
from httpx import Response
Expand All @@ -31,13 +31,16 @@
import shutil
from ._utils import (
DEFAULT_TIMEOUT_SECONDS,
PaginationGenerator,
PaginationIterable,
compute_sha256,
file_chunk_generator,
Chunk,
)

_logger = logging.getLogger(__name__)

_MAX_CONCURRENT_UPLOADS: Final[int] = 20


class FilesApi(_FilesApi):
_dev_features = [
Expand Down Expand Up @@ -116,16 +119,23 @@ def upload_file(
self,
file: Union[str, Path],
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
max_concurrent_uploads: int = _MAX_CONCURRENT_UPLOADS,
**kwargs,
):
return asyncio.run(
self.upload_file_async(file=file, timeout_seconds=timeout_seconds, **kwargs)
self.upload_file_async(
file=file,
timeout_seconds=timeout_seconds,
max_concurrent_uploads=max_concurrent_uploads,
**kwargs,
)
)

async def upload_file_async(
self,
file: Union[str, Path],
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
max_concurrent_uploads: int = _MAX_CONCURRENT_UPLOADS,
**kwargs,
) -> File:
if isinstance(file, str):
Expand Down Expand Up @@ -159,53 +169,66 @@ async def upload_file_async(
"Did not receive sufficient number of upload URLs from the server."
)

uploaded_parts: list[UploadedPart] = []
abort_body = BodyAbortMultipartUploadV0FilesFileIdAbortPost(
client_file=client_file
)
upload_tasks: Set[asyncio.Task] = set()
uploaded_parts: List[UploadedPart] = []

async with AsyncHttpClient(
configuration=self.api_client.configuration, timeout=timeout_seconds
) as session:
with logging_redirect_tqdm():
_logger.debug("Uploading %s in %i chunk(s)", file.name, n_urls)
async for chunck, size in tqdm(
file_chunk_generator(file, chunk_size),
total=n_urls,
disable=(not _logger.isEnabledFor(logging.DEBUG)),
):
index, url = next(url_iter)
uploaded_parts.append(
await self._upload_chunck(
http_client=session,
chunck=chunck,
chunck_size=size,
upload_link=url,
index=index,
configuration=self.api_client.configuration,
method="post",
url=links.abort_upload,
body=abort_body.to_dict(),
base_url=self.api_client.configuration.host,
follow_redirects=True,
auth=self._auth,
timeout=timeout_seconds,
) as api_server_session:
async with AsyncHttpClient(
configuration=self.api_client.configuration, timeout=timeout_seconds
) as s3_session:
with logging_redirect_tqdm():
_logger.debug("Uploading %s in %i chunk(s)", file.name, n_urls)
async for chunk in tqdm(
file_chunk_generator(file, chunk_size),
total=n_urls,
disable=(not _logger.isEnabledFor(logging.DEBUG)),
): # type: ignore
assert isinstance(chunk, Chunk) # nosec
index, url = next(url_iter)
upload_tasks.add(
asyncio.create_task(
self._upload_chunck(
http_client=s3_session,
chunck=chunk.data,
chunck_size=chunk.nbytes,
upload_link=url,
index=index,
)
)
)
)
while (len(upload_tasks) >= max_concurrent_uploads) or (
chunk.is_last_chunk and len(upload_tasks) > 0
):
done, upload_tasks = await asyncio.wait(
upload_tasks, return_when=asyncio.FIRST_COMPLETED
)
for task in done:
uploaded_parts.append(task.result())

abort_body = BodyAbortMultipartUploadV0FilesFileIdAbortPost(
client_file=client_file
_logger.debug(
("Completing upload of %s " "(this might take a couple of minutes)..."),
file.name,
)
async with AsyncHttpClient(
configuration=self.api_client.configuration,
method="post",
url=links.abort_upload,
body=abort_body.to_dict(),
base_url=self.api_client.configuration.host,
follow_redirects=True,
auth=self._auth,
timeout=timeout_seconds,
) as session:
_logger.debug(
(
"Completing upload of %s "
"(this might take a couple of minutes)..."
),
file.name,
)
server_file: File = await self._complete_multipart_upload(
session, links.complete_upload, client_file, uploaded_parts
)
_logger.debug("File upload complete: %s", file.name)
return server_file
server_file: File = await self._complete_multipart_upload(
api_server_session,
links.complete_upload, # type: ignore
client_file,
uploaded_parts,
)
_logger.debug("File upload complete: %s", file.name)
return server_file

async def _complete_multipart_upload(
self,
Expand Down Expand Up @@ -250,7 +273,7 @@ def _search_files(
file_id: Optional[str] = None,
sha256_checksum: Optional[str] = None,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
) -> PaginationGenerator:
) -> PaginationIterable:
kwargs = {
"file_id": file_id,
"sha256_checksum": sha256_checksum,
Expand All @@ -262,7 +285,7 @@ def _pagination_method():
**{k: v for k, v in kwargs.items() if v is not None}
)

return PaginationGenerator(
return PaginationIterable(
first_page_callback=_pagination_method,
api_client=self.api_client,
base_url=self.api_client.configuration.host,
Expand Down
8 changes: 4 additions & 4 deletions clients/python/src/osparc/_api_solvers_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ._utils import (
_DEFAULT_PAGINATION_LIMIT,
_DEFAULT_PAGINATION_OFFSET,
PaginationGenerator,
PaginationIterable,
)

import warnings
Expand Down Expand Up @@ -47,7 +47,7 @@ def list_solver_ports(
)
return page.items if page.items else []

def iter_jobs(self, solver_key: str, version: str, **kwargs) -> PaginationGenerator:
def iter_jobs(self, solver_key: str, version: str, **kwargs) -> PaginationIterable:
"""Returns an iterator through which one can iterate over
all Jobs submitted to the solver

Expand All @@ -72,14 +72,14 @@ def _pagination_method():
**kwargs,
)

return PaginationGenerator(
return PaginationIterable(
first_page_callback=_pagination_method,
api_client=self.api_client,
base_url=self.api_client.configuration.host,
auth=self._auth,
)

def jobs(self, solver_key: str, version: str, **kwargs) -> PaginationGenerator:
def jobs(self, solver_key: str, version: str, **kwargs) -> PaginationIterable:
warnings.warn(
"The 'jobs' method is deprecated and will be removed in a future version. "
"Please use 'iter_jobs' instead.",
Expand Down
8 changes: 4 additions & 4 deletions clients/python/src/osparc/_api_studies_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ._utils import (
_DEFAULT_PAGINATION_LIMIT,
_DEFAULT_PAGINATION_OFFSET,
PaginationGenerator,
PaginationIterable,
)
import warnings

Expand Down Expand Up @@ -65,7 +65,7 @@ def clone_study(self, study_id: str, **kwargs):
kwargs = {**kwargs, **ParentProjectInfo().model_dump(exclude_none=True)}
return super().clone_study(study_id, **kwargs)

def iter_studies(self, **kwargs) -> PaginationGenerator:
def iter_studies(self, **kwargs) -> PaginationIterable:
def _pagination_method():
page_study = self.list_studies(
limit=_DEFAULT_PAGINATION_LIMIT,
Expand All @@ -75,14 +75,14 @@ def _pagination_method():
assert isinstance(page_study, PageStudy) # nosec
return page_study

return PaginationGenerator(
return PaginationIterable(
first_page_callback=_pagination_method,
api_client=self.api_client,
base_url=self.api_client.configuration.host,
auth=self._auth,
)

def studies(self, **kwargs) -> PaginationGenerator:
def studies(self, **kwargs) -> PaginationIterable:
warnings.warn(
"The 'studies' method is deprecated and will be removed in a future version. "
"Please use 'iter_studies' instead.",
Expand Down
33 changes: 26 additions & 7 deletions clients/python/src/osparc/_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import asyncio
import hashlib
from pathlib import Path
from typing import AsyncGenerator, Callable, Generator, Optional, Tuple, TypeVar, Union

from typing import (
AsyncGenerator,
Callable,
Optional,
TypeVar,
Union,
NamedTuple,
Generator,
)
from collections.abc import Iterable, Sized
import httpx
from osparc_client import (
ApiClient,
Expand All @@ -15,7 +23,7 @@
Study,
)
import aiofiles
from ._exceptions import RequestError
from .exceptions import RequestError

_KB = 1024 # in bytes
_MB = _KB * 1024 # in bytes
Expand All @@ -30,8 +38,11 @@
T = TypeVar("T", Job, File, Solver, Study)


class PaginationGenerator:
"""Class for wrapping paginated http methods as generators"""
class PaginationIterable(Iterable, Sized):
"""Class for wrapping paginated http methods as iterables. It supports three simple operations:
- for elm in pagination_iterable
- elm = next(pagination_iterable)
- len(pagination_iterable)"""

def __init__(
self,
Expand Down Expand Up @@ -75,9 +86,15 @@ def __iter__(self) -> Generator[T, None, None]:
page = self._api_client._ApiClient__deserialize(response.json(), type(page))


class Chunk(NamedTuple):
data: bytes
nbytes: int
is_last_chunk: bool


async def file_chunk_generator(
file: Path, chunk_size: int
) -> AsyncGenerator[Tuple[bytes, int], None]:
) -> AsyncGenerator[Chunk, None]:
if not file.is_file():
raise RuntimeError(f"{file} must be a file")
if chunk_size <= 0:
Expand All @@ -94,8 +111,10 @@ async def file_chunk_generator(
)
assert nbytes > 0
chunk = await f.read(nbytes)
yield chunk, nbytes
bytes_read += nbytes
yield Chunk(
data=chunk, nbytes=nbytes, is_last_chunk=(bytes_read == file_size)
)


S = TypeVar("S")
Expand Down
12 changes: 10 additions & 2 deletions clients/python/test/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from packaging.version import Version
from pydantic import ByteSize
from typing import NamedTuple, Final
from memory_profiler import memory_usage

try:
from osparc._settings import ConfigurationEnvVars
Expand Down Expand Up @@ -140,6 +141,7 @@ def async_client() -> Iterable[AsyncClient]:
class ServerFile(NamedTuple):
server_file: osparc.File
local_file: Path
upload_ram_usage: int


@pytest.fixture(scope="session")
Expand All @@ -160,9 +162,15 @@ def large_server_file(
assert (
tmp_file.stat().st_size == _file_size
), f"Could not create file of size: {_file_size}"
uploaded_file: osparc.File = files_api.upload_file(tmp_file)
ram_statistics, uploaded_file = memory_usage(
(files_api.upload_file, (tmp_file,)), retval=True
)

yield ServerFile(local_file=tmp_file, server_file=uploaded_file)
yield ServerFile(
local_file=tmp_file,
server_file=uploaded_file,
upload_ram_usage=max(ram_statistics) - min(ram_statistics),
)

files_api.delete_file(uploaded_file.id)

Expand Down
Loading
Loading