Skip to content

Commit e5e8ef2

Browse files
FlorentinDbrs96orazve
committed
WIP use arrow endpiont
Co-authored-by: Brian Shi <brian.shi@neotechnology.com> Co-authored-by: Olga Razvenskaia <olga.razvenskaia@neo4j.com>
1 parent 1e8237e commit e5e8ef2

File tree

5 files changed

+58
-49
lines changed

5 files changed

+58
-49
lines changed

examples/python-runtime.ipynb

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
{
44
"cell_type": "code",
55
"execution_count": null,
6-
"metadata": {
7-
"collapsed": true
8-
},
6+
"metadata": {},
97
"outputs": [],
108
"source": [
119
"DBID = \"beefbeef\"\n",
@@ -14,68 +12,46 @@
1412
"\n",
1513
"from graphdatascience import GraphDataScience\n",
1614
"\n",
17-
"gds = GraphDataScience(\n",
18-
" f\"neo4j+s://{DBID}-{ENVIRONMENT}.databases.neo4j-dev.io/\", auth=(\"neo4j\", PASSWORD)\n",
19-
")\n",
15+
"gds = GraphDataScience(f\"neo4j+s://{DBID}-{ENVIRONMENT}.databases.neo4j-dev.io/\", auth=(\"neo4j\", PASSWORD))\n",
2016
"gds.set_database(\"neo4j\")\n",
2117
"\n",
22-
"gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])\n"
18+
"gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])"
2319
]
2420
},
2521
{
2622
"cell_type": "code",
2723
"execution_count": null,
24+
"metadata": {},
2825
"outputs": [],
2926
"source": [
3027
"try:\n",
3128
" gds.graph.load_cora()\n",
3229
"except:\n",
33-
" pass\n"
34-
],
35-
"metadata": {
36-
"collapsed": false
37-
}
30+
" pass"
31+
]
3832
},
3933
{
4034
"cell_type": "code",
4135
"execution_count": null,
36+
"metadata": {},
4237
"outputs": [],
4338
"source": [
44-
"gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])\n"
45-
],
46-
"metadata": {
47-
"collapsed": false
48-
}
39+
"gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])"
40+
]
4941
},
5042
{
5143
"cell_type": "code",
5244
"execution_count": null,
45+
"metadata": {},
5346
"outputs": [],
5447
"source": [
5548
"gds.gnn.nodeClassification.predict(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])"
56-
],
57-
"metadata": {
58-
"collapsed": false
59-
}
49+
]
6050
}
6151
],
6252
"metadata": {
63-
"kernelspec": {
64-
"display_name": "Python 3",
65-
"language": "python",
66-
"name": "python3"
67-
},
6853
"language_info": {
69-
"codemirror_mode": {
70-
"name": "ipython",
71-
"version": 2
72-
},
73-
"file_extension": ".py",
74-
"mimetype": "text/x-python",
75-
"name": "python",
76-
"nbconvert_exporter": "python",
77-
"pygments_lexer": "ipython2",
78-
"version": "2.7.6"
54+
"name": "python"
7955
}
8056
},
8157
"nbformat": 4,

