Skip to content

Commit fb95cb4

Browse files
yzcj105weinbe58
andauthored
Exclusive access task (#1025)
* Removing uneeded interface * WIP:exlcusive access class * WIP: saving results * exclusive access v1 * remove prints for debug * fix: status return code * status code "Submitted" -> Enqueued * refactor: enable saving and serialization and ExclusiveRemoteTask * refactor: comment out debug print statement in ExclusiveRemoteTask * refactor: streamline task status query logic and remove unused geometry field * use dataclass for ExclusiveCustomRemoteTask * refactor: use black to clean up whitespace and improve code formatting in HTTPHandler and ExclusiveRemoteTask * clean up duplicated imports * fix: add missing newline for improved code readability --------- Co-authored-by: Phillip Weinberg <weinbe58@gmail.com>
1 parent 45b44fe commit fb95cb4

File tree

1 file changed

+346
-0
lines changed

1 file changed

+346
-0
lines changed

src/bloqade/analog/task/exclusive.py

Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
import os
2+
import abc
3+
import uuid
4+
import re
5+
6+
from beartype.typing import Dict
7+
from dataclasses import dataclass, field
8+
9+
from bloqade.analog.task.base import Geometry, CustomRemoteTaskABC
10+
from bloqade.analog.builder.typing import ParamType
11+
from bloqade.analog.submission.ir.parallel import ParallelDecoder
12+
from bloqade.analog.submission.ir.task_results import (
13+
QuEraTaskResults,
14+
QuEraTaskStatusCode,
15+
)
16+
from bloqade.analog.submission.ir.task_specification import QuEraTaskSpecification
17+
from requests import request, get
18+
from bloqade.analog.serialize import Serializer
19+
20+
21+
class HTTPHandlerABC:
22+
@abc.abstractmethod
23+
def submit_task_via_zapier(task_ir: QuEraTaskSpecification, task_id: str):
24+
"""Submit a task and add task_id to the task fields for querying later.
25+
26+
args:
27+
task_ir: The task to be submitted.
28+
task_id: The task id to be added to the task fields.
29+
30+
returns
31+
response: The response from the Zapier webhook. used for error handling
32+
33+
"""
34+
...
35+
36+
@abc.abstractmethod
37+
def query_task_status(task_id: str):
38+
"""Query the task status from the AirTable.
39+
40+
args:
41+
task_id: The task id to be queried.
42+
43+
returns
44+
response: The response from the AirTable. used for error handling
45+
46+
"""
47+
...
48+
49+
@abc.abstractmethod
50+
def fetch_results(task_id: str):
51+
"""Fetch the task results from the AirTable.
52+
53+
args:
54+
task_id: The task id to be queried.
55+
56+
returns
57+
response: The response from the AirTable. used for error handling
58+
59+
"""
60+
61+
...
62+
63+
64+
def convert_preview_to_download(preview_url):
65+
# help function to convert the googledrive preview URL to download URL
66+
# Only used in http handler
67+
match = re.search(r"/d/([^/]+)/", preview_url)
68+
if not match:
69+
raise ValueError("Invalid preview URL format")
70+
file_id = match.group(1)
71+
return f"https://drive.usercontent.google.com/download?id={file_id}&export=download"
72+
73+
74+
class HTTPHandler(HTTPHandlerABC):
75+
def __init__(
76+
self,
77+
zapier_webhook_url: str = None,
78+
zapier_webhook_key: str = None,
79+
vercel_api_url: str = None,
80+
):
81+
self.zapier_webhook_url = zapier_webhook_url or os.environ["ZAPIER_WEBHOOK_URL"]
82+
self.zapier_webhook_key = zapier_webhook_key or os.environ["ZAPIER_WEBHOOK_KEY"]
83+
self.verrcel_api_url = vercel_api_url or os.environ["VERCEL_API_URL"]
84+
85+
def submit_task_via_zapier(
86+
self, task_ir: QuEraTaskSpecification, task_id: str, task_note: str
87+
):
88+
# implement http request logic to submit task via Zapier
89+
request_options = dict(params={"key": self.zapier_webhook_key, "note": task_id})
90+
91+
# for metadata, task_ir in self._compile_single(shots, use_experimental, args):
92+
json_request_body = task_ir.json(exclude_none=True, exclude_unset=True)
93+
94+
request_options.update(data=json_request_body)
95+
response = request("POST", self.zapier_webhook_url, **request_options)
96+
97+
if response.status_code == 200:
98+
response_data = response.json()
99+
submit_status = response_data.get("status", None)
100+
return submit_status
101+
else:
102+
print(f"HTTP request failed with status code: {response.status_code}")
103+
print("HTTP responce: ", response.text)
104+
return "Failed"
105+
106+
def query_task_status(self, task_id: str):
107+
response = request(
108+
"GET",
109+
self.verrcel_api_url,
110+
params={
111+
"searchPattern": task_id,
112+
"magicToken": self.zapier_webhook_key,
113+
"useRegex": False,
114+
},
115+
)
116+
if response.status_code != 200:
117+
return "Not Found"
118+
response_data = response.json()
119+
# Get "matched" from the response
120+
matches = response_data.get("matches", None)
121+
# The return is a list of dictionaries
122+
# Verify if the list contains only one element
123+
if matches is None:
124+
print("No task found with the given ID.")
125+
return "Failed"
126+
elif len(matches) > 1:
127+
print("Multiple tasks found with the given ID.")
128+
return "Failed"
129+
130+
# Extract the status from the first dictionary
131+
status = matches[0].get("status")
132+
return status
133+
134+
def fetch_results(self, task_id: str):
135+
response = request(
136+
"GET",
137+
self.verrcel_api_url,
138+
params={
139+
"searchPattern": task_id,
140+
"magicToken": self.zapier_webhook_key,
141+
"useRegex": False,
142+
},
143+
)
144+
if response.status_code != 200:
145+
print(f"HTTP request failed with status code: {response.status_code}")
146+
print("HTTP responce: ", response.text)
147+
return None
148+
149+
response_data = response.json()
150+
# Get "matched" from the response
151+
matches = response_data.get("matches", None)
152+
# The return is a list of dictionaries
153+
# Verify if the list contains only one element
154+
if matches is None:
155+
print("No task found with the given ID.")
156+
return None
157+
elif len(matches) > 1:
158+
print("Multiple tasks found with the given ID.")
159+
return None
160+
record = matches[0]
161+
if record.get("status") == "Completed":
162+
googledoc = record.get("resultsFileUrl")
163+
164+
# convert the preview URL to download URL
165+
googledoc = convert_preview_to_download(googledoc)
166+
res = get(googledoc)
167+
res.raise_for_status()
168+
data = res.json()
169+
170+
task_results = QuEraTaskResults(**data)
171+
return task_results
172+
173+
174+
class TestHTTPHandler(HTTPHandlerABC):
175+
pass
176+
177+
178+
@dataclass
179+
@Serializer.register
180+
class ExclusiveRemoteTask(CustomRemoteTaskABC):
181+
_task_ir: QuEraTaskSpecification | None
182+
_metadata: Dict[str, ParamType]
183+
_parallel_decoder: ParallelDecoder | None
184+
_http_handler: HTTPHandlerABC = field(default_factory=HTTPHandler)
185+
_task_id: str | None = None
186+
_task_result_ir: QuEraTaskResults | None = None
187+
188+
def __post_init__(self):
189+
float_sites = list(
190+
map(lambda x: (float(x[0]), float(x[1])), self._task_ir.lattice.sites)
191+
)
192+
self._geometry = Geometry(
193+
float_sites, self._task_ir.lattice.filling, self._parallel_decoder
194+
)
195+
196+
@classmethod
197+
def from_compile_results(cls, task_ir, metadata, parallel_decoder):
198+
return cls(
199+
_task_ir=task_ir,
200+
_metadata=metadata,
201+
_parallel_decoder=parallel_decoder,
202+
)
203+
204+
def _submit(self, force: bool = False) -> "ExclusiveRemoteTask":
205+
if not force:
206+
if self._task_id is not None:
207+
raise ValueError(
208+
"the task is already submitted with %s" % (self._task_id)
209+
)
210+
self._task_id = str(uuid.uuid4())
211+
212+
if (
213+
self._http_handler.submit_task_via_zapier(
214+
self._task_ir, self._task_id, None
215+
)
216+
== "success"
217+
):
218+
self._task_result_ir = QuEraTaskResults(
219+
task_status=QuEraTaskStatusCode.Accepted
220+
)
221+
else:
222+
self._task_result_ir = QuEraTaskResults(
223+
task_status=QuEraTaskStatusCode.Failed
224+
)
225+
return self
226+
227+
def fetch(self):
228+
if self._task_result_ir.task_status is QuEraTaskStatusCode.Unsubmitted:
229+
raise ValueError("Task ID not found.")
230+
231+
if self._task_result_ir.task_status in [
232+
QuEraTaskStatusCode.Completed,
233+
QuEraTaskStatusCode.Partial,
234+
QuEraTaskStatusCode.Failed,
235+
QuEraTaskStatusCode.Unaccepted,
236+
QuEraTaskStatusCode.Cancelled,
237+
]:
238+
return self
239+
240+
status = self.status()
241+
if status in [QuEraTaskStatusCode.Completed, QuEraTaskStatusCode.Partial]:
242+
self._task_result_ir = self._http_handler.fetch_results(self._task_id)
243+
else:
244+
self._task_result_ir = QuEraTaskResults(task_status=status)
245+
246+
return self
247+
248+
def pull(self):
249+
# Please avoid using this method, it's blocking and the waiting time is hours long
250+
# Throw an error saying this is not supported
251+
raise NotImplementedError(
252+
"Pulling is not supported. Please use fetch() instead."
253+
)
254+
255+
def cancel(self):
256+
# This is not supported
257+
raise NotImplementedError("Cancelling is not supported.")
258+
259+
def status(self) -> QuEraTaskStatusCode:
260+
if self._task_id is None:
261+
return QuEraTaskStatusCode.Unsubmitted
262+
res = self._http_handler.query_task_status(self._task_id)
263+
if res == "Failed":
264+
raise ValueError("Query task status failed.")
265+
elif res == "Submitted":
266+
return QuEraTaskStatusCode.Enqueued
267+
# TODO: please add all possible status
268+
elif res == "Completed":
269+
return QuEraTaskStatusCode.Completed
270+
elif res == "Running":
271+
# Not covered by test
272+
return QuEraTaskStatusCode.Executing
273+
else:
274+
return self._task_result_ir.task_status
275+
276+
def _result_exists(self):
277+
if self._task_result_ir is None:
278+
return False
279+
else:
280+
if self._task_result_ir.task_status == QuEraTaskStatusCode.Completed:
281+
return True
282+
else:
283+
return False
284+
285+
def result(self):
286+
if self._task_result_ir is None:
287+
raise ValueError("Task result not found.")
288+
return self._task_result_ir
289+
290+
@property
291+
def metadata(self):
292+
return self._metadata
293+
294+
@property
295+
def geometry(self):
296+
return self._geometry
297+
298+
@property
299+
def task_ir(self):
300+
return self._task_ir
301+
302+
@property
303+
def task_id(self) -> str:
304+
assert isinstance(self._task_id, str), "Task ID is not set"
305+
return self._task_id
306+
307+
@property
308+
def task_result_ir(self):
309+
return self._task_result_ir
310+
311+
@property
312+
def parallel_decoder(self):
313+
return self._parallel_decoder
314+
315+
@task_result_ir.setter
316+
def task_result_ir(self, task_result_ir: QuEraTaskResults):
317+
self._task_result_ir = task_result_ir
318+
319+
320+
@ExclusiveRemoteTask.set_serializer
321+
def _serialze(obj: ExclusiveRemoteTask) -> Dict[str, ParamType]:
322+
return {
323+
"task_id": obj.task_id or None,
324+
"task_ir": obj.task_ir.dict(by_alias=True, exclude_none=True),
325+
"metadata": obj.metadata,
326+
"parallel_decoder": (
327+
obj.parallel_decoder.dict() if obj.parallel_decoder else None
328+
),
329+
"task_result_ir": obj.task_result_ir.dict() if obj.task_result_ir else None,
330+
}
331+
332+
333+
@ExclusiveRemoteTask.set_deserializer
334+
def _deserializer(d: Dict[str, any]) -> ExclusiveRemoteTask:
335+
d1 = dict()
336+
d1["_task_ir"] = QuEraTaskSpecification(**d["task_ir"])
337+
d1["_parallel_decoder"] = (
338+
ParallelDecoder(**d["parallel_decoder"]) if d["parallel_decoder"] else None
339+
)
340+
d1["_metadata"] = d["metadata"]
341+
d1["_task_result_ir"] = (
342+
QuEraTaskResults(**d["task_result_ir"]) if d["task_result_ir"] else None
343+
)
344+
d1["_task_id"] = d["task_id"]
345+
346+
return ExclusiveRemoteTask(**d1)

0 commit comments

Comments
 (0)