Skip to content

Commit 5f0fdf1

Browse files
committed
Support GraphSage train + Model catalog ops in v2 endpoints
1 parent 9757b53 commit 5f0fdf1

File tree

13 files changed

+533
-518
lines changed

13 files changed

+533
-518
lines changed

graphdatascience/model/v2/graphsage_model.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from typing import Any
22

33
from pandas import Series
4+
from pydantic import BaseModel
5+
from pydantic.alias_generators import to_camel
46

5-
from ...call_parameters import CallParameters
67
from ...graph.graph_object import Graph
78
from ...graph.graph_type_check import graph_type_check
8-
from ..model import Model
9+
from .model import Model
910

1011

1112
class GraphSageModelV2(Model):
@@ -30,13 +31,7 @@ def predict_write(self, G: Graph, **config: Any) -> "Series[Any]":
3031
The result of the write operation.
3132
3233
"""
33-
endpoint = self._endpoint_prefix() + "write"
34-
config["modelName"] = self.name()
35-
params = CallParameters(graph_name=G.name(), config=config)
36-
37-
return self._query_runner.call_procedure( # type: ignore
38-
endpoint=endpoint, params=params, logging=True
39-
).squeeze()
34+
raise ValueError
4035

4136
@graph_type_check
4237
def predict_write_estimate(self, G: Graph, **config: Any) -> "Series[Any]":
@@ -51,4 +46,28 @@ def predict_write_estimate(self, G: Graph, **config: Any) -> "Series[Any]":
5146
The memory needed to generate embeddings for the given graph and write the results to the database.
5247
5348
"""
54-
return self._estimate_predict("write", G.name(), config)
49+
raise ValueError
50+
51+
52+
class GraphSageMutateResult(BaseModel, alias_generator=to_camel):
53+
node_count: int
54+
node_properties_written: int
55+
pre_processing_millis: int
56+
compute_millis: int
57+
mutate_millis: int
58+
configuration: dict[str, Any]
59+
60+
def __getitem__(self, item: str) -> Any:
61+
return self.__dict__[item]
62+
63+
64+
class GraphSageWriteResult(BaseModel, alias_generator=to_camel):
65+
node_count: int
66+
node_properties_written: int
67+
pre_processing_millis: int
68+
compute_millis: int
69+
write_millis: int
70+
configuration: dict[str, Any]
71+
72+
def __getitem__(self, item: str) -> Any:
73+
return self.__dict__[item]

graphdatascience/model/v2/model.py

Lines changed: 14 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,17 @@
11
from __future__ import annotations
22

3-
from abc import ABC, abstractmethod
4-
from datetime import datetime
5-
from typing import Any
3+
from abc import ABC
4+
from typing import Optional
65

7-
from pandas import DataFrame, Series
8-
9-
from graphdatascience.model.v2.model_info import ModelInfo
10-
11-
from ..call_parameters import CallParameters
12-
from ..graph.graph_object import Graph
13-
from ..graph.graph_type_check import graph_type_check
14-
from ..query_runner.query_runner import QueryRunner
15-
from ..server_version.compatible_with import compatible_with
16-
from ..server_version.server_version import ServerVersion
17-
18-
19-
class InfoProvider(ABC):
20-
@abstractmethod
21-
def fetch(self, model_name: str) -> ModelInfo:
22-
"""Return the task with progress for the given job_id."""
23-
pass
6+
from graphdatascience.model.v2.model_api import ModelApi
7+
from graphdatascience.model.v2.model_info import ModelDetails
248

259

10+
# Compared to v1 Model offering typed parameters for predict endpoints
2611
class Model(ABC):
27-
def __init__(self, name: str, info_provider: InfoProvider):
12+
def __init__(self, name: str, model_api: ModelApi):
2813
self._name = name
29-
self._info_provider = info_provider
14+
self._model_api = model_api
3015

