Skip to content

Commit 714b685

Browse files
authored
Merge pull request #648 from neo4j/push-based-no-jobid
Support remote ops using DBMS Arrow Client
2 parents 4655c91 + 62eea86 commit 714b685

21 files changed

+716
-546
lines changed

changelog.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
* Add the new concept of GDS Sessions, used to manage GDS computations in Aura, based on data from an AuraDB instance.
1010
* Add a new `gds.graph.project` endpoint to project graphs from AuraDB instances to GDS sessions.
11-
* `nodePropertySchema` and `relationshipPropertySchema` can be used to optimise remote projections.
1211
* Add a new top-level class `GdsSessions` to manage GDS sessions in Aura.
1312
* `GdsSessions` support `get_or_create()`, `list()`, and `delete()`.
1413
* Creating a new session supports various sizes.

doc/modules/ROOT/pages/tutorials/gds-sessions.adoc

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,6 @@ although we do not do that in this notebook.
173173

174174
[source, python, role=no-test]
175175
----
176-
from graphdatascience.session import GdsPropertyTypes
177-
178176
G, result = gds.graph.project(
179177
"people-and-fruits",
180178
"""
@@ -201,15 +199,6 @@ G, result = gds.graph.project(
201199
relationshipType: type(rel)
202200
})
203201
""",
204-
nodePropertySchema={
205-
"age": GdsPropertyTypes.LONG,
206-
"experience": GdsPropertyTypes.LONG,
207-
"hipster": GdsPropertyTypes.LONG,
208-
"tropical": GdsPropertyTypes.LONG,
209-
"sourness": GdsPropertyTypes.DOUBLE,
210-
"sweetness": GdsPropertyTypes.DOUBLE,
211-
},
212-
relationshipPropertySchema={},
213202
)
214203
215204
str(G)

