|
7 | 7 | import warnings
|
8 | 8 | from dataclasses import dataclass
|
9 | 9 | from types import TracebackType
|
10 |
| -from typing import Any, Callable, Iterable, Optional, Type, Union |
| 10 | +from typing import Any, Callable, Dict, Iterable, Optional, Type, Union |
11 | 11 |
|
12 | 12 | import pyarrow
|
13 | 13 | from neo4j.exceptions import ClientError
|
@@ -89,19 +89,35 @@ def __init__(
|
89 | 89 | self._host = host
|
90 | 90 | self._port = port
|
91 | 91 | self._auth = auth
|
| 92 | + self._encrypted = encrypted |
| 93 | + self._disable_server_verification = disable_server_verification |
| 94 | + self._tls_root_certs = tls_root_certs |
| 95 | + self._user_agent = user_agent |
92 | 96 |
|
93 |
| - location = flight.Location.for_grpc_tls(host, port) if encrypted else flight.Location.for_grpc_tcp(host, port) |
94 |
| - |
95 |
| - client_options: dict[str, Any] = {"disable_server_verification": disable_server_verification} |
96 | 97 | if auth:
|
97 | 98 | self._auth_middleware = AuthMiddleware(auth)
|
98 |
| - if not user_agent: |
99 |
| - user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" |
100 |
| - client_options["middleware"] = [AuthFactory(self._auth_middleware), UserAgentFactory(useragent=user_agent)] |
101 |
| - if tls_root_certs: |
102 |
| - client_options["tls_root_certs"] = tls_root_certs |
103 | 99 |
|
104 |
| - self._flight_client = flight.FlightClient(location, **client_options) |
| 100 | + self._flight_client = self._instantiate_flight_client() |
| 101 | + |
| 102 | + def _instantiate_flight_client(self) -> flight.FlightClient: |
| 103 | + location = ( |
| 104 | + flight.Location.for_grpc_tls(self._host, self._port) |
| 105 | + if self._encrypted |
| 106 | + else flight.Location.for_grpc_tcp(self._host, self._port) |
| 107 | + ) |
| 108 | + client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification} |
| 109 | + if self._auth: |
| 110 | + user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" |
| 111 | + if self._user_agent: |
| 112 | + user_agent = self._user_agent |
| 113 | + |
| 114 | + client_options["middleware"] = [ |
| 115 | + AuthFactory(self._auth_middleware), |
| 116 | + UserAgentFactory(useragent=user_agent), |
| 117 | + ] |
| 118 | + if self._tls_root_certs: |
| 119 | + client_options["tls_root_certs"] = self._tls_root_certs |
| 120 | + return flight.FlightClient(location, **client_options) |
105 | 121 |
|
106 | 122 | def connection_info(self) -> tuple[str, int]:
|
107 | 123 | """
|
@@ -537,11 +553,28 @@ def upload_triplets(
|
537 | 553 | """
|
538 | 554 | self._upload_data(graph_name, "triplet", triplet_data, batch_size, progress_callback)
|
539 | 555 |
|
| 556 | + def __getstate__(self) -> Dict[str, Any]: |
| 557 | + state = self.__dict__.copy() |
| 558 | + # Remove the FlightClient as it isn't serializable |
| 559 | + if "_flight_client" in state: |
| 560 | + del state["_flight_client"] |
| 561 | + return state |
| 562 | + |
| 563 | + def _client(self) -> flight.FlightClient: |
| 564 | + """ |
| 565 | + Lazy client construction to help pickle this class because a PyArrow |
| 566 | + FlightClient is not serializable. |
| 567 | + """ |
| 568 | + if not hasattr(self, "_flight_client") or not self._flight_client: |
| 569 | + self._flight_client = self._instantiate_flight_client() |
| 570 | + return self._flight_client |
| 571 | + |
540 | 572 | def _send_action(self, action_type: str, meta_data: dict[str, Any]) -> dict[str, Any]:
|
541 | 573 | action_type = self._versioned_action_type(action_type)
|
542 | 574 |
|
543 | 575 | try:
|
544 |
| - result = self._flight_client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8"))) |
| 576 | + client = self._client() |
| 577 | + result = client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8"))) |
545 | 578 |
|
546 | 579 | # Consume result fully to sanity check and avoid cancelled streams
|
547 | 580 | collected_result = list(result)
|
@@ -569,7 +602,9 @@ def _upload_data(
|
569 | 602 |
|
570 | 603 | flight_descriptor = self._versioned_flight_descriptor({"name": graph_name, "entity_type": entity_type})
|
571 | 604 | upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
|
572 |
| - put_stream, ack_stream = self._flight_client.do_put(upload_descriptor, batches[0].schema) |
| 605 | + |
| 606 | + client = self._client() |
| 607 | + put_stream, ack_stream = client.do_put(upload_descriptor, batches[0].schema) |
573 | 608 |
|
574 | 609 | @retry(
|
575 | 610 | stop=(stop_after_delay(10) | stop_after_attempt(5)),
|
|
0 commit comments