Skip to content

Commit 0db17d8

Browse files
authored
Merge pull request #63 from adamnsch/client-only-errors
Add consistent error handling for client side only methods
2 parents 86efd24 + 76fabda commit 0db17d8

File tree

5 files changed

+34
-6
lines changed

5 files changed

+34
-6
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from abc import ABC
2+
from typing import Any, Callable, TypeVar, cast
3+
4+
F = TypeVar("F", bound=Callable[..., Any])
5+
6+
7+
class WithNamespace(ABC):
8+
_namespace: str
9+
10+
11+
def client_only_endpoint(expected_namespace_prefix: str) -> Callable[[F], F]:
12+
def decorator(func: F) -> F:
13+
def wrapper(self: WithNamespace, *args: Any, **kwargs: Any) -> Any:
14+
if self._namespace != expected_namespace_prefix:
15+
raise SyntaxError(
16+
f"There is no '{self._namespace}.{func.__name__}' to call"
17+
)
18+
19+
return func(self, *args, **kwargs)
20+
21+
return cast(F, wrapper)
22+
23+
return decorator

graphdatascience/graph/graph_proc_runner.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Dict, List, Optional, Union
22

3+
from ..error.client_only_endpoint import client_only_endpoint
34
from ..error.illegal_attr_checker import IllegalAttrChecker
45
from ..error.uncallable_namespace import UncallableNamespace
56
from ..query_runner.query_runner import QueryResult, QueryRunner
@@ -67,10 +68,8 @@ def list(self, G: Optional[Graph] = None) -> QueryResult:
6768

6869
return self._query_runner.run_query(query, params)
6970

71+
@client_only_endpoint("gds.graph")
7072
def get(self, graph_name: str) -> Graph:
71-
if self._namespace != "gds.graph":
72-
raise SyntaxError(f"There is no {self._namespace + '.get'} to call")
73-
7473
if not self.exists(graph_name)[0]["exists"]:
7574
raise ValueError(f"No projected graph named '{graph_name}' exists")
7675

graphdatascience/model/model_proc_runner.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional, Union
22

3+
from ..error.client_only_endpoint import client_only_endpoint
34
from ..error.illegal_attr_checker import IllegalAttrChecker
45
from ..error.uncallable_namespace import UncallableNamespace
56
from ..pipeline.lp_prediction_pipeline import LPPredictionPipeline
@@ -101,10 +102,8 @@ def delete(self, model_id: ModelId) -> QueryResult:
101102

102103
return self._query_runner.run_query(query, params)
103104

105+
@client_only_endpoint("gds.model")
104106
def get(self, model_name: str) -> Model:
105-
if self._namespace != "gds.model":
106-
raise SyntaxError(f"There is no {self._namespace + '.get'} to call")
107-
108107
self._namespace = "gds.beta.model"
109108
result = self.list(model_name)
110109
if len(result) == 0:

graphdatascience/tests/unit/test_error_handling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,8 @@ def test_nonexisting_similarity_endpoint(gds: GraphDataScience) -> None:
120120
SyntaxError, match="There is no 'gds.alpha.similarity.pearson.bogus' to call"
121121
):
122122
gds.alpha.similarity.pearson.bogus() # type: ignore
123+
124+
125+
def test_wrong_client_only_prefix(gds: GraphDataScience) -> None:
126+
with pytest.raises(SyntaxError, match="There is no 'gds.beta.model.get' to call"):
127+
gds.beta.model.get("model")

graphdatascience/utils/util_endpoints.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Dict, List
22

3+
from ..error.client_only_endpoint import client_only_endpoint
34
from ..query_runner.query_runner import QueryResult, QueryRunner
45
from .util_proc_runner import UtilProcRunner
56

@@ -13,6 +14,7 @@ def __init__(self, query_runner: QueryRunner, namespace: str):
1314
def util(self) -> UtilProcRunner:
1415
return UtilProcRunner(self._query_runner, f"{self._namespace}.util")
1516

17+
@client_only_endpoint("gds")
1618
def find_node_id(
1719
self, labels: List[str] = [], properties: Dict[str, Any] = {}
1820
) -> int:

0 commit comments

Comments
 (0)