examples/gds-sessions.ipynb

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,6 @@
236236
"metadata": {},
237237
"outputs": [],
238238
"source": [
239-
"from graphdatascience.session import GdsPropertyTypes\n",
240-
"\n",
241239
"G, result = gds.graph.project(\n",
242240
" \"people-and-fruits\",\n",
243241
" \"\"\"\n",
@@ -264,15 +262,6 @@
264262
" relationshipType: type(rel)\n",
265263
" })\n",
266264
" \"\"\",\n",
267-
" nodePropertySchema={\n",
268-
" \"age\": GdsPropertyTypes.LONG,\n",
269-
" \"experience\": GdsPropertyTypes.LONG,\n",
270-
" \"hipster\": GdsPropertyTypes.LONG,\n",
271-
" \"tropical\": GdsPropertyTypes.LONG,\n",
272-
" \"sourness\": GdsPropertyTypes.DOUBLE,\n",
273-
" \"sweetness\": GdsPropertyTypes.DOUBLE,\n",
274-
" },\n",
275-
" relationshipPropertySchema={},\n",
276265
")\n",
277266
"\n",
278267
"str(G)"

graphdatascience/graph/graph_entity_ops_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def add_property(query: str, prop: str) -> str:
177177
return reduce(add_property, db_node_properties, query_prefix)
178178

179179
@compatible_with("write", min_inclusive=ServerVersion(2, 2, 0))
180-
def write(self, G: Graph, node_properties: List[str], node_labels: Strings = ["*"], **config: Any) -> "Series[Any]":
180+
def write(self, G: Graph, node_properties: Strings, node_labels: Strings = ["*"], **config: Any) -> "Series[Any]":
181181
self._namespace += ".write"
182182
return self._handle_properties(G, node_properties, node_labels, config).squeeze() # type: ignore
183183

graphdatascience/graph/graph_remote_proc_runner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,4 @@
55
class GraphRemoteProcRunner(BaseGraphProcRunner):
66
@property
77
def project(self) -> GraphProjectRemoteRunner:
8-
self._namespace += ".project.remoteDb"
98
return GraphProjectRemoteRunner(self._query_runner, self._namespace, self._server_version)

graphdatascience/graph/graph_remote_project_runner.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from typing import List, Optional
44

55
from ..error.illegal_attr_checker import IllegalAttrChecker
6+
from ..query_runner.aura_db_query_runner import AuraDbQueryRunner
67
from ..server_version.compatible_with import compatible_with
78
from .graph_object import Graph
89
from graphdatascience.call_parameters import CallParameters
@@ -11,28 +12,30 @@
1112

1213

1314
class GraphProjectRemoteRunner(IllegalAttrChecker):
14-
_SCHEMA_KEYS = ["nodePropertySchema", "relationshipPropertySchema"]
15+
@compatible_with("project", min_inclusive=ServerVersion(2, 7, 0))
16+
def __call__(
17+
self,
18+
graph_name: str,
19+
query: str,
20+
concurrency: int = 4,
21+
undirected_relationship_types: Optional[List[str]] = None,
22+
inverse_indexed_relationship_types: Optional[List[str]] = None,
23+
) -> GraphCreateResult:
24+
if inverse_indexed_relationship_types is None:
25+
inverse_indexed_relationship_types = []
26+
if undirected_relationship_types is None:
27+
undirected_relationship_types = []
1528

16-
@compatible_with("project", min_inclusive=ServerVersion(2, 6, 0))
17-
def __call__(self, graph_name: str, query: str, **config: Any) -> GraphCreateResult:
18-
placeholder = "<>" # host and token will be added by query runner
19-
self.map_property_types(config)
2029
params = CallParameters(
2130
graph_name=graph_name,
2231
query=query,
23-
token=placeholder,
24-
host=placeholder,
25-
remote_database=self._query_runner.database(),
26-
config=config,
32+
concurrency=concurrency,
33+
undirected_relationship_types=undirected_relationship_types,
34+
inverse_indexed_relationship_types=inverse_indexed_relationship_types,
2735
)
36+
2837
result = self._query_runner.call_procedure(
29-
endpoint=self._namespace,
38+
endpoint=AuraDbQueryRunner.GDS_REMOTE_PROJECTION_PROC_NAME,
3039
params=params,
3140
).squeeze()
3241
return GraphCreateResult(Graph(graph_name, self._query_runner, self._server_version), result)
33-
34-
@staticmethod
35-
def map_property_types(config: dict[str, Any]) -> None:
36-
for key in GraphProjectRemoteRunner._SCHEMA_KEYS:
37-
if key in config:
38-
config[key] = {k: v.value for k, v in config[key].items()}

graphdatascience/query_runner/arrow_graph_constructor.py

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
from __future__ import annotations
22

33
import concurrent
4-
import json
54
import math
65
import warnings
76
from concurrent.futures import ThreadPoolExecutor
87
from typing import Any, Dict, List, NoReturn, Optional
98

109
import numpy
11-
import pyarrow.flight as flight
1210
from pandas import DataFrame
1311
from pyarrow import Table
1412
from tqdm.auto import tqdm
1513

16-
from .arrow_endpoint_version import ArrowEndpointVersion
14+
from .gds_arrow_client import GdsArrowClient
1715
from .graph_constructor import GraphConstructor
1816

1917

@@ -22,17 +20,15 @@ def __init__(
2220
self,
2321
database: str,
2422
graph_name: str,
25-
flight_client: flight.FlightClient,
23+
flight_client: GdsArrowClient,
2624
concurrency: int,
27-
arrow_endpoint_version: ArrowEndpointVersion,
2825
undirected_relationship_types: Optional[List[str]],
2926
chunk_size: int = 10_000,
3027
):
3128
self._database = database
3229
self._concurrency = concurrency
3330
self._graph_name = graph_name
3431
self._client = flight_client
35-
self._arrow_endpoint_version = arrow_endpoint_version
3632
self._undirected_relationship_types = (
3733
[] if undirected_relationship_types is None else undirected_relationship_types
3834
)
@@ -49,20 +45,20 @@ def run(self, node_dfs: List[DataFrame], relationship_dfs: List[DataFrame]) -> N
4945
if self._undirected_relationship_types:
5046
config["undirected_relationship_types"] = self._undirected_relationship_types
5147

52-
self._send_action(
48+
self._client.send_action(
5349
"CREATE_GRAPH",
5450
config,
5551
)
5652

5753
self._send_dfs(node_dfs, "node")
5854

59-
self._send_action("NODE_LOAD_DONE", {"name": self._graph_name})
55+
self._client.send_action("NODE_LOAD_DONE", {"name": self._graph_name})
6056

6157
self._send_dfs(relationship_dfs, "relationship")
6258

63-
self._send_action("RELATIONSHIP_LOAD_DONE", {"name": self._graph_name})
59+
self._client.send_action("RELATIONSHIP_LOAD_DONE", {"name": self._graph_name})
6460
except (Exception, KeyboardInterrupt) as e:
65-
self._send_action("ABORT", {"name": self._graph_name})
61+
self._client.send_action("ABORT", {"name": self._graph_name})
6662

6763
raise e
6864

@@ -85,25 +81,12 @@ def _partition_dfs(self, dfs: List[DataFrame]) -> List[DataFrame]:
8581

8682
return partitioned_dfs
8783

88-
def _send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None:
89-
action_type = self._versioned_action_type(action_type)
90-
result = self._client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))
91-
92-
# Consume result fully to sanity check and avoid cancelled streams
93-
collected_result = list(result)
94-
assert len(collected_result) == 1
95-
96-
json.loads(collected_result[0].body.to_pybytes().decode())
97-
9884
def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm[NoReturn]) -> None:
9985
table = Table.from_pandas(df)
10086
batches = table.to_batches(self._chunk_size)
10187
flight_descriptor = {"name": self._graph_name, "entity_type": entity_type}
102-
flight_descriptor = self._versioned_flight_desriptor(flight_descriptor)
10388

104-
# Write schema
105-
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
106-
writer, _ = self._client.do_put(upload_descriptor, table.schema)
89+
writer, _ = self._client.start_put(flight_descriptor, table.schema)
10790

10891
with writer:
10992
# Write table in chunks
@@ -126,17 +109,3 @@ def _send_dfs(self, dfs: List[DataFrame], entity_type: str) -> None:
126109
if not future.exception():
127110
continue
128111
raise future.exception() # type: ignore
129-
130-
def _versioned_action_type(self, action_type: str) -> str:
131-
return self._arrow_endpoint_version.prefix() + action_type
132-
133-
def _versioned_flight_desriptor(self, flight_descriptor: Dict[str, Any]) -> Dict[str, Any]:
134-
return (
135-
flight_descriptor
136-
if self._arrow_endpoint_version == ArrowEndpointVersion.ALPHA
137-
else {
138-
"name": "PUT_MESSAGE",
139-
"version": ArrowEndpointVersion.V1.version(),
140-
"body": flight_descriptor,
141-
}
142-
)

0 commit comments

Comments
 (0)