diff --git a/examples/kge-distmult-nations-field.ipynb b/examples/kge-distmult-nations-field.ipynb new file mode 100644 index 000000000..d6f609ce0 --- /dev/null +++ b/examples/kge-distmult-nations-field.ipynb @@ -0,0 +1,403 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "11d08c597a9fdbf3", + "metadata": { + "collapsed": false + }, + "source": [ + "# Knowledge Graph Embedding: DistMult embedding for Nation dataset\n", + "\n", + "In this notebook, we will use the DistMult embedding model to make predictions on the Nations dataset.\n", + "The Nations dataset is a simple dataset that contains relationships between countries.\n", + "\n", + "The dataset contains three files: `train.txt`, `valid.txt`, and `test.txt`.\n", + "Each file contains triplets of the form `source_country relation target_country`.\n", + "The `entity2id.txt` file contains the mapping of country names to ids, and the `relation2id.txt` file contains the mapping of relation names to ids." + ] + }, + { + "cell_type": "markdown", + "id": "f9529174", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "We start by installing and importing our dependencies, and setting up our GDS client connection to the database." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9135277efcde2800", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import time\n", + "import warnings\n", + "from collections import defaultdict\n", + "from neo4j.exceptions import ClientError\n", + "from tqdm import tqdm\n", + "from graphdatascience import GraphDataScience" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1551fddc3a67fa5b", + "metadata": {}, + "outputs": [], + "source": [ + "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f05ee7fdb496f84", + "metadata": {}, + "outputs": [], + "source": [ + "NEO4J_URI = os.environ.get(\"NEO4J_URI\", \"bolt://localhost:7687\")\n", + "NEO4J_AUTH = None\n", + "NEO4J_DB = os.environ.get(\"NEO4J_DB\", \"neo4j\")\n", + "if os.environ.get(\"NEO4J_USER\") and os.environ.get(\"NEO4J_PASSWORD\"):\n", + " NEO4J_AUTH = (\n", + " os.environ.get(\"NEO4J_USER\"),\n", + " os.environ.get(\"NEO4J_PASSWORD\"),\n", + " )\n", + "gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True)" + ] + }, + { + "cell_type": "markdown", + "id": "98a7f9b7", + "metadata": {}, + "source": [ + "Create constraints to ensure that the `Entity` nodes have unique `text` properties." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "658c9f8369fff77e", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " _ = gds.run_cypher(\"CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.text IS UNIQUE\")\n", + "except ClientError:\n", + " print(\"CONSTRAINT entity_id already exists\")" + ] + }, + { + "cell_type": "markdown", + "id": "320f3ded", + "metadata": {}, + "source": [ + "## Download and read the data\n", + "\n", + "Let's download the Nations dataset and read the data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "485869468ad5ad2e", + "metadata": {}, + "outputs": [], + "source": [ + "def get_text_to_id_map(data_dir, text_to_id_filename):\n", + " with open(data_dir + \"/\" + text_to_id_filename, \"r\") as f:\n", + " data = [x.split(\"\\t\") for x in f.read().split(\"\\n\")[:-1]]\n", + " text_to_id_map = {text: int(id) for text, id in data}\n", + " return text_to_id_map\n", + "\n", + "\n", + "def read_data():\n", + " rel_types = {\n", + " \"train.txt\": \"TRAIN\",\n", + " \"valid.txt\": \"VALID\",\n", + " \"test.txt\": \"TEST\",\n", + " }\n", + " url = \"https://raw.githubusercontent.com/ZhenfengLei/KGDatasets/master/Nations\"\n", + " data_dir = \"./Nations\"\n", + "\n", + " raw_file_names = [\"train.txt\", \"valid.txt\", \"test.txt\"]\n", + " node_id_filename = \"entity2id.txt\"\n", + " rel_id_filename = \"relation2id.txt\"\n", + "\n", + " for file in raw_file_names + [node_id_filename, rel_id_filename]:\n", + " if not os.path.exists(f\"{data_dir}/{file}\"):\n", + " os.system(f\"wget {url}/{file} -P {data_dir}\")\n", + "\n", + " node_map = get_text_to_id_map(data_dir, node_id_filename)\n", + " rel_map = get_text_to_id_map(data_dir, rel_id_filename)\n", + " dataset = defaultdict(lambda: defaultdict(list))\n", + "\n", + " rel_split_id = {\"TRAIN\": 0, \"VALID\": 1, \"TEST\": 2}\n", + "\n", + " for file_name in raw_file_names:\n", + " file_name_path = data_dir + \"/\" + file_name\n", + "\n", + " with open(file_name_path, \"r\") as f:\n", + " data = [x.split(\"\\t\") for x in f.read().split(\"\\n\")[:-1]]\n", + "\n", + " for i, (src_text, rel_text, dst_text) in enumerate(data):\n", + " source = node_map[src_text]\n", + " target = node_map[dst_text]\n", + " rel_type = \"REL_\" + rel_text.upper()\n", + " rel_split = rel_types[file_name]\n", + "\n", + " dataset[rel_split][rel_type].append(\n", + " {\n", + " \"source\": source,\n", + " \"source_text\": src_text,\n", + " \"target\": target,\n", + " \"target_text\": dst_text,\n", + " \"rel_type\": rel_type,\n", + " \"rel_id\": rel_map[rel_text],\n", + " \"rel_split\": rel_split,\n", + " \"rel_split_id\": rel_split_id[rel_split],\n", + " }\n", + " )\n", + "\n", + " print(\"Number of nodes: \", len(node_map))\n", + " for rel_split in dataset:\n", + " print(\n", + " f\"Number of relationships of type {rel_split}: \",\n", + " sum([len(dataset[rel_split][rel_type]) for rel_type in dataset[rel_split]]),\n", + " )\n", + " return dataset, node_map\n", + "\n", + "\n", + "dataset, node_map = read_data()" + ] + }, + { + "cell_type": "markdown", + "id": "5c97b4df", + "metadata": {}, + "source": [ + "## Put data in the database\n", + "\n", + "We will put the data in the database, creating `Entity` nodes and relationships between them.\n", + "\n", + "Each node will have a `text` property. We will use `text` to identify the node later.\n", + "\n", + "Each relationship will have a `split` property to indicate whether it is part of the training, validation, or test set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2032a4e1aed1bd5", + "metadata": {}, + "outputs": [], + "source": [ + "def put_data_in_db():\n", + " res = gds.run_cypher(\"MATCH (m) RETURN count(m) as num_nodes\")\n", + " if res[\"num_nodes\"].values[0] > 0:\n", + " print(\"Data already in db, number of nodes: \", res[\"num_nodes\"].values[0])\n", + " return\n", + " pbar = tqdm(\n", + " desc=\"Putting data in db\",\n", + " total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]),\n", + " )\n", + "\n", + " for rel_split in dataset:\n", + " for rel_type in dataset[rel_split]:\n", + " edges = dataset[rel_split][rel_type]\n", + "\n", + " gds.run_cypher(\n", + " f\"\"\"\n", + " UNWIND $ll as l\n", + " MERGE (n:Entity {{id:l.source, text:l.source_text}})\n", + " MERGE (m:Entity {{id:l.target, text:l.target_text}})\n", + " MERGE (n)-[:{rel_type} {{split: l.rel_split_id, rel_id: l.rel_id}}]->(m)\n", + " \"\"\",\n", + " params={\"ll\": edges},\n", + " )\n", + " pbar.update(len(edges))\n", + " pbar.close()\n", + "\n", + " for rel_split in dataset:\n", + " res = gds.run_cypher(\n", + " f\"\"\"\n", + " MATCH ()-[r:{rel_split}]->()\n", + " RETURN COUNT(r) AS numberOfRelationships\n", + " \"\"\"\n", + " )\n", + " print(f\"Number of relationships of type {rel_split} in db: \", res.numberOfRelationships)\n", + "\n", + "\n", + "put_data_in_db()" + ] + }, + { + "cell_type": "markdown", + "id": "9f270636", + "metadata": {}, + "source": [ + "## Project graphs\n", + "\n", + "First, we will project the full graph, then we will filter the graph to create the training graph based on the `split` property." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5c4f1523a225fa3c", + "metadata": {}, + "outputs": [], + "source": [ + "def project_graphs():\n", + " all_rels = gds.run_cypher(\n", + " \"\"\"\n", + " CALL db.relationshipTypes() YIELD relationshipType\n", + " \"\"\"\n", + " )\n", + " all_rels = all_rels[\"relationshipType\"].to_list()\n", + " all_rels = {rel: {\"properties\": \"split\"} for rel in all_rels if rel.startswith(\"REL_\")}\n", + " gds.graph.drop(\"fullGraph\", failIfMissing=False)\n", + "\n", + " G_full, _ = gds.graph.project(\"fullGraph\", [\"Entity\"], all_rels)\n", + "\n", + " return G_full\n", + "\n", + "\n", + "G = project_graphs()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21da1ea76d247803", + "metadata": {}, + "outputs": [], + "source": [ + "G.relationship_types()" + ] + }, + { + "cell_type": "markdown", + "id": "88b243ea", + "metadata": {}, + "source": [ + "We will train a knowledge graph embedding model using the Graph Data Science library. The model will be trained on the `G` graph.\n", + "\n", + "We will use the DistMult scoring function and set the embedding dimension to 64. The model will be trained for 30 epochs with a split ratio of 80% for training, 10% for validation, and 10% for testing.\n", + "\n", + "After training the model, we will use it to make predictions on three specific nodes: \"brazil\", \"uk\", and \"jordan\". We will predict the top 3 relationships for each node and print the results.\n", + "\n", + "Finally, we will create new relationships in the graph based on the predicted relationships. For each predicted relationship, we will create a new relationship between the corresponding nodes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d518e67375f6ab3", + "metadata": {}, + "outputs": [], + "source": [ + "gds.set_compute_cluster_ip(\"localhost\")\n", + "\n", + "model_name = \"dummyModelName_\" + str(time.time())\n", + "\n", + "gds.kge.model.train(\n", + " G,\n", + " model_name=model_name,\n", + " scoring_function=\"DistMult\",\n", + " num_epochs=30,\n", + " embedding_dimension=64,\n", + " split_ratios={\"TRAIN\": 0.8, \"VALID\": 0.1, \"TEST\": 0.1},\n", + " mlflow_experiment_name=\"Nations-train\",\n", + " random_seed=42,\n", + ")\n", + "\n", + "predict_result = gds.kge.model.predict(\n", + " model_name=model_name,\n", + " top_k=3,\n", + " node_ids=[\n", + " gds.find_node_id([\"Entity\"], {\"text\": \"brazil\"}),\n", + " gds.find_node_id([\"Entity\"], {\"text\": \"uk\"}),\n", + " gds.find_node_id([\"Entity\"], {\"text\": \"jordan\"}),\n", + " ],\n", + " rel_types=[\"REL_RELDIPLOMACY\", \"REL_RELNGO\"],\n", + ")\n", + "\n", + "print(predict_result.to_string())" + ] + }, + { + "cell_type": "markdown", + "id": "aa583359", + "metadata": {}, + "source": [ + "In the next cell we will add this top scored relationships to te database." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "83b75194c69259a2", + "metadata": {}, + "outputs": [], + "source": [ + "for index, row in predict_result.iterrows():\n", + " h = row[\"sourceNodeId\"]\n", + " r = row[\"rel\"]\n", + " gds.run_cypher(\n", + " f\"\"\"\n", + " UNWIND $tt as t\n", + " MATCH (a:Entity WHERE id(a) = {h})\n", + " MATCH (b:Entity WHERE id(b) = t)\n", + " MERGE (a)-[:NEW_REL_{r}]->(b)\n", + " \"\"\",\n", + " params={\"tt\": row[\"targetNodeIdTopK\"]},\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "c00579a8", + "metadata": {}, + "source": [ + "There is also a API that can be used to score a list of triplets. In the next cell we will use a call to score the triplets `(brazil, REL_RELNGO, uk)` and `(brazil, REL_RELDIPLOMACY, jordan)`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4e2825a", + "metadata": {}, + "outputs": [], + "source": [ + "brazil_node = gds.find_node_id([\"Entity\"], {\"text\": \"brazil\"})\n", + "uk_node = gds.find_node_id([\"Entity\"], {\"text\": \"uk\"})\n", + "jordan_node = gds.find_node_id([\"Entity\"], {\"text\": \"jordan\"})\n", + "\n", + "triplets = [\n", + " (brazil_node, \"REL_RELNGO\", uk_node),\n", + " (brazil_node, \"REL_RELDIPLOMACY\", jordan_node),\n", + "]\n", + "\n", + "scores = gds.kge.model.score_triplets(\n", + " model_name=model_name,\n", + " triplets=triplets,\n", + ")\n", + "\n", + "print(scores)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/kge-distmult-nations.py b/examples/kge-distmult-nations.py new file mode 100644 index 000000000..910e500b1 --- /dev/null +++ b/examples/kge-distmult-nations.py @@ -0,0 +1,264 @@ +import os +import time +import warnings +from collections import defaultdict + +from neo4j.exceptions import ClientError +from tqdm import tqdm + +from graphdatascience import GraphDataScience + +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +def setup_connection(): + NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687") + NEO4J_AUTH = None + NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j") + if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"): + NEO4J_AUTH = ( + os.environ.get("NEO4J_USER"), + os.environ.get("NEO4J_PASSWORD"), + ) + gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True) + + return gds + + +def create_constraint(gds): + try: + _ = gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE") + except ClientError: + print("CONSTRAINT entity_id already exists") + + +def download_data(raw_file_names): + import os + import zipfile + + from ogb.utils.url import download_url + + url = "https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip" + raw_dir = "./data_from_zip" + download_url(f"{url}", raw_dir) + + with zipfile.ZipFile(raw_dir + "/" + os.path.basename(url), "r") as zip_ref: + for filename in raw_file_names: + zip_ref.extract(f"Release/{filename}", path=raw_dir) + data_dir = raw_dir + "/" + "Release" + return data_dir + + +def get_text_to_id_map(data_dir, text_to_id_filename): + with open(data_dir + "/" + text_to_id_filename, "r") as f: + data = [x.split("\t") for x in f.read().split("\n")[:-1]] + text_to_id_map = {text: int(id) for text, id in data} + return text_to_id_map + + +def read_data(): + rel_types = { + "train.txt": "TRAIN", + "valid.txt": "VALID", + "test.txt": "TEST", + } + raw_file_names = ["train.txt", "valid.txt", "test.txt"] + node_id_filename = "entity2id.txt" + rel_id_filename = "relation2id.txt" + + data_dir = "/Users/olgarazvenskaia/work/datasets/KGDatasets/Nations" + node_map = get_text_to_id_map(data_dir, node_id_filename) + rel_map = get_text_to_id_map(data_dir, rel_id_filename) + dataset = defaultdict(lambda: defaultdict(list)) + + rel_split_id = {"TRAIN": 0, "VALID": 1, "TEST": 2} + + for file_name in raw_file_names: + file_name_path = data_dir + "/" + file_name + + with open(file_name_path, "r") as f: + data = [x.split("\t") for x in f.read().split("\n")[:-1]] + + for i, (src_text, rel_text, dst_text) in enumerate(data): + source = node_map[src_text] + target = node_map[dst_text] + rel_type = "REL_" + rel_text.upper() + rel_split = rel_types[file_name] + + dataset[rel_split][rel_type].append( + { + "source": source, + "source_text": src_text, + "target": target, + "target_text": dst_text, + "rel_type": rel_type, + "rel_id": rel_map[rel_text], + "rel_split": rel_split, + "rel_split_id": rel_split_id[rel_split], + } + ) + + print("Number of nodes: ", len(node_map)) + for rel_split in dataset: + print( + f"Number of relationships of type {rel_split}: ", + sum([len(dataset[rel_split][rel_type]) for rel_type in dataset[rel_split]]), + ) + return dataset + + +def put_data_in_db(gds): + res = gds.run_cypher("MATCH (m) RETURN count(m) as num_nodes") + if res["num_nodes"].values[0] > 0: + print("Data already in db, number of nodes: ", res["num_nodes"].values[0]) + return + dataset = read_data() + pbar = tqdm( + desc="Putting data in db", + total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]), + ) + + for rel_split in dataset: + for rel_type in dataset[rel_split]: + edges = dataset[rel_split][rel_type] + + gds.run_cypher( + f""" + UNWIND $ll as l + MERGE (n:Entity {{id:l.source, text:l.source_text}}) + MERGE (m:Entity {{id:l.target, text:l.target_text}}) + MERGE (n)-[:{rel_type} {{split: l.rel_split_id, rel_id: l.rel_id}}]->(m) + """, + params={"ll": edges}, + ) + pbar.update(len(edges)) + pbar.close() + + for rel_split in dataset: + res = gds.run_cypher( + f""" + MATCH ()-[r:{rel_split}]->() + RETURN COUNT(r) AS numberOfRelationships + """ + ) + print(f"Number of relationships of type {rel_split} in db: ", res.numberOfRelationships) + + +def project_graphs(gds): + all_rels = gds.run_cypher( + """ + CALL db.relationshipTypes() YIELD relationshipType + """ + ) + all_rels = all_rels["relationshipType"].to_list() + all_rels = {rel: {"properties": "split"} for rel in all_rels if rel.startswith("REL_")} + gds.graph.drop("fullGraph", failIfMissing=False) + gds.graph.drop("trainGraph", failIfMissing=False) + gds.graph.drop("validGraph", failIfMissing=False) + gds.graph.drop("testGraph", failIfMissing=False) + + G_full, _ = gds.graph.project("fullGraph", ["Entity"], all_rels) + + G_train, _ = gds.graph.filter("trainGraph", G_full, "*", "r.split = 0.0") + G_valid, _ = gds.graph.filter("validGraph", G_full, "*", "r.split = 1.0") + G_test, _ = gds.graph.filter("testGraph", G_full, "*", "r.split = 2.0") + + gds.graph.drop("fullGraph", failIfMissing=False) + + return G_train, G_valid, G_test + + +def inspect_graph(G): + func_names = [ + "name", + "node_count", + "relationship_count", + "node_labels", + "relationship_types", + ] + for func_name in func_names: + print(f"==={func_name}===: {getattr(G, func_name)()}") + + +if __name__ == "__main__": + gds = setup_connection() + create_constraint(gds) + put_data_in_db(gds) + G_train, G_valid, G_test = project_graphs(gds) + + inspect_graph(G_train) + + gds.set_compute_cluster_ip("localhost") + + model_name = "dummyModelName_" + str(time.time()) + + res = gds.kge.model.train( + G_train, + model_name=model_name, + scoring_function="TransE", + num_epochs=30, + embedding_dimension=64, + epochs_per_checkpoint=0, + epochs_per_val=0, + split_ratios={"TRAIN": 0.8, "VALID": 0.1, "TEST": 0.1}, + ) + print(res["metrics"]) + + predict_result = gds.kge.model.predict( + model_name=model_name, + top_k=3, + node_ids=[ + gds.find_node_id(["Entity"], {"text": "brazil"}), + gds.find_node_id(["Entity"], {"text": "uk"}), + gds.find_node_id(["Entity"], {"text": "jordan"}), + ], + rel_types=["REL_RELDIPLOMACY", "REL_RELNGO"], + ) + + print(predict_result.to_string()) + + for index, row in predict_result.iterrows(): + h = row["sourceNodeId"] + r = row["rel"] + gds.run_cypher( + f""" + UNWIND $tt as t + MATCH (a:Entity WHERE id(a) = {h}) + MATCH (b:Entity WHERE id(b) = t) + MERGE (a)-[:NEW_REL_{r}]->(b) + """, + params={"tt": row["targetNodeIdTopK"]}, + ) + + brazil_node = gds.find_node_id(["Entity"], {"text": "brazil"}) + uk_node = gds.find_node_id(["Entity"], {"text": "uk"}) + jordan_node = gds.find_node_id(["Entity"], {"text": "jordan"}) + + triplets = [ + (brazil_node, "REL_RELNGO", uk_node), + (brazil_node, "REL_RELDIPLOMACY", jordan_node), + ] + + scores = gds.kge.model.score_triplets( + model_name=model_name, + triplets=triplets, + ) + + print(scores) + # + # gds.kge.model.predict_tail( + # G_train, + # model_name=model_name, + # top_k=10, + # node_ids=[gds.find_node_id(["Entity"], {"text": "/m/016wzw"}), gds.find_node_id(["Entity"], {"id": 2})], + # rel_types=["REL_1", "REL_2"], + # ) + # + # gds.kge.model.score_triples( + # G_train, + # model_name=model_name, + # triples=[ + # (gds.find_node_id(["Entity"], {"text": "/m/016wzw"}), "REL_1", gds.find_node_id(["Entity"], {"id": 2})), + # (gds.find_node_id(["Entity"], {"id": 0}), "REL_123", gds.find_node_id(["Entity"], {"id": 3})), + # ], + # ) diff --git a/examples/kge-distmult.ipynb b/examples/kge-distmult.ipynb new file mode 100644 index 000000000..05456b9f0 --- /dev/null +++ b/examples/kge-distmult.ipynb @@ -0,0 +1,484 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Knowledge graph embeddings: TransE" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from graphdatascience import GraphDataScience\n", + "import collections\n", + "from tqdm import tqdm\n", + "import pandas as pd\n", + "from neo4j.exceptions import ClientError" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "NEO4J_URI = os.environ.get(\"NEO4J_URI\", \"bolt://localhost:7687\")\n", + "NEO4J_AUTH = None\n", + "NEO4J_DB = os.environ.get(\"NEO4J_DB\", \"neo4j\")\n", + "if os.environ.get(\"NEO4J_USER\") and os.environ.get(\"NEO4J_PASSWORD\"):\n", + " NEO4J_AUTH = (\n", + " os.environ.get(\"NEO4J_USER\"),\n", + " os.environ.get(\"NEO4J_PASSWORD\"),\n", + " )\n", + "gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Downloading and Storing the FB15k-237 Dataset in the Database\n", + "Download the FB15k-237 dataset\n", + "Extract the required files: train.txt, valid.txt, and test.txt." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Set a constraint for unique id entries to speed up data uploads." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " _ = gds.run_cypher(\"CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.text IS UNIQUE\")\n", + "except ClientError:\n", + " print(\"CONSTRAINT entity_id already exists\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "**Creating Entity Nodes**:\n", + " Create a node with the label `Entity`. This node should have properties `id` and `text`. \n", + " - Syntax: `(:Entity {id: int, text: str})`\n", + "\n", + "**Creating Relationships for Training with PyG**:\n", + " Based on the training stage, create relationships of type `TRAIN`, `TEST`, or `VALID`. Each of these relationships should have a `rel_id` property.\n", + " - Example Syntax: `[:TRAIN {rel_id: int}]`\n", + "\n", + "**Creating Relationships for Prediction with GDS**:\n", + " For the prediction stage, create relationships of a specific type denoted as `REL_i`. Each of these relationships should have `rel_id` and `text` properties.\n", + " - Example Syntax: `[:REL_7 {rel_id: int, text: str}]`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "from ogb.utils.url import download_url\n", + "import os\n", + "import zipfile\n", + "\n", + "url = \"https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip\"\n", + "raw_dir = \"./data_from_zip\"\n", + "download_url(f\"{url}\", raw_dir)\n", + "\n", + "raw_file_names = [\"train.txt\", \"valid.txt\", \"test.txt\"]\n", + "with zipfile.ZipFile(raw_dir + \"/\" + os.path.basename(url), \"r\") as zip_ref:\n", + " for filename in raw_file_names:\n", + " zip_ref.extract(f\"Release/{filename}\", path=raw_dir)\n", + "data_dir = raw_dir + \"/\" + \"Release\"\n", + "\n", + "rel_types = {\n", + " \"train.txt\": \"TRAIN\",\n", + " \"valid.txt\": \"VALID\",\n", + " \"test.txt\": \"TEST\",\n", + "}\n", + "rel_id_to_text_dict = {}\n", + "rel_type_dict = collections.defaultdict(list)\n", + "rel_dict = {}\n", + "\n", + "\n", + "def read_data():\n", + " node_id_set = {}\n", + " dataset = defaultdict(lambda: defaultdict(list))\n", + " for file_name in raw_file_names:\n", + " file_name_path = data_dir + \"/\" + file_name\n", + "\n", + " with open(file_name_path, \"r\") as f:\n", + " data = [x.split(\"\\t\") for x in f.read().split(\"\\n\")[:-1]]\n", + "\n", + " for i, (src_text, rel_text, dst_text) in enumerate(data):\n", + " if src_text not in node_id_set:\n", + " node_id_set[src_text] = len(node_id_set)\n", + " if dst_text not in node_id_set:\n", + " node_id_set[dst_text] = len(node_id_set)\n", + " if rel_text not in rel_dict:\n", + " rel_dict[rel_text] = len(rel_dict)\n", + " rel_id_to_text_dict[rel_dict[rel_text]] = rel_text\n", + "\n", + " source = node_id_set[src_text]\n", + " target = node_id_set[dst_text]\n", + " rel_type = \"REL_\" + str(rel_dict[rel_text])\n", + " rel_split = rel_types[file_name]\n", + "\n", + " dataset[rel_split][rel_type].append(\n", + " {\n", + " \"source\": source,\n", + " \"source_text\": src_text,\n", + " \"target\": target,\n", + " \"target_text\": dst_text,\n", + " # \"rel_text\": rel_text,\n", + " }\n", + " )\n", + "\n", + " print(\"Number of nodes: \", len(node_id_set))\n", + " for rel_split in dataset:\n", + " print(\n", + " f\"Number of relationships of type {rel_split}: \",\n", + " sum([len(dataset[rel_split][rel_type]) for rel_type in dataset[rel_split]]),\n", + " )\n", + " return dataset\n", + "\n", + "\n", + "dataset = read_data()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def put_data_in_db(data):\n", + " for rel_split in tqdm(data, desc=\"Relationship\"):\n", + " for rel_type in tqdm(data[rel_split], mininterval=1, leave=False):\n", + " edges = data[rel_split][rel_type]\n", + "\n", + " gds.run_cypher(\n", + " f\"\"\"\n", + " UNWIND $ll as l\n", + " MERGE (n:Entity {{text:l.source_text}})\n", + " MERGE (m:Entity {{text:l.target_text}})\n", + " MERGE (n)-[:{rel_type}]->(m)\n", + " MERGE (n)-[:{rel_split}]->(m)\n", + " \"\"\",\n", + " params={\"ll\": edges},\n", + " )\n", + "\n", + "\n", + "put_data_in_db(dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Project all data in graph to get mapping between `id` and internal `nodeId` field from database." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ALL_RELS = dataset[\"TRAIN\"].keys()\n", + "G, result = gds.graph.cypher.project(\n", + " \"\"\"\n", + " MATCH (n:Entity)-[:TRAIN]->(m:Entity)<-[:\"\"\"\n", + " + \"|\".join(ALL_RELS)\n", + " + \"\"\"]-(n:Entity)\n", + " RETURN gds.graph.project($graph_name, n, m, {\n", + " sourceNodeLabels: $label,\n", + " targetNodeLabels: $label\n", + " })\n", + " \"\"\", # Cypher query\n", + " database=\"neo4j\", # Target database\n", + " graph_name=\"G_full\", # Query parameter\n", + " label=\"Entity\", # Query parameter\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def inspect_graph(G):\n", + " func_names = [\n", + " \"name\",\n", + " # \"database\",\n", + " \"node_count\",\n", + " \"relationship_count\",\n", + " \"node_labels\",\n", + " \"relationship_types\",\n", + " # \"degree_distribution\", \"density\", \"size_in_bytes\", \"memory_usage\", \"exists\", \"configuration\", \"creation_time\", \"modification_time\",\n", + " ]\n", + " for func_name in func_names:\n", + " print(f\"==={func_name}===: {getattr(G, func_name)()}\")\n", + "\n", + "\n", + "inspect_graph(G)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gds.set_compute_cluster_ip(\"localhost\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "model_name = \"fb15k-TransE-128-model-\" + str(time.time())\n", + "gds.kge.model.train(\n", + " G,\n", + " model_name=model_name,\n", + " scoring_function=\"TransE\",\n", + " embedding_dimension=128,\n", + " num_epochs=100,\n", + " filtered_metrics=False,\n", + " batch_size=32_768,\n", + " optimizer=\"Adam\",\n", + " optimizer_kwargs={\"lr\": 0.0003},\n", + " epochs_per_val=0,\n", + " do_validation=False,\n", + " do_test=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Project the graph with all nodes and existing relationships of the selected type." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "source_node_list = [\"/m/07l450\", \"/m/0ds2l81\", \"/m/0jvt9\"]\n", + "\n", + "source_ids_df = gds.run_cypher(\n", + " \"UNWIND $node_text_list AS t MATCH (n:Entity) WHERE n.text=t RETURN id(n) as nodeId\",\n", + " params={\"node_text_list\": source_node_list},\n", + ")\n", + "node_ids = source_ids_df[\"nodeId\"].to_list()\n", + "\n", + "rel_label_to_predict = \"REL_\" + str(rel_dict[\"/film/film/genre\"])\n", + "\n", + "predict_result = gds.kge.model.predict(\n", + " model_name=model_name,\n", + " top_k=3,\n", + " node_ids=node_ids,\n", + " rel_types=[rel_label_to_predict],\n", + ")\n", + "\n", + "print(predict_result.to_string())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Retrieve the embedding for the selected relationship from the PyG model. Then, create a GDS TransE model using the graph, node embeddings property, and the embedding for the relationship to be predicted." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "source_node_list = [\"/m/07l450\", \"/m/0ds2l81\", \"/m/0jvt9\"]\n", + "source_ids_df = gds.run_cypher(\n", + " \"UNWIND $node_text_list AS t MATCH (n:Entity) WHERE n.text=t RETURN id(n) as nodeId\",\n", + " params={\"node_text_list\": source_node_list},\n", + ")\n", + "source_ids_df[\"nodeId\"].to_list()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Now, we can use the model to make prediction." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = transe_model.predict_stream(\n", + " source_node_filter=source_ids_df.nodeId,\n", + " target_node_filter=\"Entity\",\n", + " relationship_type=rel_label_to_predict,\n", + " top_k=3,\n", + " concurrency=4,\n", + ")\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Augment the predicted result with node identifiers and their text values." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ids_in_result = pd.unique(pd.concat([result.sourceNodeId, result.targetNodeId]))\n", + "\n", + "ids_to_text = gds.run_cypher(\n", + " \"UNWIND $ids AS id MATCH (n:Entity) WHERE id(n)=id RETURN id(n) AS nodeId, n.text AS tag, n.id AS id\",\n", + " params={\"ids\": ids_in_result},\n", + ")\n", + "\n", + "nodeId_to_text_res = dict(zip(ids_to_text.nodeId, ids_to_text.tag))\n", + "nodeId_to_id_res = dict(zip(ids_to_text.nodeId, ids_to_text.id))\n", + "\n", + "result.insert(1, \"sourceTag\", result.sourceNodeId.map(lambda x: nodeId_to_text_res[x]))\n", + "result.insert(2, \"sourceId\", result.sourceNodeId.map(lambda x: nodeId_to_id_res[x]))\n", + "result.insert(4, \"targetTag\", result.targetNodeId.map(lambda x: nodeId_to_text_res[x]))\n", + "result.insert(5, \"targetId\", result.targetNodeId.map(lambda x: nodeId_to_id_res[x]))\n", + "\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Using Write Mode\n", + "\n", + "Write mode allows you to write results directly to the database as a new relationship type. This approach helps to avoid mapping from `nodeId` to `id`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "write_relationship_type = \"PREDICTED_\" + rel_label_to_predict\n", + "result_write = transe_model.predict_write(\n", + " source_node_filter=source_ids_df.nodeId,\n", + " target_node_filter=\"Entity\",\n", + " relationship_type=rel_label_to_predict,\n", + " write_relationship_type=write_relationship_type,\n", + " write_property=\"transe_score\",\n", + " top_k=3,\n", + " concurrency=4,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Extract the result from the database." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gds.run_cypher(\n", + " \"MATCH (n)-[r:\"\n", + " + write_relationship_type\n", + " + \"]->(m) RETURN n.id AS sourceId, n.text AS sourceTag, m.id AS targetId, m.text AS targetTag, r.transe_score AS score\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gds.graph.drop(G_test)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/kge-distmult.py b/examples/kge-distmult.py new file mode 100644 index 000000000..db408ad5c --- /dev/null +++ b/examples/kge-distmult.py @@ -0,0 +1,213 @@ +import os +import time +import warnings +from collections import defaultdict + +from neo4j.exceptions import ClientError +from tqdm import tqdm + +from graphdatascience import GraphDataScience + +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +def setup_connection(): + NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687") + NEO4J_AUTH = None + NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j") + if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"): + NEO4J_AUTH = ( + os.environ.get("NEO4J_USER"), + os.environ.get("NEO4J_PASSWORD"), + ) + gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True) + + return gds + + +def create_constraint(gds): + try: + _ = gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE") + except ClientError: + print("CONSTRAINT entity_id already exists") + + +def download_data(raw_file_names): + import os + import zipfile + + from ogb.utils.url import download_url + + url = "https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip" + raw_dir = "./data_from_zip" + download_url(f"{url}", raw_dir) + + with zipfile.ZipFile(raw_dir + "/" + os.path.basename(url), "r") as zip_ref: + for filename in raw_file_names: + zip_ref.extract(f"Release/{filename}", path=raw_dir) + data_dir = raw_dir + "/" + "Release" + return data_dir + + +def read_data(): + rel_types = { + "train.txt": "TRAIN", + "valid.txt": "VALID", + "test.txt": "TEST", + } + raw_file_names = ["train.txt", "valid.txt", "test.txt"] + + data_dir = download_data(raw_file_names) + + rel_id_to_text_dict = {} + rel_dict = {} + node_id_set = {} + dataset = defaultdict(lambda: defaultdict(list)) + for file_name in raw_file_names: + file_name_path = data_dir + "/" + file_name + + with open(file_name_path, "r") as f: + data = [x.split("\t") for x in f.read().split("\n")[:-1]] + + for i, (src_text, rel_text, dst_text) in enumerate(data): + if src_text not in node_id_set: + node_id_set[src_text] = len(node_id_set) + if dst_text not in node_id_set: + node_id_set[dst_text] = len(node_id_set) + if rel_text not in rel_dict: + rel_dict[rel_text] = len(rel_dict) + rel_id_to_text_dict[rel_dict[rel_text]] = rel_text + + source = node_id_set[src_text] + target = node_id_set[dst_text] + rel_type = "REL_" + str(rel_dict[rel_text]) + rel_split = rel_types[file_name] + + dataset[rel_split][rel_type].append( + { + "source": source, + "source_text": src_text, + "target": target, + "target_text": dst_text, + # "rel_text": rel_text, + } + ) + + print("Number of nodes: ", len(node_id_set)) + for rel_split in dataset: + print( + f"Number of relationships of type {rel_split}: ", + sum([len(dataset[rel_split][rel_type]) for rel_type in dataset[rel_split]]), + ) + return dataset + + +def put_data_in_db(gds): + res = gds.run_cypher("MATCH (m) RETURN count(m) as num_nodes") + if res["num_nodes"].values[0] > 0: + print("Data already in db, number of nodes: ", res["num_nodes"].values[0]) + return + dataset = read_data() + pbar = tqdm( + desc="Putting data in db", + total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]), + ) + rel_split_id = {"TRAIN": 0, "VALID": 1, "TEST": 2} + for rel_split in dataset: + for rel_type in dataset[rel_split]: + edges = dataset[rel_split][rel_type] + + gds.run_cypher( + f""" + UNWIND $ll as l + MERGE (n:Entity {{id:l.source, text:l.source_text}}) + MERGE (m:Entity {{id:l.target, text:l.target_text}}) + MERGE (n)-[:{rel_type} {{split: {rel_split_id[rel_split]}}}]->(m) + """, + params={"ll": edges}, + ) + pbar.update(len(edges)) + pbar.close() + + for rel_split in dataset: + res = gds.run_cypher( + f""" + MATCH ()-[r:{rel_split}]->() + RETURN COUNT(r) AS numberOfRelationships + """ + ) + print(f"Number of relationships of type {rel_split} in db: ", res.numberOfRelationships) + + +def project_train_graph(gds): + all_rels = gds.run_cypher( + """ + CALL db.relationshipTypes() YIELD relationshipType + """ + ) + all_rels = all_rels["relationshipType"].to_list() + all_rels = [rel for rel in all_rels if rel.startswith("REL_")] + gds.graph.drop("trainGraph", failIfMissing=False) + + G_train, result = gds.graph.project("trainGraph", ["Entity"], all_rels) + + return G_train + + +def inspect_graph(G): + func_names = [ + "name", + "node_count", + "relationship_count", + "node_labels", + "relationship_types", + ] + for func_name in func_names: + print(f"==={func_name}===: {getattr(G, func_name)()}") + + +if __name__ == "__main__": + gds = setup_connection() + create_constraint(gds) + put_data_in_db(gds) + G_train = project_train_graph(gds) + + gds.set_compute_cluster_ip("localhost") + + print(gds.debug.arrow()) + + model_name = "dummyModelName_" + str(time.time()) + + node_id_text = gds.find_node_id(["Entity"], {"text": "/m/016wzw"}) + node_id_2 = gds.find_node_id(["Entity"], {"id": 2}) + node_id_3 = gds.find_node_id(["Entity"], {"id": 3}) + node_id_0 = gds.find_node_id(["Entity"], {"id": 0}) + + res = gds.kge.model.train( + G_train, + model_name=model_name, + scoring_function="distmult", + num_epochs=1, + embedding_dimension=10, + epochs_per_checkpoint=0, + ) + print(res["metrics"]) + + res = gds.kge.model.predict( + model_name=model_name, + top_k=10, + node_ids=[node_id_3, node_id_2, node_id_text], + rel_types=["REL_1", "REL_2"], + ) + print(res.to_string()) + + scores = gds.kge.model.score_triplets( + model_name=model_name, + triplets=[ + (node_id_2, "REL_1", node_id_text), + (node_id_0, "REL_123", node_id_3), + ], + ) + print(scores) + + print("Finished training") diff --git a/examples/kge-transe-construct.ipynb b/examples/kge-transe-construct.ipynb new file mode 100644 index 000000000..1fb367008 --- /dev/null +++ b/examples/kge-transe-construct.ipynb @@ -0,0 +1,160 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "11d08c597a9fdbf3", + "metadata": { + "collapsed": false + }, + "source": [ + "# Knowledge Graph Embedding: Transe embedding for constructed dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1652ca866f022d69", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import time\n", + "import warnings\n", + "from neo4j.exceptions import ClientError\n", + "from graphdatascience import GraphDataScience" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e46dc2dd1419e518", + "metadata": {}, + "outputs": [], + "source": [ + "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "691981a7e5372ad2", + "metadata": {}, + "outputs": [], + "source": [ + "NEO4J_URI = os.environ.get(\"NEO4J_URI\", \"bolt://localhost:7687\")\n", + "NEO4J_AUTH = None\n", + "NEO4J_DB = os.environ.get(\"NEO4J_DB\", \"neo4j\")\n", + "if os.environ.get(\"NEO4J_USER\") and os.environ.get(\"NEO4J_PASSWORD\"):\n", + " NEO4J_AUTH = (\n", + " os.environ.get(\"NEO4J_USER\"),\n", + " os.environ.get(\"NEO4J_PASSWORD\"),\n", + " )\n", + "gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43cbb7c743877929", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " _ = gds.run_cypher(\"CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE\")\n", + "except ClientError:\n", + " print(\"CONSTRAINT entity_id already exists\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be889ec11e9b5759", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas\n", + "\n", + "nodes = pandas.DataFrame(\n", + " {\n", + " \"nodeId\": [0, 1, 2, 3, 7, 10],\n", + " \"labels\": [\"A\", \"B\", \"C\", \"A\", \"B\", \"C\"],\n", + " \"prop1\": [42, 1337, 8, 0, 1, 2],\n", + " \"otherProperty\": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],\n", + " }\n", + ")\n", + "\n", + "relationships = pandas.DataFrame(\n", + " {\n", + " \"sourceNodeId\": [0, 1, 2, 7],\n", + " \"targetNodeId\": [1, 2, 3, 10],\n", + " \"relationshipType\": [\"REL1\", \"REL1\", \"REL2\", \"REL2\"],\n", + " \"weight\": [0.0, 0.0, 0.1, 42.0],\n", + " }\n", + ")\n", + "\n", + "gds.graph.drop(\"my-graph\", failIfMissing=False)\n", + "G_train = gds.graph.construct(\n", + " \"my-graph\", # Graph name\n", + " nodes, # One or more dataframes containing node data\n", + " relationships, # One or more dataframes containing relationship data\n", + ")\n", + "\n", + "assert \"REL1\" in G_train.relationship_types()\n", + "assert \"REL2\" in G_train.relationship_types()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1638be48a275563f", + "metadata": {}, + "outputs": [], + "source": [ + "G_train.relationship_types()\n", + "G_train.node_labels()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "faae17b0f4551e7", + "metadata": {}, + "outputs": [], + "source": [ + "gds.set_compute_cluster_ip(\"localhost\")\n", + "\n", + "model_name = \"dummyModelName_\" + str(time.time())\n", + "\n", + "gds.kge.model.train(\n", + " G_train,\n", + " model_name=model_name,\n", + " scoring_function=\"transe\",\n", + " num_epochs=1,\n", + " embedding_dimension=16,\n", + " epochs_per_checkpoint=0,\n", + " split_ratios={\"TRAIN\": 0.75, \"TEST\": 0.25},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d88ba3d372525a6", + "metadata": {}, + "outputs": [], + "source": [ + "predict_result = gds.kge.model.predict(\n", + " model_name=model_name,\n", + " top_k=3,\n", + " node_ids=[1, 2, 0, 10, 7],\n", + " rel_types=[\"REL1\", \"REL2\"],\n", + ")\n", + "\n", + "print(predict_result.to_string())" + ] + } + ], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/graphdatascience/graph_data_science.py b/graphdatascience/graph_data_science.py index d280f1a35..e9b7aa152 100644 --- a/graphdatascience/graph_data_science.py +++ b/graphdatascience/graph_data_science.py @@ -1,19 +1,23 @@ from __future__ import annotations +import pathlib +import sys from typing import Any, Dict, Optional, Tuple, Type, Union from neo4j import Driver from pandas import DataFrame +from graphdatascience.graph.graph_proc_runner import GraphProcRunner +from graphdatascience.utils.util_proc_runner import UtilProcRunner + from .call_builder import IndirectCallBuilder from .endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints from .error.uncallable_namespace import UncallableNamespace -from .graph.graph_proc_runner import GraphProcRunner +from .model.kge_runner import KgeRunner from .query_runner.arrow_query_runner import ArrowQueryRunner from .query_runner.neo4j_query_runner import Neo4jQueryRunner from .query_runner.query_runner import QueryRunner from .server_version.server_version import ServerVersion -from .utils.util_proc_runner import UtilProcRunner class GraphDataScience(DirectEndpoints, UncallableNamespace): @@ -49,8 +53,7 @@ def __init__( database: Optional[str], default None The Neo4j database to query against. arrow : Union[str, bool], default True - Arrow connection information. This is either a string or a bool. - + Arrow connection information. This is either a string or a bool. - If it is a string, it will be interpreted as a connection URL to a GDS Arrow Server. - If it is a bool: - True will make the client discover the connection URI to the GDS Arrow server via the Neo4j endpoint. @@ -83,7 +86,31 @@ def __init__( None if arrow is True else arrow, ) - super().__init__(self._query_runner, namespace="gds", server_version=self._server_version) + # if auth is not None: + # with open(self._path("graphdatascience.resources.field-testing", "pub.pem"), "rb") as f: + # pub_key = rsa.PublicKey.load_pkcs1(f.read()) + # self._encrypted_db_password = rsa.encrypt(auth[1].encode(), pub_key).hex() + + self._encrypted_db_password = None + self._compute_cluster_ip = None + + super().__init__(self._query_runner, "gds", self._server_version) + + def set_compute_cluster_ip(self, ip: str) -> None: + self._compute_cluster_ip = ip + + @staticmethod + def _path(package: str, resource: str) -> pathlib.Path: + if sys.version_info >= (3, 9): + from importlib.resources import files + + # files() returns a Traversable, but usages require a Path object + return pathlib.Path(str(files(package) / resource)) + else: + from importlib.resources import path + + # we dont want to use a context manager here, so we need to call __enter__ manually + return path(package, resource).__enter__() @property def graph(self) -> GraphProcRunner: @@ -101,6 +128,23 @@ def alpha(self) -> AlphaEndpoints: def beta(self) -> BetaEndpoints: return BetaEndpoints(self._query_runner, "gds.beta", self._server_version) + @property + def kge(self) -> KgeRunner: + if not isinstance(self._query_runner, ArrowQueryRunner): + raise ValueError("Running FastPath requires GDS with the Arrow server enabled") + if self._compute_cluster_ip is None: + raise ValueError( + "You must set a valid computer cluster ip with the method `set_compute_cluster_ip` to use this feature" + ) + return KgeRunner( + self._query_runner, + "gds.kge.model", + self._server_version, + self._compute_cluster_ip, + self._encrypted_db_password, + self._query_runner._gds_arrow_client._host + ":" + str(self._query_runner._gds_arrow_client._port), + ) + def __getattr__(self, attr: str) -> IndirectCallBuilder: return IndirectCallBuilder(self._query_runner, f"gds.{attr}", self._server_version) diff --git a/graphdatascience/model/kge_runner.py b/graphdatascience/model/kge_runner.py new file mode 100644 index 000000000..70e3a2b6f --- /dev/null +++ b/graphdatascience/model/kge_runner.py @@ -0,0 +1,262 @@ +import json +import logging +import os +import time +from typing import Any, Dict, Optional + +import pyarrow +import requests +from pandas import DataFrame, Series + +from ..error.client_only_endpoint import client_only_endpoint +from ..error.illegal_attr_checker import IllegalAttrChecker +from ..error.uncallable_namespace import UncallableNamespace +from ..graph.graph_object import Graph +from ..query_runner.query_runner import QueryRunner +from ..server_version.server_version import ServerVersion + +logging.basicConfig(level=logging.INFO) + + +class KgeRunner(UncallableNamespace, IllegalAttrChecker): + def __init__( + self, + query_runner: QueryRunner, + namespace: str, + server_version: ServerVersion, + compute_cluster_ip: str, + encrypted_db_password: str, + arrow_uri: str, + ): + self._query_runner = query_runner + self._namespace = namespace + self._server_version = server_version + self._compute_cluster_web_uri = f"http://{compute_cluster_ip}:5005" + self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8815" + self._compute_cluster_mlflow_uri = f"http://{compute_cluster_ip}:8080" + self._encrypted_db_password = encrypted_db_password + self._arrow_uri = arrow_uri + + @property + def model(self) -> "KgeRunner": + return self + + # @compatible_with("stream", min_inclusive=ServerVersion(2, 5, 0)) + @client_only_endpoint("gds.kge.model") + def train( + self, + G: Graph, + model_name: str, + *, + num_epochs: int, + embedding_dimension: int, + epochs_per_checkpoint: Optional[int] = None, + load_from_checkpoint: Optional[tuple[str, int]] = None, + split_ratios=None, + scoring_function: str = "transe", + p_norm: float = 1.0, + batch_size: int = 512, + test_batch_size: int = 512, + optimizer: str = "adam", + optimizer_kwargs=None, + lr_scheduler: str = "ConstantLR", + lr_scheduler_kwargs=None, + loss_function: str = "MarginRanking", + loss_function_kwargs=None, + negative_sampling_size: int = 1, + use_node_type_aware_sampler: bool = False, + k_value: int = 10, + do_validation: bool = True, + do_test: bool = True, + filtered_metrics: bool = False, + epochs_per_val: int = 0, + inner_norm: bool = True, + random_seed: Optional[int] = None, + init_bound: Optional[float] = None, + mlflow_experiment_name: Optional[str] = None, + ) -> Series: + if epochs_per_checkpoint is None: + epochs_per_checkpoint = max(int(num_epochs / 10), 1) + if loss_function_kwargs is None: + loss_function_kwargs = dict(margin=1.0, adversarial_temperature=1.0, gamma=20.0) + if lr_scheduler_kwargs is None: + lr_scheduler_kwargs = dict(factor=1, total_iters=1000) + if optimizer_kwargs is None: + optimizer_kwargs = {"lr": 0.01, "weight_decay": 0.0005} + if split_ratios is None: + split_ratios = {"TRAIN": 0.8, "TEST": 0.2} + + algo_config = { + key: value + for key, value in locals().items() + if (key not in ["self", "G", "mlflow_experiment_name", "model_name"]) and (value is not None) + } + print(algo_config) + + graph_config = {"name": G.name(), "config_type": "GdsGraphConfig"} + + config = { + "user_name": "DUMMY_USER", + "task": "KGE_TRAINING_PYG", + "task_config": { + "graph_config": graph_config, + "modelname": model_name, + "task_config": algo_config, + }, + "graph_arrow_uri": self._arrow_uri, + } + if self._encrypted_db_password is not None: + config["encrypted_db_password"] = self._encrypted_db_password + + if mlflow_experiment_name is not None: + config["task_config"]["mlflow"] = { + "tracking_uri": self._compute_cluster_mlflow_uri, + "experiment_name": mlflow_experiment_name, + } + + job_id = self._start_job(config) + + self._wait_for_job(job_id) + + return Series( + { + "status": "finished", + "metrics": self._get_metrics(config["user_name"], config["task_config"]["modelname"], job_id), + } + ) + + @client_only_endpoint("gds.kge.model") + def predict( + self, + model_name: str, + top_k: int, + node_ids: list[int], + rel_types: list[str], + mlflow_experiment_name: Optional[str] = None, + ) -> DataFrame: + algo_config = { + "top_k": top_k, + "node_ids": node_ids, + "rel_types": rel_types, + } + + config = { + "user_name": "DUMMY_USER", + "task": "KGE_PREDICT_PYG", + "task_config": { + "graph_config": {"config_type": "GdsGraphConfig", "name": "NOGRAPH"}, + "modelname": model_name, + "task_config": algo_config, + "stream_rel_results": True, + }, + "graph_arrow_uri": self._arrow_uri, + } + if self._encrypted_db_password is not None: + config["encrypted_db_password"] = self._encrypted_db_password + + if mlflow_experiment_name is not None: + config["task_config"]["mlflow"] = { + "tracking_uri": self._compute_cluster_mlflow_uri, + "experiment_name": mlflow_experiment_name, + } + + job_id = self._start_job(config) + + self._wait_for_job(job_id) + + return self._stream_results(config, job_id) + + @client_only_endpoint("gds.kge.model") + def score_triplets( + self, + model_name: str, + triplets: list[tuple[int, str, int]], + mlflow_experiment_name: Optional[str] = None, + ) -> DataFrame: + algo_config = { + "triplets": triplets, + } + + config = { + "user_name": "DUMMY_USER", + "task": "KGE_SCORE_TRIPLETS_PYG", + "task_config": { + "graph_config": {"config_type": "GdsGraphConfig", "name": "NOGRAPH"}, + "modelname": model_name, + "task_config": algo_config, + "stream_rel_results": True, + }, + "graph_arrow_uri": self._arrow_uri, + } + if self._encrypted_db_password is not None: + config["encrypted_db_password"] = self._encrypted_db_password + + if mlflow_experiment_name is not None: + config["task_config"]["mlflow"] = { + "tracking_uri": self._compute_cluster_mlflow_uri, + "experiment_name": mlflow_experiment_name, + } + + job_id = self._start_job(config) + + self._wait_for_job(job_id) + + return self._stream_results(config, job_id) + + def _stream_results(self, config: dict, job_id: str) -> DataFrame: + client = pyarrow.flight.connect(self._compute_cluster_arrow_uri) + + if config["task_config"].get("stream_rel_results", False): + upload_descriptor = pyarrow.flight.FlightDescriptor.for_path(f"{job_id}.relationships") + else: + raise ValueError("No results to fetch: need to set stream_rel_results or stream_graph_results to True") + flight = client.get_flight_info(upload_descriptor) + reader = client.do_get(flight.endpoints[0].ticket) + read_table = reader.read_all() + + return read_table.to_pandas() + + def _get_metrics(self, user_name: str, model_name: str, job_id: str) -> DataFrame: + res = requests.get( + f"{self._compute_cluster_web_uri}/internal/fetch-model-metadata", + params={"user_name": user_name, "modelname": model_name}, + ) + res.raise_for_status() + + res_file_name = f"metadata_{job_id}.json" + + with open(res_file_name, mode="wb+") as f: + f.write(res.content) + + with open(res_file_name, mode="r") as f: + metadata = json.load(f) + + os.remove(res_file_name) + + return metadata.get("metrics", None) + + def _start_job(self, config: Dict[str, Any]) -> str: + url = f"{self._compute_cluster_web_uri}/api/machine-learning/start" + res = requests.post(url, json=config) + res.raise_for_status() + job_id = res.json()["job_id"] + logging.info(f"Job '{config['task']}' with ID '{job_id}' started") + + return job_id + + def _wait_for_job(self, job_id: str) -> None: + while True: + time.sleep(1) + + res = requests.get(f"{self._compute_cluster_web_uri}/api/machine-learning/status/{job_id}") + + res_json = res.json() + if res_json["job_status"] == "exited": + logging.info(f"Job with ID '{job_id}' completed") + return + elif res_json["job_status"] == "failed": + error = f"KGE job failed with errors:{os.linesep}{os.linesep.join(res_json['errors'])}" + if res.status_code == 400: + raise ValueError(error) + else: + raise RuntimeError(error) diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index 6f80f57c1..f4e1a2766 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -94,6 +94,12 @@ def __init__( if tls_root_certs: client_options["tls_root_certs"] = tls_root_certs + print("location:") + print(location) + print("client_options:") + print(client_options) + print("auth:") + print(auth) self._flight_client = flight.FlightClient(location, **client_options) def connection_info(self) -> Tuple[str, int]: diff --git a/requirements/base/base.txt b/requirements/base/base.txt index 3ca82b153..a3655cf56 100644 --- a/requirements/base/base.txt +++ b/requirements/base/base.txt @@ -7,3 +7,4 @@ textdistance >= 4.0, < 5.0 tqdm >= 4.0, < 5.0 typing-extensions >= 4.0, < 5.0 requests +rsa