Skip to content

Commit d898adb

Browse files
Mats-SXbreakanalysisbrs96
committed
Extend notebook to include logging feedback feature
Co-authored-by: Jacob Sznajdman <breakanalysis@gmail.com> Co-authored-by: Brian Shi <brian.shi@neotechnology.com>
1 parent ce06fc7 commit d898adb

File tree

1 file changed

+55
-13
lines changed

1 file changed

+55
-13
lines changed

examples/python-runtime.ipynb

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,19 +71,41 @@
7171
"It happens asynchronously, so it will return immediately (unless there's an unexpected error 😱).\n",
7272
"Of course, the training does not complete instantly, so you will have to wait for it to finish.\n",
7373
"\n",
74-
"TODO: instructions for inspecting the log\n",
74+
"## Observing the training progress\n",
75+
"\n",
76+
"You can observe the training progress by watching the logs.\n",
77+
"This is done in the subsequent cell.\n",
78+
"The watching doesn't automatically stop, so you will have to stop it manually.\n",
79+
"Once you see the message 'Training Done', you can interrupt the cell and continue.\n",
80+
"\n",
81+
"## Graph and training parameters\n",
82+
"\n",
83+
"\n",
84+
"\n",
85+
"\n",
86+
"| Parameter | Default | Type | Description |\n",
87+
"|--------------------|----------------|----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n",
88+
"| graph_name | - | str | The name of the graph to train on. |\n",
89+
"| model_name | - | str | The name of the model. Must be unique per database and username combination. Models cannot be cleaned up at this time. |\n",
90+
"| feature_properties | - | List[str] | The node properties to use as model features. |\n",
91+
"| target_property | - | str | The node property that contains the target class values. |\n",
92+
"| node_labels | None | List[str] | The node labels to use for training. By default, all labels are used. |\n",
93+
"| relationship_types | None | List[str] | The relationship types to use for training. By default, all types are used. |\n",
94+
"| target_node_label | None | str | Indicates the nodes used for training. Only nodes with this label need to have the `target_property` defined. Other nodes are used for context. By default, all nodes are considered. |\n",
95+
"| graph_sage_config | None | dict | Configuration for the GraphSAGE training. See below. |\n",
96+
"\n",
7597
"\n",
7698
"## GraphSAGE parameters\n",
7799
"\n",
78100
"We have exposed several parameters of the PyG GraphSAGE model.\n",
79101
"\n",
80-
"| Parameter | Default | Description |\n",
81-
"|-----------------|----------|-------------|\n",
82-
"| layer_config | {} | ??? |\n",
83-
"| num_neighbors | [25, 10] | ??? |\n",
84-
"| dropout | 0.5 | ??? |\n",
85-
"| hidden_channels | 256 | ??? |\n",
86-
"| learning_rate | 0.003 | ??? |\n",
102+
"| Parameter | Default | Description |\n",
103+
"|-----------------|----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n",
104+
"| layer_config | {} | Configuration of the GraphSAGE layers. It supports `aggr`, `normalize`, `root_weight`, `project`, `bias` from [this link](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SAGEConv.html). Additionally, you can provide message passing configuration from [this link](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.MessagePassing.html#torch_geometric.nn.conv.MessagePassing). |\n",
105+
"| num_neighbors | [25, 10] | Sample sizes for each layer. The length of this list is the number of layers used. All numbers must be >0. |\n",
106+
"| dropout | 0.5 | Probability of dropping out neurons during training. Must be between 0 and 1. |\n",
107+
"| hidden_channels | 256 | The dimension of each hidden layer. Higher value means more expensive training, but higher level of representation. Must be >0. |\n",
108+
"| learning_rate | 0.003 | The learning rate. Must be >0. |\n",
87109
"\n",
88110
"Please try to use any of them with any useful values.\n"
89111
]
@@ -95,11 +117,21 @@
95117
"outputs": [],
96118
"source": [
97119
"# Let's train!\n",
98-
"train_response = gds.gnn.nodeClassification.train(\n",
120+
"job_id = gds.gnn.nodeClassification.train(\n",
99121
" \"cora\", \"myModel\", [\"features\"], \"subject\", [\"CITES\"], target_node_label=\"Paper\", node_labels=[\"Paper\"]\n",
100122
")"
101123
]
102124
},
125+
{
126+
"cell_type": "code",
127+
"execution_count": null,
128+
"metadata": {},
129+
"outputs": [],
130+
"source": [
131+
"# And let's follow the progress by watching the logs\n",
132+
"gds.gnn.nodeClassification.watch_logs(job_id)"
133+
]
134+
},
103135
{
104136
"cell_type": "code",
105137
"execution_count": null,
@@ -134,8 +166,7 @@
134166
"In this case, we will use it to predict the subject of papers in the Cora dataset.\n",
135167
"\n",
136168
"Again, this call is asynchronous, so it will return immediately.\n",
137-
"\n",
138-
"TODO: instructions for inspecting the log\n",
169+
"Observe the progress by watching the logs.\n",
139170
"\n",
140171
"Once the prediction is completed, the predicted classes are added to GDS Graph Catalog (as per normal).\n",
141172
"We can retrieve the prediction result (the predictions themselves) by streaming from the graph.\n"
@@ -148,7 +179,17 @@
148179
"outputs": [],
149180
"source": [
150181
"# Let's trigger prediction!\n",
151-
"predict_result = gds.gnn.nodeClassification.predict(\"cora\", \"myModel\", \"myPredictions\")"
182+
"job_id = gds.gnn.nodeClassification.predict(\"cora\", \"myModel\", \"myPredictions\")"
183+
]
184+
},
185+
{
186+
"cell_type": "code",
187+
"execution_count": null,
188+
"metadata": {},
189+
"outputs": [],
190+
"source": [
191+
"# And let's follow progress by watching the logs\n",
192+
"gds.gnn.nodeClassification.watch_logs(job_id)"
152193
]
153194
},
154195
{
@@ -157,7 +198,7 @@
157198
"metadata": {},
158199
"outputs": [],
159200
"source": [
160-
"# Let's get a graph object\n",
201+
"# Now that prediction is done, let's see the predictions\n",
161202
"cora = gds.graph.get(\"cora\")"
162203
]
163204
},
@@ -194,6 +235,7 @@
194235
"Thank you very much for participating in the testing.\n",
195236
"We hope you enjoyed it.\n",
196237
"If you've run the notebook for the first time, now's the time to experiment and changing graph, training parameters, etc.\n",
238+
"For example, try out a heterogeneous graph problem? Or whether performance can be improved by changing some parameter? Run training jobs in parallel, on multiple databases?\n",
197239
"If you're feeling like you're done, please reach back to the Google Document and fill in our feedback form.\n",
198240
"\n",
199241
"Thank you!"

0 commit comments

Comments
 (0)