@@ -45,20 +45,52 @@ def train(
45
45
self ,
46
46
G : Graph ,
47
47
model_name : str ,
48
- scoring_function ,
49
- num_epochs ,
50
- embedding_dimension ,
51
- epochs_per_checkpoint ,
48
+ * ,
49
+ num_epochs : int ,
50
+ embedding_dimension : int ,
51
+ epochs_per_checkpoint : Optional [int ] = None ,
52
+ load_from_checkpoint : Optional [tuple [str , int ]] = None ,
53
+ split_ratios = None ,
54
+ scoring_function : str = "transe" ,
55
+ p_norm : float = 1.0 ,
56
+ batch_size : int = 512 ,
57
+ test_batch_size : int = 512 ,
58
+ optimizer : str = "adam" ,
59
+ optimizer_kwargs = None ,
60
+ lr_scheduler : str = "ConstantLR" ,
61
+ lr_scheduler_kwargs = None ,
62
+ loss_function : str = "MarginRanking" ,
63
+ loss_function_kwargs = None ,
64
+ negative_sampling_size : int = 1 ,
65
+ use_node_type_aware_sampler : bool = False ,
66
+ k_value : int = 10 ,
67
+ do_validation : bool = True ,
68
+ do_test : bool = True ,
69
+ filtered_metrics : bool = False ,
70
+ epochs_per_val : int = 50 ,
71
+ inner_norm : bool = True ,
72
+ init_bound : Optional [float ] = None ,
52
73
mlflow_experiment_name : Optional [str ] = None ,
53
74
) -> Series :
54
- graph_config = {"name" : G .name ()}
75
+ if epochs_per_checkpoint is None :
76
+ epochs_per_checkpoint = max (num_epochs / 10 , 1 )
77
+ if loss_function_kwargs is None :
78
+ loss_function_kwargs = dict (margin = 1.0 , adversarial_temperature = 1.0 , gamma = 20.0 )
79
+ if lr_scheduler_kwargs is None :
80
+ lr_scheduler_kwargs = dict (factor = 1 , total_iters = 1000 )
81
+ if optimizer_kwargs is None :
82
+ optimizer_kwargs = {"lr" : 0.01 , "weight_decay" : 0.0005 }
83
+ if split_ratios is None :
84
+ split_ratios = {"TRAIN" : 0.8 , "TEST" : 0.2 }
55
85
56
86
algo_config = {
57
- "scoring_function" : scoring_function ,
58
- "num_epochs" : num_epochs ,
59
- "embedding_dimension" : embedding_dimension ,
60
- "epochs_per_checkpoint" : epochs_per_checkpoint ,
87
+ key : value
88
+ for key , value in locals ().items ()
89
+ if (key not in ["self" , "G" , "mlflow_experiment_name" , "model_name" ]) and (value is not None )
61
90
}
91
+ print (algo_config )
92
+
93
+ graph_config = {"name" : G .name ()}
62
94
63
95
config = {
64
96
"user_name" : "DUMMY_USER" ,
0 commit comments