Skip to content

Commit f804d5d

Browse files
committed
WIP graphsage support
1 parent 083c987 commit f804d5d

File tree

6 files changed

+860
-0
lines changed

6 files changed

+860
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from typing import Any
2+
3+
from pandas import Series
4+
5+
from ...call_parameters import CallParameters
6+
from ...graph.graph_object import Graph
7+
from ...graph.graph_type_check import graph_type_check
8+
from ..model import Model
9+
10+
11+
class GraphSageModelV2(Model):
12+
"""
13+
Represents a GraphSAGE model in the model catalog.
14+
Construct this using :func:`gds.graphSage.train()`.
15+
"""
16+
17+
def _endpoint_prefix(self) -> str:
18+
return "gds.beta.graphSage."
19+
20+
@graph_type_check
21+
def predict_write(self, G: Graph, **config: Any) -> "Series[Any]":
22+
"""
23+
Generate embeddings for the given graph and write the results to the database.
24+
25+
Args:
26+
G: The graph to generate embeddings for.
27+
**config: The config for the prediction.
28+
29+
Returns:
30+
The result of the write operation.
31+
32+
"""
33+
endpoint = self._endpoint_prefix() + "write"
34+
config["modelName"] = self.name()
35+
params = CallParameters(graph_name=G.name(), config=config)
36+
37+
return self._query_runner.call_procedure( # type: ignore
38+
endpoint=endpoint, params=params, logging=True
39+
).squeeze()
40+
41+
@graph_type_check
42+
def predict_write_estimate(self, G: Graph, **config: Any) -> "Series[Any]":
43+
"""
44+
Estimate the memory needed to generate embeddings for the given graph and write the results to the database.
45+
46+
Args:
47+
G: The graph to generate embeddings for.
48+
**config: The config for the prediction.
49+
50+
Returns:
51+
The memory needed to generate embeddings for the given graph and write the results to the database.
52+
53+
"""
54+
return self._estimate_predict("write", G.name(), config)

graphdatascience/model/v2/model.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from datetime import datetime
5+
from typing import Any
6+
7+
from pandas import DataFrame, Series
8+
9+
from graphdatascience.model.v2.model_info import ModelInfo
10+
11+
from ..call_parameters import CallParameters
12+
from ..graph.graph_object import Graph
13+
from ..graph.graph_type_check import graph_type_check
14+
from ..query_runner.query_runner import QueryRunner
15+
from ..server_version.compatible_with import compatible_with
16+
from ..server_version.server_version import ServerVersion
17+
18+
19+
class InfoProvider(ABC):
20+
@abstractmethod
21+
def fetch(self, model_name: str) -> ModelInfo:
22+
"""Return the task with progress for the given job_id."""
23+
pass
24+
25+
26+
class Model(ABC):
27+
def __init__(self, name: str, info_provider: InfoProvider):
28+
self._name = name
29+
self._info_provider = info_provider
30+
31+
# TODO estimate mode, predict modes on here?
32+
# implement Cypher and Arrow info_provider and stuff
33+
34+
def name(self) -> str:
35+
"""
36+
Get the name of the model.
37+
38+
Returns:
39+
The name of the model.
40+
41+
"""
42+
return self._name
43+
44+
def type(self) -> str:
45+
"""
46+
Get the type of the model.
47+
48+
Returns:
49+
The type of the model.
50+
51+
"""
52+
return self._info_provider.fetch(self._name).type
53+
54+
def train_config(self) -> Series[Any]:
55+
"""
56+
Get the train config of the model.
57+
58+
Returns:
59+
The train config of the model.
60+
61+
"""
62+
return self._info_provider.fetch(self._name).train_config
63+
64+
def graph_schema(self) -> Series[Any]:
65+
"""
66+
Get the graph schema of the model.
67+
68+
Returns:
69+
The graph schema of the model.
70+
71+
"""
72+
return self._info_provider.fetch(self._name).graph_schema
73+
74+
def loaded(self) -> bool:
75+
"""
76+
Check whether the model is loaded in memory.
77+
78+
Returns:
79+
True if the model is loaded in memory, False otherwise.
80+
81+
"""
82+
return self._info_provider.fetch(self._name).loaded
83+
84+
def stored(self) -> bool:
85+
"""
86+
Check whether the model is stored on disk.
87+
88+
Returns:
89+
True if the model is stored on disk, False otherwise.
90+
91+
"""
92+
return self._info_provider.fetch(self._name).stored
93+
94+
def creation_time(self) -> datetime.datetime:
95+
"""
96+
Get the creation time of the model.
97+
98+
Returns:
99+
The creation time of the model.
100+
101+
"""
102+
return self._info_provider.fetch(self._name).creation_time
103+
104+
def shared(self) -> bool:
105+
"""
106+
Check whether the model is shared.
107+
108+
Returns:
109+
True if the model is shared, False otherwise.
110+
111+
"""
112+
return self._info_provider.fetch(self._name).shared
113+
114+
def published(self) -> bool:
115+
"""
116+
Check whether the model is published.
117+
118+
Returns:
119+
True if the model is published, False otherwise.
120+
121+
"""
122+
return self._info_provider.fetch(self._name).published
123+
124+
def model_info(self) -> dict[str, Any]:
125+
"""
126+
Get the model info of the model.
127+
128+
Returns:
129+
The model info of the model.
130+
131+
"""
132+
return self._info_provider.fetch(self._name).model_info
133+
134+
def exists(self) -> bool:
135+
"""
136+
Check whether the model exists.
137+
138+
Returns:
139+
True if the model exists, False otherwise.
140+
141+
"""
142+
raise NotImplementedError()
143+
144+
def drop(self, failIfMissing: bool = False) -> Series[Any]:
145+
"""
146+
Drop the model.
147+
148+
Args:
149+
failIfMissing: If True, an error is thrown if the model does not exist. If False, no error is thrown.
150+
151+
Returns:
152+
The result of the drop operation.
153+
154+
"""
155+
raise NotImplementedError()
156+
157+
def metrics(self) -> Series[Any]:
158+
"""
159+
Get the metrics of the model.
160+
161+
Returns:
162+
The metrics of the model.
163+
164+
"""
165+
model_info = self._info_provider.fetch(self._name).model_info
166+
metrics: Series[Any] = Series(model_info["metrics"])
167+
return metrics
168+
169+
def __str__(self) -> str:
170+
return f"{self.__class__.__name__}(name={self.name()}, type={self.type()})"
171+
172+
def __repr__(self) -> str:
173+
return f"{self.__class__.__name__}({self._info_provider.fetch(self._name).to_dict()})"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import datetime
2+
from typing import Any
3+
from pydantic import BaseModel
4+
from abc import ABC, abstractmethod
5+
from pydantic.alias_generators import to_camel
6+
7+
8+
class ModelInfo(BaseModel, alias_generator=to_camel):
9+
name: str
10+
type: str
11+
train_config: dict[str, Any]
12+
graph_schema: dict[str, Any]
13+
loaded: bool
14+
stored: bool
15+
shared: bool
16+
published: bool
17+
model_info: dict[str, Any] # TODO better typing in actual model?
18+
creation_time: datetime.datetime # TODO correct type? / conversion needed
19+
20+
def __getitem__(self, item: str) -> Any:
21+
return getattr(self, item)

0 commit comments

Comments
 (0)