Skip to content

Commit a4818c8

Browse files
Merge branch 'main' of github.com:neo4j/graph-data-science-client
2 parents 0b0361f + 6682607 commit a4818c8

File tree

1 file changed

+47
-12
lines changed

1 file changed

+47
-12
lines changed

graphdatascience/query_runner/gds_arrow_client.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import warnings
88
from dataclasses import dataclass
99
from types import TracebackType
10-
from typing import Any, Callable, Iterable, Optional, Type, Union
10+
from typing import Any, Callable, Dict, Iterable, Optional, Type, Union
1111

1212
import pyarrow
1313
from neo4j.exceptions import ClientError
@@ -89,19 +89,35 @@ def __init__(
8989
self._host = host
9090
self._port = port
9191
self._auth = auth
92+
self._encrypted = encrypted
93+
self._disable_server_verification = disable_server_verification
94+
self._tls_root_certs = tls_root_certs
95+
self._user_agent = user_agent
9296

93-
location = flight.Location.for_grpc_tls(host, port) if encrypted else flight.Location.for_grpc_tcp(host, port)
94-
95-
client_options: dict[str, Any] = {"disable_server_verification": disable_server_verification}
9697
if auth:
9798
self._auth_middleware = AuthMiddleware(auth)
98-
if not user_agent:
99-
user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}"
100-
client_options["middleware"] = [AuthFactory(self._auth_middleware), UserAgentFactory(useragent=user_agent)]
101-
if tls_root_certs:
102-
client_options["tls_root_certs"] = tls_root_certs
10399

104-
self._flight_client = flight.FlightClient(location, **client_options)
100+
self._flight_client = self._instantiate_flight_client()
101+
102+
def _instantiate_flight_client(self) -> flight.FlightClient:
103+
location = (
104+
flight.Location.for_grpc_tls(self._host, self._port)
105+
if self._encrypted
106+
else flight.Location.for_grpc_tcp(self._host, self._port)
107+
)
108+
client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification}
109+
if self._auth:
110+
user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}"
111+
if self._user_agent:
112+
user_agent = self._user_agent
113+
114+
client_options["middleware"] = [
115+
AuthFactory(self._auth_middleware),
116+
UserAgentFactory(useragent=user_agent),
117+
]
118+
if self._tls_root_certs:
119+
client_options["tls_root_certs"] = self._tls_root_certs
120+
return flight.FlightClient(location, **client_options)
105121

106122
def connection_info(self) -> tuple[str, int]:
107123
"""
@@ -537,11 +553,28 @@ def upload_triplets(
537553
"""
538554
self._upload_data(graph_name, "triplet", triplet_data, batch_size, progress_callback)
539555

556+
def __getstate__(self) -> Dict[str, Any]:
557+
state = self.__dict__.copy()
558+
# Remove the FlightClient as it isn't serializable
559+
if "_flight_client" in state:
560+
del state["_flight_client"]
561+
return state
562+
563+
def _client(self) -> flight.FlightClient:
564+
"""
565+
Lazy client construction to help pickle this class because a PyArrow
566+
FlightClient is not serializable.
567+
"""
568+
if not hasattr(self, "_flight_client") or not self._flight_client:
569+
self._flight_client = self._instantiate_flight_client()
570+
return self._flight_client
571+
540572
def _send_action(self, action_type: str, meta_data: dict[str, Any]) -> dict[str, Any]:
541573
action_type = self._versioned_action_type(action_type)
542574

543575
try:
544-
result = self._flight_client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))
576+
client = self._client()
577+
result = client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))
545578

546579
# Consume result fully to sanity check and avoid cancelled streams
547580
collected_result = list(result)
@@ -569,7 +602,9 @@ def _upload_data(
569602

570603
flight_descriptor = self._versioned_flight_descriptor({"name": graph_name, "entity_type": entity_type})
571604
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
572-
put_stream, ack_stream = self._flight_client.do_put(upload_descriptor, batches[0].schema)
605+
606+
client = self._client()
607+
put_stream, ack_stream = client.do_put(upload_descriptor, batches[0].schema)
573608

574609
@retry(
575610
stop=(stop_after_delay(10) | stop_after_attempt(5)),

0 commit comments

Comments
 (0)