Skip to content

Commit 877427a

Browse files
committed
Implement Arrow based wcc endpoints
1 parent be2985e commit 877427a

File tree

11 files changed

+705
-120
lines changed

11 files changed

+705
-120
lines changed

graphdatascience/arrow_client/authenticated_arrow_client.py

Lines changed: 77 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,41 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import Optional, Union, Any
4+
from dataclasses import dataclass
5+
from typing import Any, Optional, Union
56

67
from pyarrow import __version__ as arrow_version
78
from pyarrow import flight
8-
from pyarrow._flight import FlightTimedOutError, FlightUnavailableError, FlightInternalError, Action
9-
from tenacity import retry_any, retry_if_exception_type, stop_after_delay, stop_after_attempt, wait_exponential, retry
9+
from pyarrow._flight import (
10+
Action,
11+
FlightInternalError,
12+
FlightStreamReader,
13+
FlightTimedOutError,
14+
FlightUnavailableError,
15+
Ticket,
16+
)
17+
from tenacity import retry, retry_any, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential
1018

1119
from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication
1220
from graphdatascience.arrow_client.arrow_info import ArrowInfo
1321
from graphdatascience.retry_utils.retry_config import RetryConfig
14-
from .middleware.AuthMiddleware import AuthMiddleware, AuthFactory
15-
from .middleware.UserAgentMiddleware import UserAgentFactory
22+
1623
from ..retry_utils.retry_utils import before_log
1724
from ..version import __version__
25+
from .middleware.AuthMiddleware import AuthFactory, AuthMiddleware
26+
from .middleware.UserAgentMiddleware import UserAgentFactory
1827

1928

2029
class AuthenticatedArrowClient:
21-
2230
@staticmethod
2331
def create(
24-
arrow_info: ArrowInfo,
25-
auth: Optional[ArrowAuthentication] = None,
26-
encrypted: bool = False,
27-
disable_server_verification: bool = False,
28-
tls_root_certs: Optional[bytes] = None,
29-
connection_string_override: Optional[str] = None,
30-
retry_config: Optional[RetryConfig] = None,
32+
arrow_info: ArrowInfo,
33+
auth: Optional[ArrowAuthentication] = None,
34+
encrypted: bool = False,
35+
disable_server_verification: bool = False,
36+
tls_root_certs: Optional[bytes] = None,
37+
connection_string_override: Optional[str] = None,
38+
retry_config: Optional[RetryConfig] = None,
3139
) -> AuthenticatedArrowClient:
3240
connection_string: str
3341
if connection_string_override is not None:
@@ -59,15 +67,15 @@ def create(
5967
)
6068

