Skip to content

Commit 2385beb

Browse files
committed
Pass all parameters from docs
1 parent fd3ca63 commit 2385beb

File tree

2 files changed

+43
-12
lines changed

2 files changed

+43
-12
lines changed

examples/kge-distmult-nations.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ def project_graphs(gds):
163163
G_valid, _ = gds.graph.filter("validGraph", G_full, "*", "r.split = 1.0")
164164
G_test, _ = gds.graph.filter("testGraph", G_full, "*", "r.split = 2.0")
165165

166-
167166
gds.graph.drop("fullGraph", failIfMissing=False)
168167

169168
return G_train, G_valid, G_test
@@ -194,10 +193,11 @@ def inspect_graph(G):
194193
gds.kge.model.train(
195194
G_train,
196195
model_name=model_name,
197-
scoring_function="DistMult",
196+
scoring_function="distmult",
198197
num_epochs=1,
199198
embedding_dimension=10,
200199
epochs_per_checkpoint=0,
200+
epochs_per_val=0,
201201
)
202202

203203
df = gds.kge.model.predict(
@@ -229,4 +229,3 @@ def inspect_graph(G):
229229
# (gds.find_node_id(["Entity"], {"id": 0}), "REL_123", gds.find_node_id(["Entity"], {"id": 3})),
230230
# ],
231231
# )
232-

graphdatascience/model/kge_runner.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,52 @@ def train(
4545
self,
4646
G: Graph,
4747
model_name: str,
48-
scoring_function,
49-
num_epochs,
50-
embedding_dimension,
51-
epochs_per_checkpoint,
48+
*,
49+
num_epochs: int,
50+
embedding_dimension: int,
51+
epochs_per_checkpoint: Optional[int] = None,
52+
load_from_checkpoint: Optional[tuple[str, int]] = None,
53+
split_ratios=None,
54+
scoring_function: str = "transe",
55+
p_norm: float = 1.0,
56+
batch_size: int = 512,
57+
test_batch_size: int = 512,
58+
optimizer: str = "adam",
59+
optimizer_kwargs=None,
60+
lr_scheduler: str = "ConstantLR",
61+
lr_scheduler_kwargs=None,
62+
loss_function: str = "MarginRanking",
63+
loss_function_kwargs=None,
64+
negative_sampling_size: int = 1,
65+
use_node_type_aware_sampler: bool = False,
66+
k_value: int = 10,
67+
do_validation: bool = True,
68+
do_test: bool = True,
69+
filtered_metrics: bool = False,
70+
epochs_per_val: int = 50,
71+
inner_norm: bool = True,
72+
init_bound: Optional[float] = None,
5273
mlflow_experiment_name: Optional[str] = None,
5374
) -> Series:
54-
graph_config = {"name": G.name()}
75+
if epochs_per_checkpoint is None:
76+
epochs_per_checkpoint = max(num_epochs / 10, 1)
77+
if loss_function_kwargs is None:
78+
loss_function_kwargs = dict(margin=1.0, adversarial_temperature=1.0, gamma=20.0)
79+
if lr_scheduler_kwargs is None:
80+
lr_scheduler_kwargs = dict(factor=1, total_iters=1000)
81+
if optimizer_kwargs is None:
82+
optimizer_kwargs = {"lr": 0.01, "weight_decay": 0.0005}
83+
if split_ratios is None:
84+
split_ratios = {"TRAIN": 0.8, "TEST": 0.2}
5585

5686
algo_config = {
57-
"scoring_function": scoring_function,
58-
"num_epochs": num_epochs,
59-
"embedding_dimension": embedding_dimension,
60-
"epochs_per_checkpoint": epochs_per_checkpoint,
87+
key: value
88+
for key, value in locals().items()
89+
if (key not in ["self", "G", "mlflow_experiment_name", "model_name"]) and (value is not None)
6190
}
91+
print(algo_config)
92+
93+
graph_config = {"name": G.name()}
6294

6395
config = {
6496
"user_name": "DUMMY_USER",

0 commit comments

Comments
 (0)