Skip to content

Commit f17e876

Browse files
authored
Merge pull request #741 from RafalSkolasinski/readability-changes
minor readability changes
2 parents 862b555 + bac0fd5 commit f17e876

File tree

1 file changed

+33
-23
lines changed

1 file changed

+33
-23
lines changed

graphdatascience/session/aura_graph_data_science.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,56 +35,66 @@ def create(
3535
arrow_tls_root_certs: Optional[bytes] = None,
3636
bookmarks: Optional[Any] = None,
3737
):
38-
session_neo4j_query_runner = Neo4jQueryRunner.create(
39-
gds_session_connection_info.uri, gds_session_connection_info.auth(), aura_ds=True
38+
# we need to explicitly set this as the default value is None
39+
# database in the session is always neo4j
40+
session_bolt_query_runner = Neo4jQueryRunner.create(
41+
endpoint=gds_session_connection_info.uri,
42+
auth=gds_session_connection_info.auth(),
43+
aura_ds=True,
44+
database="neo4j",
4045
)
4146
session_arrow_query_runner = ArrowQueryRunner.create(
42-
session_neo4j_query_runner,
47+
fallback_query_runner=session_bolt_query_runner,
48+
auth=gds_session_connection_info.auth(),
49+
encrypted=session_bolt_query_runner.encrypted(),
50+
disable_server_verification=arrow_disable_server_verification,
51+
tls_root_certs=arrow_tls_root_certs,
52+
)
53+
54+
# TODO: merge with the gds_arrow_client created inside ArrowQueryRunner
55+
session_arrow_client = GdsArrowClient.create(
56+
session_bolt_query_runner,
4357
gds_session_connection_info.auth(),
44-
session_neo4j_query_runner.encrypted(),
58+
session_bolt_query_runner.encrypted(),
4559
arrow_disable_server_verification,
4660
arrow_tls_root_certs,
4761
)
4862

49-
# we need to explicitly set this as the default value is None
50-
# database in the session is always neo4j
51-
session_arrow_query_runner.set_database("neo4j")
52-
53-
db_query_runner = Neo4jQueryRunner.create(
63+
db_bolt_query_runner = Neo4jQueryRunner.create(
5464
db_connection_info.uri,
5565
db_connection_info.auth(),
5666
aura_ds=True,
5767
)
58-
db_query_runner.set_bookmarks(bookmarks)
68+
db_bolt_query_runner.set_bookmarks(bookmarks)
5969

60-
session_arrow_client = GdsArrowClient.create(
61-
session_neo4j_query_runner,
62-
gds_session_connection_info.auth(),
63-
session_neo4j_query_runner.encrypted(),
64-
arrow_disable_server_verification,
65-
arrow_tls_root_certs,
66-
)
67-
aura_db_query_runner = SessionQueryRunner.create(
68-
session_arrow_query_runner, db_query_runner, session_arrow_client
70+
session_query_runner = SessionQueryRunner.create(
71+
session_arrow_query_runner, db_bolt_query_runner, session_arrow_client
6972
)
7073

71-
gds_version = session_neo4j_query_runner.server_version()
72-
return cls(query_runner=aura_db_query_runner, delete_fn=delete_fn, gds_version=gds_version)
74+
gds_version = session_bolt_query_runner.server_version()
75+
return cls(
76+
query_runner=session_query_runner,
77+
delete_fn=delete_fn,
78+
gds_version=gds_version,
79+
)
7380

7481
def __init__(
7582
self,
7683
query_runner: QueryRunner,
7784
delete_fn: Callable[[], bool],
7885
gds_version: ServerVersion,
7986
):
80-
self._server_version = gds_version
8187
self._query_runner = query_runner
8288
self._delete_fn = delete_fn
89+
self._server_version = gds_version
8390

8491
super().__init__(self._query_runner, namespace="gds", server_version=self._server_version)
8592

8693
def run_cypher(
87-
self, query: str, params: Optional[Dict[str, Any]] = None, database: Optional[str] = None
94+
self,
95+
query: str,
96+
params: Optional[Dict[str, Any]] = None,
97+
database: Optional[str] = None,
8898
) -> DataFrame:
8999
"""
90100
Run a Cypher query against the Neo4j database.

0 commit comments

Comments
 (0)