Skip to content

Commit db1116a

Browse files
DarthMaxFlorentinD
andcommitted
Extract common method cypher estimation
Co-authored-by: Florentin Dörre <florentin.dorre@neo4j.com>
1 parent 283251f commit db1116a

File tree

8 files changed

+96
-100
lines changed

8 files changed

+96
-100
lines changed

graphdatascience/procedure_surface/cypher/__init__.py

Whitespace-only changes.
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from collections import OrderedDict
2+
from typing import Any, Optional, Union
3+
4+
from graphdatascience.call_parameters import CallParameters
5+
from graphdatascience.graph.graph_object import Graph
6+
from graphdatascience.query_runner.query_runner import QueryRunner
7+
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult
8+
9+
10+
def estimate_algorithm(
11+
endpoint: str,
12+
query_runner: QueryRunner,
13+
G: Optional[Graph] = None,
14+
projection_config: Optional[dict[str, Any]] = None,
15+
) -> EstimationResult:
16+
"""
17+
Estimate the memory consumption of an algorithm run.
18+
19+
This utility function provides a common implementation for estimation
20+
across all cypher endpoint implementations.
21+
22+
Parameters
23+
----------
24+
query_runner : QueryRunner
25+
The query runner to use for the estimation call
26+
endpoint : str
27+
The full endpoint name for the estimation procedure (e.g., "gds.kcore.stats.estimate")
28+
G : Optional[Graph], optional
29+
The graph to be used in the estimation
30+
projection_config : Optional[dict[str, Any]], optional
31+
Configuration dictionary for the projection
32+
33+
Returns
34+
-------
35+
EstimationResult
36+
An object containing the result of the estimation
37+
38+
Raises
39+
------
40+
ValueError
41+
If neither G nor projection_config is provided
42+
"""
43+
config: Union[dict[str, Any]] = OrderedDict()
44+
45+
if G is not None:
46+
config["graphNameOrConfiguration"] = G.name()
47+
elif projection_config is not None:
48+
config["graphNameOrConfiguration"] = projection_config
49+
else:
50+
raise ValueError("Either graph_name or projection_config must be provided.")
51+
52+
config["algoConfig"] = {}
53+
54+
params = CallParameters(**config)
55+
56+
result = query_runner.call_procedure(endpoint=endpoint, params=params).squeeze()
57+
58+
return EstimationResult(**result.to_dict())

graphdatascience/procedure_surface/cypher/k1coloring_cypher_endpoints.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from collections import OrderedDict
2-
from typing import Any, List, Optional, Union
1+
from typing import Any, List, Optional
32

43
from pandas import DataFrame
54

@@ -14,6 +13,7 @@
1413
K1ColoringWriteResult,
1514
)
1615
from ..utils.config_converter import ConfigConverter
16+
from graphdatascience.procedure_surface.cypher.estimation_utils import estimate_algorithm
1717

1818

