1
+ import base64
1
2
import json
3
+ import time
2
4
from typing import Any , Dict , Optional , Tuple
3
5
4
6
import pyarrow .flight as flight
5
7
from pandas .core .frame import DataFrame
8
+ from pyarrow .flight import ClientMiddleware , ClientMiddlewareFactory
6
9
7
10
from .arrow_graph_constructor import ArrowGraphConstructor
8
11
from .graph_constructor import GraphConstructor
@@ -28,16 +31,15 @@ def __init__(
28
31
else flight .Location .for_grpc_tcp (host , int (port_string ))
29
32
)
30
33
31
- self ._flight_client = flight .FlightClient (location , disable_server_verification = disable_server_verification )
32
- self ._flight_options = flight .FlightCallOptions ()
33
-
34
+ client_options : Dict [str , Any ] = {"disable_server_verification" : disable_server_verification }
34
35
if auth :
35
- username , password = auth
36
- header , token = self ._flight_client .authenticate_basic_token (username , password )
37
- if header :
38
- self ._flight_options = flight .FlightCallOptions (headers = [(header , token )])
36
+ client_options ["middleware" ] = [AuthFactory (auth )]
37
+
38
+ self ._flight_client = flight .FlightClient (location , ** client_options )
39
39
40
- def run_query (self , query : str , params : Dict [str , Any ] = {}) -> DataFrame :
40
+ def run_query (self , query : str , params : Optional [Dict [str , Any ]] = None ) -> DataFrame :
41
+ if params is None :
42
+ params = {}
41
43
if "gds.graph.streamNodeProperty" in query :
42
44
graph_name = params ["graph_name" ]
43
45
property_name = params ["properties" ]
@@ -57,8 +59,10 @@ def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:
57
59
58
60
return self ._fallback_query_runner .run_query (query , params )
59
61
60
- def run_query_with_logging (self , query : str , params : Dict [str , Any ] = {} ) -> DataFrame :
62
+ def run_query_with_logging (self , query : str , params : Optional [ Dict [str , Any ]] = None ) -> DataFrame :
61
63
# For now there's no logging support with Arrow queries.
64
+ if params is None :
65
+ params = {}
62
66
return self ._fallback_query_runner .run_query_with_logging (query , params )
63
67
64
68
def set_database (self , db : str ) -> None :
@@ -79,9 +83,61 @@ def _run_arrow_property_get(self, graph_name: str, procedure_name: str, configur
79
83
}
80
84
ticket = flight .Ticket (json .dumps (payload ).encode ("utf-8" ))
81
85
82
- result : DataFrame = self ._flight_client .do_get (ticket , self ._flight_options ).read_pandas ()
86
+ get = self ._flight_client .do_get (ticket )
87
+ result : DataFrame = get .read_pandas ()
83
88
84
89
return result
85
90
86
91
def create_graph_constructor (self , graph_name : str , concurrency : int ) -> GraphConstructor :
87
- return ArrowGraphConstructor (self , graph_name , self ._flight_client , self ._flight_options , concurrency )
92
+ return ArrowGraphConstructor (self , graph_name , self ._flight_client , concurrency )
93
+
94
+
95
+ class AuthFactory (ClientMiddlewareFactory ): # type: ignore
96
+ def __init__ (self , auth : Tuple [str , str ], * args : Any , ** kwargs : Any ) -> None :
97
+ super ().__init__ (* args , ** kwargs )
98
+ self ._auth = auth
99
+ self ._token : Optional [str ] = None
100
+ self ._token_timestamp = 0
101
+
102
+ def start_call (self , info : Any ) -> "AuthMiddleware" :
103
+ return AuthMiddleware (self )
104
+
105
+ def token (self ) -> Optional [str ]:
106
+ # check whether the token is older than 10 minutes. If so, reset it.
107
+ if self ._token and int (time .time ()) - self ._token_timestamp > 600 :
108
+ self ._token = None
109
+
110
+ return self ._token
111
+
112
+ def set_token (self , token : str ) -> None :
113
+ self ._token = token
114
+ self ._token_timestamp = int (time .time ())
115
+
116
+ @property
117
+ def auth (self ) -> Tuple [str , str ]:
118
+ return self ._auth
119
+
120
+
121
+ class AuthMiddleware (ClientMiddleware ): # type: ignore
122
+ def __init__ (self , factory : AuthFactory , * args : Any , ** kwargs : Any ) -> None :
123
+ super ().__init__ (* args , ** kwargs )
124
+ self ._factory = factory
125
+
126
+ def received_headers (self , headers : Dict [str , Any ]) -> None :
127
+ auth_header : str = headers .get ("Authorization" , None )
128
+ if not auth_header :
129
+ return
130
+ [auth_type , token ] = auth_header .split (" " , 1 )
131
+ if auth_type == "Bearer" :
132
+ self ._factory .set_token (token )
133
+
134
+ def sending_headers (self ) -> Dict [str , str ]:
135
+ token = self ._factory .token ()
136
+ if not token :
137
+ username , password = self ._factory .auth
138
+ auth_token = f"{ username } :{ password } "
139
+ auth_token = "Basic " + base64 .b64encode (auth_token .encode ("utf-8" )).decode ("ASCII" )
140
+ # There seems to be a bug, `authorization` must be lower key
141
+ return {"authorization" : auth_token }
142
+ else :
143
+ return {"authorization" : "Bearer " + token }
0 commit comments