|
| 1 | +import base64 |
| 2 | +import json |
| 3 | +import pathlib |
| 4 | +from uuid import UUID |
| 5 | + |
| 6 | +import polars_cloud as pc |
| 7 | + |
| 8 | +from settings import Settings |
| 9 | + |
| 10 | +settings = Settings() |
| 11 | + |
| 12 | + |
| 13 | +def reuse_compute_context(filename: str, log_reuse: bool) -> pc.ComputeContext | None: |
| 14 | + with pathlib.Path(filename).open("r", encoding="utf8") as r: |
| 15 | + context_args = json.load(r) |
| 16 | + |
| 17 | + required_keys = ["workspace_id", "compute_id"] |
| 18 | + for key in required_keys: |
| 19 | + assert key in context_args, f"Key {key} not in {filename}" |
| 20 | + if log_reuse: |
| 21 | + print(f"Reusing existing compute context: {context_args['compute_id']}") |
| 22 | + context_args = {key: UUID(context_args.get(key)) for key in required_keys} |
| 23 | + try: |
| 24 | + ctx = pc.ComputeContext.connect(**context_args) |
| 25 | + ctx.start(wait=True) |
| 26 | + assert(ctx.get_status() == pc.ComputeContextStatus.RUNNING) |
| 27 | + except RuntimeError as e: |
| 28 | + print(f"Cannot reuse existing compute context: {e.args}") |
| 29 | + return None |
| 30 | + return ctx |
| 31 | + |
| 32 | + |
| 33 | +def get_compute_context_args() -> dict[str, str | int]: |
| 34 | + return { |
| 35 | + key: value |
| 36 | + for key, value in { |
| 37 | + "cpus": settings.run.polars_cloud_cpus, |
| 38 | + "memory": settings.run.polars_cloud_memory, |
| 39 | + "instance_type": settings.run.polars_cloud_instance_type, |
| 40 | + "cluster_size": settings.run.polars_cloud_cluster_size, |
| 41 | + "workspace": settings.run.polars_cloud_workspace, |
| 42 | + }.items() |
| 43 | + if value is not None |
| 44 | + } |
| 45 | + |
| 46 | + |
| 47 | +def get_compute_context_filename(context_args: dict[str, str | int]) -> str: |
| 48 | + hash = base64.b64encode(str(context_args).encode("utf-8")).decode() |
| 49 | + return f".polars-cloud-compute-context-{hash}.json" |
| 50 | + |
| 51 | + |
| 52 | +def get_compute_context(*, create_if_no_reuse: bool = True, log_create: bool = False, log_reuse: bool = False) -> pc.ComputeContext: |
| 53 | + context_args = get_compute_context_args() |
| 54 | + context_filename = get_compute_context_filename(context_args) |
| 55 | + if pathlib.Path(context_filename).is_file(): |
| 56 | + ctx = reuse_compute_context(context_filename, log_reuse) |
| 57 | + if ctx: |
| 58 | + return ctx |
| 59 | + |
| 60 | + # start new compute context |
| 61 | + if not create_if_no_reuse: |
| 62 | + msg = "Cannot reuse compute context" |
| 63 | + raise RuntimeError(msg) |
| 64 | + if log_create: |
| 65 | + print(f"Starting new compute context: {context_args}") |
| 66 | + ctx = pc.ComputeContext(**context_args) # type: ignore[arg-type] |
| 67 | + ctx.start(wait=True) |
| 68 | + assert(ctx.get_status() == pc.ComputeContextStatus.RUNNING) |
| 69 | + context_args = {"workspace_id": str(ctx.workspace.id), "compute_id": str(ctx._compute_id)} |
| 70 | + with pathlib.Path(context_filename).open("w", encoding="utf8") as w: |
| 71 | + json.dump(context_args, w) |
| 72 | + return ctx |
| 73 | + |
| 74 | + |
| 75 | +def stop_compute_context(ctx: pc.ComputeContext) -> None: |
| 76 | + ctx.stop(wait=True) |
| 77 | + context_args = get_compute_context_args() |
| 78 | + context_filename = get_compute_context_filename(context_args) |
| 79 | + pathlib.Path(context_filename).unlink(missing_ok=True) |
0 commit comments