@@ -35,56 +35,66 @@ def create(
35
35
arrow_tls_root_certs : Optional [bytes ] = None ,
36
36
bookmarks : Optional [Any ] = None ,
37
37
):
38
- session_neo4j_query_runner = Neo4jQueryRunner .create (
39
- gds_session_connection_info .uri , gds_session_connection_info .auth (), aura_ds = True
38
+ # we need to explicitly set this as the default value is None
39
+ # database in the session is always neo4j
40
+ session_bolt_query_runner = Neo4jQueryRunner .create (
41
+ endpoint = gds_session_connection_info .uri ,
42
+ auth = gds_session_connection_info .auth (),
43
+ aura_ds = True ,
44
+ database = "neo4j" ,
40
45
)
41
46
session_arrow_query_runner = ArrowQueryRunner .create (
42
- session_neo4j_query_runner ,
47
+ fallback_query_runner = session_bolt_query_runner ,
48
+ auth = gds_session_connection_info .auth (),
49
+ encrypted = session_bolt_query_runner .encrypted (),
50
+ disable_server_verification = arrow_disable_server_verification ,
51
+ tls_root_certs = arrow_tls_root_certs ,
52
+ )
53
+
54
+ # TODO: merge with the gds_arrow_client created inside ArrowQueryRunner
55
+ session_arrow_client = GdsArrowClient .create (
56
+ session_bolt_query_runner ,
43
57
gds_session_connection_info .auth (),
44
- session_neo4j_query_runner .encrypted (),
58
+ session_bolt_query_runner .encrypted (),
45
59
arrow_disable_server_verification ,
46
60
arrow_tls_root_certs ,
47
61
)
48
62
49
- # we need to explicitly set this as the default value is None
50
- # database in the session is always neo4j
51
- session_arrow_query_runner .set_database ("neo4j" )
52
-
53
- db_query_runner = Neo4jQueryRunner .create (
63
+ db_bolt_query_runner = Neo4jQueryRunner .create (
54
64
db_connection_info .uri ,
55
65
db_connection_info .auth (),
56
66
aura_ds = True ,
57
67
)
58
- db_query_runner .set_bookmarks (bookmarks )
68
+ db_bolt_query_runner .set_bookmarks (bookmarks )
59
69
60
- session_arrow_client = GdsArrowClient .create (
61
- session_neo4j_query_runner ,
62
- gds_session_connection_info .auth (),
63
- session_neo4j_query_runner .encrypted (),
64
- arrow_disable_server_verification ,
65
- arrow_tls_root_certs ,
66
- )
67
- aura_db_query_runner = SessionQueryRunner .create (
68
- session_arrow_query_runner , db_query_runner , session_arrow_client
70
+ session_query_runner = SessionQueryRunner .create (
71
+ session_arrow_query_runner , db_bolt_query_runner , session_arrow_client
69
72
)
70
73
71
- gds_version = session_neo4j_query_runner .server_version ()
72
- return cls (query_runner = aura_db_query_runner , delete_fn = delete_fn , gds_version = gds_version )
74
+ gds_version = session_bolt_query_runner .server_version ()
75
+ return cls (
76
+ query_runner = session_query_runner ,
77
+ delete_fn = delete_fn ,
78
+ gds_version = gds_version ,
79
+ )
73
80
74
81
def __init__ (
75
82
self ,
76
83
query_runner : QueryRunner ,
77
84
delete_fn : Callable [[], bool ],
78
85
gds_version : ServerVersion ,
79
86
):
80
- self ._server_version = gds_version
81
87
self ._query_runner = query_runner
82
88
self ._delete_fn = delete_fn
89
+ self ._server_version = gds_version
83
90
84
91
super ().__init__ (self ._query_runner , namespace = "gds" , server_version = self ._server_version )
85
92
86
93
def run_cypher (
87
- self , query : str , params : Optional [Dict [str , Any ]] = None , database : Optional [str ] = None
94
+ self ,
95
+ query : str ,
96
+ params : Optional [Dict [str , Any ]] = None ,
97
+ database : Optional [str ] = None ,
88
98
) -> DataFrame :
89
99
"""
90
100
Run a Cypher query against the Neo4j database.
0 commit comments