Skip to content

Commit 7132bad

Browse files
committed
Fix lint
1 parent ba55e3c commit 7132bad

File tree

2 files changed

+93
-14
lines changed

2 files changed

+93
-14
lines changed

queries/polars/cloud_utils.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import base64
2+
import json
3+
import pathlib
4+
from uuid import UUID
5+
6+
import polars_cloud as pc
7+
8+
from settings import Settings
9+
10+
settings = Settings()
11+
12+
13+
def reuse_compute_context(filename: str, log_reuse: bool) -> pc.ComputeContext | None:
14+
with pathlib.Path(filename).open("r", encoding="utf8") as r:
15+
context_args = json.load(r)
16+
17+
required_keys = ["workspace_id", "compute_id"]
18+
for key in required_keys:
19+
assert key in context_args, f"Key {key} not in {filename}"
20+
if log_reuse:
21+
print(f"Reusing existing compute context: {context_args['compute_id']}")
22+
context_args = {key: UUID(context_args.get(key)) for key in required_keys}
23+
try:
24+
ctx = pc.ComputeContext.connect(**context_args)
25+
ctx.start(wait=True)
26+
assert(ctx.get_status() == pc.ComputeContextStatus.RUNNING)
27+
return ctx
28+
except RuntimeError as e:
29+
print(f"Cannot reuse existing compute context: {e.args}")
30+
return None
31+
32+
33+
def get_compute_context_args() -> dict[str, str | int]:
34+
return {
35+
key: value
36+
for key, value in {
37+
"cpus": settings.run.polars_cloud_cpus,
38+
"memory": settings.run.polars_cloud_memory,
39+
"instance_type": settings.run.polars_cloud_instance_type,
40+
"cluster_size": settings.run.polars_cloud_cluster_size,
41+
"workspace": settings.run.polars_cloud_workspace,
42+
}.items()
43+
if value is not None
44+
}
45+
46+
47+
def get_compute_context_filename(context_args: dict[str, str | int]) -> str:
48+
hash = base64.b64encode(str(context_args).encode("utf-8")).decode()
49+
return f".polars-cloud-compute-context-{hash}.json"
50+
51+
52+
def get_compute_context(*, create_if_no_reuse: bool = True, log_create: bool = False, log_reuse: bool = False) -> pc.ComputeContext:
53+
context_args = get_compute_context_args()
54+
context_filename = get_compute_context_filename(context_args)
55+
if pathlib.Path(context_filename).is_file():
56+
ctx = reuse_compute_context(context_filename, log_reuse)
57+
if ctx:
58+
return ctx
59+
60+
# start new compute context
61+
if not create_if_no_reuse:
62+
raise RuntimeError("Cannot reuse compute context")
63+
if log_create:
64+
print(f"Starting new compute context: {context_args}")
65+
ctx = pc.ComputeContext(**context_args) # type: ignore[arg-type]
66+
ctx.start(wait=True)
67+
assert(ctx.get_status() == pc.ComputeContextStatus.RUNNING)
68+
context_args = {"workspace_id": str(ctx.workspace.id), "compute_id": str(ctx._compute_id)}
69+
with pathlib.Path(context_filename).open("w", encoding="utf8") as w:
70+
json.dump(context_args, w)
71+
return ctx
72+
73+
74+
def stop_compute_context(ctx: pc.ComputeContext) -> None:
75+
ctx.stop(wait=True)
76+
context_args = get_compute_context_args()
77+
context_filename = get_compute_context_filename(context_args)
78+
pathlib.Path(context_filename).unlink(missing_ok=True)

queries/polars/utils.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
from typing import Literal
55

66
import polars as pl
7+
78
from queries.common_utils import (
89
check_query_result_pl,
9-
execute_all as common_execute_all,
1010
get_table_path,
1111
run_query_generic,
1212
)
13+
from queries.common_utils import (
14+
execute_all as common_execute_all,
15+
)
1316
from queries.polars.cloud_utils import get_compute_context, stop_compute_context
1417
from settings import Settings
1518

@@ -18,7 +21,7 @@
1821

1922
def execute_all() -> None:
2023
if not settings.run.polars_cloud:
21-
return execute_all("polars")
24+
return common_execute_all("polars")
2225

2326
# for polars cloud we have to create the compute context,
2427
# reuse it across the queries, and stop it in the end
@@ -32,18 +35,20 @@ def execute_all() -> None:
3235

3336
def _scan_ds(table_name: str) -> pl.LazyFrame:
3437
path = get_table_path(table_name)
35-
# pathlib.Path normalizes consecutive slashes, unless Path.from_uri is used (Python >= 3.13)
36-
if isinstance(path, pathlib.Path) and str(path).startswith("s3:/") and not str(path).startswith("s3://"):
37-
path = f"s3://{str(path)[4:]}"
38+
# pathlib.Path normalizes consecutive slashes,
39+
# unless Path.from_uri is used (Python >= 3.13)
40+
path_str = str(path)
41+
if path_str.startswith("s3:/") and not path_str.startswith("s3://"):
42+
path_str = f"s3://{str(path)[4:]}"
3843

3944
if settings.run.io_type == "skip":
40-
return pl.read_parquet(path, rechunk=True).lazy()
45+
return pl.read_parquet(path_str, rechunk=True).lazy()
4146
if settings.run.io_type == "parquet":
42-
return pl.scan_parquet(path)
47+
return pl.scan_parquet(path_str)
4348
elif settings.run.io_type == "feather":
44-
return pl.scan_ipc(path)
49+
return pl.scan_ipc(path_str)
4550
elif settings.run.io_type == "csv":
46-
return pl.scan_csv(path, try_parse_dates=True)
51+
return pl.scan_csv(path_str, try_parse_dates=True)
4752
else:
4853
msg = f"unsupported file type: {settings.run.io_type!r}"
4954
raise ValueError(msg)
@@ -184,11 +189,7 @@ def run_query(query_number: int, lf: pl.LazyFrame) -> None:
184189
ctx = get_compute_context(create_if_no_reuse=False)
185190

186191
def query(): # type: ignore[no-untyped-def]
187-
result = lf.remote(context=ctx).distributed().collect()
188-
189-
if settings.run.show_results:
190-
print(result.plan())
191-
return result.lazy().collect()
192+
return lf.remote(context=ctx).distributed().collect()
192193
else:
193194
query = partial(
194195
lf.collect,

0 commit comments

Comments
 (0)