Skip to content

Commit 50d88ab

Browse files
committed
Move arrow client related code into its own package
1 parent c0edacc commit 50d88ab

16 files changed

+106
-19
lines changed

graphdatascience/arrow_client/__init__.py

Whitespace-only changes.

graphdatascience/query_runner/arrow_info.py renamed to graphdatascience/arrow_client/arrow_info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from dataclasses import dataclass
44

5-
from ..query_runner.query_runner import QueryRunner
6-
from ..server_version.server_version import ServerVersion
5+
from graphdatascience.query_runner.query_runner import QueryRunner
6+
from graphdatascience.server_version.server_version import ServerVersion
77

88

99
@dataclass(frozen=True)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from __future__ import annotations
2+
3+
import base64
4+
import time
5+
from typing import Optional, Any
6+
7+
from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory
8+
9+
from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication
10+
11+
12+
class AuthFactory(ClientMiddlewareFactory): # type: ignore
13+
def __init__(self, middleware: AuthMiddleware, *args: Any, **kwargs: Any) -> None:
14+
super().__init__(*args, **kwargs)
15+
self._middleware = middleware
16+
17+
def start_call(self, info: Any) -> AuthMiddleware:
18+
return self._middleware
19+
20+
21+
class AuthMiddleware(ClientMiddleware): # type: ignore
22+
def __init__(self, auth: ArrowAuthentication, *args: Any, **kwargs: Any) -> None:
23+
super().__init__(*args, **kwargs)
24+
self._auth = auth
25+
self._token: Optional[str] = None
26+
self._token_timestamp = 0
27+
28+
def token(self) -> Optional[str]:
29+
# check whether the token is older than 10 minutes. If so, reset it.
30+
if self._token and int(time.time()) - self._token_timestamp > 600:
31+
self._token = None
32+
33+
return self._token
34+
35+
def _set_token(self, token: str) -> None:
36+
self._token = token
37+
self._token_timestamp = int(time.time())
38+
39+
def received_headers(self, headers: dict[str, Any]) -> None:
40+
auth_header = headers.get("authorization", None)
41+
if not auth_header:
42+
return
43+
44+
# the result is always a list
45+
header_value = auth_header[0]
46+
47+
if not isinstance(header_value, str):
48+
raise ValueError(f"Incompatible header value received from server: `{header_value}`")
49+
50+
auth_type, token = header_value.split(" ", 1)
51+
if auth_type == "Bearer":
52+
self._set_token(token)
53+
54+
def sending_headers(self) -> dict[str, str]:
55+
token = self.token()
56+
if token is not None:
57+
return {"authorization": "Bearer " + token}
58+
59+
auth_pair = self._auth.auth_pair()
60+
auth_token = f"{auth_pair[0]}:{auth_pair[1]}"
61+
auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII")
62+
# There seems to be a bug, `authorization` must be lower key
63+
return {"authorization": auth_token}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory
6+
7+
class UserAgentFactory(ClientMiddlewareFactory): # type: ignore
8+
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None:
9+
super().__init__(*args, **kwargs)
10+
self._middleware = UserAgentMiddleware(useragent)
11+
12+
def start_call(self, info: Any) -> ClientMiddleware:
13+
return self._middleware
14+
15+
class UserAgentMiddleware(ClientMiddleware): # type: ignore
16+
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None:
17+
super().__init__(*args, **kwargs)
18+
self._useragent = useragent
19+
20+
def sending_headers(self) -> dict[str, str]:
21+
return {"x-gds-user-agent": self._useragent}
22+
23+
def received_headers(self, headers: dict[str, Any]) -> None:
24+
pass

graphdatascience/graph_data_science.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
from neo4j import Driver
99
from pandas import DataFrame
1010

11-
from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication
11+
from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication
1212
from graphdatascience.procedure_surface.api.wcc_endpoints import WccEndpoints
1313
from graphdatascience.procedure_surface.cypher.wcc_proc_runner import WccCypherEndpoints
1414

1515
from .call_builder import IndirectCallBuilder
1616
from .endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints
1717
from .error.uncallable_namespace import UncallableNamespace
1818
from .graph.graph_proc_runner import GraphProcRunner
19-
from .query_runner.arrow_info import ArrowInfo
19+
from graphdatascience.arrow_client.arrow_info import ArrowInfo
2020
from .query_runner.arrow_query_runner import ArrowQueryRunner
2121
from .query_runner.neo4j_query_runner import Neo4jQueryRunner
2222
from .query_runner.query_runner import QueryRunner

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66
from pandas import DataFrame
77

8-
from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication
8+
from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication
99
from graphdatascience.query_runner.query_mode import QueryMode
1010
from graphdatascience.retry_utils.retry_config import RetryConfig
1111

1212
from ..call_parameters import CallParameters
13-
from ..query_runner.arrow_info import ArrowInfo
13+
from graphdatascience.arrow_client.arrow_info import ArrowInfo
1414
from ..server_version.server_version import ServerVersion
1515
from .arrow_graph_constructor import ArrowGraphConstructor
1616
from .gds_arrow_client import GdsArrowClient

graphdatascience/query_runner/gds_arrow_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@
3535
wait_exponential,
3636
)
3737

38-
from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication, UsernamePasswordAuthentication
38+
from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication, UsernamePasswordAuthentication
3939
from graphdatascience.retry_utils.retry_config import RetryConfig
4040
from graphdatascience.retry_utils.retry_utils import before_log
4141

4242
from ..semantic_version.semantic_version import SemanticVersion
4343
from ..version import __version__
4444
from .arrow_endpoint_version import ArrowEndpointVersion
45-
from .arrow_info import ArrowInfo
45+
from graphdatascience.arrow_client.arrow_info import ArrowInfo
4646

4747

4848
class GdsArrowClient:

graphdatascience/session/aura_api_token_authentication.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication
1+
from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication
22
from graphdatascience.session.aura_api import AuraApi
33

44

graphdatascience/session/aura_graph_data_science.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
)
1414
from graphdatascience.error.uncallable_namespace import UncallableNamespace
1515
from graphdatascience.graph.graph_remote_proc_runner import GraphRemoteProcRunner
16-
from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication
17-
from graphdatascience.query_runner.arrow_info import ArrowInfo
16+
from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication
17+
from graphdatascience.arrow_client.arrow_info import ArrowInfo
1818
from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner
1919
from graphdatascience.query_runner.gds_arrow_client import GdsArrowClient
2020
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner

0 commit comments

Comments
 (0)