1
1
from __future__ import annotations
2
2
3
- from abc import ABC , abstractmethod
4
- from datetime import datetime
5
- from typing import Any
3
+ from abc import ABC
4
+ from typing import Optional
6
5
7
- from pandas import DataFrame , Series
8
-
9
- from graphdatascience .model .v2 .model_info import ModelInfo
10
-
11
- from ..call_parameters import CallParameters
12
- from ..graph .graph_object import Graph
13
- from ..graph .graph_type_check import graph_type_check
14
- from ..query_runner .query_runner import QueryRunner
15
- from ..server_version .compatible_with import compatible_with
16
- from ..server_version .server_version import ServerVersion
17
-
18
-
19
- class InfoProvider (ABC ):
20
- @abstractmethod
21
- def fetch (self , model_name : str ) -> ModelInfo :
22
- """Return the task with progress for the given job_id."""
23
- pass
6
+ from graphdatascience .model .v2 .model_api import ModelApi
7
+ from graphdatascience .model .v2 .model_info import ModelDetails
24
8
25
9
10
+ # Compared to v1 Model offering typed parameters for predict endpoints
26
11
class Model (ABC ):
27
- def __init__ (self , name : str , info_provider : InfoProvider ):
12
+ def __init__ (self , name : str , model_api : ModelApi ):
28
13
self ._name = name
29
- self ._info_provider = info_provider
14
+ self ._model_api = model_api
30
15
31
16
# TODO estimate mode, predict modes on here?
32
17
# implement Cypher and Arrow info_provider and stuff
@@ -41,95 +26,8 @@ def name(self) -> str:
41
26
"""
42
27
return self ._name
43
28
44
- def type (self ) -> str :
45
- """
46
- Get the type of the model.
47
-
48
- Returns:
49
- The type of the model.
50
-
51
- """
52
- return self ._info_provider .fetch (self ._name ).type
53
-
54
- def train_config (self ) -> Series [Any ]:
55
- """
56
- Get the train config of the model.
57
-
58
- Returns:
59
- The train config of the model.
60
-
61
- """
62
- return self ._info_provider .fetch (self ._name ).train_config
63
-
64
- def graph_schema (self ) -> Series [Any ]:
65
- """
66
- Get the graph schema of the model.
67
-
68
- Returns:
69
- The graph schema of the model.
70
-
71
- """
72
- return self ._info_provider .fetch (self ._name ).graph_schema
73
-
74
- def loaded (self ) -> bool :
75
- """
76
- Check whether the model is loaded in memory.
77
-
78
- Returns:
79
- True if the model is loaded in memory, False otherwise.
80
-
81
- """
82
- return self ._info_provider .fetch (self ._name ).loaded
83
-
84
- def stored (self ) -> bool :
85
- """
86
- Check whether the model is stored on disk.
87
-
88
- Returns:
89
- True if the model is stored on disk, False otherwise.
90
-
91
- """
92
- return self ._info_provider .fetch (self ._name ).stored
93
-
94
- def creation_time (self ) -> datetime .datetime :
95
- """
96
- Get the creation time of the model.
97
-
98
- Returns:
99
- The creation time of the model.
100
-
101
- """
102
- return self ._info_provider .fetch (self ._name ).creation_time
103
-
104
- def shared (self ) -> bool :
105
- """
106
- Check whether the model is shared.
107
-
108
- Returns:
109
- True if the model is shared, False otherwise.
110
-
111
- """
112
- return self ._info_provider .fetch (self ._name ).shared
113
-
114
- def published (self ) -> bool :
115
- """
116
- Check whether the model is published.
117
-
118
- Returns:
119
- True if the model is published, False otherwise.
120
-
121
- """
122
- return self ._info_provider .fetch (self ._name ).published
123
-
124
- def model_info (self ) -> dict [str , Any ]:
125
- """
126
- Get the model info of the model.
127
-
128
- Returns:
129
- The model info of the model.
130
-
131
- """
132
- return self ._info_provider .fetch (self ._name ).model_info
29
+ def details (self ) -> ModelDetails :
30
+ return self ._model_api .get (self ._name )
133
31
134
32
def exists (self ) -> bool :
135
33
"""
@@ -139,9 +37,9 @@ def exists(self) -> bool:
139
37
True if the model exists, False otherwise.
140
38
141
39
"""
142
- raise NotImplementedError ( )
40
+ return self . _model_api . exists ( self . _name )
143
41
144
- def drop (self , failIfMissing : bool = False ) -> Series [ Any ]:
42
+ def drop (self , failIfMissing : bool = False ) -> Optional [ ModelDetails ]:
145
43
"""
146
44
Drop the model.
147
45
@@ -152,22 +50,10 @@ def drop(self, failIfMissing: bool = False) -> Series[Any]:
152
50
The result of the drop operation.
153
51
154
52
"""
155
- raise NotImplementedError ()
156
-
157
- def metrics (self ) -> Series [Any ]:
158
- """
159
- Get the metrics of the model.
160
-
161
- Returns:
162
- The metrics of the model.
163
-
164
- """
165
- model_info = self ._info_provider .fetch (self ._name ).model_info
166
- metrics : Series [Any ] = Series (model_info ["metrics" ])
167
- return metrics
53
+ return self ._model_api .drop (self ._name , failIfMissing )
168
54
169
55
def __str__ (self ) -> str :
170
- return f"{ self .__class__ .__name__ } (name={ self .name ()} , type={ self .type () } )"
56
+ return f"{ self .__class__ .__name__ } (name={ self .name ()} , type={ self .details (). type } )"
171
57
172
58
def __repr__ (self ) -> str :
173
- return f"{ self .__class__ .__name__ } ({ self ._info_provider . fetch ( self . _name ). to_dict ()} )"
59
+ return f"{ self .__class__ .__name__ } ({ self .details (). model_dump ()} )"
0 commit comments