Skip to content

Commit 8f8111a

Browse files
Restore GdsArrowClient signature
Co-authored-by: Florentin Dörre <florentin.dorre@neotechnology.com>
1 parent 7325b0d commit 8f8111a

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

graphdatascience/query_runner/gds_arrow_client.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
wait_exponential,
3636
)
3737

38-
from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication
38+
from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication, UsernamePasswordAuthentication
3939
from graphdatascience.retry_utils.retry_config import RetryConfig
4040
from graphdatascience.retry_utils.retry_utils import before_log
4141

@@ -49,7 +49,7 @@ class GdsArrowClient:
4949
@staticmethod
5050
def create(
5151
arrow_info: ArrowInfo,
52-
arrow_authentication: Optional[ArrowAuthentication] = None,
52+
auth: Optional[Union[ArrowAuthentication, tuple[str, str]]] = None,
5353
encrypted: bool = False,
5454
disable_server_verification: bool = False,
5555
tls_root_certs: Optional[bytes] = None,
@@ -81,7 +81,7 @@ def create(
8181
host,
8282
retry_config,
8383
int(port),
84-
arrow_authentication,
84+
auth,
8585
encrypted,
8686
disable_server_verification,
8787
tls_root_certs,
@@ -93,7 +93,7 @@ def __init__(
9393
host: str,
9494
retry_config: RetryConfig,
9595
port: int = 8491,
96-
auth: Optional[ArrowAuthentication] = None,
96+
auth: Optional[Union[ArrowAuthentication, tuple[str, str]]] = None,
9797
encrypted: bool = False,
9898
disable_server_verification: bool = False,
9999
tls_root_certs: Optional[bytes] = None,
@@ -108,8 +108,8 @@ def __init__(
108108
The host address of the GDS Arrow server
109109
port: int
110110
The host port of the GDS Arrow server (default is 8491)
111-
auth: Optional[ArrowAuthentication]
112-
An implementation of ArrowAuthentication providing a pair to be used for basic authentication
111+
auth: Optional[Union[ArrowAuthentication, tuple[str, str]]]
112+
Either an implementation of ArrowAuthentication providing a pair to be used for basic authentication, or a username, password tuple
113113
encrypted: bool
114114
A flag that indicates whether the connection should be encrypted (default is False)
115115
disable_server_verification: bool
@@ -126,7 +126,7 @@ def __init__(
126126
self._arrow_endpoint_version = arrow_endpoint_version
127127
self._host = host
128128
self._port = port
129-
self._auth = auth
129+
self._auth = None
130130
self._encrypted = encrypted
131131
self._disable_server_verification = disable_server_verification
132132
self._tls_root_certs = tls_root_certs
@@ -135,6 +135,10 @@ def __init__(
135135
self._logger = logging.getLogger("gds_arrow_client")
136136

137137
if auth:
138+
if not isinstance(auth, ArrowAuthentication):
139+
username, password = auth
140+
auth = UsernamePasswordAuthentication(username, password)
141+
self._auth = auth
138142
self._auth_middleware = AuthMiddleware(auth)
139143

140144
self._flight_client = self._instantiate_flight_client()

0 commit comments

Comments
 (0)