35
35
wait_exponential ,
36
36
)
37
37
38
- from graphdatascience .query_runner .arrow_authentication import ArrowAuthentication
38
+ from graphdatascience .query_runner .arrow_authentication import ArrowAuthentication , UsernamePasswordAuthentication
39
39
from graphdatascience .retry_utils .retry_config import RetryConfig
40
40
from graphdatascience .retry_utils .retry_utils import before_log
41
41
@@ -49,7 +49,7 @@ class GdsArrowClient:
49
49
@staticmethod
50
50
def create (
51
51
arrow_info : ArrowInfo ,
52
- arrow_authentication : Optional [ArrowAuthentication ] = None ,
52
+ auth : Optional [Union [ ArrowAuthentication , tuple [ str , str ]] ] = None ,
53
53
encrypted : bool = False ,
54
54
disable_server_verification : bool = False ,
55
55
tls_root_certs : Optional [bytes ] = None ,
@@ -81,7 +81,7 @@ def create(
81
81
host ,
82
82
retry_config ,
83
83
int (port ),
84
- arrow_authentication ,
84
+ auth ,
85
85
encrypted ,
86
86
disable_server_verification ,
87
87
tls_root_certs ,
@@ -93,7 +93,7 @@ def __init__(
93
93
host : str ,
94
94
retry_config : RetryConfig ,
95
95
port : int = 8491 ,
96
- auth : Optional [ArrowAuthentication ] = None ,
96
+ auth : Optional [Union [ ArrowAuthentication , tuple [ str , str ]] ] = None ,
97
97
encrypted : bool = False ,
98
98
disable_server_verification : bool = False ,
99
99
tls_root_certs : Optional [bytes ] = None ,
@@ -108,8 +108,8 @@ def __init__(
108
108
The host address of the GDS Arrow server
109
109
port: int
110
110
The host port of the GDS Arrow server (default is 8491)
111
- auth: Optional[ArrowAuthentication]
112
- An implementation of ArrowAuthentication providing a pair to be used for basic authentication
111
+ auth: Optional[Union[ ArrowAuthentication, tuple[str, str]] ]
112
+ Either an implementation of ArrowAuthentication providing a pair to be used for basic authentication, or a username, password tuple
113
113
encrypted: bool
114
114
A flag that indicates whether the connection should be encrypted (default is False)
115
115
disable_server_verification: bool
@@ -126,7 +126,7 @@ def __init__(
126
126
self ._arrow_endpoint_version = arrow_endpoint_version
127
127
self ._host = host
128
128
self ._port = port
129
- self ._auth = auth
129
+ self ._auth = None
130
130
self ._encrypted = encrypted
131
131
self ._disable_server_verification = disable_server_verification
132
132
self ._tls_root_certs = tls_root_certs
@@ -135,6 +135,10 @@ def __init__(
135
135
self ._logger = logging .getLogger ("gds_arrow_client" )
136
136
137
137
if auth :
138
+ if not isinstance (auth , ArrowAuthentication ):
139
+ username , password = auth
140
+ auth = UsernamePasswordAuthentication (username , password )
141
+ self ._auth = auth
138
142
self ._auth_middleware = AuthMiddleware (auth )
139
143
140
144
self ._flight_client = self ._instantiate_flight_client ()
0 commit comments