Skip to content

Commit 94a1339

Browse files
committed
Use result server for getting results
1 parent 78a3a23 commit 94a1339

File tree

3 files changed

+56
-24
lines changed

3 files changed

+56
-24
lines changed

doc/modules/ROOT/pages/gds-session-algorithms/knowledge-graph-embeddings.adoc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,26 @@ predict_result = gds.kge.model.predict(
424424
print(predict_result.to_string())
425425
----
426426

427+
Metrics will be printed after training stage which is computed on the test set.
428+
[source, python, role=no-test]
429+
----
430+
{'mean_rank': 6.062, 'mean_reciprocal_rank': 0.238, 'hits_at_k': 0.742}
431+
----
432+
433+
434+
Result will be a pandas DataFrame with top 3 tail entities and their scores for each head entity and relationship type.
435+
436+
[source, python, role=no-test]
437+
----
438+
sourceNodeId rel targetNodeIdTopK scoreTopK
439+
0 8115 REL_RELDIPLOMACY [8109, 8116, 8118] [-4.326232433319092, -4.508733749389648, -4.542135715484619]
440+
1 8115 REL_RELNGO [8109, 8116, 8117] [-4.3115034103393555, -4.3574066162109375, -4.5306196212768555]
441+
2 8116 REL_RELDIPLOMACY [8109, 8116, 8118] [-5.225207328796387, -5.367417335510254, -5.4092488288879395]
442+
3 8116 REL_RELNGO [8109, 8116, 8117] [-4.960464954376221, -4.990216255187988, -5.14272403717041]
443+
4 8119 REL_RELDIPLOMACY [8109, 8120, 8116] [-4.9556193351745605, -5.094477653503418, -5.164356708526611]
444+
5 8119 REL_RELNGO [8109, 8116, 8117] [-3.9914486408233643, -4.040783882141113, -4.112575054168701]
445+
----
446+
427447
There is also a function to score the triplets.
428448

429449
[source, python, role=no-test]
@@ -437,4 +457,13 @@ scores = gds.kge.model.score_triplets(
437457
model_name=model_name,
438458
triplets=triplets,
439459
)
460+
----
461+
462+
Result will be a dataframe with score for each triplet.
463+
464+
[source, python, role=no-test]
465+
----
466+
sourceNodeId rel targetNodeId score
467+
0 8115 REL_RELNGO 8116 -4.357407
468+
1 8115 REL_RELDIPLOMACY 8119 -5.142065
440469
----

examples/kge-distmult-nations.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,11 @@ def inspect_graph(G):
195195
res = gds.kge.model.train(
196196
G_train,
197197
model_name=model_name,
198-
scoring_function="distmult",
199-
num_epochs=1,
200-
embedding_dimension=10,
198+
scoring_function="TransE",
199+
num_epochs=30,
200+
embedding_dimension=64,
201201
epochs_per_checkpoint=0,
202-
epochs_per_val=5,
202+
epochs_per_val=0,
203203
split_ratios={"TRAIN": 0.8, "VALID": 0.1, "TEST": 0.1},
204204
)
205205
print(res["metrics"])
@@ -218,7 +218,7 @@ def inspect_graph(G):
218218
print(predict_result.to_string())
219219

220220
for index, row in predict_result.iterrows():
221-
h = row["head"]
221+
h = row["sourceNodeId"]
222222
r = row["rel"]
223223
gds.run_cypher(
224224
f"""
@@ -227,7 +227,7 @@ def inspect_graph(G):
227227
MATCH (b:Entity WHERE id(b) = t)
228228
MERGE (a)-[:NEW_REL_{r}]->(b)
229229
""",
230-
params={"tt": row["tail"]},
230+
params={"tt": row["targetNodeIdTopK"]},
231231
)
232232

233233
brazil_node = gds.find_node_id(["Entity"], {"text": "brazil"})

graphdatascience/model/kge_runner.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
from typing import Any, Dict, Optional
66

7-
import pandas as pd
7+
import pyarrow
88
import requests
99
from pandas import DataFrame, Series
1010

@@ -32,12 +32,13 @@ def __init__(
3232
self._namespace = namespace
3333
self._server_version = server_version
3434
self._compute_cluster_web_uri = f"http://{compute_cluster_ip}:5005"
35+
self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8815"
3536
self._compute_cluster_mlflow_uri = f"http://{compute_cluster_ip}:8080"
3637
self._encrypted_db_password = encrypted_db_password
3738
self._arrow_uri = arrow_uri
3839

3940
@property
40-
def model(self):
41+
def model(self) -> "KgeRunner":
4142
return self
4243

4344
# @compatible_with("stream", min_inclusive=ServerVersion(2, 5, 0))
@@ -75,7 +76,7 @@ def train(
7576
mlflow_experiment_name: Optional[str] = None,
7677
) -> Series:
7778
if epochs_per_checkpoint is None:
78-
epochs_per_checkpoint = max(num_epochs / 10, 1)
79+
epochs_per_checkpoint = max(int(num_epochs / 10), 1)
7980
if loss_function_kwargs is None:
8081
loss_function_kwargs = dict(margin=1.0, adversarial_temperature=1.0, gamma=20.0)
8182
if lr_scheduler_kwargs is None:
@@ -92,7 +93,7 @@ def train(
9293
}
9394
print(algo_config)
9495

95-
graph_config = {"name": G.name()}
96+
graph_config = {"name": G.name(), "config_type": "GdsGraphConfig"}
9697

9798
config = {
9899
"user_name": "DUMMY_USER",
@@ -144,8 +145,10 @@ def predict(
144145
"user_name": "DUMMY_USER",
145146
"task": "KGE_PREDICT_PYG",
146147
"task_config": {
148+
"graph_config": {"config_type": "GdsGraphConfig", "name": "NOGRAPH"},
147149
"modelname": model_name,
148150
"task_config": algo_config,
151+
"stream_rel_results": True,
149152
},
150153
"graph_arrow_uri": self._arrow_uri,
151154
}
@@ -162,7 +165,7 @@ def predict(
162165

163166
self._wait_for_job(job_id)
164167

165-
return self._stream_results(config["user_name"], config["task_config"]["modelname"], job_id)
168+
return self._stream_results(config, job_id)
166169

167170
@client_only_endpoint("gds.kge.model")
168171
def score_triplets(
@@ -180,8 +183,10 @@ def score_triplets(
180183
"user_name": "DUMMY_USER",
181184
"task": "KGE_SCORE_TRIPLETS_PYG",
182185
"task_config": {
186+
"graph_config": {"config_type": "GdsGraphConfig", "name": "NOGRAPH"},
183187
"modelname": model_name,
184188
"task_config": algo_config,
189+
"stream_rel_results": True,
185190
},
186191
"graph_arrow_uri": self._arrow_uri,
187192
}
@@ -198,22 +203,20 @@ def score_triplets(
198203

199204
self._wait_for_job(job_id)
200205

201-
return self._stream_results(config["user_name"], config["task_config"]["modelname"], job_id)
206+
return self._stream_results(config, job_id)
202207

203-
def _stream_results(self, user_name: str, model_name: str, job_id: str) -> DataFrame:
204-
res = requests.get(
205-
f"{self._compute_cluster_web_uri}/internal/fetch-result",
206-
params={"user_name": user_name, "modelname": model_name, "job_id": job_id},
207-
)
208-
res.raise_for_status()
208+
def _stream_results(self, config: dict, job_id: str) -> DataFrame:
209+
client = pyarrow.flight.connect(self._compute_cluster_arrow_uri)
209210

210-
res_file_name = f"res_{job_id}.json"
211-
with open(res_file_name, mode="wb+") as f:
212-
f.write(res.content)
211+
if config["task_config"].get("stream_rel_results", False):
212+
upload_descriptor = pyarrow.flight.FlightDescriptor.for_path(f"{job_id}.relationships")
213+
else:
214+
raise ValueError("No results to fetch: need to set stream_rel_results or stream_graph_results to True")
215+
flight = client.get_flight_info(upload_descriptor)
216+
reader = client.do_get(flight.endpoints[0].ticket)
217+
read_table = reader.read_all()
213218

214-
df = pd.read_json(res_file_name, orient="records", lines=True)
215-
os.remove(res_file_name)
216-
return df
219+
return read_table.to_pandas()
217220

218221
def _get_metrics(self, user_name: str, model_name: str, job_id: str) -> DataFrame:
219222
res = requests.get(

0 commit comments

Comments
 (0)