Skip to content

Commit 5117eea

Browse files
Mats-SXbreakanalysisorazve
committed
Add predict endpoint to GNN NC runner
Co-authored-by: Jacob Sznajdman <breakanalysis@gmail.com> Co-authored-by: Olga Razvenskaia <olga.razvenskaia@neo4j.com>
1 parent 7cbb64b commit 5117eea

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

graphdatascience/gnn/gnn_nc_runner.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,26 @@ def train(self, graph_name: str, model_name: str, feature_properties: List[str],
1111
configMap = {
1212
"featureProperties": feature_properties,
1313
"targetProperty": target_property,
14+
"job_type": "train",
1415
}
1516
node_properties = feature_properties + [target_property]
1617
if target_node_label:
1718
configMap["targetNodeLabel"] = target_node_label
1819
mlTrainingConfig = json.dumps(configMap)
19-
# TODO query avaiable node labels
20+
# TODO query available node labels
2021
node_labels = ["Paper"] if not node_labels else node_labels
2122
self._query_runner.run_query(f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {node_properties}}})")
23+
24+
25+
def predict(self, graph_name: str, model_name: str, feature_properties: List[str], target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]":
26+
configMap = {
27+
"featureProperties": feature_properties,
28+
"job_type": "predict",
29+
}
30+
if target_node_label:
31+
configMap["targetNodeLabel"] = target_node_label
32+
mlTrainingConfig = json.dumps(configMap)
33+
# TODO query available node labels
34+
node_labels = ["Paper"] if not node_labels else node_labels
35+
self._query_runner.run_query(
36+
f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {feature_properties}}})")

0 commit comments

Comments
 (0)