7
7
8
8
class GNNNodeClassificationRunner (UncallableNamespace , IllegalAttrChecker ):
9
9
def train (
10
- self ,
11
- graph_name : str ,
12
- model_name : str ,
13
- feature_properties : List [str ],
14
- target_property : str ,
15
- relationship_types : List [str ],
16
- target_node_label : str = None ,
17
- node_labels : List [str ] = None ,
10
+ self ,
11
+ graph_name : str ,
12
+ model_name : str ,
13
+ feature_properties : List [str ],
14
+ target_property : str ,
15
+ relationship_types : List [str ],
16
+ target_node_label : str = None ,
17
+ node_labels : List [str ] = None ,
18
18
) -> "Series[Any]" : # noqa: F821
19
19
mlConfigMap = {
20
20
"featureProperties" : feature_properties ,
@@ -40,26 +40,18 @@ def train(
40
40
)
41
41
42
42
def predict (
43
- self ,
44
- graph_name : str ,
45
- model_name : str ,
46
- feature_properties : List [str ],
47
- relationship_types : List [str ],
48
- mutateProperty : str ,
49
- target_node_label : str = None ,
50
- node_labels : List [str ] = None ,
43
+ self ,
44
+ graph_name : str ,
45
+ model_name : str ,
46
+ mutateProperty : str ,
47
+ predictedProbabilityProperty : str = None ,
51
48
) -> "Series[Any]" : # noqa: F821
52
49
mlConfigMap = {
53
- "featureProperties" : feature_properties ,
54
50
"job_type" : "predict" ,
55
- "nodeProperties" : feature_properties ,
56
- "relationshipTypes" : relationship_types ,
57
51
"mutateProperty" : mutateProperty
58
52
}
59
- if target_node_label :
60
- mlConfigMap ["targetNodeLabel" ] = target_node_label
61
- if node_labels :
62
- mlConfigMap ["nodeLabels" ] = node_labels
53
+ if predictedProbabilityProperty :
54
+ mlConfigMap ["predictedProbabilityProperty" ] = predictedProbabilityProperty
63
55
64
56
mlTrainingConfig = json .dumps (mlConfigMap )
65
57
self ._query_runner .run_query (
0 commit comments