From ce100331d8951559001686155df711445be04374 Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Mon, 1 Jul 2024 10:54:43 +0100 Subject: [PATCH 01/24] Copy from branch --- examples/FastPathExamples.ipynb | 760 ++++++++++++++++++++++ graphdatascience/graph_data_science.py | 2 +- graphdatascience/model/fastpath_runner.py | 115 ++++ 3 files changed, 876 insertions(+), 1 deletion(-) create mode 100644 examples/FastPathExamples.ipynb create mode 100644 graphdatascience/model/fastpath_runner.py diff --git a/examples/FastPathExamples.ipynb b/examples/FastPathExamples.ipynb new file mode 100644 index 000000000..2e304b218 --- /dev/null +++ b/examples/FastPathExamples.ipynb @@ -0,0 +1,760 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0f03a290", + "metadata": {}, + "source": [ + "# Path embeddings with FastPATH - Examples" + ] + }, + { + "cell_type": "markdown", + "id": "68b6e21f", + "metadata": {}, + "source": [ + "In this notebook we will show you several examples of constructing path embeddings with the FastPATH algorithm.\n", + "The full documentation for the algorithm can be found [here](https://docs.google.com/document/d/1oCAz6ukn_r19H27ghxnGM_-UQP9rgYJRhLzNLHdQc8Y/edit#heading=h.ya70gurwgyt2)." + ] + }, + { + "cell_type": "markdown", + "id": "c3bf7590", + "metadata": {}, + "source": [ + "## The Dataset\n", + "\n", + "We will use a synthetic medical dataset containg `Patients`, `Encounters`, `Conditions`, `Observations` and more.\n", + "Using FastPATH we will construct (path) embeddings for patient journey in the dataset.\n", + "You need to replace the Neo4j URL and credentials to a database that contains the dataset.\n", + "Contact the GDS team if you're interested in that.\n", + "\n", + "Below is the schema of the database:" + ] + }, + { + "attachments": { + "image.png": { + "image/png": "" + } + }, + "cell_type": "markdown", + "id": "316f034f", + "metadata": {}, + "source": [ + "![image.png](attachment:image.png)" + ] + }, + { + "cell_type": "markdown", + "id": "a062d180", + "metadata": {}, + "source": [ + "## Import and Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4fe27541", + "metadata": {}, + "outputs": [], + "source": [ + "from graphdatascience import GraphDataScience\n", + "import numpy as np\n", + "from sklearn.manifold import TSNE\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "from matplotlib import pyplot as plt\n", + "from sklearn.linear_model import LogisticRegression\n", + "from sklearn.metrics import f1_score\n", + "from sklearn.utils._testing import ignore_warnings\n", + "from sklearn.exceptions import ConvergenceWarning\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "plt.rcParams[\"figure.figsize\"] = [15, 10]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "238525b8", + "metadata": {}, + "outputs": [], + "source": [ + "gds = GraphDataScience(\n", + " \"neo4j+s://eddb7e19.databases.neo4j.io\",\n", + " auth=(\"neo4j\", \"Oz4oBK--Sx4byHjgHgJuMf5VqQncGHG9mbgpy44rQTU\"),\n", + " database=\"neo4j\",\n", + ")\n", + "gds.set_compute_cluster_ip(\"localhost\")" + ] + }, + { + "cell_type": "markdown", + "id": "c1f7417c", + "metadata": {}, + "source": [ + "## Preprocessing\n", + "\n", + "In order to make our dataset amenable to our analysis using FastPATH and downstream machine learning, we must augment it slightly.\n", + "This entails writing some additional node properties to the database with the Cypher code below.\n", + "\n", + "**NOTE: Each preprocessing cell below must be run once, and only once.**" + ] + }, + { + "cell_type": "markdown", + "id": "97bf5fc6", + "metadata": {}, + "source": [ + "First we write a `has_diabetes` property (0 or 1) to each `Patient` node.\n", + "This will give us class labels that enable us to train a classification model on patient journeys later." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b10f9b6", + "metadata": {}, + "outputs": [], + "source": [ + "gds.run_cypher(\"MATCH (p:Patient) SET p.has_diabetes=0\")\n", + "gds.run_cypher(\n", + " \"MATCH (p:Patient)-[:HAS_ENCOUNTER]->(n:Encounter)-[:HAS_CONDITION]-(c:Condition) WHERE c.description='Diabetes' SET p.has_diabetes=1\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "5f103fbb", + "metadata": {}, + "source": [ + "Then to each `Encounter` node, we write the number of days that has passed since 1 January 1970 (can be negative), based on the existing `start` node property.\n", + "We do this since the `start` property it already has is not an actual number, which is what the algorithm needs.\n", + "This is needed in the case where we don't rely on `NEXT` relationships for event timestamps, which is one of the examples below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5c307b6d", + "metadata": {}, + "outputs": [], + "source": [ + "gds.run_cypher(\n", + " \"MATCH (n:Encounter) WITH toInteger(datetime(n.start).epochseconds/(24 * 3600)) as days, n SET n.days=days\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "014e00e4", + "metadata": {}, + "source": [ + "Next we write two output time properties to each `Patient` based on the last `Encounter` before a diabetes diagnosis, or the last `Encounter` otherwise.\n", + "For the case where we are relying on the `days` node property on `Encounter`s (see above), the new `output_time` node property for `Patient`s will be equal to 1 + the `days` timestamp of their last encounter (before diabetes if they have it).\n", + "For the case where we are relying on `FIRST` and `NEXT` relationships to define the `Encounter`s belonging to a `Patient`, the new `output_time_stepwise` node property for `Patient`s will be equal to the number of encounters up to and including the last encounter (before diabetes if they have it).\n", + "\n", + "With these properties we can specify the point in time for which we want the path embeddings for each `Patient` node.\n", + "I.e. the paths that is embedded will continue up to that point, but not longer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b95b2ba3", + "metadata": {}, + "outputs": [], + "source": [ + "# Writing the `output_time` `Patient` node property\n", + "gds.run_cypher(\"MATCH (p:Patient)-[:LAST]->(n:Encounter) SET p.output_time=n.days+1\")\n", + "gds.run_cypher(\n", + " \"MATCH (p:Patient)-[:HAS_ENCOUNTER]->(e1:Encounter)-[:NEXT]->(e2:Encounter)-[:HAS_CONDITION]->(c:Condition) WHERE c.description='Diabetes' SET p.output_time=e1.days + 1\"\n", + ")\n", + "\n", + "# Writing `output_time_stepwise` `Patient` node property\n", + "gds.run_cypher(\n", + " \"MATCH (p:Patient)-[:HAS_ENCOUNTER]->(e:Encounter) WHERE e.days <= p.output_time - 1 WITH p, count(*) as cc SET p.output_time_stepwise=cc\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "68dffdb0", + "metadata": {}, + "source": [ + "Lastly we write the `class` of each `Encounter` as an integer property `intClass`.\n", + "Doing so enables us to use the class property as input to the algorithm, impacting the internal embeddings of `Encounter` nodes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b8e0f5b", + "metadata": {}, + "outputs": [], + "source": [ + "gds.run_cypher(\n", + " \"\"\"\n", + " MATCH (e:Encounter) with distinct e.class AS class\n", + " WITH collect(class) as clss\n", + " WITH apoc.map.fromLists(clss, range(0, size(clss) - 1)) as classMap\n", + " MATCH (e:Encounter) SET e.intClass = classMap[e.class]\n", + " \"\"\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "ba366637", + "metadata": {}, + "source": [ + "## Projection with Timestamps\n", + "\n", + "For the first examples, we rely on the `days` property of `Encounter` nodes for timestamp.\n", + "For this reason we don't need to project `FIRST` and `NEXT` relationships." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6cf05097", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " G = gds.graph.get(\"medical\")\n", + " G.drop()\n", + "except:\n", + " pass\n", + "\n", + "G, _ = gds.graph.project(\n", + " \"medical\",\n", + " {\n", + " \"Patient\": {\"properties\": [\"output_time\", \"has_diabetes\"]},\n", + " \"Encounter\": {\"properties\": [\"days\", \"intClass\"]},\n", + " \"Observation\": {\"properties\": []},\n", + " \"Payer\": {\"properties\": []},\n", + " \"Provider\": {\"properties\": []},\n", + " \"Organization\": {\"properties\": []},\n", + " \"Speciality\": {\"properties\": []},\n", + " \"Allergy\": {\"properties\": []},\n", + " \"Reaction\": {\"properties\": []},\n", + " \"Condition\": {\"properties\": []},\n", + " \"Drug\": {\"properties\": []},\n", + " \"Procedure\": {\"properties\": []},\n", + " \"CarePlan\": {\"properties\": []},\n", + " \"Device\": {\"properties\": []},\n", + " \"ConditionDescription\": {\"properties\": []},\n", + " },\n", + " [\n", + " \"HAS_OBSERVATION\",\n", + " \"HAS_ENCOUNTER\",\n", + " \"HAS_PROVIDER\",\n", + " \"AT_ORGANIZATION\",\n", + " \"HAS_PAYER\",\n", + " \"HAS_SPECIALITY\",\n", + " \"BELONGS_TO\",\n", + " \"INSURANCE_START\",\n", + " \"INSURANCE_END\",\n", + " \"HAS_ALLERGY\",\n", + " \"ALLERGY_DETECTED\",\n", + " \"HAS_REACTION\",\n", + " \"CAUSES_REACTION\",\n", + " \"HAS_CONDITION\",\n", + " \"HAS_DRUG\",\n", + " \"HAS_PROCEDURE\",\n", + " \"HAS_CARE_PLAN\",\n", + " \"DEVICE_USED\",\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e6aacc68", + "metadata": {}, + "source": [ + "## FastRP Features\n", + "\n", + "We should make use of the topological information we have around in `Encounter` node.\n", + "For example, what `Condition`s, `Drug`s, `Procedure`s, etc. (see schema above) are connected to it.\n", + "And perhaps one hop in the graph beyond that.\n", + "To do so, we make use of FastRP to create node embeddings.\n", + "Later we can input the node embeddings of the `Encounter` nodes to the FastPATH algorithm using the `event_features` parameter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7b0c35e", + "metadata": {}, + "outputs": [], + "source": [ + "gds.fastRP.mutate(\n", + " G,\n", + " embeddingDimension=256,\n", + " mutateProperty=\"emb\",\n", + " iterationWeights=[1, 1],\n", + " randomSeed=42,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "76cc51a7", + "metadata": {}, + "source": [ + "## Preparation of machine learning and visualization of embeddings\n", + "\n", + "Below we define a utility function that we can subsequently use to analyze the path embeddings we produce in each example below.\n", + "This function does three things:\n", + "1. Computes the average pairwise distances between embeddings of the different class combinations (no diabetes vs diabetes)\n", + "2. Plot the path embeddings in two dimensions with t-SNE\n", + "3. Train and evaluate a logistic regression diabetes classifier which takes path embeddings as input\n", + "\n", + "**NOTE: You don't have to read or understand this function, but can think of it as a black box in the context of this notebook.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f8d6247", + "metadata": {}, + "outputs": [], + "source": [ + "@ignore_warnings(category=ConvergenceWarning)\n", + "def explore(embeddings):\n", + "\n", + " # Compute pairwise distances between embeddings of healty<->healthy, sick<->sick and healthy<->sick.\n", + "\n", + " diabetes_by_nodeId = gds.graph.streamNodeProperties(G, [\"has_diabetes\"], [\"Patient\"]).set_index(\"nodeId\")[\n", + " [\"propertyValue\"]\n", + " ]\n", + " emb_and_diabetes = (\n", + " embeddings[[\"nodeId\", \"embeddings\"]]\n", + " .set_index(\"nodeId\")\n", + " .merge(diabetes_by_nodeId, left_index=True, right_index=True)\n", + " )\n", + " healthy_embs = np.array(emb_and_diabetes[emb_and_diabetes.propertyValue == 0][\"embeddings\"].tolist())\n", + " diabetes_embs = np.array(emb_and_diabetes[emb_and_diabetes.propertyValue == 1][\"embeddings\"].tolist())\n", + "\n", + " diabetes_distances = []\n", + " for i in range(diabetes_embs.shape[0]):\n", + " for j in range(i + 1, diabetes_embs.shape[0]):\n", + " x1 = diabetes_embs[i, :]\n", + " x2 = diabetes_embs[j, :]\n", + " diabetes_distances.append(np.linalg.norm(x1 - x2))\n", + "\n", + " print(f\"Avg diabetes<->diabetes L2-distances: {np.mean(diabetes_distances)}\")\n", + "\n", + " healthy_distances = []\n", + " for i in range(healthy_embs.shape[0]):\n", + " for j in range(i + 1, healthy_embs.shape[0]):\n", + " x1 = healthy_embs[i, :]\n", + " x2 = healthy_embs[j, :]\n", + " healthy_distances.append(np.linalg.norm(x1 - x2))\n", + "\n", + " print(f\"Avg healthy<->healthy L2-distances: {np.mean(healthy_distances)}\")\n", + "\n", + " mixed_distances = []\n", + " for i in range(diabetes_embs.shape[0]):\n", + " for j in range(healthy_embs.shape[0]):\n", + " x1 = diabetes_embs[i, :]\n", + " x2 = healthy_embs[j, :]\n", + " mixed_distances.append(np.linalg.norm(x1 - x2))\n", + "\n", + " print(f\"Avg healthy<->diabetes L2-distances: {np.mean(mixed_distances)}\")\n", + "\n", + " # TSNE time\n", + "\n", + " X = np.array(emb_and_diabetes[\"embeddings\"].tolist())\n", + " y = emb_and_diabetes.propertyValue.to_numpy()\n", + " tsne = TSNE(2)\n", + " tsne_result = tsne.fit_transform(X)\n", + " tsne_result_df = pd.DataFrame({\"tsne_1\": tsne_result[:, 0], \"tsne_2\": tsne_result[:, 1], \"label\": y})\n", + " fig, ax = plt.subplots(1)\n", + " sns.scatterplot(x=\"tsne_1\", y=\"tsne_2\", hue=\"label\", data=tsne_result_df, ax=ax, s=10)\n", + " lim = (tsne_result.min() - 5, tsne_result.max() + 5)\n", + " ax.set_xlim(lim)\n", + " ax.set_ylim(lim)\n", + " ax.set_aspect(\"equal\")\n", + " ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)\n", + "\n", + " # Train evaluate diabetes classifier :)\n", + "\n", + " X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42, stratify=y)\n", + "\n", + " clf = LogisticRegression()\n", + " clf.fit(X_train, y_train)\n", + "\n", + " y_train_pred = clf.predict(X_train)\n", + " y_test_pred = clf.predict(X_test)\n", + "\n", + " train_f1_score = f1_score(y_train, y_train_pred, average=\"macro\")\n", + " test_f1_score = f1_score(y_test, y_test_pred, average=\"macro\")\n", + "\n", + " print(\"Diabetes classifier scores:\")\n", + " print(f\"Train set f1: {train_f1_score}\")\n", + " print(f\"Test set f1: {test_f1_score}\")" + ] + }, + { + "cell_type": "markdown", + "id": "0236dda1", + "metadata": {}, + "source": [ + "## Examples with timestamp node properties\n", + "\n", + "In the following few examples we will let the `days` node property on `Encounter` nodes dictate when an encounter has occured." + ] + }, + { + "cell_type": "markdown", + "id": "9f012dae", + "metadata": {}, + "source": [ + "### Global output time\n", + "\n", + "To use a single fixed output time, you can either\n", + "* Use the algorithm parameter `output_times` (and optionally use subgraph filtering to run only up to a certain time), or\n", + "* Use Cypher to write a output time property to the `Patient` nodes holding a fixed timestamp, and then provide it as `output_time_property`\n", + "\n", + "Here we will use the first option.\n", + "\n", + "Note that we are also using the FastRP embeddings for `Encounter` nodes as input features to the events (encounters)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c216bcf", + "metadata": {}, + "outputs": [], + "source": [ + "# try:\n", + "gds.graph.nodeProperties.drop(G, [\"embeddings\"], node_labels=[\"Patient\"])\n", + "# except:\n", + "# pass\n", + "\n", + "gds.fastpath.mutate(\n", + " G,\n", + " base_node_label=\"Patient\",\n", + " event_node_label=\"Encounter\",\n", + " event_features=\"emb\",\n", + " time_node_property=\"days\",\n", + " dimension=256,\n", + " num_elapsed_times=100,\n", + " output_time=365 * 50, # 50 years\n", + " max_elapsed_time=365 * 10, # 10 years\n", + " smoothing_rate=0.004,\n", + " smoothing_window=3,\n", + " decay_factor=1e-5,\n", + " random_seed=42,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4def10f", + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = gds.graph.nodeProperties.stream(G, [\"embeddings\"], node_labels=[\"Patient\"], separate_property_columns=True)\n", + "print(embeddings)\n", + "explore(embeddings)" + ] + }, + { + "cell_type": "markdown", + "id": "8f283d0f", + "metadata": {}, + "source": [ + "## Example with individual output time" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "022a8096", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " gds.graph.nodeProperties.drop(G, [\"embeddings\"], node_labels=[\"Patient\"])\n", + "except:\n", + " pass\n", + "\n", + "embeddings = gds.fastpath.mutate(\n", + " G,\n", + " base_node_label=\"Patient\",\n", + " event_node_label=\"Encounter\",\n", + " event_features=\"emb\",\n", + " time_node_property=\"days\",\n", + " dimension=256,\n", + " num_elapsed_times=100,\n", + " output_time_property=\"output_time\",\n", + " max_elapsed_time=365 * 10,\n", + " smoothing_rate=0.004,\n", + " smoothing_window=3,\n", + " decay_factor=1e-4,\n", + " random_seed=42,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dce1f527", + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = gds.graph.nodeProperties.stream(G, [\"embeddings\"], node_labels=[\"Patient\"], separate_property_columns=True)\n", + "print(embeddings)\n", + "explore(embeddings)" + ] + }, + { + "cell_type": "markdown", + "id": "b629f2ab", + "metadata": {}, + "source": [ + "# Example with categorical event property and input event vectors\n", + "As the type (class) of encounter may be important to characterize patient journeys and to classify diabetes, we 'intClass' as a categorical event property." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa929721", + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = gds.fastpath.mutate(\n", + " G,\n", + " base_node_label=\"Patient\",\n", + " event_node_label=\"Encounter\",\n", + " event_features=\"emb\",\n", + " time_node_property=\"days\",\n", + " categorical_event_properties=[\"intClass\"],\n", + " dimension=256,\n", + " num_elapsed_times=100,\n", + " output_time_property=\"output_time\",\n", + " max_elapsed_time=365 * 10,\n", + " smoothing_rate=0.004,\n", + " smoothing_window=3,\n", + " decay_factor=1e-4,\n", + " random_seed=42,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01a6f368", + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = gds.graph.nodeProperties.stream(G, [\"embeddings\"], node_labels=[\"Patient\"], separate_property_columns=True)\n", + "print(embeddings)\n", + "explore(embeddings)" + ] + }, + { + "cell_type": "markdown", + "id": "9bd0798e", + "metadata": {}, + "source": [ + "# Example with context nodes and input event vectors\n", + "As the history of drugs may be important to characterize patient journeys and to classify diabetes, we add 'Drug' as a context_node_label." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5c27d60", + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = gds.fastpath.stream(\n", + " G,\n", + " base_node_label=\"Patient\",\n", + " context_node_label=\"Drug\",\n", + " event_node_label=\"Encounter\",\n", + " event_features=\"emb\",\n", + " time_node_property=\"days\",\n", + " dimension=256,\n", + " # num_elapsed_times=100,\n", + " num_elapsed_times=1,\n", + " output_time_property=\"output_time\",\n", + " # max_elapsed_time=365 * 10,\n", + " max_elapsed_time=1,\n", + " smoothing_rate=0.004,\n", + " smoothing_window=0,\n", + " # smoothing_window=3,\n", + " decay_factor=1e-4,\n", + " random_seed=43,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3f46994", + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = gds.graph.nodeProperties.stream(G, [\"embeddings\"], node_labels=[\"Patient\"], separate_property_columns=True)\n", + "print(embeddings)\n", + "explore(embeddings)" + ] + }, + { + "cell_type": "markdown", + "id": "f77b148a", + "metadata": {}, + "source": [ + "# Example with next and first relationship schema\n", + "We will now repeat one of the previous examples but use a different schema for the paths.\n", + "In this case it will give the same graph and embeddings, but the example is useful for illustrating the use of the next-first schema." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36067015", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " G = gds.graph.get(\"medical\")\n", + " G.drop()\n", + "except:\n", + " pass\n", + "\n", + "G, _ = gds.graph.project(\n", + " \"medical\",\n", + " {\n", + " \"Patient\": {\"properties\": [\"output_time\", \"output_time_stepwise\", \"has_diabetes\"]},\n", + " \"Encounter\": {\"properties\": [\"days\", \"intClass\"]},\n", + " \"Observation\": {\"properties\": []},\n", + " \"Payer\": {\"properties\": []},\n", + " \"Provider\": {\"properties\": []},\n", + " \"Organization\": {\"properties\": []},\n", + " \"Speciality\": {\"properties\": []},\n", + " \"Allergy\": {\"properties\": []},\n", + " \"Reaction\": {\"properties\": []},\n", + " \"Condition\": {\"properties\": []},\n", + " \"Drug\": {\"properties\": []},\n", + " \"Procedure\": {\"properties\": []},\n", + " \"CarePlan\": {\"properties\": []},\n", + " \"Device\": {\"properties\": []},\n", + " \"ConditionDescription\": {\"properties\": []},\n", + " },\n", + " [\n", + " \"HAS_OBSERVATION\",\n", + " \"NEXT\",\n", + " \"FIRST\",\n", + " \"HAS_PROVIDER\",\n", + " \"AT_ORGANIZATION\",\n", + " \"HAS_PAYER\",\n", + " \"HAS_SPECIALITY\",\n", + " \"BELONGS_TO\",\n", + " \"INSURANCE_START\",\n", + " \"INSURANCE_END\",\n", + " \"HAS_ALLERGY\",\n", + " \"ALLERGY_DETECTED\",\n", + " \"HAS_REACTION\",\n", + " \"CAUSES_REACTION\",\n", + " \"HAS_CONDITION\",\n", + " \"HAS_DRUG\",\n", + " \"HAS_PROCEDURE\",\n", + " \"HAS_CARE_PLAN\",\n", + " \"DEVICE_USED\",\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5e27d608", + "metadata": {}, + "outputs": [], + "source": [ + "gds.fastRP.mutate(\n", + " G,\n", + " embeddingDimension=256,\n", + " mutateProperty=\"emb\",\n", + " iterationWeights=[1, 1],\n", + " randomSeed=42,\n", + " relationshipTypes=[\n", + " \"HAS_OBSERVATION\",\n", + " \"HAS_PROVIDER\",\n", + " \"AT_ORGANIZATION\",\n", + " \"HAS_PAYER\",\n", + " \"HAS_SPECIALITY\",\n", + " \"BELONGS_TO\",\n", + " \"INSURANCE_START\",\n", + " \"INSURANCE_END\",\n", + " \"HAS_ALLERGY\",\n", + " \"ALLERGY_DETECTED\",\n", + " \"HAS_REACTION\",\n", + " \"CAUSES_REACTION\",\n", + " \"HAS_CONDITION\",\n", + " \"HAS_DRUG\",\n", + " \"HAS_PROCEDURE\",\n", + " \"HAS_CARE_PLAN\",\n", + " \"DEVICE_USED\",\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1aaebe70", + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = gds.fastpath.stream(\n", + " G,\n", + " base_node_label=\"Patient\",\n", + " context_node_label=\"Drug\",\n", + " event_node_label=\"Encounter\",\n", + " event_features=\"emb\",\n", + " next_relationship_type=\"NEXT\",\n", + " first_relationship_type=\"FIRST\",\n", + " time_node_property=\"days\",\n", + " dimension=256,\n", + " num_elapsed_times=100,\n", + " output_time_property=\"output_time\",\n", + " max_elapsed_time=365 * 10,\n", + " smoothing_rate=0.003701319681951021,\n", + " smoothing_window=3,\n", + " decay_factor=8.232744730741784e-05,\n", + " random_seed=43,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "faf3ee7a", + "metadata": {}, + "outputs": [], + "source": [ + "explore(embeddings, \"embeddings\")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/graphdatascience/graph_data_science.py b/graphdatascience/graph_data_science.py index d280f1a35..b6f37bcb9 100644 --- a/graphdatascience/graph_data_science.py +++ b/graphdatascience/graph_data_science.py @@ -83,7 +83,7 @@ def __init__( None if arrow is True else arrow, ) - super().__init__(self._query_runner, namespace="gds", server_version=self._server_version) + super().__init__(self._query_runner, "gds", self._server_version) @property def graph(self) -> GraphProcRunner: diff --git a/graphdatascience/model/fastpath_runner.py b/graphdatascience/model/fastpath_runner.py new file mode 100644 index 000000000..6455efcf2 --- /dev/null +++ b/graphdatascience/model/fastpath_runner.py @@ -0,0 +1,115 @@ +import logging +import os +import time +from typing import Any, Dict, Optional + +import requests +from pandas import Series + +from ..error.client_only_endpoint import client_only_endpoint +from ..error.illegal_attr_checker import IllegalAttrChecker +from ..error.uncallable_namespace import UncallableNamespace +from ..graph.graph_object import Graph +from ..query_runner.query_runner import QueryRunner +from ..server_version.compatible_with import compatible_with +from ..server_version.server_version import ServerVersion + +logging.basicConfig(level=logging.INFO) + + +class FastPathRunner(UncallableNamespace, IllegalAttrChecker): + def __init__( + self, + query_runner: QueryRunner, + namespace: str, + server_version: ServerVersion, + compute_cluster_ip: str, + encrypted_db_password: str, + arrow_uri: str, + ): + self._query_runner = query_runner + self._namespace = namespace + self._server_version = server_version + self._compute_cluster_web_uri = f"http://{compute_cluster_ip}:5005" + self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8815" + self._compute_cluster_mlflow_uri = f"http://{compute_cluster_ip}:8080" + self._encrypted_db_password = encrypted_db_password + self._arrow_uri = arrow_uri + + @compatible_with("stream", min_inclusive=ServerVersion(2, 5, 0)) + @client_only_endpoint("gds.fastpath") + def mutate( + self, + G: Graph, + graph_filter: Optional[Dict[str, Any]] = None, + mlflow_experiment_name: Optional[str] = None, + **algo_config: Any, + ) -> Series: + if graph_filter is None: + # Take full graph if no filter provided + node_filter = G.node_properties().to_dict() + rel_filter = G.relationship_properties().to_dict() + graph_filter = {"node_filter": node_filter, "rel_filter": rel_filter} + + graph_config = {"name": G.name()} + graph_config.update(graph_filter) + + config = { + "user_name": "DUMMY_USER", + "task": "FASTPATH", + "task_config": { + "graph_config": graph_config, + "task_config": algo_config, + "stream_node_results": True, + }, + "encrypted_db_password": self._encrypted_db_password, + "graph_arrow_uri": self._arrow_uri, + } + + if mlflow_experiment_name is not None: + config["task_config"]["mlflow"] = { + "config": {"tracking_uri": self._compute_cluster_mlflow_uri, "experiment_name": mlflow_experiment_name} + } + + job_id = self._start_job(config) + + self._wait_for_job(job_id) + + return Series({"status": "finished"}) + + # return self._stream_results(job_id) + + def _start_job(self, config: Dict[str, Any]) -> str: + res = requests.post(f"{self._compute_cluster_web_uri}/api/machine-learning/start", json=config) + res.raise_for_status() + job_id = res.json()["job_id"] + logging.info(f"Job with ID '{job_id}' started") + + return job_id + + def _wait_for_job(self, job_id: str) -> None: + while True: + time.sleep(1) + + res = requests.get(f"{self._compute_cluster_web_uri}/api/machine-learning/status/{job_id}") + + res_json = res.json() + if res_json["job_status"] == "exited": + logging.info("FastPath job completed!") + return + elif res_json["job_status"] == "failed": + error = f"FastPath job failed with errors:{os.linesep}{os.linesep.join(res_json['errors'])}" + if res.status_code == 400: + raise ValueError(error) + else: + raise RuntimeError(error) + + # def _stream_results(self, job_id: str) -> DataFrame: + # client = pa.flight.connect(self._compute_cluster_arrow_uri) + + # upload_descriptor = pa.flight.FlightDescriptor.for_path(f"{job_id}.nodes") + # flight = client.get_flight_info(upload_descriptor) + # reader = client.do_get(flight.endpoints[0].ticket) + # read_table = reader.read_all() + + # return read_table.to_pandas() From a3382eb03501ae89eed27baddf0a2f480270ebc0 Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Mon, 1 Jul 2024 16:57:50 +0100 Subject: [PATCH 02/24] Start for the notebook --- examples/kge-distmult.ipynb | 692 ++++++++++++++++++ graphdatascience/graph_data_science.py | 15 +- graphdatascience/model/kge_runner.py | 109 +++ .../resources/field-testing/__init__.py | 0 .../resources/field-testing/pub.pem | 4 + 5 files changed, 812 insertions(+), 8 deletions(-) create mode 100644 examples/kge-distmult.ipynb create mode 100644 graphdatascience/model/kge_runner.py create mode 100644 graphdatascience/resources/field-testing/__init__.py create mode 100644 graphdatascience/resources/field-testing/pub.pem diff --git a/examples/kge-distmult.ipynb b/examples/kge-distmult.ipynb new file mode 100644 index 000000000..6686b85b5 --- /dev/null +++ b/examples/kge-distmult.ipynb @@ -0,0 +1,692 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Knowledge graph embeddings: DistMult" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from graphdatascience import GraphDataScience\n", + "import torch\n", + "import torch.optim as optim\n", + "import collections\n", + "from tqdm import tqdm\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "NEO4J_URI = os.environ.get(\"NEO4J_URI\", \"bolt://localhost:7687\")\n", + "NEO4J_AUTH = None\n", + "NEO4J_DB = os.environ.get(\"NEO4J_DB\", \"neo4j\")\n", + "if os.environ.get(\"NEO4J_USER\") and os.environ.get(\"NEO4J_PASSWORD\"):\n", + " NEO4J_AUTH = (\n", + " os.environ.get(\"NEO4J_USER\"),\n", + " os.environ.get(\"NEO4J_PASSWORD\"),\n", + " )\n", + "gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Downloading and Storing the FB15k-237 Dataset in the Database\n", + "Download the FB15k-237 dataset\n", + "Extract the required files: train.txt, valid.txt, and test.txt." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Set a constraint for unique id entries to speed up data uploads." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "_ = gds.run_cypher(\"CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "**Creating Entity Nodes**:\n", + " Create a node with the label `Entity`. This node should have properties `id` and `text`. \n", + " - Syntax: `(:Entity {id: int, text: str})`\n", + "\n", + "**Creating Relationships for Training with PyG**:\n", + " Based on the training stage, create relationships of type `TRAIN`, `TEST`, or `VALID`. Each of these relationships should have a `rel_id` property.\n", + " - Example Syntax: `[:TRAIN {rel_id: int}]`\n", + "\n", + "**Creating Relationships for Prediction with GDS**:\n", + " For the prediction stage, create relationships of a specific type denoted as `REL_i`. Each of these relationships should have `rel_id` and `text` properties.\n", + " - Example Syntax: `[:REL_7 {rel_id: int, text: str}]`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "from ogb.utils.url import download_url\n", + "import os\n", + "import zipfile\n", + "\n", + "url = \"https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip\"\n", + "raw_dir = \"./data_from_zip\"\n", + "download_url(f\"{url}\", raw_dir)\n", + "\n", + "raw_file_names = [\"train.txt\", \"valid.txt\", \"test.txt\"]\n", + "with zipfile.ZipFile(raw_dir + \"/\" + os.path.basename(url), \"r\") as zip_ref:\n", + " for filename in raw_file_names:\n", + " zip_ref.extract(f\"Release/{filename}\", path=raw_dir)\n", + "data_dir = raw_dir + \"/\" + \"Release\"\n", + "\n", + "rel_types = {\n", + " \"train.txt\": \"TRAIN\",\n", + " \"valid.txt\": \"VALID\",\n", + " \"test.txt\": \"TEST\",\n", + "}\n", + "rel_id_to_text_dict = {}\n", + "rel_type_dict = collections.defaultdict(list)\n", + "rel_dict = {}\n", + "\n", + "\n", + "def read_data():\n", + " node_id_set = {}\n", + " dataset = defaultdict(lambda: defaultdict(list))\n", + " for file_name in raw_file_names:\n", + " file_name_path = data_dir + \"/\" + file_name\n", + "\n", + " with open(file_name_path, \"r\") as f:\n", + " data = [x.split(\"\\t\") for x in f.read().split(\"\\n\")[:-1]]\n", + "\n", + " for i, (src_text, rel_text, dst_text) in enumerate(data):\n", + " if src_text not in node_id_set:\n", + " node_id_set[src_text] = len(node_id_set)\n", + " if dst_text not in node_id_set:\n", + " node_id_set[dst_text] = len(node_id_set)\n", + " if rel_text not in rel_dict:\n", + " rel_dict[rel_text] = len(rel_dict)\n", + " rel_id_to_text_dict[rel_dict[rel_text]] = rel_text\n", + "\n", + " source = node_id_set[src_text]\n", + " target = node_id_set[dst_text]\n", + " rel_type = \"REL_\" + str(rel_dict[rel_text])\n", + " rel_split = rel_types[file_name]\n", + "\n", + " dataset[rel_split][rel_type].append(\n", + " {\n", + " \"source\": source,\n", + " \"source_text\": src_text,\n", + " \"target\": target,\n", + " \"target_text\": dst_text,\n", + " # \"rel_text\": rel_text,\n", + " }\n", + " )\n", + "\n", + " print(\"Number of nodes: \", len(node_id_set))\n", + " for rel_split in dataset:\n", + " print(\n", + " f\"Number of relationships of type {rel_split}: \",\n", + " sum([len(dataset[rel_split][rel_type]) for rel_type in dataset[rel_split]]),\n", + " )\n", + " return dataset\n", + "\n", + "\n", + "dataset = read_data()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def put_data_in_db(dataset):\n", + " for rel_split in tqdm(dataset, desc=\"Relationship\"):\n", + " for rel_type in tqdm(dataset[rel_split], mininterval=1, leave=False):\n", + " edges = dataset[rel_split][rel_type]\n", + "\n", + " # MERGE (n)-[:{rel_type} {{text:l.rel_text}}]->(m)\n", + " gds.run_cypher(\n", + " f\"\"\"\n", + " UNWIND $ll as l\n", + " MERGE (n:Entity {{id:l.source, text:l.source_text}})\n", + " MERGE (m:Entity {{id:l.target, text:l.target_text}})\n", + " MERGE (n)-[:{rel_split}]->(m)\n", + " MERGE (n)-[:{rel_type}]->(m)\n", + " \"\"\",\n", + " params={\"ll\": edges},\n", + " )\n", + "\n", + " for rel_split in dataset:\n", + " res = gds.run_cypher(\n", + " f\"\"\"\n", + " MATCH ()-[r:{rel_split}]->()\n", + " RETURN COUNT(r) AS numberOfRelationships\n", + " \"\"\"\n", + " )\n", + " print(f\"Number of relationships of type {rel_split} in db: \", res.numberOfRelationships)\n", + "\n", + "\n", + "put_data_in_db(dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Project all data in graph to get mapping between `id` and internal `nodeId` field from database." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ALL_RELS = dataset[\"TRAIN\"].keys()\n", + "G_train, result = gds.graph.cypher.project(\n", + " \"\"\"\n", + " MATCH (n:Entity)-[:TRAIN]->(m:Entity)<-[:\"\"\"\n", + " + \"|\".join(ALL_RELS)\n", + " + \"\"\"]-(n)\n", + " RETURN gds.graph.project($graph_name, n, m, {\n", + " sourceNodeLabels: $label,\n", + " targetNodeLabels: $label\n", + " })\n", + " \"\"\", # Cypher query\n", + " database=\"neo4j\", # Target database\n", + " graph_name=\"trainGraph\", # Query parameter\n", + " label=\"Entity\", # Query parameter\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def inspect_graph(G):\n", + " func_names = [\n", + " \"name\",\n", + " # \"database\",\n", + " \"node_count\",\n", + " \"relationship_count\",\n", + " \"node_labels\",\n", + " \"relationship_types\",\n", + " # \"degree_distribution\", \"density\", \"size_in_bytes\", \"memory_usage\", \"exists\", \"configuration\", \"creation_time\", \"modification_time\",\n", + " ]\n", + " for func_name in func_names:\n", + " print(f\"==={func_name}===: {getattr(G, func_name)()}\")\n", + "\n", + "\n", + "inspect_graph(G_train)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gds.set_compute_cluster_ip(\"localhost\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gds.kge.model.train(\n", + " G_train,\n", + " scoring_function=\"distmult\",\n", + " num_epochs=10,\n", + " embedding_dimension=100,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "node_projection = {\"Entity\": {\"properties\": \"id\"}}\n", + "relationship_projection = [\n", + " {\"TRAIN\": {\"orientation\": \"NATURAL\", \"properties\": \"rel_id\"}},\n", + " {\"TEST\": {\"orientation\": \"NATURAL\", \"properties\": \"rel_id\"}},\n", + " {\"VALID\": {\"orientation\": \"NATURAL\", \"properties\": \"rel_id\"}},\n", + "]\n", + "\n", + "ttv_G, result = gds.graph.project(\n", + " \"fb15k-graph-ttv\",\n", + " node_projection,\n", + " relationship_projection,\n", + ")\n", + "\n", + "node_properties = gds.graph.nodeProperties.stream(\n", + " ttv_G,\n", + " [\"id\"],\n", + " separate_property_columns=True,\n", + ")\n", + "\n", + "nodeId_to_id = dict(zip(node_properties.nodeId, node_properties.id))\n", + "id_to_nodeId = dict(zip(node_properties.id, node_properties.nodeId))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Training the TransE Model with PyG\n", + "\n", + "Retrieve data from the database, convert it into torch tensors, and format it into a `Data` structure suitable for training with PyG." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_data_from_graph(relationship_type):\n", + " rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, \"rel_id\", relationship_type)\n", + " topology = [\n", + " rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]),\n", + " rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]),\n", + " ]\n", + " edge_index = torch.tensor(topology, dtype=torch.long)\n", + " edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long)\n", + " data = Data(edge_index=edge_index, edge_type=edge_type)\n", + " data.num_nodes = len(nodeId_to_id)\n", + " display(data)\n", + " return data\n", + "\n", + "\n", + "train_tensor_data = create_data_from_graph(\"TRAIN\")\n", + "test_tensor_data = create_data_from_graph(\"TEST\")\n", + "val_tensor_data = create_data_from_graph(\"VALID\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Drop the projected graph to save memory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gds.graph.drop(ttv_G)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "The training process of the TransE model follows the corresponding PyG [example](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def train_model_with_pyg():\n", + " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + " model = TransE(\n", + " num_nodes=train_tensor_data.num_nodes,\n", + " num_relations=train_tensor_data.num_edge_types,\n", + " hidden_channels=50,\n", + " ).to(device)\n", + "\n", + " loader = model.loader(\n", + " head_index=train_tensor_data.edge_index[0],\n", + " rel_type=train_tensor_data.edge_type,\n", + " tail_index=train_tensor_data.edge_index[1],\n", + " batch_size=1000,\n", + " shuffle=True,\n", + " )\n", + "\n", + " optimizer = optim.Adam(model.parameters(), lr=0.01)\n", + "\n", + " def train():\n", + " model.train()\n", + " total_loss = total_examples = 0\n", + " for head_index, rel_type, tail_index in loader:\n", + " optimizer.zero_grad()\n", + " loss = model.loss(head_index, rel_type, tail_index)\n", + " loss.backward()\n", + " optimizer.step()\n", + " total_loss += float(loss) * head_index.numel()\n", + " total_examples += head_index.numel()\n", + " return total_loss / total_examples\n", + "\n", + " @torch.no_grad()\n", + " def test(data):\n", + " model.eval()\n", + " return model.test(\n", + " head_index=data.edge_index[0],\n", + " rel_type=data.edge_type,\n", + " tail_index=data.edge_index[1],\n", + " batch_size=1000,\n", + " k=10,\n", + " )\n", + "\n", + " # Consider increasing the number of epochs\n", + " epoch_count = 5\n", + " for epoch in range(1, epoch_count):\n", + " loss = train()\n", + " print(f\"Epoch: {epoch:03d}, Loss: {loss:.4f}\")\n", + " if epoch % 75 == 0:\n", + " rank, hits = test(val_tensor_data)\n", + " print(f\"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, \" f\"Val Hits@10: {hits:.4f}\")\n", + "\n", + " torch.save(model, f\"./model_{epoch_count}.pt\")\n", + "\n", + " mean_rank, mrr, hits_at_k = test(test_tensor_data)\n", + " print(f\"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}\")\n", + "\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = train_model_with_pyg()\n", + "# The model can be loaded if it was trained before\n", + "# model = torch.load(\"./model_501.pt\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Extract node embeddings from the trained model and put them into database." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for i in tqdm(range(len(nodeId_to_id))):\n", + " gds.run_cypher(\n", + " \"MATCH (n:Entity {id: $i}) SET n.emb=$EMBEDDING\",\n", + " params={\"i\": i, \"EMBEDDING\": model.node_emb.weight[i].tolist()},\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Predict Using GDS Knowledge Graph Edge Embeddings Functionality\n", + "\n", + "Select a relationship type for which to make predictions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "relationship_to_predict = \"/film/film/genre\"\n", + "rel_id_to_predict = rel_dict[relationship_to_predict]\n", + "rel_label_to_predict = f\"REL_{rel_id_to_predict}\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Project the graph with all nodes and existing relationships of the selected type." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "G_test, result = gds.graph.project(\n", + " \"graph_to_predict_\",\n", + " {\"Entity\": {\"properties\": [\"id\", \"emb\"]}},\n", + " rel_label_to_predict,\n", + ")\n", + "\n", + "\n", + "def print_graph_info(G):\n", + " print(f\"Graph '{G.name()}' node count: {G.node_count()}\")\n", + " print(f\"Graph '{G.name()}' node labels: {G.node_labels()}\")\n", + " print(f\"Graph '{G.name()}' relationship types: {G.relationship_types()}\")\n", + " print(f\"Graph '{G.name()}' relationship count: {G.relationship_count()}\")\n", + "\n", + "\n", + "print_graph_info(G_test)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Retrieve the embedding for the selected relationship from the PyG model. Then, create a GDS TransE model using the graph, node embeddings property, and the embedding for the relationship to be predicted." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target_emb = model.node_emb.weight[rel_id_to_predict].tolist()\n", + "transe_model = gds.model.transe.create(G_test, \"emb\", {rel_label_to_predict: target_emb})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "source_node_list = [\"/m/07l450\", \"/m/0ds2l81\", \"/m/0jvt9\"]\n", + "source_ids_df = gds.run_cypher(\n", + " \"UNWIND $node_text_list AS t MATCH (n:Entity) WHERE n.text=t RETURN id(n) as nodeId\",\n", + " params={\"node_text_list\": source_node_list},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Now, we can use the model to make prediction." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = transe_model.predict_stream(\n", + " source_node_filter=source_ids_df.nodeId,\n", + " target_node_filter=\"Entity\",\n", + " relationship_type=rel_label_to_predict,\n", + " top_k=3,\n", + " concurrency=4,\n", + ")\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Augment the predicted result with node identifiers and their text values." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ids_in_result = pd.unique(pd.concat([result.sourceNodeId, result.targetNodeId]))\n", + "\n", + "ids_to_text = gds.run_cypher(\n", + " \"UNWIND $ids AS id MATCH (n:Entity) WHERE id(n)=id RETURN id(n) AS nodeId, n.text AS tag, n.id AS id\",\n", + " params={\"ids\": ids_in_result},\n", + ")\n", + "\n", + "nodeId_to_text_res = dict(zip(ids_to_text.nodeId, ids_to_text.tag))\n", + "nodeId_to_id_res = dict(zip(ids_to_text.nodeId, ids_to_text.id))\n", + "\n", + "result.insert(1, \"sourceTag\", result.sourceNodeId.map(lambda x: nodeId_to_text_res[x]))\n", + "result.insert(2, \"sourceId\", result.sourceNodeId.map(lambda x: nodeId_to_id_res[x]))\n", + "result.insert(4, \"targetTag\", result.targetNodeId.map(lambda x: nodeId_to_text_res[x]))\n", + "result.insert(5, \"targetId\", result.targetNodeId.map(lambda x: nodeId_to_id_res[x]))\n", + "\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Using Write Mode\n", + "\n", + "Write mode allows you to write results directly to the database as a new relationship type. This approach helps to avoid mapping from `nodeId` to `id`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "write_relationship_type = \"PREDICTED_\" + rel_label_to_predict\n", + "result_write = transe_model.predict_write(\n", + " source_node_filter=source_ids_df.nodeId,\n", + " target_node_filter=\"Entity\",\n", + " relationship_type=rel_label_to_predict,\n", + " write_relationship_type=write_relationship_type,\n", + " write_property=\"transe_score\",\n", + " top_k=3,\n", + " concurrency=4,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Extract the result from the database." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gds.run_cypher(\n", + " \"MATCH (n)-[r:\"\n", + " + write_relationship_type\n", + " + \"]->(m) RETURN n.id AS sourceId, n.text AS sourceTag, m.id AS targetId, m.text AS targetTag, r.transe_score AS score\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gds.graph.drop(G_test)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/graphdatascience/graph_data_science.py b/graphdatascience/graph_data_science.py index b6f37bcb9..d75d19330 100644 --- a/graphdatascience/graph_data_science.py +++ b/graphdatascience/graph_data_science.py @@ -8,12 +8,12 @@ from .call_builder import IndirectCallBuilder from .endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints from .error.uncallable_namespace import UncallableNamespace -from .graph.graph_proc_runner import GraphProcRunner from .query_runner.arrow_query_runner import ArrowQueryRunner from .query_runner.neo4j_query_runner import Neo4jQueryRunner from .query_runner.query_runner import QueryRunner from .server_version.server_version import ServerVersion -from .utils.util_proc_runner import UtilProcRunner +from graphdatascience.graph.graph_proc_runner import GraphProcRunner +from graphdatascience.utils.util_proc_runner import UtilProcRunner class GraphDataScience(DirectEndpoints, UncallableNamespace): @@ -49,12 +49,11 @@ def __init__( database: Optional[str], default None The Neo4j database to query against. arrow : Union[str, bool], default True - Arrow connection information. This is either a string or a bool. - - - If it is a string, it will be interpreted as a connection URL to a GDS Arrow Server. - - If it is a bool: - - True will make the client discover the connection URI to the GDS Arrow server via the Neo4j endpoint. - - False will make the client use Bolt for all operations. + Arrow connection information. This is either a bool or a string. + If it is a string, it will be interpreted as a connection URL to a GDS Arrow Server. + If it is a bool, + True will make the client discover the connection URI to the GDS Arrow server via the Neo4j endpoint, + while False will make the client use Bolt for all operations. arrow_disable_server_verification : bool, default True A flag that overrides other TLS settings and disables server verification for TLS connections. arrow_tls_root_certs : Optional[bytes], default None diff --git a/graphdatascience/model/kge_runner.py b/graphdatascience/model/kge_runner.py new file mode 100644 index 000000000..328e5ce78 --- /dev/null +++ b/graphdatascience/model/kge_runner.py @@ -0,0 +1,109 @@ +import logging +import os +import time +from typing import Any, Dict, Optional + +import requests +from pandas import Series + +from ..error.client_only_endpoint import client_only_endpoint +from ..error.illegal_attr_checker import IllegalAttrChecker +from ..error.uncallable_namespace import UncallableNamespace +from ..graph.graph_object import Graph +from ..query_runner.query_runner import QueryRunner +from ..server_version.compatible_with import compatible_with +from ..server_version.server_version import ServerVersion + +logging.basicConfig(level=logging.INFO) + + +class KgeRunner(UncallableNamespace, IllegalAttrChecker): + def __init__( + self, + query_runner: QueryRunner, + namespace: str, + server_version: ServerVersion, + compute_cluster_ip: str, + encrypted_db_password: str, + arrow_uri: str, + ): + self._query_runner = query_runner + self._namespace = namespace + self._server_version = server_version + self._compute_cluster_web_uri = f"http://{compute_cluster_ip}:5005" + self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8815" + self._compute_cluster_mlflow_uri = f"http://{compute_cluster_ip}:8080" + self._encrypted_db_password = encrypted_db_password + self._arrow_uri = arrow_uri + + # @compatible_with(min_inclusive=ServerVersion(2, 5, 0)) + @client_only_endpoint("gds.kge") + def model(self): + print("!!!model") + return self + + # @compatible_with(min_inclusive=ServerVersion(2, 5, 0)) + @client_only_endpoint("gds.kge.model") + def train( + self, + G: Graph, + scoring_function, + num_epochs, + embedding_dimension, + mlflow_experiment_name: Optional[str] = None, + ) -> Series: + print("!!!train") + graph_config = {"name": G.name()} + + algo_config = { + "scoring_function": scoring_function, + "num_epochs": num_epochs, + "embedding_dimension": embedding_dimension, + } + + config = { + "user_name": "DUMMY_USER", + "task": "KGE_TRAINING_PYG", + "task_config": { + "graph_config": graph_config, + "task_config": algo_config, + }, + "encrypted_db_password": self._encrypted_db_password, + "graph_arrow_uri": self._arrow_uri, + } + + if mlflow_experiment_name is not None: + config["task_config"]["mlflow"] = { + "config": {"tracking_uri": self._compute_cluster_mlflow_uri, "experiment_name": mlflow_experiment_name} + } + + job_id = self._start_job(config) + + self._wait_for_job(job_id) + + return Series({"status": "finished"}) + + def _start_job(self, config: Dict[str, Any]) -> str: + res = requests.post(f"{self._compute_cluster_web_uri}/api/machine-learning/start", json=config) + res.raise_for_status() + job_id = res.json()["job_id"] + logging.info(f"Job with ID '{job_id}' started") + + return job_id + + def _wait_for_job(self, job_id: str) -> None: + while True: + time.sleep(1) + + res = requests.get(f"{self._compute_cluster_web_uri}/api/machine-learning/status/{job_id}") + + res_json = res.json() + if res_json["job_status"] == "exited": + logging.info("KGE job completed!") + return + elif res_json["job_status"] == "failed": + error = f"KGE job failed with errors:{os.linesep}{os.linesep.join(res_json['errors'])}" + if res.status_code == 400: + raise ValueError(error) + else: + raise RuntimeError(error) diff --git a/graphdatascience/resources/field-testing/__init__.py b/graphdatascience/resources/field-testing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/resources/field-testing/pub.pem b/graphdatascience/resources/field-testing/pub.pem new file mode 100644 index 000000000..0a3519e2b --- /dev/null +++ b/graphdatascience/resources/field-testing/pub.pem @@ -0,0 +1,4 @@ +-----BEGIN RSA PUBLIC KEY----- +MEgCQQDNfbk2/PGneqZO6Vx9VbPe6ZnQJ/F5kOOW07jGDU34NFfUI06Nw0HmwT2h +c9s3nZTUUlAVi/aUCl3b4NcB8vThAgMBAAE= +-----END RSA PUBLIC KEY----- From ebfc92fb2a5f5223f42483b98fe7c975b8e49a95 Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Wed, 3 Jul 2024 10:08:56 +0100 Subject: [PATCH 03/24] Fix problem with KgeRunner --- examples/kge-distmult.py | 352 +++++++++++++++++++++++++ graphdatascience/graph_data_science.py | 67 ++++- graphdatascience/model/kge_runner.py | 7 +- 3 files changed, 419 insertions(+), 7 deletions(-) create mode 100644 examples/kge-distmult.py diff --git a/examples/kge-distmult.py b/examples/kge-distmult.py new file mode 100644 index 000000000..7cd9edf2a --- /dev/null +++ b/examples/kge-distmult.py @@ -0,0 +1,352 @@ +import collections +import os + +from neo4j.exceptions import ClientError +from tqdm import tqdm + +from graphdatascience import GraphDataScience + +NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687") +NEO4J_AUTH = None +NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j") +if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"): + NEO4J_AUTH = ( + os.environ.get("NEO4J_USER"), + os.environ.get("NEO4J_PASSWORD"), + ) +gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True) + + +try: + _ = gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE") +except ClientError: + print("CONSTRAINT entity_id already exists") + +import os +import zipfile +from collections import defaultdict + +from ogb.utils.url import download_url + +url = "https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip" +raw_dir = "./data_from_zip" +download_url(f"{url}", raw_dir) + +raw_file_names = ["train.txt", "valid.txt", "test.txt"] +with zipfile.ZipFile(raw_dir + "/" + os.path.basename(url), "r") as zip_ref: + for filename in raw_file_names: + zip_ref.extract(f"Release/{filename}", path=raw_dir) +data_dir = raw_dir + "/" + "Release" + +rel_types = { + "train.txt": "TRAIN", + "valid.txt": "VALID", + "test.txt": "TEST", +} +rel_id_to_text_dict = {} +rel_type_dict = collections.defaultdict(list) +rel_dict = {} + + +def read_data(): + node_id_set = {} + dataset = defaultdict(lambda: defaultdict(list)) + for file_name in raw_file_names: + file_name_path = data_dir + "/" + file_name + + with open(file_name_path, "r") as f: + data = [x.split("\t") for x in f.read().split("\n")[:-1]] + + for i, (src_text, rel_text, dst_text) in enumerate(data): + if src_text not in node_id_set: + node_id_set[src_text] = len(node_id_set) + if dst_text not in node_id_set: + node_id_set[dst_text] = len(node_id_set) + if rel_text not in rel_dict: + rel_dict[rel_text] = len(rel_dict) + rel_id_to_text_dict[rel_dict[rel_text]] = rel_text + + source = node_id_set[src_text] + target = node_id_set[dst_text] + rel_type = "REL_" + str(rel_dict[rel_text]) + rel_split = rel_types[file_name] + + dataset[rel_split][rel_type].append( + { + "source": source, + "source_text": src_text, + "target": target, + "target_text": dst_text, + # "rel_text": rel_text, + } + ) + + print("Number of nodes: ", len(node_id_set)) + for rel_split in dataset: + print( + f"Number of relationships of type {rel_split}: ", + sum([len(dataset[rel_split][rel_type]) for rel_type in dataset[rel_split]]), + ) + return dataset + + +dataset = read_data() + + +def put_data_in_db(dataset): + for rel_split in tqdm(dataset, desc="Relationship"): + for rel_type in tqdm(dataset[rel_split], mininterval=1, leave=False): + edges = dataset[rel_split][rel_type] + + # MERGE (n)-[:{rel_type} {{text:l.rel_text}}]->(m) + gds.run_cypher( + f""" + UNWIND $ll as l + MERGE (n:Entity {{id:l.source, text:l.source_text}}) + MERGE (m:Entity {{id:l.target, text:l.target_text}}) + MERGE (n)-[:{rel_split}]->(m) + MERGE (n)-[:{rel_type}]->(m) + """, + params={"ll": edges}, + ) + + for rel_split in dataset: + res = gds.run_cypher( + f""" + MATCH ()-[r:{rel_split}]->() + RETURN COUNT(r) AS numberOfRelationships + """ + ) + print(f"Number of relationships of type {rel_split} in db: ", res.numberOfRelationships) + + +# put_data_in_db(dataset) + +ALL_RELS = dataset["TRAIN"].keys() +gds.graph.drop("trainGraph", failIfMissing=False) +G_train, result = gds.graph.cypher.project( + """ + MATCH (n:Entity)-[:TRAIN]->(m:Entity)<-[:""" + + "|".join(ALL_RELS) + + """]-(n) + RETURN gds.graph.project($graph_name, n, m, { + sourceNodeLabels: $label, + targetNodeLabels: $label + }) + """, # Cypher query + database="neo4j", # Target database + graph_name="trainGraph", # Query parameter + label="Entity", # Query parameter +) + + +def inspect_graph(G): + func_names = [ + "name", + # "database", + "node_count", + "relationship_count", + "node_labels", + "relationship_types", + # "degree_distribution", "density", "size_in_bytes", "memory_usage", "exists", "configuration", "creation_time", "modification_time", + ] + for func_name in func_names: + print(f"==={func_name}===: {getattr(G, func_name)()}") + + +inspect_graph(G_train) + +gds.set_compute_cluster_ip("localhost") + +kkge = gds.kge + +gds.kge.model.train( + G_train, + scoring_function="distmult", + num_epochs=10, + embedding_dimension=100, +) +# +# node_projection = {"Entity": {"properties": "id"}} +# relationship_projection = [ +# {"TRAIN": {"orientation": "NATURAL", "properties": "rel_id"}}, +# {"TEST": {"orientation": "NATURAL", "properties": "rel_id"}}, +# {"VALID": {"orientation": "NATURAL", "properties": "rel_id"}}, +# ] +# +# ttv_G, result = gds.graph.project( +# "fb15k-graph-ttv", +# node_projection, +# relationship_projection, +# ) +# +# node_properties = gds.graph.nodeProperties.stream( +# ttv_G, +# ["id"], +# separate_property_columns=True, +# ) +# +# nodeId_to_id = dict(zip(node_properties.nodeId, node_properties.id)) +# id_to_nodeId = dict(zip(node_properties.id, node_properties.nodeId)) +# +# def create_data_from_graph(relationship_type): +# rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, "rel_id", relationship_type) +# topology = [ +# rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]), +# rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]), +# ] +# edge_index = torch.tensor(topology, dtype=torch.long) +# edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long) +# data = Data(edge_index=edge_index, edge_type=edge_type) +# data.num_nodes = len(nodeId_to_id) +# display(data) +# return data +# +# +# train_tensor_data = create_data_from_graph("TRAIN") +# test_tensor_data = create_data_from_graph("TEST") +# val_tensor_data = create_data_from_graph("VALID") +# +# gds.graph.drop(ttv_G) +# +# def train_model_with_pyg(): +# device = "cuda" if torch.cuda.is_available() else "cpu" +# +# model = TransE( +# num_nodes=train_tensor_data.num_nodes, +# num_relations=train_tensor_data.num_edge_types, +# hidden_channels=50, +# ).to(device) +# +# loader = model.loader( +# head_index=train_tensor_data.edge_index[0], +# rel_type=train_tensor_data.edge_type, +# tail_index=train_tensor_data.edge_index[1], +# batch_size=1000, +# shuffle=True, +# ) +# +# optimizer = optim.Adam(model.parameters(), lr=0.01) +# +# def train(): +# model.train() +# total_loss = total_examples = 0 +# for head_index, rel_type, tail_index in loader: +# optimizer.zero_grad() +# loss = model.loss(head_index, rel_type, tail_index) +# loss.backward() +# optimizer.step() +# total_loss += float(loss) * head_index.numel() +# total_examples += head_index.numel() +# return total_loss / total_examples +# +# @torch.no_grad() +# def test(data): +# model.eval() +# return model.test( +# head_index=data.edge_index[0], +# rel_type=data.edge_type, +# tail_index=data.edge_index[1], +# batch_size=1000, +# k=10, +# ) +# +# # Consider increasing the number of epochs +# epoch_count = 5 +# for epoch in range(1, epoch_count): +# loss = train() +# print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}") +# if epoch % 75 == 0: +# rank, hits = test(val_tensor_data) +# print(f"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, " f"Val Hits@10: {hits:.4f}") +# +# torch.save(model, f"./model_{epoch_count}.pt") +# +# mean_rank, mrr, hits_at_k = test(test_tensor_data) +# print(f"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}") +# +# return model +# +# model = train_model_with_pyg() +# # The model can be loaded if it was trained before +# # model = torch.load("./model_501.pt") +# +# for i in tqdm(range(len(nodeId_to_id))): +# gds.run_cypher( +# "MATCH (n:Entity {id: $i}) SET n.emb=$EMBEDDING", +# params={"i": i, "EMBEDDING": model.node_emb.weight[i].tolist()}, +# ) +# +# relationship_to_predict = "/film/film/genre" +# rel_id_to_predict = rel_dict[relationship_to_predict] +# rel_label_to_predict = f"REL_{rel_id_to_predict}" +# +# G_test, result = gds.graph.project( +# "graph_to_predict_", +# {"Entity": {"properties": ["id", "emb"]}}, +# rel_label_to_predict, +# ) +# +# +# def print_graph_info(G): +# print(f"Graph '{G.name()}' node count: {G.node_count()}") +# print(f"Graph '{G.name()}' node labels: {G.node_labels()}") +# print(f"Graph '{G.name()}' relationship types: {G.relationship_types()}") +# print(f"Graph '{G.name()}' relationship count: {G.relationship_count()}") +# +# +# print_graph_info(G_test) +# +# target_emb = model.node_emb.weight[rel_id_to_predict].tolist() +# transe_model = gds.model.transe.create(G_test, "emb", {rel_label_to_predict: target_emb}) +# +# source_node_list = ["/m/07l450", "/m/0ds2l81", "/m/0jvt9"] +# source_ids_df = gds.run_cypher( +# "UNWIND $node_text_list AS t MATCH (n:Entity) WHERE n.text=t RETURN id(n) as nodeId", +# params={"node_text_list": source_node_list}, +# ) +# +# result = transe_model.predict_stream( +# source_node_filter=source_ids_df.nodeId, +# target_node_filter="Entity", +# relationship_type=rel_label_to_predict, +# top_k=3, +# concurrency=4, +# ) +# print(result) +# +# ids_in_result = pd.unique(pd.concat([result.sourceNodeId, result.targetNodeId])) +# +# ids_to_text = gds.run_cypher( +# "UNWIND $ids AS id MATCH (n:Entity) WHERE id(n)=id RETURN id(n) AS nodeId, n.text AS tag, n.id AS id", +# params={"ids": ids_in_result}, +# ) +# +# nodeId_to_text_res = dict(zip(ids_to_text.nodeId, ids_to_text.tag)) +# nodeId_to_id_res = dict(zip(ids_to_text.nodeId, ids_to_text.id)) +# +# result.insert(1, "sourceTag", result.sourceNodeId.map(lambda x: nodeId_to_text_res[x])) +# result.insert(2, "sourceId", result.sourceNodeId.map(lambda x: nodeId_to_id_res[x])) +# result.insert(4, "targetTag", result.targetNodeId.map(lambda x: nodeId_to_text_res[x])) +# result.insert(5, "targetId", result.targetNodeId.map(lambda x: nodeId_to_id_res[x])) +# +# print(result) +# +# write_relationship_type = "PREDICTED_" + rel_label_to_predict +# result_write = transe_model.predict_write( +# source_node_filter=source_ids_df.nodeId, +# target_node_filter="Entity", +# relationship_type=rel_label_to_predict, +# write_relationship_type=write_relationship_type, +# write_property="transe_score", +# top_k=3, +# concurrency=4, +# ) +# +# gds.run_cypher( +# "MATCH (n)-[r:" +# + write_relationship_type +# + "]->(m) RETURN n.id AS sourceId, n.text AS sourceTag, m.id AS targetId, m.text AS targetTag, r.transe_score AS score" +# ) +# +# gds.graph.drop(G_test) diff --git a/graphdatascience/graph_data_science.py b/graphdatascience/graph_data_science.py index d75d19330..c983006c6 100644 --- a/graphdatascience/graph_data_science.py +++ b/graphdatascience/graph_data_science.py @@ -1,13 +1,18 @@ from __future__ import annotations +import pathlib +import sys from typing import Any, Dict, Optional, Tuple, Type, Union +import rsa from neo4j import Driver from pandas import DataFrame from .call_builder import IndirectCallBuilder from .endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints from .error.uncallable_namespace import UncallableNamespace +from .model.fastpath_runner import FastPathRunner +from .model.kge_runner import KgeRunner from .query_runner.arrow_query_runner import ArrowQueryRunner from .query_runner.neo4j_query_runner import Neo4jQueryRunner from .query_runner.query_runner import QueryRunner @@ -82,16 +87,35 @@ def __init__( None if arrow is True else arrow, ) + if auth is not None: + with open(self._path("graphdatascience.resources.field-testing", "pub.pem"), "rb") as f: + pub_key = rsa.PublicKey.load_pkcs1(f.read()) + self._encrypted_db_password = rsa.encrypt(auth[1].encode(), pub_key).hex() + + self._compute_cluster_ip = None + super().__init__(self._query_runner, "gds", self._server_version) + def set_compute_cluster_ip(self, ip: str) -> None: + self._compute_cluster_ip = ip + + @staticmethod + def _path(package: str, resource: str) -> pathlib.Path: + if sys.version_info >= (3, 9): + from importlib.resources import files + + # files() returns a Traversable, but usages require a Path object + return pathlib.Path(str(files(package) / resource)) + else: + from importlib.resources import path + + # we dont want to use a context manager here, so we need to call __enter__ manually + return path(package, resource).__enter__() + @property def graph(self) -> GraphProcRunner: return GraphProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version) - @property - def util(self) -> UtilProcRunner: - return UtilProcRunner(self._query_runner, f"{self._namespace}.util", self._server_version) - @property def alpha(self) -> AlphaEndpoints: return AlphaEndpoints(self._query_runner, "gds.alpha", self._server_version) @@ -100,6 +124,41 @@ def alpha(self) -> AlphaEndpoints: def beta(self) -> BetaEndpoints: return BetaEndpoints(self._query_runner, "gds.beta", self._server_version) + @property + def fastpath(self) -> FastPathRunner: + if not isinstance(self._query_runner, ArrowQueryRunner): + raise ValueError("Running FastPath requires GDS with the Arrow server enabled") + if self._compute_cluster_ip is None: + raise ValueError( + "You must set a valid computer cluster ip with the method `set_compute_cluster_ip` to use this feature" + ) + return FastPathRunner( + self._query_runner, + "gds.fastpath", + self._server_version, + self._compute_cluster_ip, + self._encrypted_db_password, + self._query_runner.uri, + ) + + @property + def kge(self) -> KgeRunner: + print("!!!kge") + # if not isinstance(self._query_runner, ArrowQueryRunner): + # raise ValueError("Running FastPath requires GDS with the Arrow server enabled") + if self._compute_cluster_ip is None: + raise ValueError( + "You must set a valid computer cluster ip with the method `set_compute_cluster_ip` to use this feature" + ) + return KgeRunner( + self._query_runner, + "gds.kge", + self._server_version, + self._compute_cluster_ip, + self._encrypted_db_password, + self._query_runner.uri, + ) + def __getattr__(self, attr: str) -> IndirectCallBuilder: return IndirectCallBuilder(self._query_runner, f"gds.{attr}", self._server_version) diff --git a/graphdatascience/model/kge_runner.py b/graphdatascience/model/kge_runner.py index 328e5ce78..3d0a6c7d6 100644 --- a/graphdatascience/model/kge_runner.py +++ b/graphdatascience/model/kge_runner.py @@ -27,6 +27,7 @@ def __init__( encrypted_db_password: str, arrow_uri: str, ): + print("!init", flush=True) self._query_runner = query_runner self._namespace = namespace self._server_version = server_version @@ -36,13 +37,13 @@ def __init__( self._encrypted_db_password = encrypted_db_password self._arrow_uri = arrow_uri - # @compatible_with(min_inclusive=ServerVersion(2, 5, 0)) @client_only_endpoint("gds.kge") def model(self): - print("!!!model") + print("!model") return self - # @compatible_with(min_inclusive=ServerVersion(2, 5, 0)) + # @client_only_endpoint("gds.kge.model") and name is train + # @compatible_with("stream", min_inclusive=ServerVersion(2, 5, 0)) @client_only_endpoint("gds.kge.model") def train( self, From cea569ee768005e50c284c6002dc00fced0bb905 Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Mon, 8 Jul 2024 11:34:30 +0100 Subject: [PATCH 04/24] Next --- examples/kge-distmult.py | 24 ++++++++++---- graphdatascience/graph_data_science.py | 18 +++++----- graphdatascience/model/kge_runner.py | 11 +++---- .../tests/integration/test_graph_construct.py | 33 +++++++++++++++++++ .../tests/integration/test_graph_ops.py | 4 ++- 5 files changed, 68 insertions(+), 22 deletions(-) diff --git a/examples/kge-distmult.py b/examples/kge-distmult.py index 7cd9edf2a..77a247efd 100644 --- a/examples/kge-distmult.py +++ b/examples/kge-distmult.py @@ -94,8 +94,13 @@ def read_data(): def put_data_in_db(dataset): - for rel_split in tqdm(dataset, desc="Relationship"): - for rel_type in tqdm(dataset[rel_split], mininterval=1, leave=False): + res = gds.run_cypher("MATCH (m) RETURN count(m) as num_nodes") + if res['num_nodes'].values[0] > 0: + print("Data already in db, number of nodes: ", res['num_nodes'].values[0]) + return + pbar = tqdm(desc='Putting data in db', total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]])) + for rel_split in dataset: + for rel_type in dataset[rel_split]: edges = dataset[rel_split][rel_type] # MERGE (n)-[:{rel_type} {{text:l.rel_text}}]->(m) @@ -109,6 +114,8 @@ def put_data_in_db(dataset): """, params={"ll": edges}, ) + pbar.update(len(edges)) + pbar.close() for rel_split in dataset: res = gds.run_cypher( @@ -120,7 +127,7 @@ def put_data_in_db(dataset): print(f"Number of relationships of type {rel_split} in db: ", res.numberOfRelationships) -# put_data_in_db(dataset) +put_data_in_db(dataset) ALL_RELS = dataset["TRAIN"].keys() gds.graph.drop("trainGraph", failIfMissing=False) @@ -159,13 +166,18 @@ def inspect_graph(G): gds.set_compute_cluster_ip("localhost") kkge = gds.kge +kmodel = gds.kge.model + +print(gds.debug.arrow()) gds.kge.model.train( G_train, - scoring_function="distmult", - num_epochs=10, - embedding_dimension=100, + scoring_function="DistMult", + num_epochs=1, + embedding_dimension=10, ) + +print('Finished training') # # node_projection = {"Entity": {"properties": "id"}} # relationship_projection = [ diff --git a/graphdatascience/graph_data_science.py b/graphdatascience/graph_data_science.py index c983006c6..4599442fd 100644 --- a/graphdatascience/graph_data_science.py +++ b/graphdatascience/graph_data_science.py @@ -87,10 +87,10 @@ def __init__( None if arrow is True else arrow, ) - if auth is not None: - with open(self._path("graphdatascience.resources.field-testing", "pub.pem"), "rb") as f: - pub_key = rsa.PublicKey.load_pkcs1(f.read()) - self._encrypted_db_password = rsa.encrypt(auth[1].encode(), pub_key).hex() + # if auth is not None: + # with open(self._path("graphdatascience.resources.field-testing", "pub.pem"), "rb") as f: + # pub_key = rsa.PublicKey.load_pkcs1(f.read()) + # self._encrypted_db_password = rsa.encrypt(auth[1].encode(), pub_key).hex() self._compute_cluster_ip = None @@ -143,20 +143,20 @@ def fastpath(self) -> FastPathRunner: @property def kge(self) -> KgeRunner: - print("!!!kge") - # if not isinstance(self._query_runner, ArrowQueryRunner): - # raise ValueError("Running FastPath requires GDS with the Arrow server enabled") + print("!kge") + if not isinstance(self._query_runner, ArrowQueryRunner): + raise ValueError("Running FastPath requires GDS with the Arrow server enabled") if self._compute_cluster_ip is None: raise ValueError( "You must set a valid computer cluster ip with the method `set_compute_cluster_ip` to use this feature" ) return KgeRunner( self._query_runner, - "gds.kge", + "gds.kge.model", self._server_version, self._compute_cluster_ip, self._encrypted_db_password, - self._query_runner.uri, + None, ) def __getattr__(self, attr: str) -> IndirectCallBuilder: diff --git a/graphdatascience/model/kge_runner.py b/graphdatascience/model/kge_runner.py index 3d0a6c7d6..8e6fe4a45 100644 --- a/graphdatascience/model/kge_runner.py +++ b/graphdatascience/model/kge_runner.py @@ -27,22 +27,21 @@ def __init__( encrypted_db_password: str, arrow_uri: str, ): - print("!init", flush=True) self._query_runner = query_runner self._namespace = namespace self._server_version = server_version self._compute_cluster_web_uri = f"http://{compute_cluster_ip}:5005" - self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8815" + # self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8491" self._compute_cluster_mlflow_uri = f"http://{compute_cluster_ip}:8080" self._encrypted_db_password = encrypted_db_password self._arrow_uri = arrow_uri + print("KgeRunner __dict__:") + print(self.__dict__) - @client_only_endpoint("gds.kge") + @property def model(self): - print("!model") return self - # @client_only_endpoint("gds.kge.model") and name is train # @compatible_with("stream", min_inclusive=ServerVersion(2, 5, 0)) @client_only_endpoint("gds.kge.model") def train( @@ -53,7 +52,6 @@ def train( embedding_dimension, mlflow_experiment_name: Optional[str] = None, ) -> Series: - print("!!!train") graph_config = {"name": G.name()} algo_config = { @@ -85,6 +83,7 @@ def train( return Series({"status": "finished"}) def _start_job(self, config: Dict[str, Any]) -> str: + print(config) res = requests.post(f"{self._compute_cluster_web_uri}/api/machine-learning/start", json=config) res.raise_for_status() job_id = res.json()["job_id"] diff --git a/graphdatascience/tests/integration/test_graph_construct.py b/graphdatascience/tests/integration/test_graph_construct.py index 97dc85c1a..82a0da3ca 100644 --- a/graphdatascience/tests/integration/test_graph_construct.py +++ b/graphdatascience/tests/integration/test_graph_construct.py @@ -558,3 +558,36 @@ def test_graph_alpha_construct_backward_compat_with_arrow(gds: GraphDataScience) with pytest.warns(DeprecationWarning): gds.alpha.graph.construct("hello", nodes, relationships) + +@pytest.mark.enterprise +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 1, 0)) +def test_graph_alpha_construct_backward_compat_with_arrow(gds: GraphDataScience) -> None: + nodes = DataFrame({"nodeId": [0, 1, 2, 3]}) + relationships = DataFrame({"sourceNodeId": [0, 1, 2, 3], "targetNodeId": [1, 2, 3, 0]}) + + with pytest.warns(DeprecationWarning): + gds.alpha.graph.construct("hello", nodes, relationships) + + +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0)) +def test_roundtrip_with_arrow(gds: GraphDataScience) -> None: + G, _ = gds.graph.project(GRAPH_NAME, {"Node": {"properties": ["x", "y"]}}, {"REL": {"properties": "relX"}}) + + rel_df = gds.graph.relationshipProperty.stream(G, "relX") + node_df = gds.graph.nodeProperty.stream(G, "x") + + G_2 = gds.graph.construct("arrowGraph", node_df, rel_df) + + res = gds.graph.list() + try: + assert set(res['graphName'].tolist()) == {'g', 'arrowGraph'} + assert G.node_count() == G_2.node_count() + assert G.relationship_count() == G_2.relationship_count() + finally: + G_2.drop() + +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0)) +def test_drop_list_warning_reproduction(gds: GraphDataScience) -> None: + G, _ = gds.graph.project(GRAPH_NAME, {"Node": {"properties": ["x", "y"]}}, {"REL": {"properties": "relX"}}) + res = gds.graph.list() + assert res['graphName'].tolist() == ['g'] diff --git a/graphdatascience/tests/integration/test_graph_ops.py b/graphdatascience/tests/integration/test_graph_ops.py index e2b077cf7..00d32eb8e 100644 --- a/graphdatascience/tests/integration/test_graph_ops.py +++ b/graphdatascience/tests/integration/test_graph_ops.py @@ -854,7 +854,7 @@ def test_graph_relationships_stream_without_arrow(gds_without_arrow: GraphDataSc @pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0)) def test_graph_relationships_stream_with_arrow(gds: GraphDataScience) -> None: - G, _ = gds.graph.project(GRAPH_NAME, "*", ["REL", "REL2"]) + G, _ = gds.graph.project(GRAPH_NAME, "*", ["REL_0", "REL2"]) if gds.server_version() >= ServerVersion(2, 5, 0): result = gds.graph.relationships.stream(G, ["REL", "REL2"]) @@ -1058,3 +1058,5 @@ def test_empty_relationships_stream(gds: GraphDataScience) -> None: result = gds.graph.relationships.stream(G, ["SIMILAR"]) assert result.empty + + From 85bad98ad024321ce863ca77c42ead772e96d77a Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Mon, 8 Jul 2024 18:16:46 +0100 Subject: [PATCH 05/24] Works, dummy encrypted db password --- examples/kge-distmult.py | 347 ++++++------------------- graphdatascience/graph_data_science.py | 3 +- graphdatascience/model/kge_runner.py | 11 +- 3 files changed, 94 insertions(+), 267 deletions(-) diff --git a/examples/kge-distmult.py b/examples/kge-distmult.py index 77a247efd..5ae04bd4c 100644 --- a/examples/kge-distmult.py +++ b/examples/kge-distmult.py @@ -1,54 +1,64 @@ -import collections import os +import warnings +from collections import defaultdict +from graphdatascience import GraphDataScience from neo4j.exceptions import ClientError from tqdm import tqdm -from graphdatascience import GraphDataScience +warnings.filterwarnings("ignore", category=DeprecationWarning) -NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687") -NEO4J_AUTH = None -NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j") -if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"): - NEO4J_AUTH = ( - os.environ.get("NEO4J_USER"), - os.environ.get("NEO4J_PASSWORD"), - ) -gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True) +def setup_connection(): + NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687") + NEO4J_AUTH = None + NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j") + if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"): + NEO4J_AUTH = ( + os.environ.get("NEO4J_USER"), + os.environ.get("NEO4J_PASSWORD"), + ) + gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True) -try: - _ = gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE") -except ClientError: - print("CONSTRAINT entity_id already exists") + return gds + + +def create_constraint(gds): + try: + _ = gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE") + except ClientError: + print("CONSTRAINT entity_id already exists") -import os -import zipfile -from collections import defaultdict -from ogb.utils.url import download_url +def download_data(raw_file_names): + import os + import zipfile -url = "https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip" -raw_dir = "./data_from_zip" -download_url(f"{url}", raw_dir) + from ogb.utils.url import download_url -raw_file_names = ["train.txt", "valid.txt", "test.txt"] -with zipfile.ZipFile(raw_dir + "/" + os.path.basename(url), "r") as zip_ref: - for filename in raw_file_names: - zip_ref.extract(f"Release/{filename}", path=raw_dir) -data_dir = raw_dir + "/" + "Release" + url = "https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip" + raw_dir = "./data_from_zip" + download_url(f"{url}", raw_dir) -rel_types = { - "train.txt": "TRAIN", - "valid.txt": "VALID", - "test.txt": "TEST", -} -rel_id_to_text_dict = {} -rel_type_dict = collections.defaultdict(list) -rel_dict = {} + with zipfile.ZipFile(raw_dir + "/" + os.path.basename(url), "r") as zip_ref: + for filename in raw_file_names: + zip_ref.extract(f"Release/{filename}", path=raw_dir) + data_dir = raw_dir + "/" + "Release" + return data_dir def read_data(): + rel_types = { + "train.txt": "TRAIN", + "valid.txt": "VALID", + "test.txt": "TEST", + } + raw_file_names = ["train.txt", "valid.txt", "test.txt"] + + data_dir = download_data(raw_file_names) + + rel_id_to_text_dict = {} + rel_dict = {} node_id_set = {} dataset = defaultdict(lambda: defaultdict(list)) for file_name in raw_file_names: @@ -90,15 +100,16 @@ def read_data(): return dataset -dataset = read_data() - - -def put_data_in_db(dataset): +def put_data_in_db(gds): res = gds.run_cypher("MATCH (m) RETURN count(m) as num_nodes") - if res['num_nodes'].values[0] > 0: - print("Data already in db, number of nodes: ", res['num_nodes'].values[0]) + if res["num_nodes"].values[0] > 0: + print("Data already in db, number of nodes: ", res["num_nodes"].values[0]) return - pbar = tqdm(desc='Putting data in db', total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]])) + dataset = read_data() + pbar = tqdm( + desc="Putting data in db", + total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]), + ) for rel_split in dataset: for rel_type in dataset[rel_split]: edges = dataset[rel_split][rel_type] @@ -127,238 +138,50 @@ def put_data_in_db(dataset): print(f"Number of relationships of type {rel_split} in db: ", res.numberOfRelationships) -put_data_in_db(dataset) - -ALL_RELS = dataset["TRAIN"].keys() -gds.graph.drop("trainGraph", failIfMissing=False) -G_train, result = gds.graph.cypher.project( +def project_train_graph(gds): + all_rels = gds.run_cypher( + """ + CALL db.relationshipTypes() YIELD relationshipType """ - MATCH (n:Entity)-[:TRAIN]->(m:Entity)<-[:""" - + "|".join(ALL_RELS) - + """]-(n) - RETURN gds.graph.project($graph_name, n, m, { - sourceNodeLabels: $label, - targetNodeLabels: $label - }) - """, # Cypher query - database="neo4j", # Target database - graph_name="trainGraph", # Query parameter - label="Entity", # Query parameter -) + ) + all_rels = all_rels["relationshipType"].to_list() + all_rels = [rel for rel in all_rels if rel.startswith("REL_")] + gds.graph.drop("trainGraph", failIfMissing=False) + + G_train, result = gds.graph.project("trainGraph", ["Entity"], all_rels) + + return G_train def inspect_graph(G): func_names = [ "name", - # "database", "node_count", "relationship_count", "node_labels", "relationship_types", - # "degree_distribution", "density", "size_in_bytes", "memory_usage", "exists", "configuration", "creation_time", "modification_time", ] for func_name in func_names: print(f"==={func_name}===: {getattr(G, func_name)()}") -inspect_graph(G_train) - -gds.set_compute_cluster_ip("localhost") - -kkge = gds.kge -kmodel = gds.kge.model - -print(gds.debug.arrow()) - -gds.kge.model.train( - G_train, - scoring_function="DistMult", - num_epochs=1, - embedding_dimension=10, -) - -print('Finished training') -# -# node_projection = {"Entity": {"properties": "id"}} -# relationship_projection = [ -# {"TRAIN": {"orientation": "NATURAL", "properties": "rel_id"}}, -# {"TEST": {"orientation": "NATURAL", "properties": "rel_id"}}, -# {"VALID": {"orientation": "NATURAL", "properties": "rel_id"}}, -# ] -# -# ttv_G, result = gds.graph.project( -# "fb15k-graph-ttv", -# node_projection, -# relationship_projection, -# ) -# -# node_properties = gds.graph.nodeProperties.stream( -# ttv_G, -# ["id"], -# separate_property_columns=True, -# ) -# -# nodeId_to_id = dict(zip(node_properties.nodeId, node_properties.id)) -# id_to_nodeId = dict(zip(node_properties.id, node_properties.nodeId)) -# -# def create_data_from_graph(relationship_type): -# rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, "rel_id", relationship_type) -# topology = [ -# rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]), -# rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]), -# ] -# edge_index = torch.tensor(topology, dtype=torch.long) -# edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long) -# data = Data(edge_index=edge_index, edge_type=edge_type) -# data.num_nodes = len(nodeId_to_id) -# display(data) -# return data -# -# -# train_tensor_data = create_data_from_graph("TRAIN") -# test_tensor_data = create_data_from_graph("TEST") -# val_tensor_data = create_data_from_graph("VALID") -# -# gds.graph.drop(ttv_G) -# -# def train_model_with_pyg(): -# device = "cuda" if torch.cuda.is_available() else "cpu" -# -# model = TransE( -# num_nodes=train_tensor_data.num_nodes, -# num_relations=train_tensor_data.num_edge_types, -# hidden_channels=50, -# ).to(device) -# -# loader = model.loader( -# head_index=train_tensor_data.edge_index[0], -# rel_type=train_tensor_data.edge_type, -# tail_index=train_tensor_data.edge_index[1], -# batch_size=1000, -# shuffle=True, -# ) -# -# optimizer = optim.Adam(model.parameters(), lr=0.01) -# -# def train(): -# model.train() -# total_loss = total_examples = 0 -# for head_index, rel_type, tail_index in loader: -# optimizer.zero_grad() -# loss = model.loss(head_index, rel_type, tail_index) -# loss.backward() -# optimizer.step() -# total_loss += float(loss) * head_index.numel() -# total_examples += head_index.numel() -# return total_loss / total_examples -# -# @torch.no_grad() -# def test(data): -# model.eval() -# return model.test( -# head_index=data.edge_index[0], -# rel_type=data.edge_type, -# tail_index=data.edge_index[1], -# batch_size=1000, -# k=10, -# ) -# -# # Consider increasing the number of epochs -# epoch_count = 5 -# for epoch in range(1, epoch_count): -# loss = train() -# print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}") -# if epoch % 75 == 0: -# rank, hits = test(val_tensor_data) -# print(f"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, " f"Val Hits@10: {hits:.4f}") -# -# torch.save(model, f"./model_{epoch_count}.pt") -# -# mean_rank, mrr, hits_at_k = test(test_tensor_data) -# print(f"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}") -# -# return model -# -# model = train_model_with_pyg() -# # The model can be loaded if it was trained before -# # model = torch.load("./model_501.pt") -# -# for i in tqdm(range(len(nodeId_to_id))): -# gds.run_cypher( -# "MATCH (n:Entity {id: $i}) SET n.emb=$EMBEDDING", -# params={"i": i, "EMBEDDING": model.node_emb.weight[i].tolist()}, -# ) -# -# relationship_to_predict = "/film/film/genre" -# rel_id_to_predict = rel_dict[relationship_to_predict] -# rel_label_to_predict = f"REL_{rel_id_to_predict}" -# -# G_test, result = gds.graph.project( -# "graph_to_predict_", -# {"Entity": {"properties": ["id", "emb"]}}, -# rel_label_to_predict, -# ) -# -# -# def print_graph_info(G): -# print(f"Graph '{G.name()}' node count: {G.node_count()}") -# print(f"Graph '{G.name()}' node labels: {G.node_labels()}") -# print(f"Graph '{G.name()}' relationship types: {G.relationship_types()}") -# print(f"Graph '{G.name()}' relationship count: {G.relationship_count()}") -# -# -# print_graph_info(G_test) -# -# target_emb = model.node_emb.weight[rel_id_to_predict].tolist() -# transe_model = gds.model.transe.create(G_test, "emb", {rel_label_to_predict: target_emb}) -# -# source_node_list = ["/m/07l450", "/m/0ds2l81", "/m/0jvt9"] -# source_ids_df = gds.run_cypher( -# "UNWIND $node_text_list AS t MATCH (n:Entity) WHERE n.text=t RETURN id(n) as nodeId", -# params={"node_text_list": source_node_list}, -# ) -# -# result = transe_model.predict_stream( -# source_node_filter=source_ids_df.nodeId, -# target_node_filter="Entity", -# relationship_type=rel_label_to_predict, -# top_k=3, -# concurrency=4, -# ) -# print(result) -# -# ids_in_result = pd.unique(pd.concat([result.sourceNodeId, result.targetNodeId])) -# -# ids_to_text = gds.run_cypher( -# "UNWIND $ids AS id MATCH (n:Entity) WHERE id(n)=id RETURN id(n) AS nodeId, n.text AS tag, n.id AS id", -# params={"ids": ids_in_result}, -# ) -# -# nodeId_to_text_res = dict(zip(ids_to_text.nodeId, ids_to_text.tag)) -# nodeId_to_id_res = dict(zip(ids_to_text.nodeId, ids_to_text.id)) -# -# result.insert(1, "sourceTag", result.sourceNodeId.map(lambda x: nodeId_to_text_res[x])) -# result.insert(2, "sourceId", result.sourceNodeId.map(lambda x: nodeId_to_id_res[x])) -# result.insert(4, "targetTag", result.targetNodeId.map(lambda x: nodeId_to_text_res[x])) -# result.insert(5, "targetId", result.targetNodeId.map(lambda x: nodeId_to_id_res[x])) -# -# print(result) -# -# write_relationship_type = "PREDICTED_" + rel_label_to_predict -# result_write = transe_model.predict_write( -# source_node_filter=source_ids_df.nodeId, -# target_node_filter="Entity", -# relationship_type=rel_label_to_predict, -# write_relationship_type=write_relationship_type, -# write_property="transe_score", -# top_k=3, -# concurrency=4, -# ) -# -# gds.run_cypher( -# "MATCH (n)-[r:" -# + write_relationship_type -# + "]->(m) RETURN n.id AS sourceId, n.text AS sourceTag, m.id AS targetId, m.text AS targetTag, r.transe_score AS score" -# ) -# -# gds.graph.drop(G_test) +if __name__ == "__main__": + gds = setup_connection() + create_constraint(gds) + put_data_in_db(gds) + G_train = project_train_graph(gds) + inspect_graph(G_train) + + gds.set_compute_cluster_ip("localhost") + + print(gds.debug.arrow()) + + gds.kge.model.train( + G_train, + scoring_function="DistMult", + num_epochs=1, + embedding_dimension=10, + epochs_per_checkpoint=0, + ) + + print('Finished training') diff --git a/graphdatascience/graph_data_science.py b/graphdatascience/graph_data_science.py index 4599442fd..06cc6bc43 100644 --- a/graphdatascience/graph_data_science.py +++ b/graphdatascience/graph_data_science.py @@ -143,7 +143,6 @@ def fastpath(self) -> FastPathRunner: @property def kge(self) -> KgeRunner: - print("!kge") if not isinstance(self._query_runner, ArrowQueryRunner): raise ValueError("Running FastPath requires GDS with the Arrow server enabled") if self._compute_cluster_ip is None: @@ -156,7 +155,7 @@ def kge(self) -> KgeRunner: self._server_version, self._compute_cluster_ip, self._encrypted_db_password, - None, + self._query_runner._gds_arrow_client._host + ":" + str(self._query_runner._gds_arrow_client._port), ) def __getattr__(self, attr: str) -> IndirectCallBuilder: diff --git a/graphdatascience/model/kge_runner.py b/graphdatascience/model/kge_runner.py index 8e6fe4a45..48b4b9746 100644 --- a/graphdatascience/model/kge_runner.py +++ b/graphdatascience/model/kge_runner.py @@ -31,7 +31,6 @@ def __init__( self._namespace = namespace self._server_version = server_version self._compute_cluster_web_uri = f"http://{compute_cluster_ip}:5005" - # self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8491" self._compute_cluster_mlflow_uri = f"http://{compute_cluster_ip}:8080" self._encrypted_db_password = encrypted_db_password self._arrow_uri = arrow_uri @@ -50,6 +49,7 @@ def train( scoring_function, num_epochs, embedding_dimension, + epochs_per_checkpoint, mlflow_experiment_name: Optional[str] = None, ) -> Series: graph_config = {"name": G.name()} @@ -58,6 +58,7 @@ def train( "scoring_function": scoring_function, "num_epochs": num_epochs, "embedding_dimension": embedding_dimension, + "epochs_per_checkpoint": epochs_per_checkpoint, } config = { @@ -65,9 +66,10 @@ def train( "task": "KGE_TRAINING_PYG", "task_config": { "graph_config": graph_config, + "modelname": "dummmy_model_name", "task_config": algo_config, }, - "encrypted_db_password": self._encrypted_db_password, + # "encrypted_db_password": self._encrypted_db_password, "graph_arrow_uri": self._arrow_uri, } @@ -83,8 +85,11 @@ def train( return Series({"status": "finished"}) def _start_job(self, config: Dict[str, Any]) -> str: + print("_start_job") print(config) - res = requests.post(f"{self._compute_cluster_web_uri}/api/machine-learning/start", json=config) + url = f"{self._compute_cluster_web_uri}/api/machine-learning/start" + print(url) + res = requests.post(url, json=config) res.raise_for_status() job_id = res.json()["job_id"] logging.info(f"Job with ID '{job_id}' started") From 7d7deffb7bd308912255f7729a52c238a70deef2 Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Mon, 8 Jul 2024 18:23:07 +0100 Subject: [PATCH 06/24] Avoid usage of dummy password --- graphdatascience/model/kge_runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/graphdatascience/model/kge_runner.py b/graphdatascience/model/kge_runner.py index 48b4b9746..1c0956035 100644 --- a/graphdatascience/model/kge_runner.py +++ b/graphdatascience/model/kge_runner.py @@ -66,12 +66,13 @@ def train( "task": "KGE_TRAINING_PYG", "task_config": { "graph_config": graph_config, - "modelname": "dummmy_model_name", + "modelname": "dummmy_model_name"+str(time.time()), "task_config": algo_config, }, - # "encrypted_db_password": self._encrypted_db_password, "graph_arrow_uri": self._arrow_uri, } + if self._encrypted_db_password is not None: + config["encrypted_db_password"] = self._encrypted_db_password if mlflow_experiment_name is not None: config["task_config"]["mlflow"] = { From c9101660809d34403a02c502050ce1c49f2a787b Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Wed, 17 Jul 2024 16:16:19 +0100 Subject: [PATCH 07/24] KGE distmult nations is working --- examples/kge-distmult-nations.py | 243 ++++++++++++++++++ examples/kge-distmult.py | 69 ++++- graphdatascience/model/kge_runner.py | 64 ++++- .../tests/integration/test_graph_construct.py | 28 +- .../tests/integration/test_graph_ops.py | 2 - 5 files changed, 369 insertions(+), 37 deletions(-) create mode 100644 examples/kge-distmult-nations.py diff --git a/examples/kge-distmult-nations.py b/examples/kge-distmult-nations.py new file mode 100644 index 000000000..daa40e4ec --- /dev/null +++ b/examples/kge-distmult-nations.py @@ -0,0 +1,243 @@ +import os +import time +import warnings +from collections import defaultdict + +from neo4j.exceptions import ClientError +from tqdm import tqdm + +from graphdatascience import GraphDataScience + +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +def setup_connection(): + NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687") + NEO4J_AUTH = None + NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j") + if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"): + NEO4J_AUTH = ( + os.environ.get("NEO4J_USER"), + os.environ.get("NEO4J_PASSWORD"), + ) + gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True) + + return gds + + +def create_constraint(gds): + try: + _ = gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE") + except ClientError: + print("CONSTRAINT entity_id already exists") + + +def download_data(raw_file_names): + import os + import zipfile + + from ogb.utils.url import download_url + + url = "https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip" + raw_dir = "./data_from_zip" + download_url(f"{url}", raw_dir) + + with zipfile.ZipFile(raw_dir + "/" + os.path.basename(url), "r") as zip_ref: + for filename in raw_file_names: + zip_ref.extract(f"Release/{filename}", path=raw_dir) + data_dir = raw_dir + "/" + "Release" + return data_dir + + +def get_text_to_id_map(data_dir, text_to_id_filename): + with open(data_dir + "/" + text_to_id_filename, "r") as f: + data = [x.split("\t") for x in f.read().split("\n")[:-1]] + text_to_id_map = {text: int(id) for text, id in data} + return text_to_id_map + + +def read_data(): + rel_types = { + "train.txt": "TRAIN", + "valid.txt": "VALID", + "test.txt": "TEST", + } + raw_file_names = ["train.txt", "valid.txt", "test.txt"] + node_id_filename = "entity2id.txt" + rel_id_filename = "relation2id.txt" + + data_dir = "/Users/olgarazvenskaia/work/datasets/KGDatasets/Nations" + node_map = get_text_to_id_map(data_dir, node_id_filename) + rel_map = get_text_to_id_map(data_dir, rel_id_filename) + dataset = defaultdict(lambda: defaultdict(list)) + + rel_split_id = {"TRAIN": 0, "VALID": 1, "TEST": 2} + + for file_name in raw_file_names: + file_name_path = data_dir + "/" + file_name + + with open(file_name_path, "r") as f: + data = [x.split("\t") for x in f.read().split("\n")[:-1]] + + for i, (src_text, rel_text, dst_text) in enumerate(data): + source = node_map[src_text] + target = node_map[dst_text] + rel_type = "REL_" + rel_text.upper() + rel_split = rel_types[file_name] + + dataset[rel_split][rel_type].append( + { + "source": source, + "source_text": src_text, + "target": target, + "target_text": dst_text, + "rel_type": rel_type, + "rel_id": rel_map[rel_text], + "rel_split": rel_split, + "rel_split_id": rel_split_id[rel_split], + } + ) + + print("Number of nodes: ", len(node_map)) + for rel_split in dataset: + print( + f"Number of relationships of type {rel_split}: ", + sum([len(dataset[rel_split][rel_type]) for rel_type in dataset[rel_split]]), + ) + return dataset + + +def put_data_in_db(gds): + res = gds.run_cypher("MATCH (m) RETURN count(m) as num_nodes") + if res["num_nodes"].values[0] > 0: + print("Data already in db, number of nodes: ", res["num_nodes"].values[0]) + return + dataset = read_data() + pbar = tqdm( + desc="Putting data in db", + total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]), + ) + + for rel_split in dataset: + for rel_type in dataset[rel_split]: + edges = dataset[rel_split][rel_type] + + gds.run_cypher( + f""" + UNWIND $ll as l + MERGE (n:Entity {{id:l.source, text:l.source_text}}) + MERGE (m:Entity {{id:l.target, text:l.target_text}}) + MERGE (n)-[:{rel_type} {{split: l.rel_split_id, rel_id: l.rel_id}}]->(m) + """, + params={"ll": edges}, + ) + pbar.update(len(edges)) + pbar.close() + + for rel_split in dataset: + res = gds.run_cypher( + f""" + MATCH ()-[r:{rel_split}]->() + RETURN COUNT(r) AS numberOfRelationships + """ + ) + print(f"Number of relationships of type {rel_split} in db: ", res.numberOfRelationships) + + +def project_graphs(gds): + all_rels = gds.run_cypher( + """ + CALL db.relationshipTypes() YIELD relationshipType + """ + ) + all_rels = all_rels["relationshipType"].to_list() + all_rels = {rel: {"properties": "split"} for rel in all_rels if rel.startswith("REL_")} + gds.graph.drop("fullGraph", failIfMissing=False) + gds.graph.drop("trainGraph", failIfMissing=False) + gds.graph.drop("validGraph", failIfMissing=False) + gds.graph.drop("testGraph", failIfMissing=False) + + G_full, _ = gds.graph.project("fullGraph", ["Entity"], all_rels) + inspect_graph(G_full) + + G_train, _ = gds.graph.filter("trainGraph", G_full, "*", "r.split = 0.0") + G_valid, _ = gds.graph.filter("validGraph", G_full, "*", "r.split = 1.0") + G_test, _ = gds.graph.filter("testGraph", G_full, "*", "r.split = 2.0") + + inspect_graph(G_train) + inspect_graph(G_valid) + inspect_graph(G_test) + + gds.graph.drop("fullGraph", failIfMissing=False) + + return G_train, G_valid, G_test + + +def inspect_graph(G): + func_names = [ + "name", + "node_count", + "relationship_count", + "node_labels", + "relationship_types", + ] + for func_name in func_names: + print(f"==={func_name}===: {getattr(G, func_name)()}") + + +if __name__ == "__main__": + gds = setup_connection() + create_constraint(gds) + put_data_in_db(gds) + G_train, G_valid, G_test = project_graphs(gds) + inspect_graph(G_train) + inspect_graph(G_valid) + inspect_graph(G_test) + + gds.set_compute_cluster_ip("localhost") + + print(gds.debug.arrow()) + + model_name = "dummyModelName_" + str(time.time()) + + gds.kge.model.train( + G_train, + model_name=model_name, + scoring_function="DistMult", + num_epochs=1, + embedding_dimension=10, + epochs_per_checkpoint=0, + ) + + df = gds.kge.model.predict( + G_train, + model_name=model_name, + top_k=10, + node_ids=[ + gds.find_node_id(["Entity"], {"text": "brazil"}), + gds.find_node_id(["Entity"], {"text": "uk"}), + gds.find_node_id(["Entity"], {"text": "jordan"}), + ], + rel_types=["REL_RELDIPLOMACY", "REL_RELNGO"], + ) + + print(df) + # + # gds.kge.model.predict_tail( + # G_train, + # model_name=model_name, + # top_k=10, + # node_ids=[gds.find_node_id(["Entity"], {"text": "/m/016wzw"}), gds.find_node_id(["Entity"], {"id": 2})], + # rel_types=["REL_1", "REL_2"], + # ) + # + # gds.kge.model.score_triples( + # G_train, + # model_name=model_name, + # triples=[ + # (gds.find_node_id(["Entity"], {"text": "/m/016wzw"}), "REL_1", gds.find_node_id(["Entity"], {"id": 2})), + # (gds.find_node_id(["Entity"], {"id": 0}), "REL_123", gds.find_node_id(["Entity"], {"id": 3})), + # ], + # ) + + print("Finished training") diff --git a/examples/kge-distmult.py b/examples/kge-distmult.py index 5ae04bd4c..20b7f15e4 100644 --- a/examples/kge-distmult.py +++ b/examples/kge-distmult.py @@ -1,11 +1,13 @@ import os +import time import warnings from collections import defaultdict -from graphdatascience import GraphDataScience from neo4j.exceptions import ClientError from tqdm import tqdm +from graphdatascience import GraphDataScience + warnings.filterwarnings("ignore", category=DeprecationWarning) @@ -110,18 +112,19 @@ def put_data_in_db(gds): desc="Putting data in db", total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]), ) + rel_split_id = {"TRAIN": 0, "VALID": 1, "TEST": 2} for rel_split in dataset: for rel_type in dataset[rel_split]: edges = dataset[rel_split][rel_type] # MERGE (n)-[:{rel_type} {{text:l.rel_text}}]->(m) + # MERGE (n)-[:{rel_split}]->(m) gds.run_cypher( f""" UNWIND $ll as l MERGE (n:Entity {{id:l.source, text:l.source_text}}) MERGE (m:Entity {{id:l.target, text:l.target_text}}) - MERGE (n)-[:{rel_split}]->(m) - MERGE (n)-[:{rel_type}]->(m) + MERGE (n)-[:{rel_type} {{split: {rel_split_id[rel_split]}}}]->(m) """, params={"ll": edges}, ) @@ -153,6 +156,33 @@ def project_train_graph(gds): return G_train +def project_predict_graph(gds): + all_rels = gds.run_cypher( + """ + CALL db.relationshipTypes() YIELD relationshipType + """ + ) + all_rels = all_rels["relationshipType"].to_list() + rel_spec = {} + for rel in all_rels: + if rel.startswith("REL_"): + rel_spec[rel] = {"properties": ["split"]} + + gds.graph.drop("fullGraph", failIfMissing=False) + gds.graph.drop("predictGraph", failIfMissing=False) + + # {"REL": {"properties": ["relY"]}, "RELR": {"properties": ["relY"]}} + # print(rel_spec) + + G_full, result = gds.graph.project("fullGraph", ["Entity"], all_rels) + + G_full, result = gds.graph.project("fullGraph", ["Entity"], rel_spec) + # G_predict = gds.graph.filter('predictGraph', 'fullGraph', '*', 'r.split == 2') + + inspect_graph(G_full) + return G_full + + def inspect_graph(G): func_names = [ "name", @@ -170,18 +200,47 @@ def inspect_graph(G): create_constraint(gds) put_data_in_db(gds) G_train = project_train_graph(gds) - inspect_graph(G_train) + # G_predict = project_predict_graph(gds) + # inspect_graph(G_train) gds.set_compute_cluster_ip("localhost") print(gds.debug.arrow()) + model_name = "dummyModelName_" + str(time.time()) + gds.kge.model.train( G_train, + model_name=model_name, scoring_function="DistMult", num_epochs=1, embedding_dimension=10, epochs_per_checkpoint=0, ) - print('Finished training') + gds.kge.model.predict( + G_train, + model_name=model_name, + top_k=10, + node_ids=[1, 2, 3], + rel_types=["REL_1", "REL_2"], + ) + + gds.kge.model.predict_tail( + G_train, + model_name=model_name, + top_k=10, + node_ids=[gds.find_node_id(["Entity"], {"text": "/m/016wzw"}), gds.find_node_id(["Entity"], {"id": 2})], + rel_types=["REL_1", "REL_2"], + ) + + gds.kge.model.score_triples( + G_train, + model_name=model_name, + triples=[ + (gds.find_node_id(["Entity"], {"text": "/m/016wzw"}), "REL_1", gds.find_node_id(["Entity"], {"id": 2})), + (gds.find_node_id(["Entity"], {"id": 0}), "REL_123", gds.find_node_id(["Entity"], {"id": 3})), + ], + ) + + print("Finished training") diff --git a/graphdatascience/model/kge_runner.py b/graphdatascience/model/kge_runner.py index 1c0956035..5cf99314b 100644 --- a/graphdatascience/model/kge_runner.py +++ b/graphdatascience/model/kge_runner.py @@ -3,15 +3,15 @@ import time from typing import Any, Dict, Optional +import pandas as pd import requests -from pandas import Series +from pandas import DataFrame, Series from ..error.client_only_endpoint import client_only_endpoint from ..error.illegal_attr_checker import IllegalAttrChecker from ..error.uncallable_namespace import UncallableNamespace from ..graph.graph_object import Graph from ..query_runner.query_runner import QueryRunner -from ..server_version.compatible_with import compatible_with from ..server_version.server_version import ServerVersion logging.basicConfig(level=logging.INFO) @@ -46,6 +46,7 @@ def model(self): def train( self, G: Graph, + model_name: str, scoring_function, num_epochs, embedding_dimension, @@ -66,7 +67,7 @@ def train( "task": "KGE_TRAINING_PYG", "task_config": { "graph_config": graph_config, - "modelname": "dummmy_model_name"+str(time.time()), + "modelname": model_name, "task_config": algo_config, }, "graph_arrow_uri": self._arrow_uri, @@ -85,6 +86,63 @@ def train( return Series({"status": "finished"}) + @client_only_endpoint("gds.kge.model") + def predict( + self, + G: Graph, + model_name: str, + top_k: int, + node_ids: list[int], + rel_types: list[str], + mlflow_experiment_name: Optional[str] = None, + ) -> DataFrame: + graph_config = {"name": G.name()} + + algo_config = { + "top_k": top_k, + "node_ids": node_ids, + "rel_types": rel_types, + } + + config = { + "user_name": "DUMMY_USER", + "task": "KGE_PREDICT_PYG", + "task_config": { + "graph_config": graph_config, + "modelname": model_name, + "task_config": algo_config, + }, + "graph_arrow_uri": self._arrow_uri, + } + if self._encrypted_db_password is not None: + config["encrypted_db_password"] = self._encrypted_db_password + + if mlflow_experiment_name is not None: + config["task_config"]["mlflow"] = { + "config": {"tracking_uri": self._compute_cluster_mlflow_uri, "experiment_name": mlflow_experiment_name} + } + + print("predict config") + print(config) + job_id = self._start_job(config) + + self._wait_for_job(job_id) + + return self._stream_results(config["user_name"], config["task_config"]["modelname"], job_id) + + def _stream_results(self, user_name: str, model_name: str, job_id: str) -> DataFrame: + res = requests.get( + f"{self._compute_cluster_web_uri}/internal/fetch-result", + params={"user_name": user_name, "modelname": model_name, "job_id": job_id}, + ) + res.raise_for_status() + + with open("res.json", mode="wb+") as f: + f.write(res.content) + + df = pd.read_json("res.json", orient="records", lines=True) + return df + def _start_job(self, config: Dict[str, Any]) -> str: print("_start_job") print(config) diff --git a/graphdatascience/tests/integration/test_graph_construct.py b/graphdatascience/tests/integration/test_graph_construct.py index 82a0da3ca..4f2379256 100644 --- a/graphdatascience/tests/integration/test_graph_construct.py +++ b/graphdatascience/tests/integration/test_graph_construct.py @@ -559,35 +559,9 @@ def test_graph_alpha_construct_backward_compat_with_arrow(gds: GraphDataScience) with pytest.warns(DeprecationWarning): gds.alpha.graph.construct("hello", nodes, relationships) -@pytest.mark.enterprise -@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 1, 0)) -def test_graph_alpha_construct_backward_compat_with_arrow(gds: GraphDataScience) -> None: - nodes = DataFrame({"nodeId": [0, 1, 2, 3]}) - relationships = DataFrame({"sourceNodeId": [0, 1, 2, 3], "targetNodeId": [1, 2, 3, 0]}) - - with pytest.warns(DeprecationWarning): - gds.alpha.graph.construct("hello", nodes, relationships) - - -@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0)) -def test_roundtrip_with_arrow(gds: GraphDataScience) -> None: - G, _ = gds.graph.project(GRAPH_NAME, {"Node": {"properties": ["x", "y"]}}, {"REL": {"properties": "relX"}}) - - rel_df = gds.graph.relationshipProperty.stream(G, "relX") - node_df = gds.graph.nodeProperty.stream(G, "x") - - G_2 = gds.graph.construct("arrowGraph", node_df, rel_df) - - res = gds.graph.list() - try: - assert set(res['graphName'].tolist()) == {'g', 'arrowGraph'} - assert G.node_count() == G_2.node_count() - assert G.relationship_count() == G_2.relationship_count() - finally: - G_2.drop() @pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0)) def test_drop_list_warning_reproduction(gds: GraphDataScience) -> None: G, _ = gds.graph.project(GRAPH_NAME, {"Node": {"properties": ["x", "y"]}}, {"REL": {"properties": "relX"}}) res = gds.graph.list() - assert res['graphName'].tolist() == ['g'] + assert res["graphName"].tolist() == ["g"] diff --git a/graphdatascience/tests/integration/test_graph_ops.py b/graphdatascience/tests/integration/test_graph_ops.py index 00d32eb8e..3a505313a 100644 --- a/graphdatascience/tests/integration/test_graph_ops.py +++ b/graphdatascience/tests/integration/test_graph_ops.py @@ -1058,5 +1058,3 @@ def test_empty_relationships_stream(gds: GraphDataScience) -> None: result = gds.graph.relationships.stream(G, ["SIMILAR"]) assert result.empty - - From b532f83a4a3e41b43d62e696cab346dd9c983eee Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Wed, 17 Jul 2024 17:04:02 +0100 Subject: [PATCH 08/24] Next --- examples/kge-distmult-nations.py | 2 +- graphdatascience/model/kge_runner.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/kge-distmult-nations.py b/examples/kge-distmult-nations.py index daa40e4ec..c8b27b45e 100644 --- a/examples/kge-distmult-nations.py +++ b/examples/kge-distmult-nations.py @@ -212,7 +212,7 @@ def inspect_graph(G): df = gds.kge.model.predict( G_train, model_name=model_name, - top_k=10, + top_k=3, node_ids=[ gds.find_node_id(["Entity"], {"text": "brazil"}), gds.find_node_id(["Entity"], {"text": "uk"}), diff --git a/graphdatascience/model/kge_runner.py b/graphdatascience/model/kge_runner.py index 5cf99314b..a4b526cc3 100644 --- a/graphdatascience/model/kge_runner.py +++ b/graphdatascience/model/kge_runner.py @@ -137,10 +137,12 @@ def _stream_results(self, user_name: str, model_name: str, job_id: str) -> DataF ) res.raise_for_status() - with open("res.json", mode="wb+") as f: + res_file_name = f'res_{job_id}.json' + with open(res_file_name, mode="wb+") as f: f.write(res.content) - df = pd.read_json("res.json", orient="records", lines=True) + df = pd.read_json(res_file_name, orient="records", lines=True) + os.remove(res_file_name) return df def _start_job(self, config: Dict[str, Any]) -> str: From 365ec4d9ca1022499fc38c953bc9091760bb6d2f Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Thu, 18 Jul 2024 12:29:46 +0100 Subject: [PATCH 09/24] Remove fastpath related code --- examples/FastPathExamples.ipynb | 760 ------------------ graphdatascience/graph_data_science.py | 18 - graphdatascience/model/fastpath_runner.py | 115 --- graphdatascience/model/kge_runner.py | 2 +- .../query_runner/gds_arrow_client.py | 6 + .../resources/field-testing/pub.pem | 3 +- .../tests/integration/test_graph_construct.py | 7 - .../tests/integration/test_graph_ops.py | 2 +- 8 files changed, 9 insertions(+), 904 deletions(-) delete mode 100644 examples/FastPathExamples.ipynb delete mode 100644 graphdatascience/model/fastpath_runner.py diff --git a/examples/FastPathExamples.ipynb b/examples/FastPathExamples.ipynb deleted file mode 100644 index 2e304b218..000000000 --- a/examples/FastPathExamples.ipynb +++ /dev/null @@ -1,760 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "0f03a290", - "metadata": {}, - "source": [ - "# Path embeddings with FastPATH - Examples" - ] - }, - { - "cell_type": "markdown", - "id": "68b6e21f", - "metadata": {}, - "source": [ - "In this notebook we will show you several examples of constructing path embeddings with the FastPATH algorithm.\n", - "The full documentation for the algorithm can be found [here](https://docs.google.com/document/d/1oCAz6ukn_r19H27ghxnGM_-UQP9rgYJRhLzNLHdQc8Y/edit#heading=h.ya70gurwgyt2)." - ] - }, - { - "cell_type": "markdown", - "id": "c3bf7590", - "metadata": {}, - "source": [ - "## The Dataset\n", - "\n", - "We will use a synthetic medical dataset containg `Patients`, `Encounters`, `Conditions`, `Observations` and more.\n", - "Using FastPATH we will construct (path) embeddings for patient journey in the dataset.\n", - "You need to replace the Neo4j URL and credentials to a database that contains the dataset.\n", - "Contact the GDS team if you're interested in that.\n", - "\n", - "Below is the schema of the database:" - ] - }, - { - "attachments": { - "image.png": { - "image/png": "" - } - }, - "cell_type": "markdown", - "id": "316f034f", - "metadata": {}, - "source": [ - "![image.png](attachment:image.png)" - ] - }, - { - "cell_type": "markdown", - "id": "a062d180", - "metadata": {}, - "source": [ - "## Import and Setup" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4fe27541", - "metadata": {}, - "outputs": [], - "source": [ - "from graphdatascience import GraphDataScience\n", - "import numpy as np\n", - "from sklearn.manifold import TSNE\n", - "import pandas as pd\n", - "import seaborn as sns\n", - "from matplotlib import pyplot as plt\n", - "from sklearn.linear_model import LogisticRegression\n", - "from sklearn.metrics import f1_score\n", - "from sklearn.utils._testing import ignore_warnings\n", - "from sklearn.exceptions import ConvergenceWarning\n", - "from sklearn.model_selection import train_test_split\n", - "\n", - "plt.rcParams[\"figure.figsize\"] = [15, 10]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "238525b8", - "metadata": {}, - "outputs": [], - "source": [ - "gds = GraphDataScience(\n", - " \"neo4j+s://eddb7e19.databases.neo4j.io\",\n", - " auth=(\"neo4j\", \"Oz4oBK--Sx4byHjgHgJuMf5VqQncGHG9mbgpy44rQTU\"),\n", - " database=\"neo4j\",\n", - ")\n", - "gds.set_compute_cluster_ip(\"localhost\")" - ] - }, - { - "cell_type": "markdown", - "id": "c1f7417c", - "metadata": {}, - "source": [ - "## Preprocessing\n", - "\n", - "In order to make our dataset amenable to our analysis using FastPATH and downstream machine learning, we must augment it slightly.\n", - "This entails writing some additional node properties to the database with the Cypher code below.\n", - "\n", - "**NOTE: Each preprocessing cell below must be run once, and only once.**" - ] - }, - { - "cell_type": "markdown", - "id": "97bf5fc6", - "metadata": {}, - "source": [ - "First we write a `has_diabetes` property (0 or 1) to each `Patient` node.\n", - "This will give us class labels that enable us to train a classification model on patient journeys later." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9b10f9b6", - "metadata": {}, - "outputs": [], - "source": [ - "gds.run_cypher(\"MATCH (p:Patient) SET p.has_diabetes=0\")\n", - "gds.run_cypher(\n", - " \"MATCH (p:Patient)-[:HAS_ENCOUNTER]->(n:Encounter)-[:HAS_CONDITION]-(c:Condition) WHERE c.description='Diabetes' SET p.has_diabetes=1\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "5f103fbb", - "metadata": {}, - "source": [ - "Then to each `Encounter` node, we write the number of days that has passed since 1 January 1970 (can be negative), based on the existing `start` node property.\n", - "We do this since the `start` property it already has is not an actual number, which is what the algorithm needs.\n", - "This is needed in the case where we don't rely on `NEXT` relationships for event timestamps, which is one of the examples below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5c307b6d", - "metadata": {}, - "outputs": [], - "source": [ - "gds.run_cypher(\n", - " \"MATCH (n:Encounter) WITH toInteger(datetime(n.start).epochseconds/(24 * 3600)) as days, n SET n.days=days\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "014e00e4", - "metadata": {}, - "source": [ - "Next we write two output time properties to each `Patient` based on the last `Encounter` before a diabetes diagnosis, or the last `Encounter` otherwise.\n", - "For the case where we are relying on the `days` node property on `Encounter`s (see above), the new `output_time` node property for `Patient`s will be equal to 1 + the `days` timestamp of their last encounter (before diabetes if they have it).\n", - "For the case where we are relying on `FIRST` and `NEXT` relationships to define the `Encounter`s belonging to a `Patient`, the new `output_time_stepwise` node property for `Patient`s will be equal to the number of encounters up to and including the last encounter (before diabetes if they have it).\n", - "\n", - "With these properties we can specify the point in time for which we want the path embeddings for each `Patient` node.\n", - "I.e. the paths that is embedded will continue up to that point, but not longer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b95b2ba3", - "metadata": {}, - "outputs": [], - "source": [ - "# Writing the `output_time` `Patient` node property\n", - "gds.run_cypher(\"MATCH (p:Patient)-[:LAST]->(n:Encounter) SET p.output_time=n.days+1\")\n", - "gds.run_cypher(\n", - " \"MATCH (p:Patient)-[:HAS_ENCOUNTER]->(e1:Encounter)-[:NEXT]->(e2:Encounter)-[:HAS_CONDITION]->(c:Condition) WHERE c.description='Diabetes' SET p.output_time=e1.days + 1\"\n", - ")\n", - "\n", - "# Writing `output_time_stepwise` `Patient` node property\n", - "gds.run_cypher(\n", - " \"MATCH (p:Patient)-[:HAS_ENCOUNTER]->(e:Encounter) WHERE e.days <= p.output_time - 1 WITH p, count(*) as cc SET p.output_time_stepwise=cc\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "68dffdb0", - "metadata": {}, - "source": [ - "Lastly we write the `class` of each `Encounter` as an integer property `intClass`.\n", - "Doing so enables us to use the class property as input to the algorithm, impacting the internal embeddings of `Encounter` nodes." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4b8e0f5b", - "metadata": {}, - "outputs": [], - "source": [ - "gds.run_cypher(\n", - " \"\"\"\n", - " MATCH (e:Encounter) with distinct e.class AS class\n", - " WITH collect(class) as clss\n", - " WITH apoc.map.fromLists(clss, range(0, size(clss) - 1)) as classMap\n", - " MATCH (e:Encounter) SET e.intClass = classMap[e.class]\n", - " \"\"\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "ba366637", - "metadata": {}, - "source": [ - "## Projection with Timestamps\n", - "\n", - "For the first examples, we rely on the `days` property of `Encounter` nodes for timestamp.\n", - "For this reason we don't need to project `FIRST` and `NEXT` relationships." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6cf05097", - "metadata": {}, - "outputs": [], - "source": [ - "try:\n", - " G = gds.graph.get(\"medical\")\n", - " G.drop()\n", - "except:\n", - " pass\n", - "\n", - "G, _ = gds.graph.project(\n", - " \"medical\",\n", - " {\n", - " \"Patient\": {\"properties\": [\"output_time\", \"has_diabetes\"]},\n", - " \"Encounter\": {\"properties\": [\"days\", \"intClass\"]},\n", - " \"Observation\": {\"properties\": []},\n", - " \"Payer\": {\"properties\": []},\n", - " \"Provider\": {\"properties\": []},\n", - " \"Organization\": {\"properties\": []},\n", - " \"Speciality\": {\"properties\": []},\n", - " \"Allergy\": {\"properties\": []},\n", - " \"Reaction\": {\"properties\": []},\n", - " \"Condition\": {\"properties\": []},\n", - " \"Drug\": {\"properties\": []},\n", - " \"Procedure\": {\"properties\": []},\n", - " \"CarePlan\": {\"properties\": []},\n", - " \"Device\": {\"properties\": []},\n", - " \"ConditionDescription\": {\"properties\": []},\n", - " },\n", - " [\n", - " \"HAS_OBSERVATION\",\n", - " \"HAS_ENCOUNTER\",\n", - " \"HAS_PROVIDER\",\n", - " \"AT_ORGANIZATION\",\n", - " \"HAS_PAYER\",\n", - " \"HAS_SPECIALITY\",\n", - " \"BELONGS_TO\",\n", - " \"INSURANCE_START\",\n", - " \"INSURANCE_END\",\n", - " \"HAS_ALLERGY\",\n", - " \"ALLERGY_DETECTED\",\n", - " \"HAS_REACTION\",\n", - " \"CAUSES_REACTION\",\n", - " \"HAS_CONDITION\",\n", - " \"HAS_DRUG\",\n", - " \"HAS_PROCEDURE\",\n", - " \"HAS_CARE_PLAN\",\n", - " \"DEVICE_USED\",\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "e6aacc68", - "metadata": {}, - "source": [ - "## FastRP Features\n", - "\n", - "We should make use of the topological information we have around in `Encounter` node.\n", - "For example, what `Condition`s, `Drug`s, `Procedure`s, etc. (see schema above) are connected to it.\n", - "And perhaps one hop in the graph beyond that.\n", - "To do so, we make use of FastRP to create node embeddings.\n", - "Later we can input the node embeddings of the `Encounter` nodes to the FastPATH algorithm using the `event_features` parameter." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c7b0c35e", - "metadata": {}, - "outputs": [], - "source": [ - "gds.fastRP.mutate(\n", - " G,\n", - " embeddingDimension=256,\n", - " mutateProperty=\"emb\",\n", - " iterationWeights=[1, 1],\n", - " randomSeed=42,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "76cc51a7", - "metadata": {}, - "source": [ - "## Preparation of machine learning and visualization of embeddings\n", - "\n", - "Below we define a utility function that we can subsequently use to analyze the path embeddings we produce in each example below.\n", - "This function does three things:\n", - "1. Computes the average pairwise distances between embeddings of the different class combinations (no diabetes vs diabetes)\n", - "2. Plot the path embeddings in two dimensions with t-SNE\n", - "3. Train and evaluate a logistic regression diabetes classifier which takes path embeddings as input\n", - "\n", - "**NOTE: You don't have to read or understand this function, but can think of it as a black box in the context of this notebook.**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0f8d6247", - "metadata": {}, - "outputs": [], - "source": [ - "@ignore_warnings(category=ConvergenceWarning)\n", - "def explore(embeddings):\n", - "\n", - " # Compute pairwise distances between embeddings of healty<->healthy, sick<->sick and healthy<->sick.\n", - "\n", - " diabetes_by_nodeId = gds.graph.streamNodeProperties(G, [\"has_diabetes\"], [\"Patient\"]).set_index(\"nodeId\")[\n", - " [\"propertyValue\"]\n", - " ]\n", - " emb_and_diabetes = (\n", - " embeddings[[\"nodeId\", \"embeddings\"]]\n", - " .set_index(\"nodeId\")\n", - " .merge(diabetes_by_nodeId, left_index=True, right_index=True)\n", - " )\n", - " healthy_embs = np.array(emb_and_diabetes[emb_and_diabetes.propertyValue == 0][\"embeddings\"].tolist())\n", - " diabetes_embs = np.array(emb_and_diabetes[emb_and_diabetes.propertyValue == 1][\"embeddings\"].tolist())\n", - "\n", - " diabetes_distances = []\n", - " for i in range(diabetes_embs.shape[0]):\n", - " for j in range(i + 1, diabetes_embs.shape[0]):\n", - " x1 = diabetes_embs[i, :]\n", - " x2 = diabetes_embs[j, :]\n", - " diabetes_distances.append(np.linalg.norm(x1 - x2))\n", - "\n", - " print(f\"Avg diabetes<->diabetes L2-distances: {np.mean(diabetes_distances)}\")\n", - "\n", - " healthy_distances = []\n", - " for i in range(healthy_embs.shape[0]):\n", - " for j in range(i + 1, healthy_embs.shape[0]):\n", - " x1 = healthy_embs[i, :]\n", - " x2 = healthy_embs[j, :]\n", - " healthy_distances.append(np.linalg.norm(x1 - x2))\n", - "\n", - " print(f\"Avg healthy<->healthy L2-distances: {np.mean(healthy_distances)}\")\n", - "\n", - " mixed_distances = []\n", - " for i in range(diabetes_embs.shape[0]):\n", - " for j in range(healthy_embs.shape[0]):\n", - " x1 = diabetes_embs[i, :]\n", - " x2 = healthy_embs[j, :]\n", - " mixed_distances.append(np.linalg.norm(x1 - x2))\n", - "\n", - " print(f\"Avg healthy<->diabetes L2-distances: {np.mean(mixed_distances)}\")\n", - "\n", - " # TSNE time\n", - "\n", - " X = np.array(emb_and_diabetes[\"embeddings\"].tolist())\n", - " y = emb_and_diabetes.propertyValue.to_numpy()\n", - " tsne = TSNE(2)\n", - " tsne_result = tsne.fit_transform(X)\n", - " tsne_result_df = pd.DataFrame({\"tsne_1\": tsne_result[:, 0], \"tsne_2\": tsne_result[:, 1], \"label\": y})\n", - " fig, ax = plt.subplots(1)\n", - " sns.scatterplot(x=\"tsne_1\", y=\"tsne_2\", hue=\"label\", data=tsne_result_df, ax=ax, s=10)\n", - " lim = (tsne_result.min() - 5, tsne_result.max() + 5)\n", - " ax.set_xlim(lim)\n", - " ax.set_ylim(lim)\n", - " ax.set_aspect(\"equal\")\n", - " ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)\n", - "\n", - " # Train evaluate diabetes classifier :)\n", - "\n", - " X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42, stratify=y)\n", - "\n", - " clf = LogisticRegression()\n", - " clf.fit(X_train, y_train)\n", - "\n", - " y_train_pred = clf.predict(X_train)\n", - " y_test_pred = clf.predict(X_test)\n", - "\n", - " train_f1_score = f1_score(y_train, y_train_pred, average=\"macro\")\n", - " test_f1_score = f1_score(y_test, y_test_pred, average=\"macro\")\n", - "\n", - " print(\"Diabetes classifier scores:\")\n", - " print(f\"Train set f1: {train_f1_score}\")\n", - " print(f\"Test set f1: {test_f1_score}\")" - ] - }, - { - "cell_type": "markdown", - "id": "0236dda1", - "metadata": {}, - "source": [ - "## Examples with timestamp node properties\n", - "\n", - "In the following few examples we will let the `days` node property on `Encounter` nodes dictate when an encounter has occured." - ] - }, - { - "cell_type": "markdown", - "id": "9f012dae", - "metadata": {}, - "source": [ - "### Global output time\n", - "\n", - "To use a single fixed output time, you can either\n", - "* Use the algorithm parameter `output_times` (and optionally use subgraph filtering to run only up to a certain time), or\n", - "* Use Cypher to write a output time property to the `Patient` nodes holding a fixed timestamp, and then provide it as `output_time_property`\n", - "\n", - "Here we will use the first option.\n", - "\n", - "Note that we are also using the FastRP embeddings for `Encounter` nodes as input features to the events (encounters)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1c216bcf", - "metadata": {}, - "outputs": [], - "source": [ - "# try:\n", - "gds.graph.nodeProperties.drop(G, [\"embeddings\"], node_labels=[\"Patient\"])\n", - "# except:\n", - "# pass\n", - "\n", - "gds.fastpath.mutate(\n", - " G,\n", - " base_node_label=\"Patient\",\n", - " event_node_label=\"Encounter\",\n", - " event_features=\"emb\",\n", - " time_node_property=\"days\",\n", - " dimension=256,\n", - " num_elapsed_times=100,\n", - " output_time=365 * 50, # 50 years\n", - " max_elapsed_time=365 * 10, # 10 years\n", - " smoothing_rate=0.004,\n", - " smoothing_window=3,\n", - " decay_factor=1e-5,\n", - " random_seed=42,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c4def10f", - "metadata": {}, - "outputs": [], - "source": [ - "embeddings = gds.graph.nodeProperties.stream(G, [\"embeddings\"], node_labels=[\"Patient\"], separate_property_columns=True)\n", - "print(embeddings)\n", - "explore(embeddings)" - ] - }, - { - "cell_type": "markdown", - "id": "8f283d0f", - "metadata": {}, - "source": [ - "## Example with individual output time" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "022a8096", - "metadata": {}, - "outputs": [], - "source": [ - "try:\n", - " gds.graph.nodeProperties.drop(G, [\"embeddings\"], node_labels=[\"Patient\"])\n", - "except:\n", - " pass\n", - "\n", - "embeddings = gds.fastpath.mutate(\n", - " G,\n", - " base_node_label=\"Patient\",\n", - " event_node_label=\"Encounter\",\n", - " event_features=\"emb\",\n", - " time_node_property=\"days\",\n", - " dimension=256,\n", - " num_elapsed_times=100,\n", - " output_time_property=\"output_time\",\n", - " max_elapsed_time=365 * 10,\n", - " smoothing_rate=0.004,\n", - " smoothing_window=3,\n", - " decay_factor=1e-4,\n", - " random_seed=42,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dce1f527", - "metadata": {}, - "outputs": [], - "source": [ - "embeddings = gds.graph.nodeProperties.stream(G, [\"embeddings\"], node_labels=[\"Patient\"], separate_property_columns=True)\n", - "print(embeddings)\n", - "explore(embeddings)" - ] - }, - { - "cell_type": "markdown", - "id": "b629f2ab", - "metadata": {}, - "source": [ - "# Example with categorical event property and input event vectors\n", - "As the type (class) of encounter may be important to characterize patient journeys and to classify diabetes, we 'intClass' as a categorical event property." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fa929721", - "metadata": {}, - "outputs": [], - "source": [ - "embeddings = gds.fastpath.mutate(\n", - " G,\n", - " base_node_label=\"Patient\",\n", - " event_node_label=\"Encounter\",\n", - " event_features=\"emb\",\n", - " time_node_property=\"days\",\n", - " categorical_event_properties=[\"intClass\"],\n", - " dimension=256,\n", - " num_elapsed_times=100,\n", - " output_time_property=\"output_time\",\n", - " max_elapsed_time=365 * 10,\n", - " smoothing_rate=0.004,\n", - " smoothing_window=3,\n", - " decay_factor=1e-4,\n", - " random_seed=42,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "01a6f368", - "metadata": {}, - "outputs": [], - "source": [ - "embeddings = gds.graph.nodeProperties.stream(G, [\"embeddings\"], node_labels=[\"Patient\"], separate_property_columns=True)\n", - "print(embeddings)\n", - "explore(embeddings)" - ] - }, - { - "cell_type": "markdown", - "id": "9bd0798e", - "metadata": {}, - "source": [ - "# Example with context nodes and input event vectors\n", - "As the history of drugs may be important to characterize patient journeys and to classify diabetes, we add 'Drug' as a context_node_label." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d5c27d60", - "metadata": {}, - "outputs": [], - "source": [ - "embeddings = gds.fastpath.stream(\n", - " G,\n", - " base_node_label=\"Patient\",\n", - " context_node_label=\"Drug\",\n", - " event_node_label=\"Encounter\",\n", - " event_features=\"emb\",\n", - " time_node_property=\"days\",\n", - " dimension=256,\n", - " # num_elapsed_times=100,\n", - " num_elapsed_times=1,\n", - " output_time_property=\"output_time\",\n", - " # max_elapsed_time=365 * 10,\n", - " max_elapsed_time=1,\n", - " smoothing_rate=0.004,\n", - " smoothing_window=0,\n", - " # smoothing_window=3,\n", - " decay_factor=1e-4,\n", - " random_seed=43,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e3f46994", - "metadata": {}, - "outputs": [], - "source": [ - "embeddings = gds.graph.nodeProperties.stream(G, [\"embeddings\"], node_labels=[\"Patient\"], separate_property_columns=True)\n", - "print(embeddings)\n", - "explore(embeddings)" - ] - }, - { - "cell_type": "markdown", - "id": "f77b148a", - "metadata": {}, - "source": [ - "# Example with next and first relationship schema\n", - "We will now repeat one of the previous examples but use a different schema for the paths.\n", - "In this case it will give the same graph and embeddings, but the example is useful for illustrating the use of the next-first schema." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "36067015", - "metadata": {}, - "outputs": [], - "source": [ - "try:\n", - " G = gds.graph.get(\"medical\")\n", - " G.drop()\n", - "except:\n", - " pass\n", - "\n", - "G, _ = gds.graph.project(\n", - " \"medical\",\n", - " {\n", - " \"Patient\": {\"properties\": [\"output_time\", \"output_time_stepwise\", \"has_diabetes\"]},\n", - " \"Encounter\": {\"properties\": [\"days\", \"intClass\"]},\n", - " \"Observation\": {\"properties\": []},\n", - " \"Payer\": {\"properties\": []},\n", - " \"Provider\": {\"properties\": []},\n", - " \"Organization\": {\"properties\": []},\n", - " \"Speciality\": {\"properties\": []},\n", - " \"Allergy\": {\"properties\": []},\n", - " \"Reaction\": {\"properties\": []},\n", - " \"Condition\": {\"properties\": []},\n", - " \"Drug\": {\"properties\": []},\n", - " \"Procedure\": {\"properties\": []},\n", - " \"CarePlan\": {\"properties\": []},\n", - " \"Device\": {\"properties\": []},\n", - " \"ConditionDescription\": {\"properties\": []},\n", - " },\n", - " [\n", - " \"HAS_OBSERVATION\",\n", - " \"NEXT\",\n", - " \"FIRST\",\n", - " \"HAS_PROVIDER\",\n", - " \"AT_ORGANIZATION\",\n", - " \"HAS_PAYER\",\n", - " \"HAS_SPECIALITY\",\n", - " \"BELONGS_TO\",\n", - " \"INSURANCE_START\",\n", - " \"INSURANCE_END\",\n", - " \"HAS_ALLERGY\",\n", - " \"ALLERGY_DETECTED\",\n", - " \"HAS_REACTION\",\n", - " \"CAUSES_REACTION\",\n", - " \"HAS_CONDITION\",\n", - " \"HAS_DRUG\",\n", - " \"HAS_PROCEDURE\",\n", - " \"HAS_CARE_PLAN\",\n", - " \"DEVICE_USED\",\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5e27d608", - "metadata": {}, - "outputs": [], - "source": [ - "gds.fastRP.mutate(\n", - " G,\n", - " embeddingDimension=256,\n", - " mutateProperty=\"emb\",\n", - " iterationWeights=[1, 1],\n", - " randomSeed=42,\n", - " relationshipTypes=[\n", - " \"HAS_OBSERVATION\",\n", - " \"HAS_PROVIDER\",\n", - " \"AT_ORGANIZATION\",\n", - " \"HAS_PAYER\",\n", - " \"HAS_SPECIALITY\",\n", - " \"BELONGS_TO\",\n", - " \"INSURANCE_START\",\n", - " \"INSURANCE_END\",\n", - " \"HAS_ALLERGY\",\n", - " \"ALLERGY_DETECTED\",\n", - " \"HAS_REACTION\",\n", - " \"CAUSES_REACTION\",\n", - " \"HAS_CONDITION\",\n", - " \"HAS_DRUG\",\n", - " \"HAS_PROCEDURE\",\n", - " \"HAS_CARE_PLAN\",\n", - " \"DEVICE_USED\",\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1aaebe70", - "metadata": {}, - "outputs": [], - "source": [ - "embeddings = gds.fastpath.stream(\n", - " G,\n", - " base_node_label=\"Patient\",\n", - " context_node_label=\"Drug\",\n", - " event_node_label=\"Encounter\",\n", - " event_features=\"emb\",\n", - " next_relationship_type=\"NEXT\",\n", - " first_relationship_type=\"FIRST\",\n", - " time_node_property=\"days\",\n", - " dimension=256,\n", - " num_elapsed_times=100,\n", - " output_time_property=\"output_time\",\n", - " max_elapsed_time=365 * 10,\n", - " smoothing_rate=0.003701319681951021,\n", - " smoothing_window=3,\n", - " decay_factor=8.232744730741784e-05,\n", - " random_seed=43,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "faf3ee7a", - "metadata": {}, - "outputs": [], - "source": [ - "explore(embeddings, \"embeddings\")" - ] - } - ], - "metadata": { - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/graphdatascience/graph_data_science.py b/graphdatascience/graph_data_science.py index 06cc6bc43..7693dc29f 100644 --- a/graphdatascience/graph_data_science.py +++ b/graphdatascience/graph_data_science.py @@ -11,7 +11,6 @@ from .call_builder import IndirectCallBuilder from .endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints from .error.uncallable_namespace import UncallableNamespace -from .model.fastpath_runner import FastPathRunner from .model.kge_runner import KgeRunner from .query_runner.arrow_query_runner import ArrowQueryRunner from .query_runner.neo4j_query_runner import Neo4jQueryRunner @@ -124,23 +123,6 @@ def alpha(self) -> AlphaEndpoints: def beta(self) -> BetaEndpoints: return BetaEndpoints(self._query_runner, "gds.beta", self._server_version) - @property - def fastpath(self) -> FastPathRunner: - if not isinstance(self._query_runner, ArrowQueryRunner): - raise ValueError("Running FastPath requires GDS with the Arrow server enabled") - if self._compute_cluster_ip is None: - raise ValueError( - "You must set a valid computer cluster ip with the method `set_compute_cluster_ip` to use this feature" - ) - return FastPathRunner( - self._query_runner, - "gds.fastpath", - self._server_version, - self._compute_cluster_ip, - self._encrypted_db_password, - self._query_runner.uri, - ) - @property def kge(self) -> KgeRunner: if not isinstance(self._query_runner, ArrowQueryRunner): diff --git a/graphdatascience/model/fastpath_runner.py b/graphdatascience/model/fastpath_runner.py deleted file mode 100644 index 6455efcf2..000000000 --- a/graphdatascience/model/fastpath_runner.py +++ /dev/null @@ -1,115 +0,0 @@ -import logging -import os -import time -from typing import Any, Dict, Optional - -import requests -from pandas import Series - -from ..error.client_only_endpoint import client_only_endpoint -from ..error.illegal_attr_checker import IllegalAttrChecker -from ..error.uncallable_namespace import UncallableNamespace -from ..graph.graph_object import Graph -from ..query_runner.query_runner import QueryRunner -from ..server_version.compatible_with import compatible_with -from ..server_version.server_version import ServerVersion - -logging.basicConfig(level=logging.INFO) - - -class FastPathRunner(UncallableNamespace, IllegalAttrChecker): - def __init__( - self, - query_runner: QueryRunner, - namespace: str, - server_version: ServerVersion, - compute_cluster_ip: str, - encrypted_db_password: str, - arrow_uri: str, - ): - self._query_runner = query_runner - self._namespace = namespace - self._server_version = server_version - self._compute_cluster_web_uri = f"http://{compute_cluster_ip}:5005" - self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8815" - self._compute_cluster_mlflow_uri = f"http://{compute_cluster_ip}:8080" - self._encrypted_db_password = encrypted_db_password - self._arrow_uri = arrow_uri - - @compatible_with("stream", min_inclusive=ServerVersion(2, 5, 0)) - @client_only_endpoint("gds.fastpath") - def mutate( - self, - G: Graph, - graph_filter: Optional[Dict[str, Any]] = None, - mlflow_experiment_name: Optional[str] = None, - **algo_config: Any, - ) -> Series: - if graph_filter is None: - # Take full graph if no filter provided - node_filter = G.node_properties().to_dict() - rel_filter = G.relationship_properties().to_dict() - graph_filter = {"node_filter": node_filter, "rel_filter": rel_filter} - - graph_config = {"name": G.name()} - graph_config.update(graph_filter) - - config = { - "user_name": "DUMMY_USER", - "task": "FASTPATH", - "task_config": { - "graph_config": graph_config, - "task_config": algo_config, - "stream_node_results": True, - }, - "encrypted_db_password": self._encrypted_db_password, - "graph_arrow_uri": self._arrow_uri, - } - - if mlflow_experiment_name is not None: - config["task_config"]["mlflow"] = { - "config": {"tracking_uri": self._compute_cluster_mlflow_uri, "experiment_name": mlflow_experiment_name} - } - - job_id = self._start_job(config) - - self._wait_for_job(job_id) - - return Series({"status": "finished"}) - - # return self._stream_results(job_id) - - def _start_job(self, config: Dict[str, Any]) -> str: - res = requests.post(f"{self._compute_cluster_web_uri}/api/machine-learning/start", json=config) - res.raise_for_status() - job_id = res.json()["job_id"] - logging.info(f"Job with ID '{job_id}' started") - - return job_id - - def _wait_for_job(self, job_id: str) -> None: - while True: - time.sleep(1) - - res = requests.get(f"{self._compute_cluster_web_uri}/api/machine-learning/status/{job_id}") - - res_json = res.json() - if res_json["job_status"] == "exited": - logging.info("FastPath job completed!") - return - elif res_json["job_status"] == "failed": - error = f"FastPath job failed with errors:{os.linesep}{os.linesep.join(res_json['errors'])}" - if res.status_code == 400: - raise ValueError(error) - else: - raise RuntimeError(error) - - # def _stream_results(self, job_id: str) -> DataFrame: - # client = pa.flight.connect(self._compute_cluster_arrow_uri) - - # upload_descriptor = pa.flight.FlightDescriptor.for_path(f"{job_id}.nodes") - # flight = client.get_flight_info(upload_descriptor) - # reader = client.do_get(flight.endpoints[0].ticket) - # read_table = reader.read_all() - - # return read_table.to_pandas() diff --git a/graphdatascience/model/kge_runner.py b/graphdatascience/model/kge_runner.py index a4b526cc3..8542cb0d7 100644 --- a/graphdatascience/model/kge_runner.py +++ b/graphdatascience/model/kge_runner.py @@ -137,7 +137,7 @@ def _stream_results(self, user_name: str, model_name: str, job_id: str) -> DataF ) res.raise_for_status() - res_file_name = f'res_{job_id}.json' + res_file_name = f"res_{job_id}.json" with open(res_file_name, mode="wb+") as f: f.write(res.content) diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index 6f80f57c1..f4e1a2766 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -94,6 +94,12 @@ def __init__( if tls_root_certs: client_options["tls_root_certs"] = tls_root_certs + print("location:") + print(location) + print("client_options:") + print(client_options) + print("auth:") + print(auth) self._flight_client = flight.FlightClient(location, **client_options) def connection_info(self) -> Tuple[str, int]: diff --git a/graphdatascience/resources/field-testing/pub.pem b/graphdatascience/resources/field-testing/pub.pem index 0a3519e2b..daf0828ca 100644 --- a/graphdatascience/resources/field-testing/pub.pem +++ b/graphdatascience/resources/field-testing/pub.pem @@ -1,4 +1,3 @@ -----BEGIN RSA PUBLIC KEY----- -MEgCQQDNfbk2/PGneqZO6Vx9VbPe6ZnQJ/F5kOOW07jGDU34NFfUI06Nw0HmwT2h -c9s3nZTUUlAVi/aUCl3b4NcB8vThAgMBAAE= +WRONGKEY -----END RSA PUBLIC KEY----- diff --git a/graphdatascience/tests/integration/test_graph_construct.py b/graphdatascience/tests/integration/test_graph_construct.py index 4f2379256..97dc85c1a 100644 --- a/graphdatascience/tests/integration/test_graph_construct.py +++ b/graphdatascience/tests/integration/test_graph_construct.py @@ -558,10 +558,3 @@ def test_graph_alpha_construct_backward_compat_with_arrow(gds: GraphDataScience) with pytest.warns(DeprecationWarning): gds.alpha.graph.construct("hello", nodes, relationships) - - -@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0)) -def test_drop_list_warning_reproduction(gds: GraphDataScience) -> None: - G, _ = gds.graph.project(GRAPH_NAME, {"Node": {"properties": ["x", "y"]}}, {"REL": {"properties": "relX"}}) - res = gds.graph.list() - assert res["graphName"].tolist() == ["g"] diff --git a/graphdatascience/tests/integration/test_graph_ops.py b/graphdatascience/tests/integration/test_graph_ops.py index 3a505313a..e2b077cf7 100644 --- a/graphdatascience/tests/integration/test_graph_ops.py +++ b/graphdatascience/tests/integration/test_graph_ops.py @@ -854,7 +854,7 @@ def test_graph_relationships_stream_without_arrow(gds_without_arrow: GraphDataSc @pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0)) def test_graph_relationships_stream_with_arrow(gds: GraphDataScience) -> None: - G, _ = gds.graph.project(GRAPH_NAME, "*", ["REL_0", "REL2"]) + G, _ = gds.graph.project(GRAPH_NAME, "*", ["REL", "REL2"]) if gds.server_version() >= ServerVersion(2, 5, 0): result = gds.graph.relationships.stream(G, ["REL", "REL2"]) From d552f8ec70475a3550e0fc622a7c60ba2560c28f Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Fri, 19 Jul 2024 12:51:19 +0100 Subject: [PATCH 10/24] Remove Graph from prediction --- examples/kge-distmult-nations.py | 13 +------------ graphdatascience/model/kge_runner.py | 9 --------- 2 files changed, 1 insertion(+), 21 deletions(-) diff --git a/examples/kge-distmult-nations.py b/examples/kge-distmult-nations.py index c8b27b45e..0c66b56d3 100644 --- a/examples/kge-distmult-nations.py +++ b/examples/kge-distmult-nations.py @@ -158,15 +158,11 @@ def project_graphs(gds): gds.graph.drop("testGraph", failIfMissing=False) G_full, _ = gds.graph.project("fullGraph", ["Entity"], all_rels) - inspect_graph(G_full) G_train, _ = gds.graph.filter("trainGraph", G_full, "*", "r.split = 0.0") G_valid, _ = gds.graph.filter("validGraph", G_full, "*", "r.split = 1.0") G_test, _ = gds.graph.filter("testGraph", G_full, "*", "r.split = 2.0") - inspect_graph(G_train) - inspect_graph(G_valid) - inspect_graph(G_test) gds.graph.drop("fullGraph", failIfMissing=False) @@ -190,14 +186,9 @@ def inspect_graph(G): create_constraint(gds) put_data_in_db(gds) G_train, G_valid, G_test = project_graphs(gds) - inspect_graph(G_train) - inspect_graph(G_valid) - inspect_graph(G_test) gds.set_compute_cluster_ip("localhost") - print(gds.debug.arrow()) - model_name = "dummyModelName_" + str(time.time()) gds.kge.model.train( @@ -210,7 +201,6 @@ def inspect_graph(G): ) df = gds.kge.model.predict( - G_train, model_name=model_name, top_k=3, node_ids=[ @@ -221,7 +211,7 @@ def inspect_graph(G): rel_types=["REL_RELDIPLOMACY", "REL_RELNGO"], ) - print(df) + print(df.to_string()) # # gds.kge.model.predict_tail( # G_train, @@ -240,4 +230,3 @@ def inspect_graph(G): # ], # ) - print("Finished training") diff --git a/graphdatascience/model/kge_runner.py b/graphdatascience/model/kge_runner.py index 8542cb0d7..23a7c2de8 100644 --- a/graphdatascience/model/kge_runner.py +++ b/graphdatascience/model/kge_runner.py @@ -34,8 +34,6 @@ def __init__( self._compute_cluster_mlflow_uri = f"http://{compute_cluster_ip}:8080" self._encrypted_db_password = encrypted_db_password self._arrow_uri = arrow_uri - print("KgeRunner __dict__:") - print(self.__dict__) @property def model(self): @@ -89,14 +87,12 @@ def train( @client_only_endpoint("gds.kge.model") def predict( self, - G: Graph, model_name: str, top_k: int, node_ids: list[int], rel_types: list[str], mlflow_experiment_name: Optional[str] = None, ) -> DataFrame: - graph_config = {"name": G.name()} algo_config = { "top_k": top_k, @@ -108,7 +104,6 @@ def predict( "user_name": "DUMMY_USER", "task": "KGE_PREDICT_PYG", "task_config": { - "graph_config": graph_config, "modelname": model_name, "task_config": algo_config, }, @@ -122,8 +117,6 @@ def predict( "config": {"tracking_uri": self._compute_cluster_mlflow_uri, "experiment_name": mlflow_experiment_name} } - print("predict config") - print(config) job_id = self._start_job(config) self._wait_for_job(job_id) @@ -146,8 +139,6 @@ def _stream_results(self, user_name: str, model_name: str, job_id: str) -> DataF return df def _start_job(self, config: Dict[str, Any]) -> str: - print("_start_job") - print(config) url = f"{self._compute_cluster_web_uri}/api/machine-learning/start" print(url) res = requests.post(url, json=config) From 0d29f977a8112e1902b7780597ff464325f1b9ff Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Fri, 19 Jul 2024 15:20:19 +0100 Subject: [PATCH 11/24] Pass all parameters from docs --- examples/kge-distmult-nations.py | 5 ++- graphdatascience/model/kge_runner.py | 50 +++++++++++++++++++++++----- 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/examples/kge-distmult-nations.py b/examples/kge-distmult-nations.py index 0c66b56d3..7162ed066 100644 --- a/examples/kge-distmult-nations.py +++ b/examples/kge-distmult-nations.py @@ -163,7 +163,6 @@ def project_graphs(gds): G_valid, _ = gds.graph.filter("validGraph", G_full, "*", "r.split = 1.0") G_test, _ = gds.graph.filter("testGraph", G_full, "*", "r.split = 2.0") - gds.graph.drop("fullGraph", failIfMissing=False) return G_train, G_valid, G_test @@ -194,10 +193,11 @@ def inspect_graph(G): gds.kge.model.train( G_train, model_name=model_name, - scoring_function="DistMult", + scoring_function="distmult", num_epochs=1, embedding_dimension=10, epochs_per_checkpoint=0, + epochs_per_val=0, ) df = gds.kge.model.predict( @@ -229,4 +229,3 @@ def inspect_graph(G): # (gds.find_node_id(["Entity"], {"id": 0}), "REL_123", gds.find_node_id(["Entity"], {"id": 3})), # ], # ) - diff --git a/graphdatascience/model/kge_runner.py b/graphdatascience/model/kge_runner.py index 23a7c2de8..0a1ff60c2 100644 --- a/graphdatascience/model/kge_runner.py +++ b/graphdatascience/model/kge_runner.py @@ -45,20 +45,52 @@ def train( self, G: Graph, model_name: str, - scoring_function, - num_epochs, - embedding_dimension, - epochs_per_checkpoint, + *, + num_epochs: int, + embedding_dimension: int, + epochs_per_checkpoint: Optional[int] = None, + load_from_checkpoint: Optional[tuple[str, int]] = None, + split_ratios=None, + scoring_function: str = "transe", + p_norm: float = 1.0, + batch_size: int = 512, + test_batch_size: int = 512, + optimizer: str = "adam", + optimizer_kwargs=None, + lr_scheduler: str = "ConstantLR", + lr_scheduler_kwargs=None, + loss_function: str = "MarginRanking", + loss_function_kwargs=None, + negative_sampling_size: int = 1, + use_node_type_aware_sampler: bool = False, + k_value: int = 10, + do_validation: bool = True, + do_test: bool = True, + filtered_metrics: bool = False, + epochs_per_val: int = 50, + inner_norm: bool = True, + init_bound: Optional[float] = None, mlflow_experiment_name: Optional[str] = None, ) -> Series: - graph_config = {"name": G.name()} + if epochs_per_checkpoint is None: + epochs_per_checkpoint = max(num_epochs / 10, 1) + if loss_function_kwargs is None: + loss_function_kwargs = dict(margin=1.0, adversarial_temperature=1.0, gamma=20.0) + if lr_scheduler_kwargs is None: + lr_scheduler_kwargs = dict(factor=1, total_iters=1000) + if optimizer_kwargs is None: + optimizer_kwargs = {"lr": 0.01, "weight_decay": 0.0005} + if split_ratios is None: + split_ratios = {"TRAIN": 0.8, "TEST": 0.2} algo_config = { - "scoring_function": scoring_function, - "num_epochs": num_epochs, - "embedding_dimension": embedding_dimension, - "epochs_per_checkpoint": epochs_per_checkpoint, + key: value + for key, value in locals().items() + if (key not in ["self", "G", "mlflow_experiment_name", "model_name"]) and (value is not None) } + print(algo_config) + + graph_config = {"name": G.name()} config = { "user_name": "DUMMY_USER", From 60be966f14288d61ac74c17b77c480f398a0e31d Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Mon, 22 Jul 2024 11:30:20 +0100 Subject: [PATCH 12/24] Add docs for prediction stage Add jupyter notebook for kge nations --- examples/kge-distmult-nations.ipynb | 306 ++++++++++++++++++++++++++++ 1 file changed, 306 insertions(+) create mode 100644 examples/kge-distmult-nations.ipynb diff --git a/examples/kge-distmult-nations.ipynb b/examples/kge-distmult-nations.ipynb new file mode 100644 index 000000000..3f2b0f4a9 --- /dev/null +++ b/examples/kge-distmult-nations.ipynb @@ -0,0 +1,306 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "11d08c597a9fdbf3", + "metadata": { + "collapsed": false + }, + "source": [ + "# Knowledge Graph Embedding: DistMult embedding for Nation dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d9719b198c3fe8e", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import time\n", + "import warnings\n", + "from collections import defaultdict\n", + "from neo4j.exceptions import ClientError\n", + "from tqdm import tqdm\n", + "from graphdatascience import GraphDataScience" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d4d82474217c5ca2", + "metadata": {}, + "outputs": [], + "source": [ + "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c522b3dba2a0c1c9", + "metadata": {}, + "outputs": [], + "source": [ + "NEO4J_URI = os.environ.get(\"NEO4J_URI\", \"bolt://localhost:7687\")\n", + "NEO4J_AUTH = None\n", + "NEO4J_DB = os.environ.get(\"NEO4J_DB\", \"neo4j\")\n", + "if os.environ.get(\"NEO4J_USER\") and os.environ.get(\"NEO4J_PASSWORD\"):\n", + " NEO4J_AUTH = (\n", + " os.environ.get(\"NEO4J_USER\"),\n", + " os.environ.get(\"NEO4J_PASSWORD\"),\n", + " )\n", + "gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "532f7596", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " _ = gds.run_cypher(\"CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE\")\n", + "except ClientError:\n", + " print(\"CONSTRAINT entity_id already exists\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00757ac4", + "metadata": {}, + "outputs": [], + "source": [ + "def get_text_to_id_map(data_dir, text_to_id_filename):\n", + " with open(data_dir + \"/\" + text_to_id_filename, \"r\") as f:\n", + " data = [x.split(\"\\t\") for x in f.read().split(\"\\n\")[:-1]]\n", + " text_to_id_map = {text: int(id) for text, id in data}\n", + " return text_to_id_map" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c9a1c4d", + "metadata": {}, + "outputs": [], + "source": [ + "def read_data():\n", + " rel_types = {\n", + " \"train.txt\": \"TRAIN\",\n", + " \"valid.txt\": \"VALID\",\n", + " \"test.txt\": \"TEST\",\n", + " }\n", + " url = \"https://raw.githubusercontent.com/ZhenfengLei/KGDatasets/master/Nations\"\n", + " data_dir = \"./Nations\"\n", + "\n", + " raw_file_names = [\"train.txt\", \"valid.txt\", \"test.txt\"]\n", + " node_id_filename = \"entity2id.txt\"\n", + " rel_id_filename = \"relation2id.txt\"\n", + "\n", + " for file in raw_file_names + [node_id_filename, rel_id_filename]:\n", + " if not os.path.exists(f\"{data_dir}/{file}\"):\n", + " os.system(f\"wget {url}/{file} -P {data_dir}\")\n", + "\n", + " node_map = get_text_to_id_map(data_dir, node_id_filename)\n", + " rel_map = get_text_to_id_map(data_dir, rel_id_filename)\n", + " dataset = defaultdict(lambda: defaultdict(list))\n", + "\n", + " rel_split_id = {\"TRAIN\": 0, \"VALID\": 1, \"TEST\": 2}\n", + "\n", + " for file_name in raw_file_names:\n", + " file_name_path = data_dir + \"/\" + file_name\n", + "\n", + " with open(file_name_path, \"r\") as f:\n", + " data = [x.split(\"\\t\") for x in f.read().split(\"\\n\")[:-1]]\n", + "\n", + " for i, (src_text, rel_text, dst_text) in enumerate(data):\n", + " source = node_map[src_text]\n", + " target = node_map[dst_text]\n", + " rel_type = \"REL_\" + rel_text.upper()\n", + " rel_split = rel_types[file_name]\n", + "\n", + " dataset[rel_split][rel_type].append(\n", + " {\n", + " \"source\": source,\n", + " \"source_text\": src_text,\n", + " \"target\": target,\n", + " \"target_text\": dst_text,\n", + " \"rel_type\": rel_type,\n", + " \"rel_id\": rel_map[rel_text],\n", + " \"rel_split\": rel_split,\n", + " \"rel_split_id\": rel_split_id[rel_split],\n", + " }\n", + " )\n", + "\n", + " print(\"Number of nodes: \", len(node_map))\n", + " for rel_split in dataset:\n", + " print(\n", + " f\"Number of relationships of type {rel_split}: \",\n", + " sum([len(dataset[rel_split][rel_type]) for rel_type in dataset[rel_split]]),\n", + " )\n", + " return dataset\n", + "\n", + "\n", + "dataset = read_data()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1cb98e4", + "metadata": {}, + "outputs": [], + "source": [ + "def put_data_in_db():\n", + " res = gds.run_cypher(\"MATCH (m) RETURN count(m) as num_nodes\")\n", + " if res[\"num_nodes\"].values[0] > 0:\n", + " print(\"Data already in db, number of nodes: \", res[\"num_nodes\"].values[0])\n", + " return\n", + " dataset = read_data()\n", + " pbar = tqdm(\n", + " desc=\"Putting data in db\",\n", + " total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]),\n", + " )\n", + "\n", + " for rel_split in dataset:\n", + " for rel_type in dataset[rel_split]:\n", + " edges = dataset[rel_split][rel_type]\n", + "\n", + " gds.run_cypher(\n", + " f\"\"\"\n", + " UNWIND $ll as l\n", + " MERGE (n:Entity {{id:l.source, text:l.source_text}})\n", + " MERGE (m:Entity {{id:l.target, text:l.target_text}})\n", + " MERGE (n)-[:{rel_type} {{split: l.rel_split_id, rel_id: l.rel_id}}]->(m)\n", + " \"\"\",\n", + " params={\"ll\": edges},\n", + " )\n", + " pbar.update(len(edges))\n", + " pbar.close()\n", + "\n", + " for rel_split in dataset:\n", + " res = gds.run_cypher(\n", + " f\"\"\"\n", + " MATCH ()-[r:{rel_split}]->()\n", + " RETURN COUNT(r) AS numberOfRelationships\n", + " \"\"\"\n", + " )\n", + " print(f\"Number of relationships of type {rel_split} in db: \", res.numberOfRelationships)\n", + "\n", + "\n", + "put_data_in_db()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0fceb15b", + "metadata": {}, + "outputs": [], + "source": [ + "def project_graphs():\n", + " all_rels = gds.run_cypher(\n", + " \"\"\"\n", + " CALL db.relationshipTypes() YIELD relationshipType\n", + " \"\"\"\n", + " )\n", + " all_rels = all_rels[\"relationshipType\"].to_list()\n", + " all_rels = {rel: {\"properties\": \"split\"} for rel in all_rels if rel.startswith(\"REL_\")}\n", + " gds.graph.drop(\"fullGraph\", failIfMissing=False)\n", + " gds.graph.drop(\"trainGraph\", failIfMissing=False)\n", + " gds.graph.drop(\"validGraph\", failIfMissing=False)\n", + " gds.graph.drop(\"testGraph\", failIfMissing=False)\n", + "\n", + " G_full, _ = gds.graph.project(\"fullGraph\", [\"Entity\"], all_rels)\n", + "\n", + " G_train, _ = gds.graph.filter(\"trainGraph\", G_full, \"*\", \"r.split = 0.0\")\n", + " G_valid, _ = gds.graph.filter(\"validGraph\", G_full, \"*\", \"r.split = 1.0\")\n", + " G_test, _ = gds.graph.filter(\"testGraph\", G_full, \"*\", \"r.split = 2.0\")\n", + "\n", + " gds.graph.drop(\"fullGraph\", failIfMissing=False)\n", + "\n", + " return G_train, G_valid, G_test\n", + "\n", + "\n", + "G_train, G_valid, G_test = project_graphs()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4e2825a", + "metadata": {}, + "outputs": [], + "source": [ + "gds.set_compute_cluster_ip(\"localhost\")\n", + "\n", + "model_name = \"dummyModelName_\" + str(time.time())\n", + "\n", + "gds.kge.model.train(\n", + " G_train,\n", + " model_name=model_name,\n", + " scoring_function=\"distmult\",\n", + " num_epochs=1,\n", + " embedding_dimension=10,\n", + " epochs_per_checkpoint=0,\n", + " epochs_per_val=0,\n", + ")\n", + "\n", + "predict_result = gds.kge.model.predict(\n", + " model_name=model_name,\n", + " top_k=3,\n", + " node_ids=[\n", + " gds.find_node_id([\"Entity\"], {\"text\": \"brazil\"}),\n", + " gds.find_node_id([\"Entity\"], {\"text\": \"uk\"}),\n", + " gds.find_node_id([\"Entity\"], {\"text\": \"jordan\"}),\n", + " ],\n", + " rel_types=[\"REL_RELDIPLOMACY\", \"REL_RELNGO\"],\n", + ")\n", + "\n", + "print(predict_result.to_string())\n", + "#\n", + "# gds.kge.model.predict_tail(\n", + "# G_train,\n", + "# model_name=model_name,\n", + "# top_k=10,\n", + "# node_ids=[gds.find_node_id([\"Entity\"], {\"text\": \"/m/016wzw\"}), gds.find_node_id([\"Entity\"], {\"id\": 2})],\n", + "# rel_types=[\"REL_1\", \"REL_2\"],\n", + "# )\n", + "#\n", + "# gds.kge.model.score_triples(\n", + "# G_train,\n", + "# model_name=model_name,\n", + "# triples=[\n", + "# (gds.find_node_id([\"Entity\"], {\"text\": \"/m/016wzw\"}), \"REL_1\", gds.find_node_id([\"Entity\"], {\"id\": 2})),\n", + "# (gds.find_node_id([\"Entity\"], {\"id\": 0}), \"REL_123\", gds.find_node_id([\"Entity\"], {\"id\": 3})),\n", + "# ],\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "786eda29280ed31f", + "metadata": {}, + "outputs": [], + "source": [ + "# Create the dictionary" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74c501f8fcb411eb", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 5 +} From 4151dc5d25eec247489338f71d6947662b49d28e Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Mon, 22 Jul 2024 16:43:55 +0100 Subject: [PATCH 13/24] Fix log wording --- graphdatascience/model/kge_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphdatascience/model/kge_runner.py b/graphdatascience/model/kge_runner.py index 0a1ff60c2..b001e8d05 100644 --- a/graphdatascience/model/kge_runner.py +++ b/graphdatascience/model/kge_runner.py @@ -188,7 +188,7 @@ def _wait_for_job(self, job_id: str) -> None: res_json = res.json() if res_json["job_status"] == "exited": - logging.info("KGE job completed!") + logging.info(f"Job with ID '{job_id}' completed") return elif res_json["job_status"] == "failed": error = f"KGE job failed with errors:{os.linesep}{os.linesep.join(res_json['errors'])}" From bb1cd0a1a816009b2ec8e75213de541549947c03 Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Tue, 23 Jul 2024 15:49:47 +0100 Subject: [PATCH 14/24] Added score_triplets function --- examples/kge-distmult-nations.ipynb | 77 ++++++++++++++------------ examples/kge-distmult-nations.py | 39 ++++++++++++- graphdatascience/graph_data_science.py | 1 - graphdatascience/model/kge_runner.py | 38 ++++++++++++- 4 files changed, 115 insertions(+), 40 deletions(-) diff --git a/examples/kge-distmult-nations.ipynb b/examples/kge-distmult-nations.ipynb index 3f2b0f4a9..859f9e8ed 100644 --- a/examples/kge-distmult-nations.ipynb +++ b/examples/kge-distmult-nations.ipynb @@ -13,7 +13,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8d9719b198c3fe8e", + "id": "9135277efcde2800", "metadata": {}, "outputs": [], "source": [ @@ -29,7 +29,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d4d82474217c5ca2", + "id": "1551fddc3a67fa5b", "metadata": {}, "outputs": [], "source": [ @@ -39,7 +39,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c522b3dba2a0c1c9", + "id": "2f05ee7fdb496f84", "metadata": {}, "outputs": [], "source": [ @@ -57,7 +57,7 @@ { "cell_type": "code", "execution_count": null, - "id": "532f7596", + "id": "658c9f8369fff77e", "metadata": {}, "outputs": [], "source": [ @@ -70,7 +70,7 @@ { "cell_type": "code", "execution_count": null, - "id": "00757ac4", + "id": "bdbf4f91da4b9934", "metadata": {}, "outputs": [], "source": [ @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6c9a1c4d", + "id": "485869468ad5ad2e", "metadata": {}, "outputs": [], "source": [ @@ -142,16 +142,16 @@ " f\"Number of relationships of type {rel_split}: \",\n", " sum([len(dataset[rel_split][rel_type]) for rel_type in dataset[rel_split]]),\n", " )\n", - " return dataset\n", + " return dataset, node_map\n", "\n", "\n", - "dataset = read_data()" + "dataset, node_map = read_data()" ] }, { "cell_type": "code", "execution_count": null, - "id": "e1cb98e4", + "id": "2032a4e1aed1bd5", "metadata": {}, "outputs": [], "source": [ @@ -160,7 +160,6 @@ " if res[\"num_nodes\"].values[0] > 0:\n", " print(\"Data already in db, number of nodes: \", res[\"num_nodes\"].values[0])\n", " return\n", - " dataset = read_data()\n", " pbar = tqdm(\n", " desc=\"Putting data in db\",\n", " total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]),\n", @@ -198,7 +197,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0fceb15b", + "id": "5c4f1523a225fa3c", "metadata": {}, "outputs": [], "source": [ @@ -232,7 +231,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b4e2825a", + "id": "5d518e67375f6ab3", "metadata": {}, "outputs": [], "source": [ @@ -261,43 +260,53 @@ " rel_types=[\"REL_RELDIPLOMACY\", \"REL_RELNGO\"],\n", ")\n", "\n", - "print(predict_result.to_string())\n", - "#\n", - "# gds.kge.model.predict_tail(\n", - "# G_train,\n", - "# model_name=model_name,\n", - "# top_k=10,\n", - "# node_ids=[gds.find_node_id([\"Entity\"], {\"text\": \"/m/016wzw\"}), gds.find_node_id([\"Entity\"], {\"id\": 2})],\n", - "# rel_types=[\"REL_1\", \"REL_2\"],\n", - "# )\n", - "#\n", - "# gds.kge.model.score_triples(\n", - "# G_train,\n", - "# model_name=model_name,\n", - "# triples=[\n", - "# (gds.find_node_id([\"Entity\"], {\"text\": \"/m/016wzw\"}), \"REL_1\", gds.find_node_id([\"Entity\"], {\"id\": 2})),\n", - "# (gds.find_node_id([\"Entity\"], {\"id\": 0}), \"REL_123\", gds.find_node_id([\"Entity\"], {\"id\": 3})),\n", - "# ],\n", - "# )" + "print(predict_result.to_string())" ] }, { "cell_type": "code", "execution_count": null, - "id": "786eda29280ed31f", + "id": "83b75194c69259a2", "metadata": {}, "outputs": [], "source": [ - "# Create the dictionary" + "for index, row in predict_result.iterrows():\n", + " h = row[\"head\"]\n", + " r = row[\"rel\"]\n", + " gds.run_cypher(\n", + " f\"\"\"\n", + " UNWIND $tt as t\n", + " MATCH (a:Entity WHERE id(a) = {h})\n", + " MATCH (b:Entity WHERE id(b) = t)\n", + " MERGE (a)-[:NEW_REL_{r}]->(b)\n", + " \"\"\",\n", + " params={\"tt\": row[\"tail\"]},\n", + " )" ] }, { "cell_type": "code", "execution_count": null, - "id": "74c501f8fcb411eb", + "id": "b4e2825a", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "brazil_node = gds.find_node_id([\"Entity\"], {\"text\": \"brazil\"})\n", + "uk_node = gds.find_node_id([\"Entity\"], {\"text\": \"uk\"})\n", + "jordan_node = gds.find_node_id([\"Entity\"], {\"text\": \"jordan\"})\n", + "\n", + "triplets = [\n", + " (brazil_node, \"REL_RELNGO\", uk_node),\n", + " (brazil_node, \"REL_RELDIPLOMACY\", jordan_node),\n", + "]\n", + "\n", + "scores = gds.kge.model.score_triplets(\n", + " model_name=model_name,\n", + " triplets=triplets,\n", + ")\n", + "\n", + "print(scores)" + ] } ], "metadata": {}, diff --git a/examples/kge-distmult-nations.py b/examples/kge-distmult-nations.py index 7162ed066..1ee5c256b 100644 --- a/examples/kge-distmult-nations.py +++ b/examples/kge-distmult-nations.py @@ -186,6 +186,8 @@ def inspect_graph(G): put_data_in_db(gds) G_train, G_valid, G_test = project_graphs(gds) + inspect_graph(G_train) + gds.set_compute_cluster_ip("localhost") model_name = "dummyModelName_" + str(time.time()) @@ -197,10 +199,11 @@ def inspect_graph(G): num_epochs=1, embedding_dimension=10, epochs_per_checkpoint=0, - epochs_per_val=0, + epochs_per_val=5, + split_ratios={"TRAIN": 0.8, "VALID": 0.1, "TEST": 0.1}, ) - df = gds.kge.model.predict( + predict_result = gds.kge.model.predict( model_name=model_name, top_k=3, node_ids=[ @@ -211,7 +214,37 @@ def inspect_graph(G): rel_types=["REL_RELDIPLOMACY", "REL_RELNGO"], ) - print(df.to_string()) + print(predict_result.to_string()) + + print(predict_result.to_string()) + for index, row in predict_result.iterrows(): + h = row["head"] + r = row["rel"] + gds.run_cypher( + f""" + UNWIND $tt as t + MATCH (a:Entity WHERE id(a) = {h}) + MATCH (b:Entity WHERE id(b) = t) + MERGE (a)-[:NEW_REL_{r}]->(b) + """, + params={"tt": row["tail"]}, + ) + + brazil_node = gds.find_node_id(["Entity"], {"text": "brazil"}) + uk_node = gds.find_node_id(["Entity"], {"text": "uk"}) + jordan_node = gds.find_node_id(["Entity"], {"text": "jordan"}) + + triplets = [ + (brazil_node, "REL_RELNGO", uk_node), + (brazil_node, "REL_RELDIPLOMACY", jordan_node), + ] + + scores = gds.kge.model.score_triplets( + model_name=model_name, + triplets=triplets, + ) + + print(scores) # # gds.kge.model.predict_tail( # G_train, diff --git a/graphdatascience/graph_data_science.py b/graphdatascience/graph_data_science.py index 7693dc29f..266babc72 100644 --- a/graphdatascience/graph_data_science.py +++ b/graphdatascience/graph_data_science.py @@ -17,7 +17,6 @@ from .query_runner.query_runner import QueryRunner from .server_version.server_version import ServerVersion from graphdatascience.graph.graph_proc_runner import GraphProcRunner -from graphdatascience.utils.util_proc_runner import UtilProcRunner class GraphDataScience(DirectEndpoints, UncallableNamespace): diff --git a/graphdatascience/model/kge_runner.py b/graphdatascience/model/kge_runner.py index b001e8d05..03f8f49fb 100644 --- a/graphdatascience/model/kge_runner.py +++ b/graphdatascience/model/kge_runner.py @@ -155,6 +155,41 @@ def predict( return self._stream_results(config["user_name"], config["task_config"]["modelname"], job_id) + @client_only_endpoint("gds.kge.model") + def score_triplets( + self, + model_name: str, + triplets: list[tuple[int, str, int]], + mlflow_experiment_name: Optional[str] = None, + ) -> DataFrame: + + algo_config = { + "triplets": triplets, + } + + config = { + "user_name": "DUMMY_USER", + "task": "KGE_SCORE_TRIPLETS_PYG", + "task_config": { + "modelname": model_name, + "task_config": algo_config, + }, + "graph_arrow_uri": self._arrow_uri, + } + if self._encrypted_db_password is not None: + config["encrypted_db_password"] = self._encrypted_db_password + + if mlflow_experiment_name is not None: + config["task_config"]["mlflow"] = { + "config": {"tracking_uri": self._compute_cluster_mlflow_uri, "experiment_name": mlflow_experiment_name} + } + + job_id = self._start_job(config) + + self._wait_for_job(job_id) + + return self._stream_results(config["user_name"], config["task_config"]["modelname"], job_id) + def _stream_results(self, user_name: str, model_name: str, job_id: str) -> DataFrame: res = requests.get( f"{self._compute_cluster_web_uri}/internal/fetch-result", @@ -172,11 +207,10 @@ def _stream_results(self, user_name: str, model_name: str, job_id: str) -> DataF def _start_job(self, config: Dict[str, Any]) -> str: url = f"{self._compute_cluster_web_uri}/api/machine-learning/start" - print(url) res = requests.post(url, json=config) res.raise_for_status() job_id = res.json()["job_id"] - logging.info(f"Job with ID '{job_id}' started") + logging.info(f"Job '{config['task']}' with ID '{job_id}' started") return job_id From 39f0aa0d4156a629d258f71c2a098a21c43ef9ce Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Wed, 24 Jul 2024 21:15:39 +0100 Subject: [PATCH 15/24] Added doc about triplet scoring --- examples/kge-distmult-nations.ipynb | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/examples/kge-distmult-nations.ipynb b/examples/kge-distmult-nations.ipynb index 859f9e8ed..dfccad8e9 100644 --- a/examples/kge-distmult-nations.ipynb +++ b/examples/kge-distmult-nations.ipynb @@ -228,6 +228,16 @@ "G_train, G_valid, G_test = project_graphs()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "21da1ea76d247803", + "metadata": {}, + "outputs": [], + "source": [ + "G_train.relationship_types()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -242,11 +252,10 @@ "gds.kge.model.train(\n", " G_train,\n", " model_name=model_name,\n", - " scoring_function=\"distmult\",\n", - " num_epochs=1,\n", - " embedding_dimension=10,\n", - " epochs_per_checkpoint=0,\n", - " epochs_per_val=0,\n", + " scoring_function=\"transe\",\n", + " num_epochs=30,\n", + " embedding_dimension=64,\n", + " split_ratios={\"TRAIN\": 0.8, \"VALID\": 0.1, \"TEST\": 0.1},\n", ")\n", "\n", "predict_result = gds.kge.model.predict(\n", From 141cfdb93e6a77e7290fed89019a22fbfcbeb1ad Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Wed, 24 Jul 2024 21:39:18 +0100 Subject: [PATCH 16/24] Report metrics from training stage --- examples/kge-distmult-nations.py | 4 ++-- graphdatascience/model/kge_runner.py | 27 ++++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/examples/kge-distmult-nations.py b/examples/kge-distmult-nations.py index 1ee5c256b..a3011ac25 100644 --- a/examples/kge-distmult-nations.py +++ b/examples/kge-distmult-nations.py @@ -192,7 +192,7 @@ def inspect_graph(G): model_name = "dummyModelName_" + str(time.time()) - gds.kge.model.train( + res = gds.kge.model.train( G_train, model_name=model_name, scoring_function="distmult", @@ -202,6 +202,7 @@ def inspect_graph(G): epochs_per_val=5, split_ratios={"TRAIN": 0.8, "VALID": 0.1, "TEST": 0.1}, ) + print(res["metrics"]) predict_result = gds.kge.model.predict( model_name=model_name, @@ -216,7 +217,6 @@ def inspect_graph(G): print(predict_result.to_string()) - print(predict_result.to_string()) for index, row in predict_result.iterrows(): h = row["head"] r = row["rel"] diff --git a/graphdatascience/model/kge_runner.py b/graphdatascience/model/kge_runner.py index 03f8f49fb..8ba46ae95 100644 --- a/graphdatascience/model/kge_runner.py +++ b/graphdatascience/model/kge_runner.py @@ -1,3 +1,4 @@ +import json import logging import os import time @@ -114,7 +115,12 @@ def train( self._wait_for_job(job_id) - return Series({"status": "finished"}) + return Series( + { + "status": "finished", + "metrics": self._get_metrics(config["user_name"], config["task_config"]["modelname"], job_id), + } + ) @client_only_endpoint("gds.kge.model") def predict( @@ -205,6 +211,25 @@ def _stream_results(self, user_name: str, model_name: str, job_id: str) -> DataF os.remove(res_file_name) return df + def _get_metrics(self, user_name: str, model_name: str, job_id: str) -> DataFrame: + res = requests.get( + f"{self._compute_cluster_web_uri}/internal/fetch-model-metadata", + params={"user_name": user_name, "modelname": model_name}, + ) + res.raise_for_status() + + res_file_name = f"metadata_{job_id}.json" + + with open(res_file_name, mode="wb+") as f: + f.write(res.content) + + with open(res_file_name, mode="r") as f: + metadata = json.load(f) + + os.remove(res_file_name) + + return metadata["metrics"] + def _start_job(self, config: Dict[str, Any]) -> str: url = f"{self._compute_cluster_web_uri}/api/machine-learning/start" res = requests.post(url, json=config) From 18f5a61b817e5451067a7ab71856f7025ce1e15f Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Wed, 24 Jul 2024 21:41:07 +0100 Subject: [PATCH 17/24] Remove files --- graphdatascience/resources/field-testing/__init__.py | 0 graphdatascience/resources/field-testing/pub.pem | 3 --- 2 files changed, 3 deletions(-) delete mode 100644 graphdatascience/resources/field-testing/__init__.py delete mode 100644 graphdatascience/resources/field-testing/pub.pem diff --git a/graphdatascience/resources/field-testing/__init__.py b/graphdatascience/resources/field-testing/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/graphdatascience/resources/field-testing/pub.pem b/graphdatascience/resources/field-testing/pub.pem deleted file mode 100644 index daf0828ca..000000000 --- a/graphdatascience/resources/field-testing/pub.pem +++ /dev/null @@ -1,3 +0,0 @@ ------BEGIN RSA PUBLIC KEY----- -WRONGKEY ------END RSA PUBLIC KEY----- From b2f7f69b9c6fdd8108d7db8eb59999e74bbfc741 Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Wed, 24 Jul 2024 21:48:38 +0100 Subject: [PATCH 18/24] Move back util data runner --- graphdatascience/graph_data_science.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/graphdatascience/graph_data_science.py b/graphdatascience/graph_data_science.py index 266babc72..a696e59b5 100644 --- a/graphdatascience/graph_data_science.py +++ b/graphdatascience/graph_data_science.py @@ -17,6 +17,7 @@ from .query_runner.query_runner import QueryRunner from .server_version.server_version import ServerVersion from graphdatascience.graph.graph_proc_runner import GraphProcRunner +from graphdatascience.utils.util_proc_runner import UtilProcRunner class GraphDataScience(DirectEndpoints, UncallableNamespace): @@ -114,6 +115,10 @@ def _path(package: str, resource: str) -> pathlib.Path: def graph(self) -> GraphProcRunner: return GraphProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version) + @property + def util(self) -> UtilProcRunner: + return UtilProcRunner(self._query_runner, f"{self._namespace}.util", self._server_version) + @property def alpha(self) -> AlphaEndpoints: return AlphaEndpoints(self._query_runner, "gds.alpha", self._server_version) From 9aa6ed0c8a41d4258451ee5ae6e7e697d61a1307 Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Wed, 24 Jul 2024 22:54:34 +0100 Subject: [PATCH 19/24] Update reqs and printings --- examples/kge-distmult.py | 65 ++++++++++---------------------------- requirements/base/base.txt | 1 + 2 files changed, 17 insertions(+), 49 deletions(-) diff --git a/examples/kge-distmult.py b/examples/kge-distmult.py index 20b7f15e4..121c4d8b0 100644 --- a/examples/kge-distmult.py +++ b/examples/kge-distmult.py @@ -117,8 +117,6 @@ def put_data_in_db(gds): for rel_type in dataset[rel_split]: edges = dataset[rel_split][rel_type] - # MERGE (n)-[:{rel_type} {{text:l.rel_text}}]->(m) - # MERGE (n)-[:{rel_split}]->(m) gds.run_cypher( f""" UNWIND $ll as l @@ -156,33 +154,6 @@ def project_train_graph(gds): return G_train -def project_predict_graph(gds): - all_rels = gds.run_cypher( - """ - CALL db.relationshipTypes() YIELD relationshipType - """ - ) - all_rels = all_rels["relationshipType"].to_list() - rel_spec = {} - for rel in all_rels: - if rel.startswith("REL_"): - rel_spec[rel] = {"properties": ["split"]} - - gds.graph.drop("fullGraph", failIfMissing=False) - gds.graph.drop("predictGraph", failIfMissing=False) - - # {"REL": {"properties": ["relY"]}, "RELR": {"properties": ["relY"]}} - # print(rel_spec) - - G_full, result = gds.graph.project("fullGraph", ["Entity"], all_rels) - - G_full, result = gds.graph.project("fullGraph", ["Entity"], rel_spec) - # G_predict = gds.graph.filter('predictGraph', 'fullGraph', '*', 'r.split == 2') - - inspect_graph(G_full) - return G_full - - def inspect_graph(G): func_names = [ "name", @@ -200,8 +171,6 @@ def inspect_graph(G): create_constraint(gds) put_data_in_db(gds) G_train = project_train_graph(gds) - # G_predict = project_predict_graph(gds) - # inspect_graph(G_train) gds.set_compute_cluster_ip("localhost") @@ -209,38 +178,36 @@ def inspect_graph(G): model_name = "dummyModelName_" + str(time.time()) - gds.kge.model.train( + node_id_text = gds.find_node_id(["Entity"], {"text": "/m/016wzw"}) + node_id_2 = gds.find_node_id(["Entity"], {"id": 2}) + node_id_3 = gds.find_node_id(["Entity"], {"id": 3}) + node_id_0 = gds.find_node_id(["Entity"], {"id": 0}) + + res = gds.kge.model.train( G_train, model_name=model_name, - scoring_function="DistMult", + scoring_function="distmult", num_epochs=1, embedding_dimension=10, epochs_per_checkpoint=0, ) + print(res['metrics']) - gds.kge.model.predict( - G_train, + res = gds.kge.model.predict( model_name=model_name, top_k=10, - node_ids=[1, 2, 3], + node_ids=[node_id_3, node_id_2, node_id_text], rel_types=["REL_1", "REL_2"], ) + print(res.to_string()) - gds.kge.model.predict_tail( - G_train, - model_name=model_name, - top_k=10, - node_ids=[gds.find_node_id(["Entity"], {"text": "/m/016wzw"}), gds.find_node_id(["Entity"], {"id": 2})], - rel_types=["REL_1", "REL_2"], - ) - - gds.kge.model.score_triples( - G_train, + scores = gds.kge.model.score_triplets( model_name=model_name, - triples=[ - (gds.find_node_id(["Entity"], {"text": "/m/016wzw"}), "REL_1", gds.find_node_id(["Entity"], {"id": 2})), - (gds.find_node_id(["Entity"], {"id": 0}), "REL_123", gds.find_node_id(["Entity"], {"id": 3})), + triplets=[ + (node_id_2, "REL_1", node_id_text), + (node_id_0, "REL_123", node_id_3), ], ) + print(scores) print("Finished training") diff --git a/requirements/base/base.txt b/requirements/base/base.txt index 3ca82b153..a3655cf56 100644 --- a/requirements/base/base.txt +++ b/requirements/base/base.txt @@ -7,3 +7,4 @@ textdistance >= 4.0, < 5.0 tqdm >= 4.0, < 5.0 typing-extensions >= 4.0, < 5.0 requests +rsa From 8ff016a6dc1384be1c2efce5b0ea356801ec9c8c Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Thu, 25 Jul 2024 12:45:30 +0100 Subject: [PATCH 20/24] Add notebook for constructed graph --- examples/kge-distmult.py | 2 +- examples/kge-transe-construct.ipynb | 160 ++++++++++++++++++++++++++++ 2 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 examples/kge-transe-construct.ipynb diff --git a/examples/kge-distmult.py b/examples/kge-distmult.py index 121c4d8b0..db408ad5c 100644 --- a/examples/kge-distmult.py +++ b/examples/kge-distmult.py @@ -191,7 +191,7 @@ def inspect_graph(G): embedding_dimension=10, epochs_per_checkpoint=0, ) - print(res['metrics']) + print(res["metrics"]) res = gds.kge.model.predict( model_name=model_name, diff --git a/examples/kge-transe-construct.ipynb b/examples/kge-transe-construct.ipynb new file mode 100644 index 000000000..1fb367008 --- /dev/null +++ b/examples/kge-transe-construct.ipynb @@ -0,0 +1,160 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "11d08c597a9fdbf3", + "metadata": { + "collapsed": false + }, + "source": [ + "# Knowledge Graph Embedding: Transe embedding for constructed dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1652ca866f022d69", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import time\n", + "import warnings\n", + "from neo4j.exceptions import ClientError\n", + "from graphdatascience import GraphDataScience" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e46dc2dd1419e518", + "metadata": {}, + "outputs": [], + "source": [ + "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "691981a7e5372ad2", + "metadata": {}, + "outputs": [], + "source": [ + "NEO4J_URI = os.environ.get(\"NEO4J_URI\", \"bolt://localhost:7687\")\n", + "NEO4J_AUTH = None\n", + "NEO4J_DB = os.environ.get(\"NEO4J_DB\", \"neo4j\")\n", + "if os.environ.get(\"NEO4J_USER\") and os.environ.get(\"NEO4J_PASSWORD\"):\n", + " NEO4J_AUTH = (\n", + " os.environ.get(\"NEO4J_USER\"),\n", + " os.environ.get(\"NEO4J_PASSWORD\"),\n", + " )\n", + "gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43cbb7c743877929", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " _ = gds.run_cypher(\"CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE\")\n", + "except ClientError:\n", + " print(\"CONSTRAINT entity_id already exists\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be889ec11e9b5759", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas\n", + "\n", + "nodes = pandas.DataFrame(\n", + " {\n", + " \"nodeId\": [0, 1, 2, 3, 7, 10],\n", + " \"labels\": [\"A\", \"B\", \"C\", \"A\", \"B\", \"C\"],\n", + " \"prop1\": [42, 1337, 8, 0, 1, 2],\n", + " \"otherProperty\": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],\n", + " }\n", + ")\n", + "\n", + "relationships = pandas.DataFrame(\n", + " {\n", + " \"sourceNodeId\": [0, 1, 2, 7],\n", + " \"targetNodeId\": [1, 2, 3, 10],\n", + " \"relationshipType\": [\"REL1\", \"REL1\", \"REL2\", \"REL2\"],\n", + " \"weight\": [0.0, 0.0, 0.1, 42.0],\n", + " }\n", + ")\n", + "\n", + "gds.graph.drop(\"my-graph\", failIfMissing=False)\n", + "G_train = gds.graph.construct(\n", + " \"my-graph\", # Graph name\n", + " nodes, # One or more dataframes containing node data\n", + " relationships, # One or more dataframes containing relationship data\n", + ")\n", + "\n", + "assert \"REL1\" in G_train.relationship_types()\n", + "assert \"REL2\" in G_train.relationship_types()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1638be48a275563f", + "metadata": {}, + "outputs": [], + "source": [ + "G_train.relationship_types()\n", + "G_train.node_labels()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "faae17b0f4551e7", + "metadata": {}, + "outputs": [], + "source": [ + "gds.set_compute_cluster_ip(\"localhost\")\n", + "\n", + "model_name = \"dummyModelName_\" + str(time.time())\n", + "\n", + "gds.kge.model.train(\n", + " G_train,\n", + " model_name=model_name,\n", + " scoring_function=\"transe\",\n", + " num_epochs=1,\n", + " embedding_dimension=16,\n", + " epochs_per_checkpoint=0,\n", + " split_ratios={\"TRAIN\": 0.75, \"TEST\": 0.25},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d88ba3d372525a6", + "metadata": {}, + "outputs": [], + "source": [ + "predict_result = gds.kge.model.predict(\n", + " model_name=model_name,\n", + " top_k=3,\n", + " node_ids=[1, 2, 0, 10, 7],\n", + " rel_types=[\"REL1\", \"REL2\"],\n", + ")\n", + "\n", + "print(predict_result.to_string())" + ] + } + ], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 5 +} From d614307aa34f3234be38d2ae38ad6c2fd30e4406 Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Mon, 29 Jul 2024 16:27:09 +0100 Subject: [PATCH 21/24] Notebook for DistMult --- examples/kge-distmult-nations.ipynb | 129 ++++++++++++++++++++++------ 1 file changed, 103 insertions(+), 26 deletions(-) diff --git a/examples/kge-distmult-nations.ipynb b/examples/kge-distmult-nations.ipynb index dfccad8e9..095b55bfd 100644 --- a/examples/kge-distmult-nations.ipynb +++ b/examples/kge-distmult-nations.ipynb @@ -7,7 +7,24 @@ "collapsed": false }, "source": [ - "# Knowledge Graph Embedding: DistMult embedding for Nation dataset" + "# Knowledge Graph Embedding: DistMult embedding for Nation dataset\n", + "\n", + "In this notebook, we will use the DistMult embedding model to make predictions on the Nations dataset.\n", + "The Nations dataset is a simple dataset that contains relationships between countries.\n", + "\n", + "The dataset contains three files: `train.txt`, `valid.txt`, and `test.txt`.\n", + "Each file contains triplets of the form `source_country relation target_country`.\n", + "The `entity2id.txt` file contains the mapping of country names to ids, and the `relation2id.txt` file contains the mapping of relation names to ids." + ] + }, + { + "cell_type": "markdown", + "id": "f9529174", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "We start by installing and importing our dependencies, and setting up our GDS client connection to the database." ] }, { @@ -54,6 +71,14 @@ "gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True)" ] }, + { + "cell_type": "markdown", + "id": "98a7f9b7", + "metadata": {}, + "source": [ + "Create constraints to ensure that the `Entity` nodes have unique `text` properties." + ] + }, { "cell_type": "code", "execution_count": null, @@ -62,23 +87,19 @@ "outputs": [], "source": [ "try:\n", - " _ = gds.run_cypher(\"CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE\")\n", + " _ = gds.run_cypher(\"CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.text IS UNIQUE\")\n", "except ClientError:\n", " print(\"CONSTRAINT entity_id already exists\")" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "bdbf4f91da4b9934", + "cell_type": "markdown", + "id": "320f3ded", "metadata": {}, - "outputs": [], "source": [ - "def get_text_to_id_map(data_dir, text_to_id_filename):\n", - " with open(data_dir + \"/\" + text_to_id_filename, \"r\") as f:\n", - " data = [x.split(\"\\t\") for x in f.read().split(\"\\n\")[:-1]]\n", - " text_to_id_map = {text: int(id) for text, id in data}\n", - " return text_to_id_map" + "## Download and read the data\n", + "\n", + "Let's download the Nations dataset and read the data." ] }, { @@ -88,6 +109,13 @@ "metadata": {}, "outputs": [], "source": [ + "def get_text_to_id_map(data_dir, text_to_id_filename):\n", + " with open(data_dir + \"/\" + text_to_id_filename, \"r\") as f:\n", + " data = [x.split(\"\\t\") for x in f.read().split(\"\\n\")[:-1]]\n", + " text_to_id_map = {text: int(id) for text, id in data}\n", + " return text_to_id_map\n", + "\n", + "\n", "def read_data():\n", " rel_types = {\n", " \"train.txt\": \"TRAIN\",\n", @@ -148,6 +176,20 @@ "dataset, node_map = read_data()" ] }, + { + "cell_type": "markdown", + "id": "5c97b4df", + "metadata": {}, + "source": [ + "## Put data in the database\n", + "\n", + "We will put the data in the database, creating `Entity` nodes and relationships between them.\n", + "\n", + "Each node will have a `text` property. We will use `text` to identify the node later.\n", + "\n", + "Each relationship will have a `split` property to indicate whether it is part of the training, validation, or test set." + ] + }, { "cell_type": "code", "execution_count": null, @@ -194,6 +236,16 @@ "put_data_in_db()" ] }, + { + "cell_type": "markdown", + "id": "9f270636", + "metadata": {}, + "source": [ + "## Project graphs\n", + "\n", + "First, we will project the full graph, then we will filter the graph to create the training graph based on the `split` property." + ] + }, { "cell_type": "code", "execution_count": null, @@ -210,22 +262,13 @@ " all_rels = all_rels[\"relationshipType\"].to_list()\n", " all_rels = {rel: {\"properties\": \"split\"} for rel in all_rels if rel.startswith(\"REL_\")}\n", " gds.graph.drop(\"fullGraph\", failIfMissing=False)\n", - " gds.graph.drop(\"trainGraph\", failIfMissing=False)\n", - " gds.graph.drop(\"validGraph\", failIfMissing=False)\n", - " gds.graph.drop(\"testGraph\", failIfMissing=False)\n", "\n", " G_full, _ = gds.graph.project(\"fullGraph\", [\"Entity\"], all_rels)\n", "\n", - " G_train, _ = gds.graph.filter(\"trainGraph\", G_full, \"*\", \"r.split = 0.0\")\n", - " G_valid, _ = gds.graph.filter(\"validGraph\", G_full, \"*\", \"r.split = 1.0\")\n", - " G_test, _ = gds.graph.filter(\"testGraph\", G_full, \"*\", \"r.split = 2.0\")\n", - "\n", - " gds.graph.drop(\"fullGraph\", failIfMissing=False)\n", + " return G_full\n", "\n", - " return G_train, G_valid, G_test\n", "\n", - "\n", - "G_train, G_valid, G_test = project_graphs()" + "G = project_graphs()" ] }, { @@ -235,7 +278,21 @@ "metadata": {}, "outputs": [], "source": [ - "G_train.relationship_types()" + "G.relationship_types()" + ] + }, + { + "cell_type": "markdown", + "id": "88b243ea", + "metadata": {}, + "source": [ + "We will train a knowledge graph embedding model using the Graph Data Science library. The model will be trained on the `G` graph.\n", + "\n", + "We will use the DistMult scoring function and set the embedding dimension to 64. The model will be trained for 30 epochs with a split ratio of 80% for training, 10% for validation, and 10% for testing.\n", + "\n", + "After training the model, we will use it to make predictions on three specific nodes: \"brazil\", \"uk\", and \"jordan\". We will predict the top 3 relationships for each node and print the results.\n", + "\n", + "Finally, we will create new relationships in the graph based on the predicted relationships. For each predicted relationship, we will create a new relationship between the corresponding nodes." ] }, { @@ -250,9 +307,9 @@ "model_name = \"dummyModelName_\" + str(time.time())\n", "\n", "gds.kge.model.train(\n", - " G_train,\n", + " G,\n", " model_name=model_name,\n", - " scoring_function=\"transe\",\n", + " scoring_function=\"DistMult\",\n", " num_epochs=30,\n", " embedding_dimension=64,\n", " split_ratios={\"TRAIN\": 0.8, \"VALID\": 0.1, \"TEST\": 0.1},\n", @@ -272,6 +329,14 @@ "print(predict_result.to_string())" ] }, + { + "cell_type": "markdown", + "id": "aa583359", + "metadata": {}, + "source": [ + "In the next cell we will add this top scored relationships to te database." + ] + }, { "cell_type": "code", "execution_count": null, @@ -293,6 +358,14 @@ " )" ] }, + { + "cell_type": "markdown", + "id": "c00579a8", + "metadata": {}, + "source": [ + "There is also a API that can be used to score a list of triplets. In the next cell we will use a call to score the triplets `(brazil, REL_RELNGO, uk)` and `(brazil, REL_RELDIPLOMACY, jordan)`." + ] + }, { "cell_type": "code", "execution_count": null, @@ -318,7 +391,11 @@ ] } ], - "metadata": {}, + "metadata": { + "language_info": { + "name": "python" + } + }, "nbformat": 4, "nbformat_minor": 5 } From efd51a347c8e6886a4c7c34b7e7983bc58ead171 Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Tue, 30 Jul 2024 11:44:35 +0100 Subject: [PATCH 22/24] Enable mlflow --- examples/kge-distmult-nations.ipynb | 1 + examples/kge-distmult.ipynb | 316 +++++---------------------- graphdatascience/model/kge_runner.py | 13 +- 3 files changed, 63 insertions(+), 267 deletions(-) diff --git a/examples/kge-distmult-nations.ipynb b/examples/kge-distmult-nations.ipynb index 095b55bfd..1e570da8b 100644 --- a/examples/kge-distmult-nations.ipynb +++ b/examples/kge-distmult-nations.ipynb @@ -313,6 +313,7 @@ " num_epochs=30,\n", " embedding_dimension=64,\n", " split_ratios={\"TRAIN\": 0.8, \"VALID\": 0.1, \"TEST\": 0.1},\n", + " mlflow_experiment_name=\"Nations-train\",\n", ")\n", "\n", "predict_result = gds.kge.model.predict(\n", diff --git a/examples/kge-distmult.ipynb b/examples/kge-distmult.ipynb index 6686b85b5..05456b9f0 100644 --- a/examples/kge-distmult.ipynb +++ b/examples/kge-distmult.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Knowledge graph embeddings: DistMult" + "# Knowledge graph embeddings: TransE" ] }, { @@ -25,11 +25,10 @@ "source": [ "import os\n", "from graphdatascience import GraphDataScience\n", - "import torch\n", - "import torch.optim as optim\n", "import collections\n", "from tqdm import tqdm\n", - "import pandas as pd" + "import pandas as pd\n", + "from neo4j.exceptions import ClientError" ] }, { @@ -75,7 +74,10 @@ "metadata": {}, "outputs": [], "source": [ - "_ = gds.run_cypher(\"CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE\")" + "try:\n", + " _ = gds.run_cypher(\"CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.text IS UNIQUE\")\n", + "except ClientError:\n", + " print(\"CONSTRAINT entity_id already exists\")" ] }, { @@ -179,32 +181,22 @@ "metadata": {}, "outputs": [], "source": [ - "def put_data_in_db(dataset):\n", - " for rel_split in tqdm(dataset, desc=\"Relationship\"):\n", - " for rel_type in tqdm(dataset[rel_split], mininterval=1, leave=False):\n", - " edges = dataset[rel_split][rel_type]\n", + "def put_data_in_db(data):\n", + " for rel_split in tqdm(data, desc=\"Relationship\"):\n", + " for rel_type in tqdm(data[rel_split], mininterval=1, leave=False):\n", + " edges = data[rel_split][rel_type]\n", "\n", - " # MERGE (n)-[:{rel_type} {{text:l.rel_text}}]->(m)\n", " gds.run_cypher(\n", " f\"\"\"\n", " UNWIND $ll as l\n", - " MERGE (n:Entity {{id:l.source, text:l.source_text}})\n", - " MERGE (m:Entity {{id:l.target, text:l.target_text}})\n", - " MERGE (n)-[:{rel_split}]->(m)\n", + " MERGE (n:Entity {{text:l.source_text}})\n", + " MERGE (m:Entity {{text:l.target_text}})\n", " MERGE (n)-[:{rel_type}]->(m)\n", + " MERGE (n)-[:{rel_split}]->(m)\n", " \"\"\",\n", " params={\"ll\": edges},\n", " )\n", "\n", - " for rel_split in dataset:\n", - " res = gds.run_cypher(\n", - " f\"\"\"\n", - " MATCH ()-[r:{rel_split}]->()\n", - " RETURN COUNT(r) AS numberOfRelationships\n", - " \"\"\"\n", - " )\n", - " print(f\"Number of relationships of type {rel_split} in db: \", res.numberOfRelationships)\n", - "\n", "\n", "put_data_in_db(dataset)" ] @@ -225,18 +217,18 @@ "outputs": [], "source": [ "ALL_RELS = dataset[\"TRAIN\"].keys()\n", - "G_train, result = gds.graph.cypher.project(\n", + "G, result = gds.graph.cypher.project(\n", " \"\"\"\n", " MATCH (n:Entity)-[:TRAIN]->(m:Entity)<-[:\"\"\"\n", " + \"|\".join(ALL_RELS)\n", - " + \"\"\"]-(n)\n", + " + \"\"\"]-(n:Entity)\n", " RETURN gds.graph.project($graph_name, n, m, {\n", " sourceNodeLabels: $label,\n", " targetNodeLabels: $label\n", " })\n", " \"\"\", # Cypher query\n", " database=\"neo4j\", # Target database\n", - " graph_name=\"trainGraph\", # Query parameter\n", + " graph_name=\"G_full\", # Query parameter\n", " label=\"Entity\", # Query parameter\n", ")" ] @@ -261,7 +253,7 @@ " print(f\"==={func_name}===: {getattr(G, func_name)()}\")\n", "\n", "\n", - "inspect_graph(G_train)" + "inspect_graph(G)" ] }, { @@ -279,226 +271,25 @@ "metadata": {}, "outputs": [], "source": [ + "import time\n", + "\n", + "model_name = \"fb15k-TransE-128-model-\" + str(time.time())\n", "gds.kge.model.train(\n", - " G_train,\n", - " scoring_function=\"distmult\",\n", - " num_epochs=10,\n", - " embedding_dimension=100,\n", + " G,\n", + " model_name=model_name,\n", + " scoring_function=\"TransE\",\n", + " embedding_dimension=128,\n", + " num_epochs=100,\n", + " filtered_metrics=False,\n", + " batch_size=32_768,\n", + " optimizer=\"Adam\",\n", + " optimizer_kwargs={\"lr\": 0.0003},\n", + " epochs_per_val=0,\n", + " do_validation=False,\n", + " do_test=False,\n", ")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "node_projection = {\"Entity\": {\"properties\": \"id\"}}\n", - "relationship_projection = [\n", - " {\"TRAIN\": {\"orientation\": \"NATURAL\", \"properties\": \"rel_id\"}},\n", - " {\"TEST\": {\"orientation\": \"NATURAL\", \"properties\": \"rel_id\"}},\n", - " {\"VALID\": {\"orientation\": \"NATURAL\", \"properties\": \"rel_id\"}},\n", - "]\n", - "\n", - "ttv_G, result = gds.graph.project(\n", - " \"fb15k-graph-ttv\",\n", - " node_projection,\n", - " relationship_projection,\n", - ")\n", - "\n", - "node_properties = gds.graph.nodeProperties.stream(\n", - " ttv_G,\n", - " [\"id\"],\n", - " separate_property_columns=True,\n", - ")\n", - "\n", - "nodeId_to_id = dict(zip(node_properties.nodeId, node_properties.id))\n", - "id_to_nodeId = dict(zip(node_properties.id, node_properties.nodeId))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "## Training the TransE Model with PyG\n", - "\n", - "Retrieve data from the database, convert it into torch tensors, and format it into a `Data` structure suitable for training with PyG." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def create_data_from_graph(relationship_type):\n", - " rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, \"rel_id\", relationship_type)\n", - " topology = [\n", - " rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]),\n", - " rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]),\n", - " ]\n", - " edge_index = torch.tensor(topology, dtype=torch.long)\n", - " edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long)\n", - " data = Data(edge_index=edge_index, edge_type=edge_type)\n", - " data.num_nodes = len(nodeId_to_id)\n", - " display(data)\n", - " return data\n", - "\n", - "\n", - "train_tensor_data = create_data_from_graph(\"TRAIN\")\n", - "test_tensor_data = create_data_from_graph(\"TEST\")\n", - "val_tensor_data = create_data_from_graph(\"VALID\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "Drop the projected graph to save memory." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "gds.graph.drop(ttv_G)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "The training process of the TransE model follows the corresponding PyG [example](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def train_model_with_pyg():\n", - " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "\n", - " model = TransE(\n", - " num_nodes=train_tensor_data.num_nodes,\n", - " num_relations=train_tensor_data.num_edge_types,\n", - " hidden_channels=50,\n", - " ).to(device)\n", - "\n", - " loader = model.loader(\n", - " head_index=train_tensor_data.edge_index[0],\n", - " rel_type=train_tensor_data.edge_type,\n", - " tail_index=train_tensor_data.edge_index[1],\n", - " batch_size=1000,\n", - " shuffle=True,\n", - " )\n", - "\n", - " optimizer = optim.Adam(model.parameters(), lr=0.01)\n", - "\n", - " def train():\n", - " model.train()\n", - " total_loss = total_examples = 0\n", - " for head_index, rel_type, tail_index in loader:\n", - " optimizer.zero_grad()\n", - " loss = model.loss(head_index, rel_type, tail_index)\n", - " loss.backward()\n", - " optimizer.step()\n", - " total_loss += float(loss) * head_index.numel()\n", - " total_examples += head_index.numel()\n", - " return total_loss / total_examples\n", - "\n", - " @torch.no_grad()\n", - " def test(data):\n", - " model.eval()\n", - " return model.test(\n", - " head_index=data.edge_index[0],\n", - " rel_type=data.edge_type,\n", - " tail_index=data.edge_index[1],\n", - " batch_size=1000,\n", - " k=10,\n", - " )\n", - "\n", - " # Consider increasing the number of epochs\n", - " epoch_count = 5\n", - " for epoch in range(1, epoch_count):\n", - " loss = train()\n", - " print(f\"Epoch: {epoch:03d}, Loss: {loss:.4f}\")\n", - " if epoch % 75 == 0:\n", - " rank, hits = test(val_tensor_data)\n", - " print(f\"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, \" f\"Val Hits@10: {hits:.4f}\")\n", - "\n", - " torch.save(model, f\"./model_{epoch_count}.pt\")\n", - "\n", - " mean_rank, mrr, hits_at_k = test(test_tensor_data)\n", - " print(f\"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}\")\n", - "\n", - " return model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model = train_model_with_pyg()\n", - "# The model can be loaded if it was trained before\n", - "# model = torch.load(\"./model_501.pt\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "Extract node embeddings from the trained model and put them into database." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for i in tqdm(range(len(nodeId_to_id))):\n", - " gds.run_cypher(\n", - " \"MATCH (n:Entity {id: $i}) SET n.emb=$EMBEDDING\",\n", - " params={\"i\": i, \"EMBEDDING\": model.node_emb.weight[i].tolist()},\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "## Predict Using GDS Knowledge Graph Edge Embeddings Functionality\n", - "\n", - "Select a relationship type for which to make predictions." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "relationship_to_predict = \"/film/film/genre\"\n", - "rel_id_to_predict = rel_dict[relationship_to_predict]\n", - "rel_label_to_predict = f\"REL_{rel_id_to_predict}\"" - ] - }, { "cell_type": "markdown", "metadata": { @@ -514,21 +305,24 @@ "metadata": {}, "outputs": [], "source": [ - "G_test, result = gds.graph.project(\n", - " \"graph_to_predict_\",\n", - " {\"Entity\": {\"properties\": [\"id\", \"emb\"]}},\n", - " rel_label_to_predict,\n", - ")\n", + "source_node_list = [\"/m/07l450\", \"/m/0ds2l81\", \"/m/0jvt9\"]\n", "\n", + "source_ids_df = gds.run_cypher(\n", + " \"UNWIND $node_text_list AS t MATCH (n:Entity) WHERE n.text=t RETURN id(n) as nodeId\",\n", + " params={\"node_text_list\": source_node_list},\n", + ")\n", + "node_ids = source_ids_df[\"nodeId\"].to_list()\n", "\n", - "def print_graph_info(G):\n", - " print(f\"Graph '{G.name()}' node count: {G.node_count()}\")\n", - " print(f\"Graph '{G.name()}' node labels: {G.node_labels()}\")\n", - " print(f\"Graph '{G.name()}' relationship types: {G.relationship_types()}\")\n", - " print(f\"Graph '{G.name()}' relationship count: {G.relationship_count()}\")\n", + "rel_label_to_predict = \"REL_\" + str(rel_dict[\"/film/film/genre\"])\n", "\n", + "predict_result = gds.kge.model.predict(\n", + " model_name=model_name,\n", + " top_k=3,\n", + " node_ids=node_ids,\n", + " rel_types=[rel_label_to_predict],\n", + ")\n", "\n", - "print_graph_info(G_test)" + "print(predict_result.to_string())" ] }, { @@ -546,8 +340,12 @@ "metadata": {}, "outputs": [], "source": [ - "target_emb = model.node_emb.weight[rel_id_to_predict].tolist()\n", - "transe_model = gds.model.transe.create(G_test, \"emb\", {rel_label_to_predict: target_emb})" + "source_node_list = [\"/m/07l450\", \"/m/0ds2l81\", \"/m/0jvt9\"]\n", + "source_ids_df = gds.run_cypher(\n", + " \"UNWIND $node_text_list AS t MATCH (n:Entity) WHERE n.text=t RETURN id(n) as nodeId\",\n", + " params={\"node_text_list\": source_node_list},\n", + ")\n", + "source_ids_df[\"nodeId\"].to_list()" ] }, { @@ -555,13 +353,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "source_node_list = [\"/m/07l450\", \"/m/0ds2l81\", \"/m/0jvt9\"]\n", - "source_ids_df = gds.run_cypher(\n", - " \"UNWIND $node_text_list AS t MATCH (n:Entity) WHERE n.text=t RETURN id(n) as nodeId\",\n", - " params={\"node_text_list\": source_node_list},\n", - ")" - ] + "source": [] }, { "cell_type": "markdown", diff --git a/graphdatascience/model/kge_runner.py b/graphdatascience/model/kge_runner.py index 8ba46ae95..b42950df2 100644 --- a/graphdatascience/model/kge_runner.py +++ b/graphdatascience/model/kge_runner.py @@ -68,7 +68,7 @@ def train( do_validation: bool = True, do_test: bool = True, filtered_metrics: bool = False, - epochs_per_val: int = 50, + epochs_per_val: int = 0, inner_norm: bool = True, init_bound: Optional[float] = None, mlflow_experiment_name: Optional[str] = None, @@ -108,7 +108,8 @@ def train( if mlflow_experiment_name is not None: config["task_config"]["mlflow"] = { - "config": {"tracking_uri": self._compute_cluster_mlflow_uri, "experiment_name": mlflow_experiment_name} + "tracking_uri": self._compute_cluster_mlflow_uri, + "experiment_name": mlflow_experiment_name, } job_id = self._start_job(config) @@ -152,7 +153,8 @@ def predict( if mlflow_experiment_name is not None: config["task_config"]["mlflow"] = { - "config": {"tracking_uri": self._compute_cluster_mlflow_uri, "experiment_name": mlflow_experiment_name} + "tracking_uri": self._compute_cluster_mlflow_uri, + "experiment_name": mlflow_experiment_name, } job_id = self._start_job(config) @@ -187,7 +189,8 @@ def score_triplets( if mlflow_experiment_name is not None: config["task_config"]["mlflow"] = { - "config": {"tracking_uri": self._compute_cluster_mlflow_uri, "experiment_name": mlflow_experiment_name} + "tracking_uri": self._compute_cluster_mlflow_uri, + "experiment_name": mlflow_experiment_name, } job_id = self._start_job(config) @@ -228,7 +231,7 @@ def _get_metrics(self, user_name: str, model_name: str, job_id: str) -> DataFram os.remove(res_file_name) - return metadata["metrics"] + return metadata.get("metrics", None) def _start_job(self, config: Dict[str, Any]) -> str: url = f"{self._compute_cluster_web_uri}/api/machine-learning/start" From 961906d20616d5188e9af339feae3636c1649474 Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Wed, 31 Jul 2024 12:38:15 +0100 Subject: [PATCH 23/24] Add random seed --- examples/kge-distmult-nations.ipynb | 1 + graphdatascience/model/kge_runner.py | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/kge-distmult-nations.ipynb b/examples/kge-distmult-nations.ipynb index 1e570da8b..0cc161df7 100644 --- a/examples/kge-distmult-nations.ipynb +++ b/examples/kge-distmult-nations.ipynb @@ -314,6 +314,7 @@ " embedding_dimension=64,\n", " split_ratios={\"TRAIN\": 0.8, \"VALID\": 0.1, \"TEST\": 0.1},\n", " mlflow_experiment_name=\"Nations-train\",\n", + " random_seed=42,\n", ")\n", "\n", "predict_result = gds.kge.model.predict(\n", diff --git a/graphdatascience/model/kge_runner.py b/graphdatascience/model/kge_runner.py index b42950df2..73e93f177 100644 --- a/graphdatascience/model/kge_runner.py +++ b/graphdatascience/model/kge_runner.py @@ -70,6 +70,7 @@ def train( filtered_metrics: bool = False, epochs_per_val: int = 0, inner_norm: bool = True, + random_seed: Optional[int] = None, init_bound: Optional[float] = None, mlflow_experiment_name: Optional[str] = None, ) -> Series: From a709d22533695eedcec454248ce6d7b5bec52a58 Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Fri, 2 Aug 2024 13:08:12 +0100 Subject: [PATCH 24/24] Make field notebook working again --- ...ipynb => kge-distmult-nations-field.ipynb} | 4 +- examples/kge-distmult-nations.py | 12 +++--- graphdatascience/graph_data_science.py | 17 ++++---- graphdatascience/model/kge_runner.py | 41 ++++++++++--------- 4 files changed, 38 insertions(+), 36 deletions(-) rename examples/{kge-distmult-nations.ipynb => kge-distmult-nations-field.ipynb} (99%) diff --git a/examples/kge-distmult-nations.ipynb b/examples/kge-distmult-nations-field.ipynb similarity index 99% rename from examples/kge-distmult-nations.ipynb rename to examples/kge-distmult-nations-field.ipynb index 0cc161df7..d6f609ce0 100644 --- a/examples/kge-distmult-nations.ipynb +++ b/examples/kge-distmult-nations-field.ipynb @@ -347,7 +347,7 @@ "outputs": [], "source": [ "for index, row in predict_result.iterrows():\n", - " h = row[\"head\"]\n", + " h = row[\"sourceNodeId\"]\n", " r = row[\"rel\"]\n", " gds.run_cypher(\n", " f\"\"\"\n", @@ -356,7 +356,7 @@ " MATCH (b:Entity WHERE id(b) = t)\n", " MERGE (a)-[:NEW_REL_{r}]->(b)\n", " \"\"\",\n", - " params={\"tt\": row[\"tail\"]},\n", + " params={\"tt\": row[\"targetNodeIdTopK\"]},\n", " )" ] }, diff --git a/examples/kge-distmult-nations.py b/examples/kge-distmult-nations.py index a3011ac25..910e500b1 100644 --- a/examples/kge-distmult-nations.py +++ b/examples/kge-distmult-nations.py @@ -195,11 +195,11 @@ def inspect_graph(G): res = gds.kge.model.train( G_train, model_name=model_name, - scoring_function="distmult", - num_epochs=1, - embedding_dimension=10, + scoring_function="TransE", + num_epochs=30, + embedding_dimension=64, epochs_per_checkpoint=0, - epochs_per_val=5, + epochs_per_val=0, split_ratios={"TRAIN": 0.8, "VALID": 0.1, "TEST": 0.1}, ) print(res["metrics"]) @@ -218,7 +218,7 @@ def inspect_graph(G): print(predict_result.to_string()) for index, row in predict_result.iterrows(): - h = row["head"] + h = row["sourceNodeId"] r = row["rel"] gds.run_cypher( f""" @@ -227,7 +227,7 @@ def inspect_graph(G): MATCH (b:Entity WHERE id(b) = t) MERGE (a)-[:NEW_REL_{r}]->(b) """, - params={"tt": row["tail"]}, + params={"tt": row["targetNodeIdTopK"]}, ) brazil_node = gds.find_node_id(["Entity"], {"text": "brazil"}) diff --git a/graphdatascience/graph_data_science.py b/graphdatascience/graph_data_science.py index a696e59b5..e9b7aa152 100644 --- a/graphdatascience/graph_data_science.py +++ b/graphdatascience/graph_data_science.py @@ -4,10 +4,12 @@ import sys from typing import Any, Dict, Optional, Tuple, Type, Union -import rsa from neo4j import Driver from pandas import DataFrame +from graphdatascience.graph.graph_proc_runner import GraphProcRunner +from graphdatascience.utils.util_proc_runner import UtilProcRunner + from .call_builder import IndirectCallBuilder from .endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints from .error.uncallable_namespace import UncallableNamespace @@ -16,8 +18,6 @@ from .query_runner.neo4j_query_runner import Neo4jQueryRunner from .query_runner.query_runner import QueryRunner from .server_version.server_version import ServerVersion -from graphdatascience.graph.graph_proc_runner import GraphProcRunner -from graphdatascience.utils.util_proc_runner import UtilProcRunner class GraphDataScience(DirectEndpoints, UncallableNamespace): @@ -53,11 +53,11 @@ def __init__( database: Optional[str], default None The Neo4j database to query against. arrow : Union[str, bool], default True - Arrow connection information. This is either a bool or a string. - If it is a string, it will be interpreted as a connection URL to a GDS Arrow Server. - If it is a bool, - True will make the client discover the connection URI to the GDS Arrow server via the Neo4j endpoint, - while False will make the client use Bolt for all operations. + Arrow connection information. This is either a string or a bool. + - If it is a string, it will be interpreted as a connection URL to a GDS Arrow Server. + - If it is a bool: + - True will make the client discover the connection URI to the GDS Arrow server via the Neo4j endpoint. + - False will make the client use Bolt for all operations. arrow_disable_server_verification : bool, default True A flag that overrides other TLS settings and disables server verification for TLS connections. arrow_tls_root_certs : Optional[bytes], default None @@ -91,6 +91,7 @@ def __init__( # pub_key = rsa.PublicKey.load_pkcs1(f.read()) # self._encrypted_db_password = rsa.encrypt(auth[1].encode(), pub_key).hex() + self._encrypted_db_password = None self._compute_cluster_ip = None super().__init__(self._query_runner, "gds", self._server_version) diff --git a/graphdatascience/model/kge_runner.py b/graphdatascience/model/kge_runner.py index 73e93f177..70e3a2b6f 100644 --- a/graphdatascience/model/kge_runner.py +++ b/graphdatascience/model/kge_runner.py @@ -4,7 +4,7 @@ import time from typing import Any, Dict, Optional -import pandas as pd +import pyarrow import requests from pandas import DataFrame, Series @@ -32,12 +32,13 @@ def __init__( self._namespace = namespace self._server_version = server_version self._compute_cluster_web_uri = f"http://{compute_cluster_ip}:5005" + self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8815" self._compute_cluster_mlflow_uri = f"http://{compute_cluster_ip}:8080" self._encrypted_db_password = encrypted_db_password self._arrow_uri = arrow_uri @property - def model(self): + def model(self) -> "KgeRunner": return self # @compatible_with("stream", min_inclusive=ServerVersion(2, 5, 0)) @@ -75,7 +76,7 @@ def train( mlflow_experiment_name: Optional[str] = None, ) -> Series: if epochs_per_checkpoint is None: - epochs_per_checkpoint = max(num_epochs / 10, 1) + epochs_per_checkpoint = max(int(num_epochs / 10), 1) if loss_function_kwargs is None: loss_function_kwargs = dict(margin=1.0, adversarial_temperature=1.0, gamma=20.0) if lr_scheduler_kwargs is None: @@ -92,7 +93,7 @@ def train( } print(algo_config) - graph_config = {"name": G.name()} + graph_config = {"name": G.name(), "config_type": "GdsGraphConfig"} config = { "user_name": "DUMMY_USER", @@ -133,7 +134,6 @@ def predict( rel_types: list[str], mlflow_experiment_name: Optional[str] = None, ) -> DataFrame: - algo_config = { "top_k": top_k, "node_ids": node_ids, @@ -144,8 +144,10 @@ def predict( "user_name": "DUMMY_USER", "task": "KGE_PREDICT_PYG", "task_config": { + "graph_config": {"config_type": "GdsGraphConfig", "name": "NOGRAPH"}, "modelname": model_name, "task_config": algo_config, + "stream_rel_results": True, }, "graph_arrow_uri": self._arrow_uri, } @@ -162,7 +164,7 @@ def predict( self._wait_for_job(job_id) - return self._stream_results(config["user_name"], config["task_config"]["modelname"], job_id) + return self._stream_results(config, job_id) @client_only_endpoint("gds.kge.model") def score_triplets( @@ -171,7 +173,6 @@ def score_triplets( triplets: list[tuple[int, str, int]], mlflow_experiment_name: Optional[str] = None, ) -> DataFrame: - algo_config = { "triplets": triplets, } @@ -180,8 +181,10 @@ def score_triplets( "user_name": "DUMMY_USER", "task": "KGE_SCORE_TRIPLETS_PYG", "task_config": { + "graph_config": {"config_type": "GdsGraphConfig", "name": "NOGRAPH"}, "modelname": model_name, "task_config": algo_config, + "stream_rel_results": True, }, "graph_arrow_uri": self._arrow_uri, } @@ -198,22 +201,20 @@ def score_triplets( self._wait_for_job(job_id) - return self._stream_results(config["user_name"], config["task_config"]["modelname"], job_id) + return self._stream_results(config, job_id) - def _stream_results(self, user_name: str, model_name: str, job_id: str) -> DataFrame: - res = requests.get( - f"{self._compute_cluster_web_uri}/internal/fetch-result", - params={"user_name": user_name, "modelname": model_name, "job_id": job_id}, - ) - res.raise_for_status() + def _stream_results(self, config: dict, job_id: str) -> DataFrame: + client = pyarrow.flight.connect(self._compute_cluster_arrow_uri) - res_file_name = f"res_{job_id}.json" - with open(res_file_name, mode="wb+") as f: - f.write(res.content) + if config["task_config"].get("stream_rel_results", False): + upload_descriptor = pyarrow.flight.FlightDescriptor.for_path(f"{job_id}.relationships") + else: + raise ValueError("No results to fetch: need to set stream_rel_results or stream_graph_results to True") + flight = client.get_flight_info(upload_descriptor) + reader = client.do_get(flight.endpoints[0].ticket) + read_table = reader.read_all() - df = pd.read_json(res_file_name, orient="records", lines=True) - os.remove(res_file_name) - return df + return read_table.to_pandas() def _get_metrics(self, user_name: str, model_name: str, job_id: str) -> DataFrame: res = requests.get(