1919
class K1ColoringCypherEndpoints(K1ColoringEndpoints):
@@ -169,19 +169,9 @@ def write(
169169
def estimate(
170170
self, G: Optional[Graph] = None, projection_config: Optional[dict[str, Any]] = None
171171
) -> EstimationResult:
172-
config: Union[dict[str, Any]] = OrderedDict()
173-
174-
if G is not None:
175-
config["graphNameOrConfiguration"] = G.name()
176-
elif projection_config is not None:
177-
config["graphNameOrConfiguration"] = projection_config
178-
else:
179-
raise ValueError("Either graph_name or projection_config must be provided.")
180-
181-
config["algoConfig"] = {}
182-
183-
params = CallParameters(**config)
184-
185-
result = self._query_runner.call_procedure(endpoint="gds.k1coloring.stats.estimate", params=params).squeeze()
186-
187-
return EstimationResult(**result.to_dict())
172+
return estimate_algorithm(
173+
endpoint="gds.k1coloring.stats.estimate",
174+
query_runner=self._query_runner,
175+
G=G,
176+
projection_config=projection_config,
177+
)

graphdatascience/procedure_surface/cypher/kcore_cypher_endpoints.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from collections import OrderedDict
2-
from typing import Any, List, Optional, Union
1+
from typing import Any, List, Optional
32

43
from pandas import DataFrame
54

@@ -9,6 +8,7 @@
98
from ..api.estimation_result import EstimationResult
109
from ..api.kcore_endpoints import KCoreEndpoints, KCoreMutateResult, KCoreStatsResult, KCoreWriteResult
1110
from ..utils.config_converter import ConfigConverter
11+
from graphdatascience.procedure_surface.cypher.estimation_utils import estimate_algorithm
1212

1313

1414
class KCoreCypherEndpoints(KCoreEndpoints):
@@ -145,19 +145,9 @@ def write(
145145
def estimate(
146146
self, G: Optional[Graph] = None, projection_config: Optional[dict[str, Any]] = None
147147
) -> EstimationResult:
148-
config: Union[dict[str, Any]] = OrderedDict()
149-
150-
if G is not None:
151-
config["graphNameOrConfiguration"] = G.name()
152-
elif projection_config is not None:
153-
config["graphNameOrConfiguration"] = projection_config
154-
else:
155-
raise ValueError("Either graph_name or projection_config must be provided.")
156-
157-
config["algoConfig"] = {}
158-
159-
params = CallParameters(**config)
160-
161-
result = self._query_runner.call_procedure(endpoint="gds.kcore.stats.estimate", params=params).squeeze()
162-
163-
return EstimationResult(**result.to_dict())
148+
return estimate_algorithm(
149+
endpoint="gds.kcore.stats.estimate",
150+
query_runner=self._query_runner,
151+
G=G,
152+
projection_config=projection_config,
153+
)

graphdatascience/procedure_surface/cypher/louvain_cypher_endpoints.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from collections import OrderedDict
2-
from typing import Any, List, Optional, Union
1+
from typing import Any, List, Optional
32

43
from pandas import DataFrame
54

@@ -9,6 +8,7 @@
98
from ..api.estimation_result import EstimationResult
109
from ..api.louvain_endpoints import LouvainEndpoints, LouvainMutateResult, LouvainStatsResult, LouvainWriteResult
1110
from ..utils.config_converter import ConfigConverter
11+
from graphdatascience.procedure_surface.cypher.estimation_utils import estimate_algorithm
1212

1313

1414
class LouvainCypherEndpoints(LouvainEndpoints):
@@ -57,7 +57,6 @@ def mutate(
5757
username=username,
5858
)
5959

60-
# Run procedure and return results
6160
params = CallParameters(graph_name=G.name(), config=config)
6261
params.ensure_job_id_in_config()
6362

@@ -100,11 +99,10 @@ def stats(
10099
username=username,
101100
)
102101

103-
# Run procedure and return results
104102
params = CallParameters(graph_name=G.name(), config=config)
105103
params.ensure_job_id_in_config()
106104

107-
cypher_result = self._query_runner.call_procedure(endpoint="gds.louvain.stats", params=params).squeeze() # type: ignore
105+
cypher_result = self._query_runner.call_procedure(endpoint="gds.louvain.stats", params=params).squeeze()
108106

109107
return LouvainStatsResult(**cypher_result.to_dict())
110108

@@ -145,7 +143,6 @@ def stream(
145143
username=username,
146144
)
147145

148-
# Run procedure and return results
149146
params = CallParameters(graph_name=G.name(), config=config)
150147
params.ensure_job_id_in_config()
151148

@@ -206,19 +203,9 @@ def write(
206203
def estimate(
207204
self, G: Optional[Graph] = None, projection_config: Optional[dict[str, Any]] = None
208205
) -> EstimationResult:
209-
config: Union[dict[str, Any]] = OrderedDict()
210-
211-
if G is not None:
212-
config["graphNameOrConfiguration"] = G.name()
213-
elif projection_config is not None:
214-
config["graphNameOrConfiguration"] = projection_config
215-
else:
216-
raise ValueError("Either graph_name or projection_config must be provided.")
217-
218-
config["algoConfig"] = {}
219-
220-
params = CallParameters(**config)
221-
222-
result = self._query_runner.call_procedure(endpoint="gds.louvain.stats.estimate", params=params).squeeze()
223-
224-
return EstimationResult(**result.to_dict())
206+
return estimate_algorithm(
207+
endpoint="gds.louvain.stats.estimate",
208+
query_runner=self._query_runner,
209+
G=G,
210+
projection_config=projection_config,
211+
)

graphdatascience/procedure_surface/cypher/scc_cypher_endpoints.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from collections import OrderedDict
2-
from typing import Any, List, Optional, Union
1+
from typing import Any, List, Optional
32

43
from pandas import DataFrame
54

@@ -9,11 +8,12 @@
98
from ..api.estimation_result import EstimationResult
109
from ..api.scc_endpoints import SccEndpoints, SccMutateResult, SccStatsResult, SccWriteResult
1110
from ..utils.config_converter import ConfigConverter
11+
from graphdatascience.procedure_surface.cypher.estimation_utils import estimate_algorithm
1212

1313

1414
class SccCypherEndpoints(SccEndpoints):
1515
"""
16-
Implementation of the SCC algorithm endpoints.
16+
Implementation of the Strongly Connected Components (SCC) algorithm endpoints.
1717
This class handles the actual execution by forwarding calls to the query runner.
1818
"""
1919

@@ -153,19 +153,6 @@ def write(
153153
def estimate(
154154
self, G: Optional[Graph] = None, projection_config: Optional[dict[str, Any]] = None
155155
) -> EstimationResult:
156-
config: Union[dict[str, Any]] = OrderedDict()
157-
158-
if G is not None:
159-
config["graphNameOrConfiguration"] = G.name()
160-
elif projection_config is not None:
161-
config["graphNameOrConfiguration"] = projection_config
162-
else:
163-
raise ValueError("Either graph_name or projection_config must be provided.")
164-
165-
config["algoConfig"] = {}
166-
167-
params = CallParameters(**config)
168-
169-
result = self._query_runner.call_procedure(endpoint="gds.scc.stats.estimate", params=params).squeeze()
170-
171-
return EstimationResult(**result.to_dict())
156+
return estimate_algorithm(
157+
endpoint="gds.scc.stats.estimate", query_runner=self._query_runner, G=G, projection_config=projection_config
158+
)

graphdatascience/procedure_surface/cypher/wcc_cypher_endpoints.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from collections import OrderedDict
2-
from typing import Any, List, Optional, Union
1+
from typing import Any, List, Optional
32

43
from pandas import DataFrame
54

@@ -9,11 +8,12 @@
98
from ..api.estimation_result import EstimationResult
109
from ..api.wcc_endpoints import WccEndpoints, WccMutateResult, WccStatsResult, WccWriteResult
1110
from ..utils.config_converter import ConfigConverter
11+
from graphdatascience.procedure_surface.cypher.estimation_utils import estimate_algorithm
1212

1313

1414
class WccCypherEndpoints(WccEndpoints):
1515
"""
16-
Implementation of the WCC algorithm endpoints.
16+
Implementation of the Weakly Connected Components (WCC) algorithm endpoints.
1717
This class handles the actual execution by forwarding calls to the query runner.
1818
"""
1919

@@ -51,7 +51,6 @@ def mutate(
5151
username=username,
5252
)
5353

54-
# Run procedure and return results
5554
params = CallParameters(graph_name=G.name(), config=config)
5655
params.ensure_job_id_in_config()
5756

@@ -88,11 +87,10 @@ def stats(
8887
username=username,
8988
)
9089

91-
# Run procedure and return results
9290
params = CallParameters(graph_name=G.name(), config=config)
9391
params.ensure_job_id_in_config()
9492

95-
cypher_result = self._query_runner.call_procedure(endpoint="gds.wcc.stats", params=params).squeeze() # type: ignore
93+
cypher_result = self._query_runner.call_procedure(endpoint="gds.wcc.stats", params=params).squeeze()
9694

9795
return WccStatsResult(**cypher_result.to_dict())
9896

@@ -127,7 +125,6 @@ def stream(
127125
username=username,
128126
)
129127

130-
# Run procedure and return results
131128
params = CallParameters(graph_name=G.name(), config=config)
132129
params.ensure_job_id_in_config()
133130

@@ -180,19 +177,6 @@ def write(
180177
def estimate(
181178
self, G: Optional[Graph] = None, projection_config: Optional[dict[str, Any]] = None
182179
) -> EstimationResult:
183-
config: Union[dict[str, Any]] = OrderedDict()
184-
185-
if G is not None:
186-
config["graphNameOrConfiguration"] = G.name()
187-
elif projection_config is not None:
188-
config["graphNameOrConfiguration"] = projection_config
189-
else:
190-
raise ValueError("Either graph_name or projection_config must be provided.")
191-
192-
config["algoConfig"] = {}
193-
194-
params = CallParameters(**config)
195-
196-
result = self._query_runner.call_procedure(endpoint="gds.wcc.stats.estimate", params=params).squeeze()
197-
198-
return EstimationResult(**result.to_dict())
180+
return estimate_algorithm(
181+
endpoint="gds.wcc.stats.estimate", query_runner=self._query_runner, G=G, projection_config=projection_config
182+
)

graphdatascience/procedure_surface/utils/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)