Skip to content

Commit 0a09dfd

Browse files
authored
XGBoost tutorial (#820)
* xgboost tutorial * update by comment * update by comment * update
1 parent 432b584 commit 0a09dfd

File tree

2 files changed

+358
-0
lines changed

2 files changed

+358
-0
lines changed

example/jupyter/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.ipynb_checkpoints
Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# XGBoost on SQLFlow Tutorial\n",
8+
"\n",
9+
"This is a tutorial on train/predict XGBoost model in SQLFLow, you can find more SQLFlow usage from the [User Guide](https://github.com/sql-machine-learning/sqlflow/blob/develop/doc/user_guide.md), in this tutorial you will learn how to:\n",
10+
"- Train a XGBoost model to fit the boston housing dataset; and\n",
11+
"- Predict the housing price using the trained model;\n",
12+
"\n",
13+
"\n",
14+
"## The Dataset\n",
15+
"\n",
16+
"This tutorial would use the [Boston Housing](https://www.kaggle.com/c/boston-housing) as the demonstration dataset.\n",
17+
"The database contains 506 lines and 14 columns, the meaning of each column is as follows:\n",
18+
"\n",
19+
"Column | Explain \n",
20+
"-- | -- \n",
21+
"crim|per capita crime rate by town.\n",
22+
"zn|proportion of residential land zoned for lots over 25,000 sq.ft.\n",
23+
"indus|proportion of non-retail business acres per town.\n",
24+
"chas|Charles River dummy variable (= 1 if tract bounds river; 0 otherwise).\n",
25+
"nox|nitrogen oxides concentration (parts per 10 million).\n",
26+
"rm|average number of rooms per dwelling.\n",
27+
"age|proportion of owner-occupied units built prior to 1940.\n",
28+
"dis|weighted mean of distances to five Boston employment centres.\n",
29+
"rad|index of accessibility to radial highways.\n",
30+
"tax|full-value property-tax rate per \\$10,000.\n",
31+
"ptratio|pupil-teacher ratio by town.\n",
32+
"black|1000(Bk - 0.63)^2 where Bk is the proportion of blacks by town.\n",
33+
"lstat|lower status of the population (percent).\n",
34+
"medv|median value of owner-occupied homes in $1000s.\n",
35+
"\n",
36+
"We separated the dataset into train/test dataset, which is used to train/predict our model. SQLFlow would automatically split the training dataset into train/validation dataset while training progress."
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": 1,
42+
"metadata": {},
43+
"outputs": [
44+
{
45+
"data": {
46+
"text/plain": [
47+
"+---------+---------+------+-----+---------+-------+\n",
48+
"| Field | Type | Null | Key | Default | Extra |\n",
49+
"+---------+---------+------+-----+---------+-------+\n",
50+
"| crim | float | YES | | None | |\n",
51+
"| zn | float | YES | | None | |\n",
52+
"| indus | float | YES | | None | |\n",
53+
"| chas | int(11) | YES | | None | |\n",
54+
"| nox | float | YES | | None | |\n",
55+
"| rm | float | YES | | None | |\n",
56+
"| age | float | YES | | None | |\n",
57+
"| dis | float | YES | | None | |\n",
58+
"| rad | int(11) | YES | | None | |\n",
59+
"| tax | int(11) | YES | | None | |\n",
60+
"| ptratio | float | YES | | None | |\n",
61+
"| b | float | YES | | None | |\n",
62+
"| lstat | float | YES | | None | |\n",
63+
"| medv | float | YES | | None | |\n",
64+
"+---------+---------+------+-----+---------+-------+"
65+
]
66+
},
67+
"execution_count": 1,
68+
"metadata": {},
69+
"output_type": "execute_result"
70+
}
71+
],
72+
"source": [
73+
"%%sqlflow\n",
74+
"describe boston.train;"
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": 2,
80+
"metadata": {},
81+
"outputs": [
82+
{
83+
"data": {
84+
"text/plain": [
85+
"+---------+---------+------+-----+---------+-------+\n",
86+
"| Field | Type | Null | Key | Default | Extra |\n",
87+
"+---------+---------+------+-----+---------+-------+\n",
88+
"| crim | float | YES | | None | |\n",
89+
"| zn | float | YES | | None | |\n",
90+
"| indus | float | YES | | None | |\n",
91+
"| chas | int(11) | YES | | None | |\n",
92+
"| nox | float | YES | | None | |\n",
93+
"| rm | float | YES | | None | |\n",
94+
"| age | float | YES | | None | |\n",
95+
"| dis | float | YES | | None | |\n",
96+
"| rad | int(11) | YES | | None | |\n",
97+
"| tax | int(11) | YES | | None | |\n",
98+
"| ptratio | float | YES | | None | |\n",
99+
"| b | float | YES | | None | |\n",
100+
"| lstat | float | YES | | None | |\n",
101+
"| medv | float | YES | | None | |\n",
102+
"+---------+---------+------+-----+---------+-------+"
103+
]
104+
},
105+
"execution_count": 2,
106+
"metadata": {},
107+
"output_type": "execute_result"
108+
}
109+
],
110+
"source": [
111+
"%%sqlflow\n",
112+
"describe boston.test;"
113+
]
114+
},
115+
{
116+
"cell_type": "markdown",
117+
"metadata": {},
118+
"source": [
119+
"## Fit Boston Housing Dataset\n",
120+
"\n",
121+
"First, let's train an XGBoost regression model to fit the boston housing dataset, we prefer to train the model for `30 rounds`,\n",
122+
"and using `squarederror` loss function that the SQLFLow extended SQL can be like:\n",
123+
"\n",
124+
"``` sql\n",
125+
"TRAIN xgboost.gbtree\n",
126+
"WITH\n",
127+
" train.num_boost_round=30,\n",
128+
" objective=\"reg:squarederror\"\n",
129+
"```\n",
130+
"\n",
131+
"`xgboost.gbtree` is the estimator name, `gbtree` is one of the XGBoost booster, you can find more information from [here](https://xgboost.readthedocs.io/en/latest/parameter.html#general-parameters).\n",
132+
"\n",
133+
"We can specify the training data columns in `COLUMN clause`, and the label by `LABEL` keyword:\n",
134+
"\n",
135+
"``` sql\n",
136+
"COLUMN crim, zn, indus, chas, nox, rm, age, dis, rad, tax, ptratio, b, lstat\n",
137+
"LABEL medv\n",
138+
"```\n",
139+
"\n",
140+
"To save the trained model, we can use `INTO clause` to specify a model name:\n",
141+
"\n",
142+
"``` sql\n",
143+
"INTO sqlflow_models.my_xgb_regression_model\n",
144+
"```\n",
145+
"\n",
146+
"Second, let's use a standar SQL to fetch the traning data from table `boston.train`:\n",
147+
"\n",
148+
"``` sql\n",
149+
"SELECT * FROM boston.train\n",
150+
"```\n",
151+
"\n",
152+
"Finally, the following is the SQLFlow Train statment of this regression task, you can run it in the cell:"
153+
]
154+
},
155+
{
156+
"cell_type": "code",
157+
"execution_count": 5,
158+
"metadata": {},
159+
"outputs": [
160+
{
161+
"name": "stdout",
162+
"output_type": "stream",
163+
"text": [
164+
"[03:44:56] 387x13 matrix with 5031 entries loaded from train.txt\n",
165+
"\n",
166+
"[03:44:56] 109x13 matrix with 1417 entries loaded from test.txt\n",
167+
"\n",
168+
"[0]\ttrain-rmse:17.0286\tvalidation-rmse:17.8089\n",
169+
"\n",
170+
"[1]\ttrain-rmse:12.285\tvalidation-rmse:13.2787\n",
171+
"\n",
172+
"[2]\ttrain-rmse:8.93071\tvalidation-rmse:9.87677\n",
173+
"\n",
174+
"[3]\ttrain-rmse:6.60757\tvalidation-rmse:7.64013\n",
175+
"\n",
176+
"[4]\ttrain-rmse:4.96022\tvalidation-rmse:6.0181\n",
177+
"\n",
178+
"[5]\ttrain-rmse:3.80725\tvalidation-rmse:4.95013\n",
179+
"\n",
180+
"[6]\ttrain-rmse:2.94382\tvalidation-rmse:4.2357\n",
181+
"\n",
182+
"[7]\ttrain-rmse:2.36361\tvalidation-rmse:3.74683\n",
183+
"\n",
184+
"[8]\ttrain-rmse:1.95236\tvalidation-rmse:3.43284\n",
185+
"\n",
186+
"[9]\ttrain-rmse:1.66604\tvalidation-rmse:3.20455\n",
187+
"\n",
188+
"[10]\ttrain-rmse:1.4738\tvalidation-rmse:3.08947\n",
189+
"\n",
190+
"[11]\ttrain-rmse:1.35336\tvalidation-rmse:3.0492\n",
191+
"\n",
192+
"[12]\ttrain-rmse:1.22835\tvalidation-rmse:2.99508\n",
193+
"\n",
194+
"[13]\ttrain-rmse:1.15615\tvalidation-rmse:2.98604\n",
195+
"\n",
196+
"[14]\ttrain-rmse:1.11082\tvalidation-rmse:2.96433\n",
197+
"\n",
198+
"[15]\ttrain-rmse:1.01666\tvalidation-rmse:2.96584\n",
199+
"\n",
200+
"[16]\ttrain-rmse:0.953761\tvalidation-rmse:2.94013\n",
201+
"\n",
202+
"[17]\ttrain-rmse:0.905753\tvalidation-rmse:2.91569\n",
203+
"\n",
204+
"[18]\ttrain-rmse:0.870137\tvalidation-rmse:2.89735\n",
205+
"\n",
206+
"[19]\ttrain-rmse:0.800778\tvalidation-rmse:2.87206\n",
207+
"\n",
208+
"[20]\ttrain-rmse:0.757704\tvalidation-rmse:2.86564\n",
209+
"\n",
210+
"[21]\ttrain-rmse:0.74058\tvalidation-rmse:2.86587\n",
211+
"\n",
212+
"[22]\ttrain-rmse:0.66901\tvalidation-rmse:2.86224\n",
213+
"\n",
214+
"[23]\ttrain-rmse:0.647195\tvalidation-rmse:2.87395\n",
215+
"\n",
216+
"[24]\ttrain-rmse:0.609025\tvalidation-rmse:2.86069\n",
217+
"\n",
218+
"[25]\ttrain-rmse:0.562925\tvalidation-rmse:2.87205\n",
219+
"\n",
220+
"[26]\ttrain-rmse:0.541676\tvalidation-rmse:2.86275\n",
221+
"\n",
222+
"[27]\ttrain-rmse:0.524815\tvalidation-rmse:2.87106\n",
223+
"\n",
224+
"[28]\ttrain-rmse:0.483566\tvalidation-rmse:2.86129\n",
225+
"\n",
226+
"[29]\ttrain-rmse:0.460363\tvalidation-rmse:2.85877\n",
227+
"\n"
228+
]
229+
}
230+
],
231+
"source": [
232+
"%%sqlflow\n",
233+
"SELECT * FROM boston.train\n",
234+
"TRAIN xgboost.gbtree\n",
235+
"WITH\n",
236+
" objective=\"reg:squarederror\",\n",
237+
" train.num_boost_round = 30\n",
238+
"COLUMN crim, zn, indus, chas, nox, rm, age, dis, rad, tax, ptratio, b, lstat\n",
239+
"LABEL medv\n",
240+
"INTO sqlflow_models.my_xgb_regression_model;"
241+
]
242+
},
243+
{
244+
"cell_type": "markdown",
245+
"metadata": {},
246+
"source": [
247+
"### Predict the housing price\n",
248+
"After training the regression model, let's predict the house price using the trained model.\n",
249+
"\n",
250+
"First, we can specify the trained model by `USING clause`: \n",
251+
"\n",
252+
"```sql\n",
253+
"USING sqlflow_models.my_xgb_regression_model\n",
254+
"```\n",
255+
"\n",
256+
"Than, we can specify the prediction result table by `PREDICT clause`:\n",
257+
"\n",
258+
"``` sql\n",
259+
"PREDICT boston.predict.medv\n",
260+
"```\n",
261+
"\n",
262+
"And using a standar SQL to fetch the prediction data:\n",
263+
"\n",
264+
"``` sql\n",
265+
"SELECT * FROM boston.test\n",
266+
"```\n",
267+
"\n",
268+
"Finally, the following is the SQLFLow Prediction statment:"
269+
]
270+
},
271+
{
272+
"cell_type": "code",
273+
"execution_count": 8,
274+
"metadata": {},
275+
"outputs": [
276+
{
277+
"name": "stdout",
278+
"output_type": "stream",
279+
"text": [
280+
"[03:45:18] 10x13 matrix with 130 entries loaded from predict.txt\n",
281+
"\n",
282+
"Done predicting. Predict table : boston.predict\n",
283+
"\n"
284+
]
285+
}
286+
],
287+
"source": [
288+
"%%sqlflow\n",
289+
"SELECT * FROM boston.test\n",
290+
"PREDICT boston.predict.medv\n",
291+
"USING sqlflow_models.my_xgb_regression_model;"
292+
]
293+
},
294+
{
295+
"cell_type": "markdown",
296+
"metadata": {},
297+
"source": [
298+
"Let's have a glance at prediction results."
299+
]
300+
},
301+
{
302+
"cell_type": "code",
303+
"execution_count": 10,
304+
"metadata": {},
305+
"outputs": [
306+
{
307+
"data": {
308+
"text/plain": [
309+
"+---------+-----+-------+------+-------+-------+------+--------+-----+-----+---------+--------+-------+---------+\n",
310+
"| crim | zn | indus | chas | nox | rm | age | dis | rad | tax | ptratio | b | lstat | medv |\n",
311+
"+---------+-----+-------+------+-------+-------+------+--------+-----+-----+---------+--------+-------+---------+\n",
312+
"| 0.2896 | 0.0 | 9.69 | 0 | 0.585 | 5.39 | 72.9 | 2.7986 | 6 | 391 | 19.2 | 396.9 | 21.14 | 21.9436 |\n",
313+
"| 0.26838 | 0.0 | 9.69 | 0 | 0.585 | 5.794 | 70.6 | 2.8927 | 6 | 391 | 19.2 | 396.9 | 14.1 | 21.9667 |\n",
314+
"| 0.23912 | 0.0 | 9.69 | 0 | 0.585 | 6.019 | 65.3 | 2.4091 | 6 | 391 | 19.2 | 396.9 | 12.92 | 22.9708 |\n",
315+
"| 0.17783 | 0.0 | 9.69 | 0 | 0.585 | 5.569 | 73.5 | 2.3999 | 6 | 391 | 19.2 | 395.77 | 15.1 | 22.6373 |\n",
316+
"| 0.22438 | 0.0 | 9.69 | 0 | 0.585 | 6.027 | 79.7 | 2.4982 | 6 | 391 | 19.2 | 396.9 | 14.33 | 21.9439 |\n",
317+
"| 0.06263 | 0.0 | 11.93 | 0 | 0.573 | 6.593 | 69.1 | 2.4786 | 1 | 273 | 21.0 | 391.99 | 9.67 | 24.0095 |\n",
318+
"| 0.04527 | 0.0 | 11.93 | 0 | 0.573 | 6.12 | 76.7 | 2.2875 | 1 | 273 | 21.0 | 396.9 | 9.08 | 25.0 |\n",
319+
"| 0.06076 | 0.0 | 11.93 | 0 | 0.573 | 6.976 | 91.0 | 2.1675 | 1 | 273 | 21.0 | 396.9 | 5.64 | 31.6326 |\n",
320+
"| 0.10959 | 0.0 | 11.93 | 0 | 0.573 | 6.794 | 89.3 | 2.3889 | 1 | 273 | 21.0 | 393.45 | 6.48 | 26.8375 |\n",
321+
"| 0.04741 | 0.0 | 11.93 | 0 | 0.573 | 6.03 | 80.8 | 2.505 | 1 | 273 | 21.0 | 396.9 | 7.88 | 22.5877 |\n",
322+
"+---------+-----+-------+------+-------+-------+------+--------+-----+-----+---------+--------+-------+---------+"
323+
]
324+
},
325+
"execution_count": 10,
326+
"metadata": {},
327+
"output_type": "execute_result"
328+
}
329+
],
330+
"source": [
331+
"%%sqlflow\n",
332+
"SELECT * FROM boston.predict;"
333+
]
334+
}
335+
],
336+
"metadata": {
337+
"kernelspec": {
338+
"display_name": "Python 3",
339+
"language": "python",
340+
"name": "python3"
341+
},
342+
"language_info": {
343+
"codemirror_mode": {
344+
"name": "ipython",
345+
"version": 3
346+
},
347+
"file_extension": ".py",
348+
"mimetype": "text/x-python",
349+
"name": "python",
350+
"nbconvert_exporter": "python",
351+
"pygments_lexer": "ipython3",
352+
"version": "3.6.9"
353+
}
354+
},
355+
"nbformat": 4,
356+
"nbformat_minor": 2
357+
}

0 commit comments

Comments
 (0)