Skip to content

Commit fd3ca63

Browse files
committed
Remove Graph from prediction
1 parent 6ef275d commit fd3ca63

File tree

2 files changed

+1
-21
lines changed

2 files changed

+1
-21
lines changed

examples/kge-distmult-nations.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,11 @@ def project_graphs(gds):
158158
gds.graph.drop("testGraph", failIfMissing=False)
159159

160160
G_full, _ = gds.graph.project("fullGraph", ["Entity"], all_rels)
161-
inspect_graph(G_full)
162161

163162
G_train, _ = gds.graph.filter("trainGraph", G_full, "*", "r.split = 0.0")
164163
G_valid, _ = gds.graph.filter("validGraph", G_full, "*", "r.split = 1.0")
165164
G_test, _ = gds.graph.filter("testGraph", G_full, "*", "r.split = 2.0")
166165

167-
inspect_graph(G_train)
168-
inspect_graph(G_valid)
169-
inspect_graph(G_test)
170166

171167
gds.graph.drop("fullGraph", failIfMissing=False)
172168

@@ -190,14 +186,9 @@ def inspect_graph(G):
190186
create_constraint(gds)
191187
put_data_in_db(gds)
192188
G_train, G_valid, G_test = project_graphs(gds)
193-
inspect_graph(G_train)
194-
inspect_graph(G_valid)
195-
inspect_graph(G_test)
196189

197190
gds.set_compute_cluster_ip("localhost")
198191

199-
print(gds.debug.arrow())
200-
201192
model_name = "dummyModelName_" + str(time.time())
202193

203194
gds.kge.model.train(
@@ -210,7 +201,6 @@ def inspect_graph(G):
210201
)
211202

212203
df = gds.kge.model.predict(
213-
G_train,
214204
model_name=model_name,
215205
top_k=3,
216206
node_ids=[
@@ -221,7 +211,7 @@ def inspect_graph(G):
221211
rel_types=["REL_RELDIPLOMACY", "REL_RELNGO"],
222212
)
223213

224-
print(df)
214+
print(df.to_string())
225215
#
226216
# gds.kge.model.predict_tail(
227217
# G_train,
@@ -240,4 +230,3 @@ def inspect_graph(G):
240230
# ],
241231
# )
242232

243-
print("Finished training")

graphdatascience/model/kge_runner.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ def __init__(
3434
self._compute_cluster_mlflow_uri = f"http://{compute_cluster_ip}:8080"
3535
self._encrypted_db_password = encrypted_db_password
3636
self._arrow_uri = arrow_uri
37-
print("KgeRunner __dict__:")
38-
print(self.__dict__)
3937

4038
@property
4139
def model(self):
@@ -89,14 +87,12 @@ def train(
8987
@client_only_endpoint("gds.kge.model")
9088
def predict(
9189
self,
92-
G: Graph,
9390
model_name: str,
9491
top_k: int,
9592
node_ids: list[int],
9693
rel_types: list[str],
9794
mlflow_experiment_name: Optional[str] = None,
9895
) -> DataFrame:
99-
graph_config = {"name": G.name()}
10096

10197
algo_config = {
10298
"top_k": top_k,
@@ -108,7 +104,6 @@ def predict(
108104
"user_name": "DUMMY_USER",
109105
"task": "KGE_PREDICT_PYG",
110106
"task_config": {
111-
"graph_config": graph_config,
112107
"modelname": model_name,
113108
"task_config": algo_config,
114109
},
@@ -122,8 +117,6 @@ def predict(
122117
"config": {"tracking_uri": self._compute_cluster_mlflow_uri, "experiment_name": mlflow_experiment_name}
123118
}
124119

125-
print("predict config")
126-
print(config)
127120
job_id = self._start_job(config)
128121

129122
self._wait_for_job(job_id)
@@ -146,8 +139,6 @@ def _stream_results(self, user_name: str, model_name: str, job_id: str) -> DataF
146139
return df
147140

148141
def _start_job(self, config: Dict[str, Any]) -> str:
149-
print("_start_job")
150-
print(config)
151142
url = f"{self._compute_cluster_web_uri}/api/machine-learning/start"
152143
print(url)
153144
res = requests.post(url, json=config)

0 commit comments

Comments
 (0)