Skip to content

Commit f694b8b

Browse files
authored
Merge pull request #107 from DarthMax/fix_authentication_timeout
Fix client authentication
2 parents ac65152 + 4125b79 commit f694b8b

File tree

5 files changed

+85
-24
lines changed

5 files changed

+85
-24
lines changed

graphdatascience/query_runner/arrow_graph_constructor.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,13 @@ def __init__(
1818
query_runner: QueryRunner,
1919
graph_name: str,
2020
flight_client: flight.FlightClient,
21-
flight_options: flight.FlightCallOptions,
2221
concurrency: int,
2322
chunk_size: int = 10_000,
2423
):
2524
self._query_runner = query_runner
2625
self._concurrency = concurrency
2726
self._graph_name = graph_name
2827
self._client = flight_client
29-
self._flight_options = flight_options
3028
self._chunk_size = chunk_size
3129
self._min_batch_size = chunk_size * 10
3230

@@ -58,9 +56,7 @@ def _partition_dfs(self, dfs: List[DataFrame]) -> List[DataFrame]:
5856
return partitioned_dfs
5957

6058
def _send_action(self, action_type: str, meta_data: Dict[str, str]) -> None:
61-
result = self._client.do_action(
62-
flight.Action(action_type, json.dumps(meta_data).encode("utf-8")), self._flight_options
63-
)
59+
result = self._client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))
6460

6561
json.loads(next(result).body.to_pybytes().decode())
6662

@@ -71,7 +67,7 @@ def _send_df(self, df: DataFrame, entity_type: str) -> None:
7167
# Write schema
7268
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
7369

74-
writer, _ = self._client.do_put(upload_descriptor, table.schema, self._flight_options)
70+
writer, _ = self._client.do_put(upload_descriptor, table.schema)
7571

7672
with writer:
7773
# Write table in chunks

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import base64
12
import json
3+
import time
24
from typing import Any, Dict, Optional, Tuple
35

46
import pyarrow.flight as flight
57
from pandas.core.frame import DataFrame
8+
from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory
69

710
from .arrow_graph_constructor import ArrowGraphConstructor
811
from .graph_constructor import GraphConstructor
@@ -28,16 +31,15 @@ def __init__(
2831
else flight.Location.for_grpc_tcp(host, int(port_string))
2932
)
3033

31-
self._flight_client = flight.FlightClient(location, disable_server_verification=disable_server_verification)
32-
self._flight_options = flight.FlightCallOptions()
33-
34+
client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification}
3435
if auth:
35-
username, password = auth
36-
header, token = self._flight_client.authenticate_basic_token(username, password)
37-
if header:
38-
self._flight_options = flight.FlightCallOptions(headers=[(header, token)])
36+
client_options["middleware"] = [AuthFactory(auth)]
37+
38+
self._flight_client = flight.FlightClient(location, **client_options)
3939

40-
def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:
40+
def run_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> DataFrame:
41+
if params is None:
42+
params = {}
4143
if "gds.graph.streamNodeProperty" in query:
4244
graph_name = params["graph_name"]
4345
property_name = params["properties"]
@@ -57,8 +59,10 @@ def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:
5759

5860
return self._fallback_query_runner.run_query(query, params)
5961

