Skip to content

Commit 87b0baf

Browse files
orazvebreakanalysis
andcommitted
Make predict call short, add predictedProbabilityProperty
Co-authored-by: Jacob Sznajdman <breakanalysis@gmail.com>
1 parent 82087b7 commit 87b0baf

File tree

1 file changed

+15
-23
lines changed

1 file changed

+15
-23
lines changed

graphdatascience/gnn/gnn_nc_runner.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77

88
class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker):
99
def train(
10-
self,
11-
graph_name: str,
12-
model_name: str,
13-
feature_properties: List[str],
14-
target_property: str,
15-
relationship_types: List[str],
16-
target_node_label: str = None,
17-
node_labels: List[str] = None,
10+
self,
11+
graph_name: str,
12+
model_name: str,
13+
feature_properties: List[str],
14+
target_property: str,
15+
relationship_types: List[str],
16+
target_node_label: str = None,
17+
node_labels: List[str] = None,
1818
) -> "Series[Any]": # noqa: F821
1919
mlConfigMap = {
2020
"featureProperties": feature_properties,
@@ -40,26 +40,18 @@ def train(
4040
)
4141

4242
def predict(
43-
self,
44-
graph_name: str,
45-
model_name: str,
46-
feature_properties: List[str],
47-
relationship_types: List[str],
48-
mutateProperty: str,
49-
target_node_label: str = None,
50-
node_labels: List[str] = None,
43+
self,
44+
graph_name: str,
45+
model_name: str,
46+
mutateProperty: str,
47+
predictedProbabilityProperty: str = None,
5148
) -> "Series[Any]": # noqa: F821
5249
mlConfigMap = {
53-
"featureProperties": feature_properties,
5450
"job_type": "predict",
55-
"nodeProperties": feature_properties,
56-
"relationshipTypes": relationship_types,
5751
"mutateProperty": mutateProperty
5852
}
59-
if target_node_label:
60-
mlConfigMap["targetNodeLabel"] = target_node_label
61-
if node_labels:
62-
mlConfigMap["nodeLabels"] = node_labels
53+
if predictedProbabilityProperty:
54+
mlConfigMap["predictedProbabilityProperty"] = predictedProbabilityProperty
6355

6456
mlTrainingConfig = json.dumps(mlConfigMap)
6557
self._query_runner.run_query(

0 commit comments

Comments
 (0)