|
71 | 71 | "It happens asynchronously, so it will return immediately (unless there's an unexpected error 😱).\n",
|
72 | 72 | "Of course, the training does not complete instantly, so you will have to wait for it to finish.\n",
|
73 | 73 | "\n",
|
74 |
| - "TODO: instructions for inspecting the log\n", |
| 74 | + "## Observing the training progress\n", |
| 75 | + "\n", |
| 76 | + "You can observe the training progress by watching the logs.\n", |
| 77 | + "This is done in the subsequent cell.\n", |
| 78 | + "The watching doesn't automatically stop, so you will have to stop it manually.\n", |
| 79 | + "Once you see the message 'Training Done', you can interrupt the cell and continue.\n", |
| 80 | + "\n", |
| 81 | + "## Graph and training parameters\n", |
| 82 | + "\n", |
| 83 | + "\n", |
| 84 | + "\n", |
| 85 | + "\n", |
| 86 | + "| Parameter | Default | Type | Description |\n", |
| 87 | + "|--------------------|----------------|----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", |
| 88 | + "| graph_name | - | str | The name of the graph to train on. |\n", |
| 89 | + "| model_name | - | str | The name of the model. Must be unique per database and username combination. Models cannot be cleaned up at this time. |\n", |
| 90 | + "| feature_properties | - | List[str] | The node properties to use as model features. |\n", |
| 91 | + "| target_property | - | str | The node property that contains the target class values. |\n", |
| 92 | + "| node_labels | None | List[str] | The node labels to use for training. By default, all labels are used. |\n", |
| 93 | + "| relationship_types | None | List[str] | The relationship types to use for training. By default, all types are used. |\n", |
| 94 | + "| target_node_label | None | str | Indicates the nodes used for training. Only nodes with this label need to have the `target_property` defined. Other nodes are used for context. By default, all nodes are considered. |\n", |
| 95 | + "| graph_sage_config | None | dict | Configuration for the GraphSAGE training. See below. |\n", |
| 96 | + "\n", |
75 | 97 | "\n",
|
76 | 98 | "## GraphSAGE parameters\n",
|
77 | 99 | "\n",
|
78 | 100 | "We have exposed several parameters of the PyG GraphSAGE model.\n",
|
79 | 101 | "\n",
|
80 |
| - "| Parameter | Default | Description |\n", |
81 |
| - "|-----------------|----------|-------------|\n", |
82 |
| - "| layer_config | {} | ??? |\n", |
83 |
| - "| num_neighbors | [25, 10] | ??? |\n", |
84 |
| - "| dropout | 0.5 | ??? |\n", |
85 |
| - "| hidden_channels | 256 | ??? |\n", |
86 |
| - "| learning_rate | 0.003 | ??? |\n", |
| 102 | + "| Parameter | Default | Description |\n", |
| 103 | + "|-----------------|----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", |
| 104 | + "| layer_config | {} | Configuration of the GraphSAGE layers. It supports `aggr`, `normalize`, `root_weight`, `project`, `bias` from [this link](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SAGEConv.html). Additionally, you can provide message passing configuration from [this link](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.MessagePassing.html#torch_geometric.nn.conv.MessagePassing). |\n", |
| 105 | + "| num_neighbors | [25, 10] | Sample sizes for each layer. The length of this list is the number of layers used. All numbers must be >0. |\n", |
| 106 | + "| dropout | 0.5 | Probability of dropping out neurons during training. Must be between 0 and 1. |\n", |
| 107 | + "| hidden_channels | 256 | The dimension of each hidden layer. Higher value means more expensive training, but higher level of representation. Must be >0. |\n", |
| 108 | + "| learning_rate | 0.003 | The learning rate. Must be >0. |\n", |
87 | 109 | "\n",
|
88 | 110 | "Please try to use any of them with any useful values.\n"
|
89 | 111 | ]
|
|
95 | 117 | "outputs": [],
|
96 | 118 | "source": [
|
97 | 119 | "# Let's train!\n",
|
98 |
| - "train_response = gds.gnn.nodeClassification.train(\n", |
| 120 | + "job_id = gds.gnn.nodeClassification.train(\n", |
99 | 121 | " \"cora\", \"myModel\", [\"features\"], \"subject\", [\"CITES\"], target_node_label=\"Paper\", node_labels=[\"Paper\"]\n",
|
100 | 122 | ")"
|
101 | 123 | ]
|
102 | 124 | },
|
| 125 | + { |
| 126 | + "cell_type": "code", |
| 127 | + "execution_count": null, |
| 128 | + "metadata": {}, |
| 129 | + "outputs": [], |
| 130 | + "source": [ |
| 131 | + "# And let's follow the progress by watching the logs\n", |
| 132 | + "gds.gnn.nodeClassification.watch_logs(job_id)" |
| 133 | + ] |
| 134 | + }, |
103 | 135 | {
|
104 | 136 | "cell_type": "code",
|
105 | 137 | "execution_count": null,
|
|
134 | 166 | "In this case, we will use it to predict the subject of papers in the Cora dataset.\n",
|
135 | 167 | "\n",
|
136 | 168 | "Again, this call is asynchronous, so it will return immediately.\n",
|
137 |
| - "\n", |
138 |
| - "TODO: instructions for inspecting the log\n", |
| 169 | + "Observe the progress by watching the logs.\n", |
139 | 170 | "\n",
|
140 | 171 | "Once the prediction is completed, the predicted classes are added to GDS Graph Catalog (as per normal).\n",
|
141 | 172 | "We can retrieve the prediction result (the predictions themselves) by streaming from the graph.\n"
|
|
148 | 179 | "outputs": [],
|
149 | 180 | "source": [
|
150 | 181 | "# Let's trigger prediction!\n",
|
151 |
| - "predict_result = gds.gnn.nodeClassification.predict(\"cora\", \"myModel\", \"myPredictions\")" |
| 182 | + "job_id = gds.gnn.nodeClassification.predict(\"cora\", \"myModel\", \"myPredictions\")" |
| 183 | + ] |
| 184 | + }, |
| 185 | + { |
| 186 | + "cell_type": "code", |
| 187 | + "execution_count": null, |
| 188 | + "metadata": {}, |
| 189 | + "outputs": [], |
| 190 | + "source": [ |
| 191 | + "# And let's follow progress by watching the logs\n", |
| 192 | + "gds.gnn.nodeClassification.watch_logs(job_id)" |
152 | 193 | ]
|
153 | 194 | },
|
154 | 195 | {
|
|
157 | 198 | "metadata": {},
|
158 | 199 | "outputs": [],
|
159 | 200 | "source": [
|
160 |
| - "# Let's get a graph object\n", |
| 201 | + "# Now that prediction is done, let's see the predictions\n", |
161 | 202 | "cora = gds.graph.get(\"cora\")"
|
162 | 203 | ]
|
163 | 204 | },
|
|
194 | 235 | "Thank you very much for participating in the testing.\n",
|
195 | 236 | "We hope you enjoyed it.\n",
|
196 | 237 | "If you've run the notebook for the first time, now's the time to experiment and changing graph, training parameters, etc.\n",
|
| 238 | + "For example, try out a heterogeneous graph problem? Or whether performance can be improved by changing some parameter? Run training jobs in parallel, on multiple databases?\n", |
197 | 239 | "If you're feeling like you're done, please reach back to the Google Document and fill in our feedback form.\n",
|
198 | 240 | "\n",
|
199 | 241 | "Thank you!"
|
|
0 commit comments