60-
def run_query_with_logging(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:
62+
def run_query_with_logging(self, query: str, params: Optional[Dict[str, Any]] = None) -> DataFrame:
6163
# For now there's no logging support with Arrow queries.
64+
if params is None:
65+
params = {}
6266
return self._fallback_query_runner.run_query_with_logging(query, params)
6367

6468
def set_database(self, db: str) -> None:
@@ -79,9 +83,61 @@ def _run_arrow_property_get(self, graph_name: str, procedure_name: str, configur
7983
}
8084
ticket = flight.Ticket(json.dumps(payload).encode("utf-8"))
8185

82-
result: DataFrame = self._flight_client.do_get(ticket, self._flight_options).read_pandas()
86+
get = self._flight_client.do_get(ticket)
87+
result: DataFrame = get.read_pandas()
8388

8489
return result
8590

8691
def create_graph_constructor(self, graph_name: str, concurrency: int) -> GraphConstructor:
87-
return ArrowGraphConstructor(self, graph_name, self._flight_client, self._flight_options, concurrency)
92+
return ArrowGraphConstructor(self, graph_name, self._flight_client, concurrency)
93+
94+
95+
class AuthFactory(ClientMiddlewareFactory): # type: ignore
96+
def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None:
97+
super().__init__(*args, **kwargs)
98+
self._auth = auth
99+
self._token: Optional[str] = None
100+
self._token_timestamp = 0
101+
102+
def start_call(self, info: Any) -> "AuthMiddleware":
103+
return AuthMiddleware(self)
104+
105+
def token(self) -> Optional[str]:
106+
# check whether the token is older than 10 minutes. If so, reset it.
107+
if self._token and int(time.time()) - self._token_timestamp > 600:
108+
self._token = None
109+
110+
return self._token
111+
112+
def set_token(self, token: str) -> None:
113+
self._token = token
114+
self._token_timestamp = int(time.time())
115+
116+
@property
117+
def auth(self) -> Tuple[str, str]:
118+
return self._auth
119+
120+
121+
class AuthMiddleware(ClientMiddleware): # type: ignore
122+
def __init__(self, factory: AuthFactory, *args: Any, **kwargs: Any) -> None:
123+
super().__init__(*args, **kwargs)
124+
self._factory = factory
125+
126+
def received_headers(self, headers: Dict[str, Any]) -> None:
127+
auth_header: str = headers.get("Authorization", None)
128+
if not auth_header:
129+
return
130+
[auth_type, token] = auth_header.split(" ", 1)
131+
if auth_type == "Bearer":
132+
self._factory.set_token(token)
133+
134+
def sending_headers(self) -> Dict[str, str]:
135+
token = self._factory.token()
136+
if not token:
137+
username, password = self._factory.auth
138+
auth_token = f"{username}:{password}"
139+
auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII")
140+
# There seems to be a bug, `authorization` must be lower key
141+
return {"authorization": auth_token}
142+
else:
143+
return {"authorization": "Bearer " + token}

graphdatascience/query_runner/neo4j_query_runner.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ def __init__(self, driver: neo4j.Driver, db: Optional[str] = neo4j.DEFAULT_DATAB
3131
except Exception as e:
3232
raise UnableToConnectError(e)
3333

34-
def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:
34+
def run_query(self, query: str, params: Optional[Dict[str, str]] = None) -> DataFrame:
35+
if params is None:
36+
params = {}
37+
3538
with self._driver.session(database=self._db) as session:
3639
result = session.run(query, params)
3740

@@ -44,7 +47,10 @@ def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:
4447

4548
return result.to_df() # type: ignore
4649

47-
def run_query_with_logging(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:
50+
def run_query_with_logging(self, query: str, params: Optional[Dict[str, Any]] = None) -> DataFrame:
51+
if params is None:
52+
params = {}
53+
4854
if self._server_version < ServerVersion(2, 1, 0):
4955
return self.run_query(query, params)
5056

graphdatascience/query_runner/query_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Dict
2+
from typing import Any, Dict, Optional
33

44
from pandas.core.frame import DataFrame
55

@@ -9,10 +9,10 @@
99

1010
class QueryRunner(ABC):
1111
@abstractmethod
12-
def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:
12+
def run_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> DataFrame:
1313
pass
1414

15-
def run_query_with_logging(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:
15+
def run_query_with_logging(self, query: str, params: Optional[Dict[str, Any]] = None) -> DataFrame:
1616
return self.run_query(query, params)
1717

1818
@abstractmethod

graphdatascience/tests/unit/conftest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Union
1+
from typing import Any, Dict, List, Optional, Union
22

33
import pandas
44
import pytest
@@ -19,7 +19,10 @@ def __init__(self, server_version: Union[str, ServerVersion]) -> None:
1919
self.params: List[Dict[str, Any]] = []
2020
self.server_version = server_version
2121

22-
def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:
22+
def run_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> DataFrame:
23+
if params is None:
24+
params = {}
25+
2326
self.queries.append(query)
2427
self.params.append(params)
2528

0 commit comments

Comments
 (0)