Skip to content

Commit bb942e4

Browse files
✨ Ensure file chunks are uploaded concurrently and improve PaginationIterator (#220)
1 parent 534fb11 commit bb942e4

File tree

8 files changed

+196
-81
lines changed

8 files changed

+196
-81
lines changed

clients/python/src/osparc/_api_files_api.py

Lines changed: 71 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import math
77
from pathlib import Path
8-
from typing import Any, Iterator, List, Optional, Tuple, Union
8+
from typing import Any, Iterator, List, Optional, Tuple, Union, Set, Final
99

1010
import httpx
1111
from httpx import Response
@@ -31,13 +31,16 @@
3131
import shutil
3232
from ._utils import (
3333
DEFAULT_TIMEOUT_SECONDS,
34-
PaginationGenerator,
34+
PaginationIterable,
3535
compute_sha256,
3636
file_chunk_generator,
37+
Chunk,
3738
)
3839

3940
_logger = logging.getLogger(__name__)
4041

42+
_MAX_CONCURRENT_UPLOADS: Final[int] = 20
43+
4144

4245
class FilesApi(_FilesApi):
4346
_dev_features = [
@@ -116,16 +119,23 @@ def upload_file(
116119
self,
117120
file: Union[str, Path],
118121
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
122+
max_concurrent_uploads: int = _MAX_CONCURRENT_UPLOADS,
119123
**kwargs,
120124
):
121125
return asyncio.run(
122-
self.upload_file_async(file=file, timeout_seconds=timeout_seconds, **kwargs)
126+
self.upload_file_async(
127+
file=file,
128+
timeout_seconds=timeout_seconds,
129+
max_concurrent_uploads=max_concurrent_uploads,
130+
**kwargs,
131+
)
123132
)
124133

125134
async def upload_file_async(
126135
self,
127136
file: Union[str, Path],
128137
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
138+
max_concurrent_uploads: int = _MAX_CONCURRENT_UPLOADS,
129139
**kwargs,
130140
) -> File:
131141
if isinstance(file, str):
@@ -159,53 +169,66 @@ async def upload_file_async(
159169
"Did not receive sufficient number of upload URLs from the server."
160170
)
161171

162-
uploaded_parts: list[UploadedPart] = []
172+
abort_body = BodyAbortMultipartUploadV0FilesFileIdAbortPost(
173+
client_file=client_file
174+
)
175+
upload_tasks: Set[asyncio.Task] = set()
176+
uploaded_parts: List[UploadedPart] = []
177+
163178
async with AsyncHttpClient(
164-
configuration=self.api_client.configuration, timeout=timeout_seconds
165-
) as session:
166-
with logging_redirect_tqdm():
167-
_logger.debug("Uploading %s in %i chunk(s)", file.name, n_urls)
168-
async for chunck, size in tqdm(
169-
file_chunk_generator(file, chunk_size),
170-
total=n_urls,
171-
disable=(not _logger.isEnabledFor(logging.DEBUG)),
172-
):
173-
index, url = next(url_iter)
174-
uploaded_parts.append(
175-
await self._upload_chunck(
176-
http_client=session,
177-
chunck=chunck,
178-
chunck_size=size,
179-
upload_link=url,
180-
index=index,
179+
configuration=self.api_client.configuration,
180+
method="post",
181+
url=links.abort_upload,
182+
body=abort_body.to_dict(),
183+
base_url=self.api_client.configuration.host,
184+
follow_redirects=True,
185+
auth=self._auth,
186+
timeout=timeout_seconds,
187+
) as api_server_session:
188+
async with AsyncHttpClient(
189+
configuration=self.api_client.configuration, timeout=timeout_seconds
190+
) as s3_session:
191+
with logging_redirect_tqdm():
192+
_logger.debug("Uploading %s in %i chunk(s)", file.name, n_urls)
193+
async for chunk in tqdm(
194+
file_chunk_generator(file, chunk_size),
195+
total=n_urls,
196+
disable=(not _logger.isEnabledFor(logging.DEBUG)),
197+
): # type: ignore
198+
assert isinstance(chunk, Chunk) # nosec
199+
index, url = next(url_iter)
200+
upload_tasks.add(
201+
asyncio.create_task(
202+
self._upload_chunck(
203+
http_client=s3_session,
204+
chunck=chunk.data,
205+
chunck_size=chunk.nbytes,
206+
upload_link=url,
207+
index=index,
208+
)
209+
)
181210
)
182-
)
211+
while (len(upload_tasks) >= max_concurrent_uploads) or (
212+
chunk.is_last_chunk and len(upload_tasks) > 0
213+
):
214+
done, upload_tasks = await asyncio.wait(
215+
upload_tasks, return_when=asyncio.FIRST_COMPLETED
216+
)
217+
for task in done:
218+
uploaded_parts.append(task.result())
183219

184-
abort_body = BodyAbortMultipartUploadV0FilesFileIdAbortPost(
185-
client_file=client_file
220+
_logger.debug(
221+
("Completing upload of %s " "(this might take a couple of minutes)..."),
222+
file.name,
186223
)
187-
async with AsyncHttpClient(
188-
configuration=self.api_client.configuration,
189-
method="post",
190-
url=links.abort_upload,
191-
body=abort_body.to_dict(),
192-
base_url=self.api_client.configuration.host,
193-
follow_redirects=True,
194-
auth=self._auth,
195-
timeout=timeout_seconds,
196-
) as session:
197-
_logger.debug(
198-
(
199-
"Completing upload of %s "
200-
"(this might take a couple of minutes)..."
201-
),
202-
file.name,
203-
)
204-
server_file: File = await self._complete_multipart_upload(
205-
session, links.complete_upload, client_file, uploaded_parts
206-
)
207-
_logger.debug("File upload complete: %s", file.name)
208-
return server_file
224+
server_file: File = await self._complete_multipart_upload(
225+
api_server_session,
226+
links.complete_upload, # type: ignore
227+
client_file,
228+
uploaded_parts,
229+
)
230+
_logger.debug("File upload complete: %s", file.name)
231+
return server_file
209232

210233
async def _complete_multipart_upload(
211234
self,
@@ -250,7 +273,7 @@ def _search_files(
250273
file_id: Optional[str] = None,
251274
sha256_checksum: Optional[str] = None,
252275
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
253-
) -> PaginationGenerator:
276+
) -> PaginationIterable:
254277
kwargs = {
255278
"file_id": file_id,
256279
"sha256_checksum": sha256_checksum,
@@ -262,7 +285,7 @@ def _pagination_method():
262285
**{k: v for k, v in kwargs.items() if v is not None}
263286
)
264287

265-
return PaginationGenerator(
288+
return PaginationIterable(
266289
first_page_callback=_pagination_method,
267290
api_client=self.api_client,
268291
base_url=self.api_client.configuration.host,

clients/python/src/osparc/_api_solvers_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ._utils import (
1212
_DEFAULT_PAGINATION_LIMIT,
1313
_DEFAULT_PAGINATION_OFFSET,
14-
PaginationGenerator,
14+
PaginationIterable,
1515
)
1616

1717
import warnings
@@ -47,7 +47,7 @@ def list_solver_ports(
4747
)
4848
return page.items if page.items else []
4949

50-
def iter_jobs(self, solver_key: str, version: str, **kwargs) -> PaginationGenerator:
50+
def iter_jobs(self, solver_key: str, version: str, **kwargs) -> PaginationIterable:
5151
"""Returns an iterator through which one can iterate over
5252
all Jobs submitted to the solver
5353
@@ -72,14 +72,14 @@ def _pagination_method():
7272
**kwargs,
7373
)
7474

75-
return PaginationGenerator(
75+
return PaginationIterable(
7676
first_page_callback=_pagination_method,
7777
api_client=self.api_client,
7878
base_url=self.api_client.configuration.host,
7979
auth=self._auth,
8080
)
8181

82-
def jobs(self, solver_key: str, version: str, **kwargs) -> PaginationGenerator:
82+
def jobs(self, solver_key: str, version: str, **kwargs) -> PaginationIterable:
8383
warnings.warn(
8484
"The 'jobs' method is deprecated and will be removed in a future version. "
8585
"Please use 'iter_jobs' instead.",

clients/python/src/osparc/_api_studies_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ._utils import (
1818
_DEFAULT_PAGINATION_LIMIT,
1919
_DEFAULT_PAGINATION_OFFSET,
20-
PaginationGenerator,
20+
PaginationIterable,
2121
)
2222
import warnings
2323

@@ -65,7 +65,7 @@ def clone_study(self, study_id: str, **kwargs):
6565
kwargs = {**kwargs, **ParentProjectInfo().model_dump(exclude_none=True)}
6666
return super().clone_study(study_id, **kwargs)
6767

68-
def iter_studies(self, **kwargs) -> PaginationGenerator:
68+
def iter_studies(self, **kwargs) -> PaginationIterable:
6969
def _pagination_method():
7070
page_study = self.list_studies(
7171
limit=_DEFAULT_PAGINATION_LIMIT,
@@ -75,14 +75,14 @@ def _pagination_method():
7575
assert isinstance(page_study, PageStudy) # nosec
7676
return page_study
7777

78-
return PaginationGenerator(
78+
return PaginationIterable(
7979
first_page_callback=_pagination_method,
8080
api_client=self.api_client,
8181
base_url=self.api_client.configuration.host,
8282
auth=self._auth,
8383
)
8484

85-
def studies(self, **kwargs) -> PaginationGenerator:
85+
def studies(self, **kwargs) -> PaginationIterable:
8686
warnings.warn(
8787
"The 'studies' method is deprecated and will be removed in a future version. "
8888
"Please use 'iter_studies' instead.",

clients/python/src/osparc/_utils.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
import asyncio
22
import hashlib
33
from pathlib import Path
4-
from typing import AsyncGenerator, Callable, Generator, Optional, Tuple, TypeVar, Union
5-
4+
from typing import (
5+
AsyncGenerator,
6+
Callable,
7+
Optional,
8+
TypeVar,
9+
Union,
10+
NamedTuple,
11+
Generator,
12+
)
13+
from collections.abc import Iterable, Sized
614
import httpx
715
from osparc_client import (
816
ApiClient,
@@ -15,7 +23,7 @@
1523
Study,
1624
)
1725
import aiofiles
18-
from ._exceptions import RequestError
26+
from .exceptions import RequestError
1927

2028
_KB = 1024 # in bytes
2129
_MB = _KB * 1024 # in bytes
@@ -30,8 +38,11 @@
3038
T = TypeVar("T", Job, File, Solver, Study)
3139

3240

33-
class PaginationGenerator:
34-
"""Class for wrapping paginated http methods as generators"""
41+
class PaginationIterable(Iterable, Sized):
42+
"""Class for wrapping paginated http methods as iterables. It supports three simple operations:
43+
- for elm in pagination_iterable
44+
- elm = next(pagination_iterable)
45+
- len(pagination_iterable)"""
3546

3647
def __init__(
3748
self,
@@ -75,9 +86,15 @@ def __iter__(self) -> Generator[T, None, None]:
7586
page = self._api_client._ApiClient__deserialize(response.json(), type(page))
7687

7788

89+
class Chunk(NamedTuple):
90+
data: bytes
91+
nbytes: int
92+
is_last_chunk: bool
93+
94+
7895
async def file_chunk_generator(
7996
file: Path, chunk_size: int
80-
) -> AsyncGenerator[Tuple[bytes, int], None]:
97+
) -> AsyncGenerator[Chunk, None]:
8198
if not file.is_file():
8299
raise RuntimeError(f"{file} must be a file")
83100
if chunk_size <= 0:
@@ -94,8 +111,10 @@ async def file_chunk_generator(
94111
)
95112
assert nbytes > 0
96113
chunk = await f.read(nbytes)
97-
yield chunk, nbytes
98114
bytes_read += nbytes
115+
yield Chunk(
116+
data=chunk, nbytes=nbytes, is_last_chunk=(bytes_read == file_size)
117+
)
99118

100119

101120
S = TypeVar("S")

clients/python/test/e2e/conftest.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from packaging.version import Version
1919
from pydantic import ByteSize
2020
from typing import NamedTuple, Final
21+
from memory_profiler import memory_usage
2122

2223
try:
2324
from osparc._settings import ConfigurationEnvVars
@@ -140,6 +141,7 @@ def async_client() -> Iterable[AsyncClient]:
140141
class ServerFile(NamedTuple):
141142
server_file: osparc.File
142143
local_file: Path
144+
upload_ram_usage: int
143145

144146

145147
@pytest.fixture(scope="session")
@@ -160,9 +162,15 @@ def large_server_file(
160162
assert (
161163
tmp_file.stat().st_size == _file_size
162164
), f"Could not create file of size: {_file_size}"
163-
uploaded_file: osparc.File = files_api.upload_file(tmp_file)
165+
ram_statistics, uploaded_file = memory_usage(
166+
(files_api.upload_file, (tmp_file,)), retval=True
167+
)
164168

165-
yield ServerFile(local_file=tmp_file, server_file=uploaded_file)
169+
yield ServerFile(
170+
local_file=tmp_file,
171+
server_file=uploaded_file,
172+
upload_ram_usage=max(ram_statistics) - min(ram_statistics),
173+
)
166174

167175
files_api.delete_file(uploaded_file.id)
168176

0 commit comments

Comments
 (0)