@@ -11,11 +11,26 @@ def train(self, graph_name: str, model_name: str, feature_properties: List[str],
11
11
configMap = {
12
12
"featureProperties" : feature_properties ,
13
13
"targetProperty" : target_property ,
14
+ "job_type" : "train" ,
14
15
}
15
16
node_properties = feature_properties + [target_property ]
16
17
if target_node_label :
17
18
configMap ["targetNodeLabel" ] = target_node_label
18
19
mlTrainingConfig = json .dumps (configMap )
19
- # TODO query avaiable node labels
20
+ # TODO query available node labels
20
21
node_labels = ["Paper" ] if not node_labels else node_labels
21
22
self ._query_runner .run_query (f"CALL gds.upload.graph('{ graph_name } ', {{mlTrainingConfig: '{ mlTrainingConfig } ', modelName: '{ model_name } ', nodeLabels: { node_labels } , nodeProperties: { node_properties } }})" )
23
+
24
+
25
+ def predict (self , graph_name : str , model_name : str , feature_properties : List [str ], target_node_label : str = None , node_labels : List [str ] = None ) -> "Series[Any]" :
26
+ configMap = {
27
+ "featureProperties" : feature_properties ,
28
+ "job_type" : "predict" ,
29
+ }
30
+ if target_node_label :
31
+ configMap ["targetNodeLabel" ] = target_node_label
32
+ mlTrainingConfig = json .dumps (configMap )
33
+ # TODO query available node labels
34
+ node_labels = ["Paper" ] if not node_labels else node_labels
35
+ self ._query_runner .run_query (
36
+ f"CALL gds.upload.graph('{ graph_name } ', {{mlTrainingConfig: '{ mlTrainingConfig } ', modelName: '{ model_name } ', nodeLabels: { node_labels } , nodeProperties: { feature_properties } }})" )
0 commit comments