Skip to content

Commit be2985e

Browse files
committed
Introduce AuthenticatedArrowClient
1 parent 6c39bc0 commit be2985e

File tree

2 files changed

+157
-0
lines changed

2 files changed

+157
-0
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from typing import Optional, Union, Any
5+
6+
from pyarrow import __version__ as arrow_version
7+
from pyarrow import flight
8+
from pyarrow._flight import FlightTimedOutError, FlightUnavailableError, FlightInternalError, Action
9+
from tenacity import retry_any, retry_if_exception_type, stop_after_delay, stop_after_attempt, wait_exponential, retry
10+
11+
from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication
12+
from graphdatascience.arrow_client.arrow_info import ArrowInfo
13+
from graphdatascience.retry_utils.retry_config import RetryConfig
14+
from .middleware.AuthMiddleware import AuthMiddleware, AuthFactory
15+
from .middleware.UserAgentMiddleware import UserAgentFactory
16+
from ..retry_utils.retry_utils import before_log
17+
from ..version import __version__
18+
19+
20+
class AuthenticatedArrowClient:
21+
22+
@staticmethod
23+
def create(
24+
arrow_info: ArrowInfo,
25+
auth: Optional[ArrowAuthentication] = None,
26+
encrypted: bool = False,
27+
disable_server_verification: bool = False,
28+
tls_root_certs: Optional[bytes] = None,
29+
connection_string_override: Optional[str] = None,
30+
retry_config: Optional[RetryConfig] = None,
31+
) -> AuthenticatedArrowClient:
32+
connection_string: str
33+
if connection_string_override is not None:
34+
connection_string = connection_string_override
35+
else:
36+
connection_string = arrow_info.listenAddress
37+
38+
host, port = connection_string.split(":")
39+
40+
if retry_config is None:
41+
retry_config = RetryConfig(
42+
retry=retry_any(
43+
retry_if_exception_type(FlightTimedOutError),
44+
retry_if_exception_type(FlightUnavailableError),
45+
retry_if_exception_type(FlightInternalError),
46+
),
47+
stop=(stop_after_delay(10) | stop_after_attempt(5)),
48+
wait=wait_exponential(multiplier=1, min=1, max=10),
49+
)
50+
51+
return AuthenticatedArrowClient(
52+
host,
53+
retry_config,
54+
int(port),
55+
auth,
56+
encrypted,
57+
disable_server_verification,
58+
tls_root_certs,
59+
)
60+
61+
def __init__(
62+
self,
63+
host: str,
64+
retry_config: RetryConfig,
65+
port: int = 8491,
66+
auth: Optional[Union[ArrowAuthentication, tuple[str, str]]] = None,
67+
encrypted: bool = False,
68+
disable_server_verification: bool = False,
69+
tls_root_certs: Optional[bytes] = None,
70+
user_agent: Optional[str] = None,
71+
):
72+
"""Creates a new GdsArrowClient instance.
73+
74+
Parameters
75+
----------
76+
host: str
77+
The host address of the GDS Arrow server
78+
port: int
79+
The host port of the GDS Arrow server (default is 8491)
80+
auth: Optional[Union[ArrowAuthentication, tuple[str, str]]]
81+
Either an implementation of ArrowAuthentication providing a pair to be used for basic authentication, or a username, password tuple
82+
encrypted: bool
83+
A flag that indicates whether the connection should be encrypted (default is False)
84+
disable_server_verification: bool
85+
A flag that disables server verification for TLS connections (default is False)
86+
tls_root_certs: Optional[bytes]
87+
PEM-encoded certificates that are used for the connection to the GDS Arrow Flight server
88+
arrow_endpoint_version:
89+
The version of the Arrow endpoint to use (default is ArrowEndpointVersion.V1)
90+
user_agent: Optional[str]
91+
The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION])
92+
retry_config: Optional[RetryConfig]
93+
The retry configuration to use for the Arrow requests send by the client.
94+
"""
95+
self._host = host
96+
self._port = port
97+
self._auth = None
98+
self._encrypted = encrypted
99+
self._disable_server_verification = disable_server_verification
100+
self._tls_root_certs = tls_root_certs
101+
self._user_agent = user_agent
102+
self._retry_config = retry_config
103+
self._logger = logging.getLogger("gds_arrow_client")
104+
self._retry_config = RetryConfig(
105+
retry=retry_any(
106+
retry_if_exception_type(FlightTimedOutError),
107+
retry_if_exception_type(FlightUnavailableError),
108+
retry_if_exception_type(FlightInternalError),
109+
),
110+
stop=(stop_after_delay(10) | stop_after_attempt(5)),
111+
wait=wait_exponential(multiplier=1, min=1, max=10),
112+
)
113+
114+
if auth:
115+
self._auth = auth
116+
self._auth_middleware = AuthMiddleware(auth)
117+
118+
self._flight_client = self._instantiate_flight_client()
119+
120+
121+
def do_action(self, endpoint: str, payload: bytes):
122+
return self._flight_client.do_action(Action(endpoint, payload))
123+
124+
def do_action_with_retry(self, endpoint: str, payload: bytes):
125+
@retry(
126+
reraise=True,
127+
before=before_log("Send action", self._logger, logging.DEBUG),
128+
retry=self._retry_config.retry,
129+
stop=self._retry_config.stop,
130+
wait=self._retry_config.wait,
131+
)
132+
def run_with_retry():
133+
return self.do_action(endpoint, payload)
134+
135+
return run_with_retry()
136+
137+
def _instantiate_flight_client(self) -> flight.FlightClient:
138+
location = (
139+
flight.Location.for_grpc_tls(self._host, self._port)
140+
if self._encrypted
141+
else flight.Location.for_grpc_tcp(self._host, self._port)
142+
)
143+
client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification}
144+
if self._auth:
145+
user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}"
146+
if self._user_agent:
147+
user_agent = self._user_agent
148+
149+
client_options["middleware"] = [
150+
AuthFactory(self._auth_middleware),
151+
UserAgentFactory(useragent=user_agent),
152+
]
153+
if self._tls_root_certs:
154+
client_options["tls_root_certs"] = self._tls_root_certs
155+
return flight.FlightClient(location, **client_options)
156+
157+

graphdatascience/tests/unit/arrow_client/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)