diff --git a/examples/python-runtime.ipynb b/examples/python-runtime.ipynb new file mode 100644 index 000000000..ca7d4c2f5 --- /dev/null +++ b/examples/python-runtime.ipynb @@ -0,0 +1,59 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "DBID = \"beefbeef\"\n", + "ENVIRONMENT = \"\"\n", + "PASSWORD = \"\"\n", + "\n", + "from graphdatascience import GraphDataScience\n", + "\n", + "gds = GraphDataScience(f\"neo4j+s://{DBID}-{ENVIRONMENT}.databases.neo4j-dev.io/\", auth=(\"neo4j\", PASSWORD))\n", + "gds.set_database(\"neo4j\")\n", + "\n", + "gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " gds.graph.load_cora()\n", + "except:\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gds.gnn.nodeClassification.predict(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/graphdatascience/endpoints.py b/graphdatascience/endpoints.py index 4abd44247..e91c1702b 100644 --- a/graphdatascience/endpoints.py +++ b/graphdatascience/endpoints.py @@ -1,5 +1,6 @@ from .algo.single_mode_algo_endpoints import SingleModeAlgoEndpoints from .call_builder import IndirectAlphaCallBuilder, IndirectBetaCallBuilder +from .gnn.gnn_endpoints import GnnEndpoints from .graph.graph_endpoints import ( GraphAlphaEndpoints, GraphBetaEndpoints, @@ -32,7 +33,9 @@ """ -class DirectEndpoints(DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints): +class DirectEndpoints( + DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints, GnnEndpoints +): def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion): super().__init__(query_runner, namespace, server_version) diff --git a/graphdatascience/gnn/__init__.py b/graphdatascience/gnn/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/gnn/gnn_endpoints.py b/graphdatascience/gnn/gnn_endpoints.py new file mode 100644 index 000000000..ba1b7b2b7 --- /dev/null +++ b/graphdatascience/gnn/gnn_endpoints.py @@ -0,0 +1,18 @@ +from ..caller_base import CallerBase +from ..error.illegal_attr_checker import IllegalAttrChecker +from ..error.uncallable_namespace import UncallableNamespace +from .gnn_nc_runner import GNNNodeClassificationRunner + + +class GNNRunner(UncallableNamespace, IllegalAttrChecker): + @property + def nodeClassification(self) -> GNNNodeClassificationRunner: + return GNNNodeClassificationRunner( + self._query_runner, f"{self._namespace}.nodeClassification", self._server_version + ) + + +class GnnEndpoints(CallerBase): + @property + def gnn(self) -> GNNRunner: + return GNNRunner(self._query_runner, f"{self._namespace}.gnn", self._server_version) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py new file mode 100644 index 000000000..27aec8d63 --- /dev/null +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -0,0 +1,79 @@ +import json +from typing import Any, List + +from ..error.illegal_attr_checker import IllegalAttrChecker +from ..error.uncallable_namespace import UncallableNamespace + + +class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker): + def make_graph_sage_config(self, graph_sage_config): + GRAPH_SAGE_DEFAULT_CONFIG = {"layer_config": {}, "num_neighbors": [25, 10], "dropout": 0.5, + "hidden_channels": 256, "learning_rate": 0.003} + final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG + if graph_sage_config: + bad_keys = [] + for key in graph_sage_config: + if key not in GRAPH_SAGE_DEFAULT_CONFIG: + bad_keys.append(key) + if len(bad_keys) > 0: + raise Exception(f"Argument graph_sage_config contains invalid keys {', '.join(bad_keys)}.") + + final_sage_config.update(graph_sage_config) + return final_sage_config + + def train( + self, + graph_name: str, + model_name: str, + feature_properties: List[str], + target_property: str, + relationship_types: List[str], + target_node_label: str = None, + node_labels: List[str] = None, + graph_sage_config = None + ) -> "Series[Any]": # noqa: F821 + mlConfigMap = { + "featureProperties": feature_properties, + "targetProperty": target_property, + "job_type": "train", + "nodeProperties": feature_properties + [target_property], + "relationshipTypes": relationship_types, + "graph_sage_config": self.make_graph_sage_config(graph_sage_config) + } + + if target_node_label: + mlConfigMap["targetNodeLabel"] = target_node_label + if node_labels: + mlConfigMap["nodeLabels"] = node_labels + + mlTrainingConfig = json.dumps(mlConfigMap) + + # token and uri will be injected by arrow_query_runner + self._query_runner.run_query( + "CALL gds.upload.graph($config)", + params={ + "config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name}, + }, + ) + + def predict( + self, + graph_name: str, + model_name: str, + mutateProperty: str, + predictedProbabilityProperty: str = None, + ) -> "Series[Any]": # noqa: F821 + mlConfigMap = { + "job_type": "predict", + "mutateProperty": mutateProperty + } + if predictedProbabilityProperty: + mlConfigMap["predictedProbabilityProperty"] = predictedProbabilityProperty + + mlTrainingConfig = json.dumps(mlConfigMap) + self._query_runner.run_query( + "CALL gds.upload.graph($config)", + params={ + "config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name}, + }, + ) # type: ignore diff --git a/graphdatascience/ignored_server_endpoints.py b/graphdatascience/ignored_server_endpoints.py index 89ad9f0b2..d103a90c4 100644 --- a/graphdatascience/ignored_server_endpoints.py +++ b/graphdatascience/ignored_server_endpoints.py @@ -47,6 +47,7 @@ "gds.alpha.pipeline.nodeRegression.predict.stream", "gds.alpha.pipeline.nodeRegression.selectFeatures", "gds.alpha.pipeline.nodeRegression.train", + "gds.gnn.nc", "gds.similarity.cosine", "gds.similarity.euclidean", "gds.similarity.euclideanDistance", diff --git a/graphdatascience/query_runner/arrow_query_runner.py b/graphdatascience/query_runner/arrow_query_runner.py index cf648879a..eab64398c 100644 --- a/graphdatascience/query_runner/arrow_query_runner.py +++ b/graphdatascience/query_runner/arrow_query_runner.py @@ -29,6 +29,9 @@ def __init__( ): self._fallback_query_runner = fallback_query_runner self._server_version = server_version + # FIXME handle version were tls cert is given + self._auth = auth + self._uri = uri host, port_string = uri.split(":") @@ -39,8 +42,9 @@ def __init__( ) client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification} + self._auth_factory = AuthFactory(auth) if auth: - client_options["middleware"] = [AuthFactory(auth)] + client_options["middleware"] = [self._auth_factory] if tls_root_certs: client_options["tls_root_certs"] = tls_root_certs @@ -129,6 +133,10 @@ def run_query( endpoint = "gds.beta.graph.relationships.stream" return self._run_arrow_property_get(graph_name, endpoint, {"relationship_types": relationship_types}) + elif "gds.upload.graph" in query: + # inject parameters + params["config"]["token"] = self._get_or_request_token() + params["config"]["arrowEndpoint"] = self._uri return self._fallback_query_runner.run_query(query, params, database, custom_error) @@ -184,6 +192,10 @@ def create_graph_constructor( database, graph_name, self._flight_client, concurrency, undirected_relationship_types ) + def _get_or_request_token(self) -> str: + self._flight_client.authenticate_basic_token(self._auth[0], self._auth[1]) + return self._auth_factory.token() + class AuthFactory(ClientMiddlewareFactory): # type: ignore def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None: @@ -217,9 +229,14 @@ def __init__(self, factory: AuthFactory, *args: Any, **kwargs: Any) -> None: self._factory = factory def received_headers(self, headers: Dict[str, Any]) -> None: - auth_header: str = headers.get("Authorization", None) + auth_header: str = headers.get("authorization", None) if not auth_header: return + # authenticate_basic_token() returns a list. + # TODO We should take the first Bearer element here + if isinstance(auth_header, list): + auth_header = auth_header[0] + [auth_type, token] = auth_header.split(" ", 1) if auth_type == "Bearer": self._factory.set_token(token)