graphdatascience/endpoints.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
"""
3434

3535

36-
class DirectEndpoints(DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints, GnnEndpoints):
36+
class DirectEndpoints(
37+
DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints, GnnEndpoints
38+
):
3739
def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion):
3840
super().__init__(query_runner, namespace, server_version)
3941

graphdatascience/gnn/gnn_endpoints.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1-
from .gnn_nc_runner import GNNNodeClassificationRunner
21
from ..caller_base import CallerBase
32
from ..error.illegal_attr_checker import IllegalAttrChecker
43
from ..error.uncallable_namespace import UncallableNamespace
4+
from .gnn_nc_runner import GNNNodeClassificationRunner
5+
56

67
class GNNRunner(UncallableNamespace, IllegalAttrChecker):
78
@property
89
def nodeClassification(self) -> GNNNodeClassificationRunner:
9-
return GNNNodeClassificationRunner(self._query_runner, f"{self._namespace}.nodeClassification", self._server_version)
10+
return GNNNodeClassificationRunner(
11+
self._query_runner, f"{self._namespace}.nodeClassification", self._server_version
12+
)
13+
1014

1115
class GnnEndpoints(CallerBase):
1216
@property
1317
def gnn(self) -> GNNRunner:
1418
return GNNRunner(self._query_runner, f"{self._namespace}.gnn", self._server_version)
15-
16-
17-

graphdatascience/gnn/gnn_nc_runner.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,44 @@
1+
import json
12
from typing import Any, List
23

34
from ..error.illegal_attr_checker import IllegalAttrChecker
45
from ..error.uncallable_namespace import UncallableNamespace
5-
import json
66

77

88
class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker):
9-
def train(self, graph_name: str, model_name: str, feature_properties: List[str], target_property: str,
10-
target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]":
9+
def train(
10+
self,
11+
graph_name: str,
12+
model_name: str,
13+
feature_properties: List[str],
14+
target_property: str,
15+
target_node_label: str = None,
16+
node_labels: List[str] = None,
17+
) -> "Series[Any]":
1118
configMap = {
1219
"featureProperties": feature_properties,
1320
"targetProperty": target_property,
1421
"job_type": "train",
1522
}
23+
1624
node_properties = feature_properties + [target_property]
1725
if target_node_label:
1826
configMap["targetNodeLabel"] = target_node_label
1927
mlTrainingConfig = json.dumps(configMap)
2028
# TODO query available node labels
2129
node_labels = ["Paper"] if not node_labels else node_labels
22-
self._query_runner.run_query(f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {node_properties}}})")
23-
30+
self._query_runner.run_query(
31+
f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {node_properties}}})"
32+
)
2433

25-
def predict(self, graph_name: str, model_name: str, feature_properties: List[str], target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]":
34+
def predict(
35+
self,
36+
graph_name: str,
37+
model_name: str,
38+
feature_properties: List[str],
39+
target_node_label: str = None,
40+
node_labels: List[str] = None,
41+
) -> "Series[Any]":
2642
configMap = {
2743
"featureProperties": feature_properties,
2844
"job_type": "predict",
@@ -33,4 +49,5 @@ def predict(self, graph_name: str, model_name: str, feature_properties: List[str
3349
# TODO query available node labels
3450
node_labels = ["Paper"] if not node_labels else node_labels
3551
self._query_runner.run_query(
36-
f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {feature_properties}}})")
52+
f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {feature_properties}}})"
53+
)

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ def run_query(
129129
endpoint = "gds.beta.graph.relationships.stream"
130130

131131
return self._run_arrow_property_get(graph_name, endpoint, {"relationship_types": relationship_types})
132+
elif "gds.upload.graph" in query:
133+
self._run_arrow_upload_graph(params["config"])
132134

133135
return self._fallback_query_runner.run_query(query, params, database, custom_error)
134136

@@ -170,6 +172,17 @@ def _run_arrow_property_get(self, graph_name: str, procedure_name: str, configur
170172

171173
return result
172174

175+
def _run_arrow_upload_graph(self, meta_data: Dict[str, Any]) -> None:
176+
result = self._flight_client.do_action()
177+
# TODO : better name of the action -- INIT ML JOB ?
178+
result = self._flight_client().do_put(flight.Action("UPLOAD_GRAPH"), json.dumps(meta_data).encode("utf-8"))
179+
180+
# Consume result fully to sanity check and avoid cancelled streams
181+
collected_result = list(result)
182+
assert len(collected_result) == 1
183+
184+
print(collected_result[0])
185+
173186
def create_graph_constructor(
174187
self, graph_name: str, concurrency: int, undirected_relationship_types: Optional[List[str]]
175188
) -> GraphConstructor:

0 commit comments

Comments
 (0)