3116
# TODO estimate mode, predict modes on here?
3217
# implement Cypher and Arrow info_provider and stuff
@@ -41,95 +26,8 @@ def name(self) -> str:
4126
"""
4227
return self._name
4328

44-
def type(self) -> str:
45-
"""
46-
Get the type of the model.
47-
48-
Returns:
49-
The type of the model.
50-
51-
"""
52-
return self._info_provider.fetch(self._name).type
53-
54-
def train_config(self) -> Series[Any]:
55-
"""
56-
Get the train config of the model.
57-
58-
Returns:
59-
The train config of the model.
60-
61-
"""
62-
return self._info_provider.fetch(self._name).train_config
63-
64-
def graph_schema(self) -> Series[Any]:
65-
"""
66-
Get the graph schema of the model.
67-
68-
Returns:
69-
The graph schema of the model.
70-
71-
"""
72-
return self._info_provider.fetch(self._name).graph_schema
73-
74-
def loaded(self) -> bool:
75-
"""
76-
Check whether the model is loaded in memory.
77-
78-
Returns:
79-
True if the model is loaded in memory, False otherwise.
80-
81-
"""
82-
return self._info_provider.fetch(self._name).loaded
83-
84-
def stored(self) -> bool:
85-
"""
86-
Check whether the model is stored on disk.
87-
88-
Returns:
89-
True if the model is stored on disk, False otherwise.
90-
91-
"""
92-
return self._info_provider.fetch(self._name).stored
93-
94-
def creation_time(self) -> datetime.datetime:
95-
"""
96-
Get the creation time of the model.
97-
98-
Returns:
99-
The creation time of the model.
100-
101-
"""
102-
return self._info_provider.fetch(self._name).creation_time
103-
104-
def shared(self) -> bool:
105-
"""
106-
Check whether the model is shared.
107-
108-
Returns:
109-
True if the model is shared, False otherwise.
110-
111-
"""
112-
return self._info_provider.fetch(self._name).shared
113-
114-
def published(self) -> bool:
115-
"""
116-
Check whether the model is published.
117-
118-
Returns:
119-
True if the model is published, False otherwise.
120-
121-
"""
122-
return self._info_provider.fetch(self._name).published
123-
124-
def model_info(self) -> dict[str, Any]:
125-
"""
126-
Get the model info of the model.
127-
128-
Returns:
129-
The model info of the model.
130-
131-
"""
132-
return self._info_provider.fetch(self._name).model_info
29+
def details(self) -> ModelDetails:
30+
return self._model_api.get(self._name)
13331

13432
def exists(self) -> bool:
13533
"""
@@ -139,9 +37,9 @@ def exists(self) -> bool:
13937
True if the model exists, False otherwise.
14038
14139
"""
142-
raise NotImplementedError()
40+
return self._model_api.exists(self._name)
14341

144-
def drop(self, failIfMissing: bool = False) -> Series[Any]:
42+
def drop(self, failIfMissing: bool = False) -> Optional[ModelDetails]:
14543
"""
14644
Drop the model.
14745
@@ -152,22 +50,10 @@ def drop(self, failIfMissing: bool = False) -> Series[Any]:
15250
The result of the drop operation.
15351
15452
"""
155-
raise NotImplementedError()
156-
157-
def metrics(self) -> Series[Any]:
158-
"""
159-
Get the metrics of the model.
160-
161-
Returns:
162-
The metrics of the model.
163-
164-
"""
165-
model_info = self._info_provider.fetch(self._name).model_info
166-
metrics: Series[Any] = Series(model_info["metrics"])
167-
return metrics
53+
return self._model_api.drop(self._name, failIfMissing)
16854

16955
def __str__(self) -> str:
170-
return f"{self.__class__.__name__}(name={self.name()}, type={self.type()})"
56+
return f"{self.__class__.__name__}(name={self.name()}, type={self.details().type})"
17157

17258
def __repr__(self) -> str:
173-
return f"{self.__class__.__name__}({self._info_provider.fetch(self._name).to_dict()})"
59+
return f"{self.__class__.__name__}({self.details().model_dump()})"
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Optional
3+
4+
from graphdatascience.model.v2.model_info import ModelDetails
5+
6+
7+
class ModelApi(ABC):
8+
"""
9+
Abstract base class defining the API for model operations.
10+
This class is intended to be subclassed by specific model implementations.
11+
"""
12+
13+
@abstractmethod
14+
def exists(self, model: str) -> bool:
15+
"""
16+
Check if a specific model exists.
17+
18+
Args:
19+
model: The name of the model.
20+
21+
Returns:
22+
True if the model exists, False otherwise.
23+
"""
24+
pass
25+
26+
@abstractmethod
27+
def get(self, model: str) -> ModelDetails:
28+
"""
29+
Get the details of a specific model.
30+
31+
Args:
32+
model: The name of the model.
33+
34+
Returns:
35+
The details of the model.
36+
"""
37+
pass
38+
39+
@abstractmethod
40+
def drop(self, model: str, fail_if_missing: bool) -> Optional[ModelDetails]:
41+
"""
42+
Drop a specific model.
43+
44+
Args:
45+
model: The name of the model.
46+
fail_if_missing: If True, an error is thrown if the model does not exist. If False, no error is thrown.
47+
48+
Returns:
49+
The result of the drop operation.
50+
"""
51+
pass
Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
import datetime
22
from typing import Any
3-
from pydantic import BaseModel
4-
from abc import ABC, abstractmethod
3+
4+
from pydantic import BaseModel, Field
55
from pydantic.alias_generators import to_camel
66

77

8-
class ModelInfo(BaseModel, alias_generator=to_camel):
9-
name: str
10-
type: str
8+
class ModelDetails(BaseModel, alias_generator=to_camel):
9+
name: str = Field(alias="modelName")
10+
type: str = Field(alias="modelType")
1111
train_config: dict[str, Any]
1212
graph_schema: dict[str, Any]
1313
loaded: bool
1414
stored: bool
15-
shared: bool
1615
published: bool
1716
model_info: dict[str, Any] # TODO better typing in actual model?
18-
creation_time: datetime.datetime # TODO correct type? / conversion needed
17+
creation_time: datetime.datetime
1918

2019
def __getitem__(self, item: str) -> Any:
2120
return getattr(self, item)

0 commit comments

Comments
 (0)