Skip to content

Commit 5fd1600

Browse files
committed
Merge branch 'master' into 201-experiment-with-upgrading-openapi-generator
2 parents 4041620 + 0388ebb commit 5fd1600

File tree

8 files changed

+221
-91
lines changed

8 files changed

+221
-91
lines changed

clients/python/src/osparc/_api_files_api.py

Lines changed: 78 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import math
77
from pathlib import Path
8-
from typing import Any, Iterator, List, Optional, Tuple, Union
8+
from typing import Any, Iterator, List, Optional, Tuple, Union, Final, Set
99
from tempfile import NamedTemporaryFile
1010

1111
import httpx
@@ -31,13 +31,16 @@
3131
import shutil
3232
from ._utils import (
3333
DEFAULT_TIMEOUT_SECONDS,
34-
PaginationGenerator,
34+
PaginationIterable,
3535
compute_sha256,
3636
file_chunk_generator,
37+
Chunk,
3738
)
3839

3940
_logger = logging.getLogger(__name__)
4041

42+
_MAX_CONCURRENT_UPLOADS: Final[int] = 20
43+
4144

4245
class FilesApi(_FilesApi):
4346
_dev_features = [
@@ -116,16 +119,23 @@ def upload_file(
116119
self,
117120
file: Union[str, Path],
118121
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
122+
max_concurrent_uploads: int = _MAX_CONCURRENT_UPLOADS,
119123
**kwargs,
120124
):
121125
return asyncio.run(
122-
self.upload_file_async(file=file, timeout_seconds=timeout_seconds, **kwargs)
126+
self.upload_file_async(
127+
file=file,
128+
timeout_seconds=timeout_seconds,
129+
max_concurrent_uploads=max_concurrent_uploads,
130+
**kwargs,
131+
)
123132
)
124133

125134
async def upload_file_async(
126135
self,
127136
file: Union[str, Path],
128137
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
138+
max_concurrent_uploads: int = _MAX_CONCURRENT_UPLOADS,
129139
**kwargs,
130140
) -> File:
131141
if isinstance(file, str):
@@ -159,53 +169,66 @@ async def upload_file_async(
159169
"Did not receive sufficient number of upload URLs from the server."
160170
)
161171

162-
uploaded_parts: list[UploadedPart] = []
172+
abort_body = BodyAbortMultipartUploadV0FilesFileIdAbortPost(
173+
client_file=client_file
174+
)
175+
upload_tasks: Set[asyncio.Task] = set()
176+
uploaded_parts: List[UploadedPart] = []
177+
163178
async with AsyncHttpClient(
164-
configuration=self.api_client.configuration, timeout=timeout_seconds
165-
) as session:
166-
with logging_redirect_tqdm():
167-
_logger.debug("Uploading %s in %i chunk(s)", file.name, n_urls)
168-
async for chunck, size in tqdm(
169-
file_chunk_generator(file, chunk_size),
170-
total=n_urls,
171-
disable=(not _logger.isEnabledFor(logging.DEBUG)),
172-
):
173-
index, url = next(url_iter)
174-
uploaded_parts.append(
175-
await self._upload_chunck(
176-
http_client=session,
177-
chunck=chunck,
178-
chunck_size=size,
179-
upload_link=url,
180-
index=index,
179+
configuration=self.api_client.configuration,
180+
method="post",
181+
url=links.abort_upload,
182+
body=abort_body.to_dict(),
183+
base_url=self.api_client.configuration.host,
184+
follow_redirects=True,
185+
auth=self._auth,
186+
timeout=timeout_seconds,
187+
) as api_server_session:
188+
async with AsyncHttpClient(
189+
configuration=self.api_client.configuration, timeout=timeout_seconds
190+
) as s3_session:
191+
with logging_redirect_tqdm():
192+
_logger.debug("Uploading %s in %i chunk(s)", file.name, n_urls)
193+
async for chunk in tqdm(
194+
file_chunk_generator(file, chunk_size),
195+
total=n_urls,
196+
disable=(not _logger.isEnabledFor(logging.DEBUG)),
197+
): # type: ignore
198+
assert isinstance(chunk, Chunk) # nosec
199+
index, url = next(url_iter)
200+
upload_tasks.add(
201+
asyncio.create_task(
202+
self._upload_chunck(
203+
http_client=s3_session,
204+
chunck=chunk.data,
205+
chunck_size=chunk.nbytes,
206+
upload_link=url,
207+
index=index,
208+
)
209+
)
181210
)
182-
)
211+
while (len(upload_tasks) >= max_concurrent_uploads) or (
212+
chunk.is_last_chunk and len(upload_tasks) > 0
213+
):
214+
done, upload_tasks = await asyncio.wait(
215+
upload_tasks, return_when=asyncio.FIRST_COMPLETED
216+
)
217+
for task in done:
218+
uploaded_parts.append(task.result())
183219

184-
abort_body = BodyAbortMultipartUploadV0FilesFileIdAbortPost(
185-
client_file=client_file
220+
_logger.debug(
221+
("Completing upload of %s " "(this might take a couple of minutes)..."),
222+
file.name,
186223
)
187-
async with AsyncHttpClient(
188-
configuration=self.api_client.configuration,
189-
method="post",
190-
url=links.abort_upload,
191-
body=abort_body.to_dict(),
192-
base_url=self.api_client.configuration.host,
193-
follow_redirects=True,
194-
auth=self._auth,
195-
timeout=timeout_seconds,
196-
) as session:
197-
_logger.debug(
198-
(
199-
"Completing upload of %s "
200-
"(this might take a couple of minutes)..."
201-
),
202-
file.name,
203-
)
204-
server_file: File = await self._complete_multipart_upload(
205-
session, links.complete_upload, client_file, uploaded_parts
206-
)
207-
_logger.debug("File upload complete: %s", file.name)
208-
return server_file
224+
server_file: File = await self._complete_multipart_upload(
225+
api_server_session,
226+
links.complete_upload, # type: ignore
227+
client_file,
228+
uploaded_parts,
229+
)
230+
_logger.debug("File upload complete: %s", file.name)
231+
return server_file
209232

210233
async def _complete_multipart_upload(
211234
self,
@@ -250,15 +273,19 @@ def _search_files(
250273
file_id: Optional[str] = None,
251274
sha256_checksum: Optional[str] = None,
252275
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
253-
) -> PaginationGenerator:
276+
) -> PaginationIterable:
277+
kwargs = {
278+
"file_id": file_id,
279+
"sha256_checksum": sha256_checksum,
280+
"_request_timeout": timeout_seconds,
281+
}
282+
254283
def _pagination_method():
255284
return super(FilesApi, self).search_files_page(
256-
file_id=file_id,
257-
sha256_checksum=sha256_checksum,
258-
_request_timeout=timeout_seconds,
285+
**{k: v for k, v in kwargs.items() if v is not None}
259286
)
260287

261-
return PaginationGenerator(
288+
return PaginationIterable(
262289
first_page_callback=_pagination_method,
263290
api_client=self.api_client,
264291
base_url=self.api_client.configuration.host,

clients/python/src/osparc/_api_solvers_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ._utils import (
2222
_DEFAULT_PAGINATION_LIMIT,
2323
_DEFAULT_PAGINATION_OFFSET,
24-
PaginationGenerator,
24+
PaginationIterable,
2525
)
2626
import warnings
2727
from tempfile import NamedTemporaryFile
@@ -60,7 +60,7 @@ def list_solver_ports(
6060
)
6161
return page.items if page.items else []
6262

63-
def iter_jobs(self, solver_key: str, version: str, **kwargs) -> PaginationGenerator:
63+
def iter_jobs(self, solver_key: str, version: str, **kwargs) -> PaginationIterable:
6464
"""Returns an iterator through which one can iterate over
6565
all Jobs submitted to the solver
6666
@@ -85,14 +85,14 @@ def _pagination_method():
8585
**kwargs,
8686
)
8787

88-
return PaginationGenerator(
88+
return PaginationIterable(
8989
first_page_callback=_pagination_method,
9090
api_client=self.api_client,
9191
base_url=self.api_client.configuration.host,
9292
auth=self._auth,
9393
)
9494

95-
def jobs(self, solver_key: str, version: str, **kwargs) -> PaginationGenerator:
95+
def jobs(self, solver_key: str, version: str, **kwargs) -> PaginationIterable:
9696
warnings.warn(
9797
"The 'jobs' method is deprecated and will be removed in a future version. "
9898
"Please use 'iter_jobs' instead.",

clients/python/src/osparc/_api_studies_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ._utils import (
2929
_DEFAULT_PAGINATION_LIMIT,
3030
_DEFAULT_PAGINATION_OFFSET,
31-
PaginationGenerator,
31+
PaginationIterable,
3232
)
3333
import warnings
3434

@@ -86,7 +86,7 @@ def clone_study(self, study_id: str, **kwargs):
8686
kwargs = {**kwargs, **ParentProjectInfo().model_dump(exclude_none=True)}
8787
return super().clone_study(study_id, **kwargs)
8888

89-
def iter_studies(self, **kwargs) -> PaginationGenerator:
89+
def iter_studies(self, **kwargs) -> PaginationIterable:
9090
def _pagination_method():
9191
page_study = self.list_studies(
9292
limit=_DEFAULT_PAGINATION_LIMIT,
@@ -96,14 +96,14 @@ def _pagination_method():
9696
assert isinstance(page_study, PageStudy) # nosec
9797
return page_study
9898

99-
return PaginationGenerator(
99+
return PaginationIterable(
100100
first_page_callback=_pagination_method,
101101
api_client=self.api_client,
102102
base_url=self.api_client.configuration.host,
103103
auth=self._auth,
104104
)
105105

106-
def studies(self, **kwargs) -> PaginationGenerator:
106+
def studies(self, **kwargs) -> PaginationIterable:
107107
warnings.warn(
108108
"The 'studies' method is deprecated and will be removed in a future version. "
109109
"Please use 'iter_studies' instead.",

clients/python/src/osparc/_utils.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
import asyncio
22
import hashlib
33
from pathlib import Path
4-
from typing import AsyncGenerator, Callable, Generator, Optional, Tuple, TypeVar, Union
5-
4+
from typing import (
5+
AsyncGenerator,
6+
Callable,
7+
Optional,
8+
TypeVar,
9+
Union,
10+
NamedTuple,
11+
Generator,
12+
)
13+
from collections.abc import Iterable, Sized
614
import httpx
715
from osparc_client import (
816
ApiClient,
@@ -15,7 +23,7 @@
1523
Study,
1624
)
1725
import aiofiles
18-
from ._exceptions import RequestError
26+
from .exceptions import RequestError
1927

2028
_KB = 1024 # in bytes
2129
_MB = _KB * 1024 # in bytes
@@ -30,8 +38,11 @@
3038
T = TypeVar("T", Job, File, Solver, Study)
3139

3240

33-
class PaginationGenerator:
34-
"""Class for wrapping paginated http methods as generators"""
41+
class PaginationIterable(Iterable, Sized):
42+
"""Class for wrapping paginated http methods as iterables. It supports three simple operations:
43+
- for elm in pagination_iterable
44+
- elm = next(pagination_iterable)
45+
- len(pagination_iterable)"""
3546

3647
def __init__(
3748
self,
@@ -75,9 +86,15 @@ def __iter__(self) -> Generator[T, None, None]:
7586
page = self._api_client._ApiClient__deserialize(response.json(), type(page))
7687

7788

89+
class Chunk(NamedTuple):
90+
data: bytes
91+
nbytes: int
92+
is_last_chunk: bool
93+
94+
7895
async def file_chunk_generator(
7996
file: Path, chunk_size: int
80-
) -> AsyncGenerator[Tuple[bytes, int], None]:
97+
) -> AsyncGenerator[Chunk, None]:
8198
if not file.is_file():
8299
raise RuntimeError(f"{file} must be a file")
83100
if chunk_size <= 0:
@@ -94,8 +111,10 @@ async def file_chunk_generator(
94111
)
95112
assert nbytes > 0
96113
chunk = await f.read(nbytes)
97-
yield chunk, nbytes
98114
bytes_read += nbytes
115+
yield Chunk(
116+
data=chunk, nbytes=nbytes, is_last_chunk=(bytes_read == file_size)
117+
)
99118

100119

101120
S = TypeVar("S")

clients/python/test/e2e/conftest.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from packaging.version import Version
1919
from pydantic import ByteSize
2020
from typing import NamedTuple, Final
21+
from memory_profiler import memory_usage
2122

2223
try:
2324
from osparc._settings import ConfigurationEnvVars
@@ -29,6 +30,8 @@
2930
_MB: ByteSize = ByteSize(_KB * 1024) # in bytes
3031
_GB: ByteSize = ByteSize(_MB * 1024) # in bytes
3132

33+
_logger = logging.getLogger(__name__)
34+
3235
# Dictionary to store start times of tests
3336
_test_start_times = {}
3437

@@ -140,6 +143,7 @@ def async_client() -> Iterable[AsyncClient]:
140143
class ServerFile(NamedTuple):
141144
server_file: osparc.File
142145
local_file: Path
146+
upload_ram_usage: int
143147

144148

145149
@pytest.fixture(scope="session")
@@ -160,11 +164,23 @@ def large_server_file(
160164
assert (
161165
tmp_file.stat().st_size == _file_size
162166
), f"Could not create file of size: {_file_size}"
163-
uploaded_file: osparc.File = files_api.upload_file(tmp_file)
167+
ram_statistics, uploaded_file = memory_usage(
168+
(files_api.upload_file, (tmp_file,)), retval=True
169+
)
164170

165-
yield ServerFile(local_file=tmp_file, server_file=uploaded_file)
171+
yield ServerFile(
172+
local_file=tmp_file,
173+
server_file=uploaded_file,
174+
upload_ram_usage=max(ram_statistics) - min(ram_statistics),
175+
)
166176

167-
files_api.delete_file(uploaded_file.id)
177+
try:
178+
files_api.delete_file(uploaded_file.id)
179+
except osparc.ApiException:
180+
_logger.warning(
181+
f"Could not delete file on server in {file_with_number.__name__}",
182+
exc_info=True,
183+
)
168184

169185

170186
@pytest.fixture
@@ -198,4 +214,10 @@ def file_with_number(
198214
server_file = files_api.upload_file(file)
199215
yield server_file
200216

201-
files_api.delete_file(server_file.id)
217+
try:
218+
files_api.delete_file(server_file.id)
219+
except osparc.ApiException:
220+
_logger.warning(
221+
f"Could not delete file on server in {file_with_number.__name__}",
222+
exc_info=True,
223+
)

0 commit comments

Comments
 (0)