6169
def __init__(
62-
self,
63-
host: str,
64-
retry_config: RetryConfig,
65-
port: int = 8491,
66-
auth: Optional[Union[ArrowAuthentication, tuple[str, str]]] = None,
67-
encrypted: bool = False,
68-
disable_server_verification: bool = False,
69-
tls_root_certs: Optional[bytes] = None,
70-
user_agent: Optional[str] = None,
70+
self,
71+
host: str,
72+
retry_config: RetryConfig,
73+
port: int = 8491,
74+
auth: Optional[Union[ArrowAuthentication, tuple[str, str]]] = None,
75+
encrypted: bool = False,
76+
disable_server_verification: bool = False,
77+
tls_root_certs: Optional[bytes] = None,
78+
user_agent: Optional[str] = None,
7179
):
7280
"""Creates a new GdsArrowClient instance.
7381
@@ -85,8 +93,6 @@ def __init__(
8593
A flag that disables server verification for TLS connections (default is False)
8694
tls_root_certs: Optional[bytes]
8795
PEM-encoded certificates that are used for the connection to the GDS Arrow Flight server
88-
arrow_endpoint_version:
89-
The version of the Arrow endpoint to use (default is ArrowEndpointVersion.V1)
9096
user_agent: Optional[str]
9197
The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION])
9298
retry_config: Optional[RetryConfig]
@@ -117,6 +123,48 @@ def __init__(
117123

118124
self._flight_client = self._instantiate_flight_client()
119125

126+
def connection_info(self) -> ConnectionInfo:
127+
"""
128+
Returns the host and port of the GDS Arrow server.
129+
130+
Returns
131+
-------
132+
tuple[str, int]
133+
the host and port of the GDS Arrow server
134+
"""
135+
return ConnectionInfo(self._host, self._port, self._encrypted)
136+
137+
def request_token(self) -> Optional[str]:
138+
"""
139+
Requests a token from the server and returns it.
140+
141+
Returns
142+
-------
143+
Optional[str]
144+
a token from the server and returns it.
145+
"""
146+
147+
@retry(
148+
reraise=True,
149+
before=before_log("Request token", self._logger, logging.DEBUG),
150+
retry=self._retry_config.retry,
151+
stop=self._retry_config.stop,
152+
wait=self._retry_config.wait,
153+
)
154+
def auth_with_retry() -> None:
155+
client = self._flight_client
156+
if self._auth:
157+
auth_pair = self._auth.auth_pair()
158+
client.authenticate_basic_token(auth_pair[0], auth_pair[1])
159+
160+
if self._auth:
161+
auth_with_retry()
162+
return self._auth_middleware.token()
163+
else:
164+
return "IGNORED"
165+
166+
def get_stream(self, ticket: Ticket) -> FlightStreamReader:
167+
return self._flight_client.do_get(ticket)
120168

121169
def do_action(self, endpoint: str, payload: bytes):
122170
return self._flight_client.do_action(Action(endpoint, payload))
@@ -155,3 +203,8 @@ def _instantiate_flight_client(self) -> flight.FlightClient:
155203
return flight.FlightClient(location, **client_options)
156204

157205

206+
@dataclass
207+
class ConnectionInfo:
208+
host: str
209+
port: int
210+
encrypted: bool
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import dataclasses
2+
import json
3+
from dataclasses import fields
4+
from typing import Any, Dict, Iterator, Type, TypeVar
5+
6+
from pyarrow._flight import Result
7+
8+
9+
class DataMapper:
10+
T = TypeVar("T")
11+
12+
@staticmethod
13+
def deserialize_single(input_stream: Iterator[Result], cls: Type[T]) -> T:
14+
rows = DataMapper.deserialize(input_stream, cls)
15+
16+
if len(rows) != 1:
17+
raise ValueError(f"Expected exactly one row, got {len(rows)}")
18+
19+
return rows[0]
20+
21+
@staticmethod
22+
def deserialize(input_stream, cls: Type[T]) -> list[T]:
23+
def deserialize_row(row: Any):
24+
result_dicts = json.loads(row.body.to_pybytes().decode())
25+
if cls == Dict:
26+
return result_dicts
27+
return DataMapper.dict_to_dataclass(result_dicts, cls)
28+
29+
return [deserialize_row(row) for row in list(input_stream)]
30+
31+
@staticmethod
32+
def dict_to_dataclass(data: Dict[str, Any], cls: Type[T], strict: bool = False) -> T:
33+
"""
34+
Convert a dictionary to a dataclass instance with nested dataclass support.
35+
"""
36+
if not dataclasses.is_dataclass(cls):
37+
raise ValueError(f"{cls} is not a dataclass")
38+
39+
field_dict = {f.name: f for f in fields(cls)}
40+
filtered_data = {}
41+
42+
for key, value in data.items():
43+
if key in field_dict:
44+
field = field_dict[key]
45+
field_type = field.type
46+
47+
# Handle nested dataclasses
48+
if dataclasses.is_dataclass(field_type) and isinstance(value, dict):
49+
filtered_data[key] = DataMapper.dict_to_dataclass(value, field_type, strict)
50+
else:
51+
filtered_data[key] = value
52+
elif strict:
53+
raise ValueError(f"Extra field '{key}' not allowed in {cls.__name__}")
54+
55+
return cls(**filtered_data)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from dataclasses import dataclass
2+
3+
@dataclass(frozen=True, repr=True)
4+
class JobIdConfig:
5+
jobId: str
6+
7+
@dataclass(frozen=True, repr=True)
8+
class JobStatus:
9+
jobId: str
10+
status: str
11+
progress: float
12+
13+
14+
@dataclass(frozen=True, repr=True)
15+
class MutateResult:
16+
nodePropertiesWritten: int
17+
relationshipsWritten: int
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import json
2+
from typing import Any, Dict
3+
4+
from pandas import ArrowDtype, DataFrame
5+
from pyarrow._flight import Ticket
6+
7+
from graphdatascience.arrow_client.authenticated_arrow_client import AuthenticatedArrowClient
8+
from graphdatascience.arrow_client.data_mapper import DataMapper
9+
from graphdatascience.arrow_client.v2.api_types import JobIdConfig, JobStatus
10+
11+
JOB_STATUS_ENDPOINT = "v2/jobs.status"
12+
RESULTS_SUMMARY_ENDPOINT = "v2/results.summary"
13+
14+
15+
class JobClient:
16+
@staticmethod
17+
def run_job_and_wait(client: AuthenticatedArrowClient, endpoint: str, config: dict[str, Any]) -> str:
18+
job_id = JobClient.run_job(client, endpoint, config)
19+
JobClient.wait_for_job(client, job_id)
20+
return job_id
21+
22+
@staticmethod
23+
def run_job(client: AuthenticatedArrowClient, endpoint: str, config: dict[str, Any]) -> str:
24+
encoded_config = json.dumps(config).encode("utf-8")
25+
res = client.do_action_with_retry(endpoint, encoded_config)
26+
return DataMapper.deserialize_single(res, JobIdConfig).jobId
27+
28+
@staticmethod
29+
def wait_for_job(client: AuthenticatedArrowClient, job_id: str):
30+
while True:
31+
job_id_config = {"jobId": job_id}
32+
encoded_config = json.dumps(job_id_config).encode("utf-8")
33+
34+
arrow_res = client.do_action_with_retry(JOB_STATUS_ENDPOINT, encoded_config)
35+
job_status = DataMapper.deserialize_single(arrow_res, JobStatus)
36+
if job_status.status == "Done":
37+
break
38+
39+
@staticmethod
40+
def get_summary(client: AuthenticatedArrowClient, job_id: str) -> dict[str, Any]:
41+
job_id_config = {"jobId": job_id}
42+
encoded_config = json.dumps(job_id_config).encode("utf-8")
43+
44+
res = client.do_action_with_retry(RESULTS_SUMMARY_ENDPOINT, encoded_config)
45+
return DataMapper.deserialize_single(res, Dict)
46+
47+
@staticmethod
48+
def stream_results(client: AuthenticatedArrowClient, job_id: str) -> DataFrame:
49+
job_id_config = {"jobId": job_id}
50+
encoded_config = json.dumps(job_id_config).encode("utf-8")
51+
52+
res = client.do_action_with_retry("v2/results.stream", encoded_config)
53+
export_job_id = DataMapper.deserialize_single(res, JobIdConfig).jobId
54+
55+
payload = {
56+
"name": export_job_id,
57+
"version": 1,
58+
}
59+
60+
ticket = Ticket(json.dumps(payload).encode("utf-8"))
61+
with client.get_stream(ticket) as get:
62+
arrow_table = get.read_all()
63+
64+
return arrow_table.to_pandas(types_mapper=ArrowDtype) # type: ignore
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import json
2+
3+
from graphdatascience.arrow_client.authenticated_arrow_client import AuthenticatedArrowClient
4+
from graphdatascience.arrow_client.data_mapper import DataMapper
5+
from graphdatascience.arrow_client.v2.api_types import MutateResult
6+
7+
8+
class MutationClient:
9+
MUTATE_ENDPOINT = "v2/results.mutate"
10+
11+
@staticmethod
12+
def mutate_node_property(client: AuthenticatedArrowClient, job_id: str, mutate_property: str) -> MutateResult:
13+
mutate_config = {"jobId": job_id, "mutateProperty": mutate_property}
14+
encoded_config = json.dumps(mutate_config).encode("utf-8")
15+
mutate_arrow_res = client.do_action_with_retry(MutationClient.MUTATE_ENDPOINT, encoded_config)
16+
return DataMapper.deserialize_single(mutate_arrow_res, MutateResult)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import time
2+
from typing import Any, Optional
3+
4+
from graphdatascience import QueryRunner
5+
from graphdatascience.arrow_client.authenticated_arrow_client import AuthenticatedArrowClient
6+
from graphdatascience.call_parameters import CallParameters
7+
from graphdatascience.query_runner.protocol.write_protocols import WriteProtocol
8+
from graphdatascience.query_runner.termination_flag import TerminationFlagNoop
9+
from graphdatascience.session.dbms.protocol_resolver import ProtocolVersionResolver
10+
11+
12+
class WriteBackClient:
13+
def __init__(self, arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner):
14+
self._arrow_client = arrow_client
15+
self._query_runner = query_runner
16+
17+
protocol_version = ProtocolVersionResolver(query_runner).resolve()
18+
self._write_protocol = WriteProtocol.select(protocol_version)
19+
20+
# TODO: Add progress logging
21+
# TODO: Support setting custom writeProperties and relationshipTypes
22+
def write(self, graph_name: str, job_id: str, concurrency: Optional[int]) -> int:
23+
arrow_config = self._arrow_configuration()
24+
25+
configuration = {}
26+
if concurrency is not None:
27+
configuration["concurrency"] = concurrency
28+
29+
write_back_params = CallParameters(
30+
graphName=graph_name,
31+
jobId=job_id,
32+
arrowConfiguration=arrow_config,
33+
configuration=configuration,
34+
)
35+
36+
start_time = time.time()
37+
38+
self._write_protocol.run_write_back(self._query_runner, write_back_params, None, TerminationFlagNoop())
39+
40+
return int((time.time() - start_time) * 1000)
41+
42+
def _arrow_configuration(self) -> dict[str, Any]:
43+
host, port, encrypted = self._arrow_client.connection_info()
44+
token = self._arrow_client.request_token()
45+
if token is None:
46+
token = "IGNORED"
47+
arrow_config = {"host": host, "port": port, "token": token, "encrypted": encrypted()}
48+
49+
return arrow_config

0 commit comments

Comments
 (0)