Skip to content

Commit 88f2e2e

Browse files
authored
convert WITH clause to xgboost params (#780)
* with clause to xgboost params * with clause to xgboost parameters * remove unused code * update by comment
1 parent 176391c commit 88f2e2e

File tree

7 files changed

+238
-128
lines changed

7 files changed

+238
-128
lines changed

sql/codegen.go

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,27 +51,33 @@ type modelConfig struct {
5151
IsKerasModel bool
5252
}
5353

54-
type featureMeta struct {
54+
// FeatureMeta describes feature column meta data
55+
type FeatureMeta struct {
5556
FeatureName string
5657
Dtype string
5758
Delimiter string
5859
InputShape string
5960
IsSparse bool
6061
}
6162

62-
type filler struct {
63+
// Estimator describes estimator meta data
64+
type Estimator struct {
6365
IsTrain bool
6466
TrainingDatasetSQL string // IsTrain == true
6567
ValidationDatasetSQL string // IsTrain == true
6668
PredictionDatasetSQL string // IsTrain != true
67-
X []*featureMeta
68-
FeatureColumnsCode map[string][]string
69-
Y *featureMeta
69+
X []*FeatureMeta
70+
Y *FeatureMeta
7071
TableName string
71-
modelConfig
7272
*connectionConfig
7373
}
7474

75+
type tfFiller struct {
76+
Estimator
77+
modelConfig
78+
FeatureColumnsCode map[string][]string
79+
}
80+
7581
// parseModelURI returns isKerasModel, modelClassString
7682
func parseModelURI(modelString string) (bool, string) {
7783
if strings.HasPrefix(modelString, "sqlflow_models.") {
@@ -87,14 +93,16 @@ func trainingAndValidationDataset(pr *extendedSelect, ds *trainAndValDataset) (s
8793
return pr.standardSelect.String(), pr.standardSelect.String()
8894
}
8995

90-
func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) (*filler, error) {
96+
func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) (*tfFiller, error) {
9197
isKerasModel, modelClassString := parseModelURI(pr.estimator)
9298
training, validation := trainingAndValidationDataset(pr, ds)
93-
r := &filler{
94-
IsTrain: pr.train,
95-
TrainingDatasetSQL: training,
96-
ValidationDatasetSQL: validation,
97-
PredictionDatasetSQL: pr.standardSelect.String(),
99+
r := &tfFiller{
100+
Estimator: Estimator{
101+
IsTrain: pr.train,
102+
TrainingDatasetSQL: training,
103+
ValidationDatasetSQL: validation,
104+
PredictionDatasetSQL: pr.standardSelect.String(),
105+
},
98106
modelConfig: modelConfig{
99107
EstimatorCode: modelClassString,
100108
BatchSize: 1,
@@ -148,7 +156,7 @@ func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *D
148156
isSparse = true
149157
}
150158
}
151-
fm := &featureMeta{
159+
fm := &FeatureMeta{
152160
FeatureName: col.GetKey(),
153161
Dtype: col.GetDtype(),
154162
Delimiter: col.GetDelimiter(),
@@ -202,7 +210,7 @@ func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *D
202210
log.Fatalf("Unsupported label data type: %s", v)
203211
}
204212
}
205-
r.Y = &featureMeta{
213+
r.Y = &FeatureMeta{
206214
FeatureName: pr.label,
207215
Dtype: labelDtype,
208216
Delimiter: ",",

sql/codegen_analyze.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ import (
2121

2222
type analyzeFiller struct {
2323
*connectionConfig
24-
X []*featureMeta
24+
X []*FeatureMeta
2525
Label string
2626
AnalyzeDatasetSQL string
2727
ModelFile string // path/to/model_file
2828
}
2929

30-
func newAnalyzeFiller(pr *extendedSelect, db *DB, fms []*featureMeta, label, modelPath string) (*analyzeFiller, error) {
30+
func newAnalyzeFiller(pr *extendedSelect, db *DB, fms []*FeatureMeta, label, modelPath string) (*analyzeFiller, error) {
3131
conn, err := newConnectionConfig(db)
3232
if err != nil {
3333
return nil, err
@@ -43,22 +43,22 @@ func newAnalyzeFiller(pr *extendedSelect, db *DB, fms []*featureMeta, label, mod
4343
}, nil
4444
}
4545

46-
func readAntXGBFeatures(pr *extendedSelect, db *DB) ([]*featureMeta, string, error) {
46+
func readAntXGBFeatures(pr *extendedSelect, db *DB) ([]*FeatureMeta, string, error) {
4747
// TODO(weiguo): It's a quick way to read column and label names from
4848
// xgboost.*, but too heavy.
4949
fr, err := newAntXGBoostFiller(pr, nil, db)
5050
if err != nil {
5151
return nil, "", err
5252
}
5353

54-
xs := make([]*featureMeta, len(fr.X))
54+
xs := make([]*FeatureMeta, len(fr.X))
5555
for i := 0; i < len(fr.X); i++ {
5656
// FIXME(weiguo): we convert xgboost.X to normal(tf).X to reuse
5757
// DB access API, but I don't think it is a good practice,
5858
// Think about the AI engines increased, such as ALPS, (EDL?)
5959
// we should write as many as such converters.
6060
// How about we unify all featureMetas?
61-
xs[i] = &featureMeta{
61+
xs[i] = &FeatureMeta{
6262
FeatureName: fr.X[i].FeatureName,
6363
Dtype: fr.X[i].Dtype,
6464
Delimiter: fr.X[i].Delimiter,

sql/codegen_xgboost.go

Lines changed: 93 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,41 +14,110 @@
1414
package sql
1515

1616
import (
17+
"encoding/json"
1718
"fmt"
1819
"io"
20+
"strconv"
21+
"strings"
1922
"text/template"
23+
24+
"github.com/asaskevich/govalidator"
2025
)
2126

2227
type xgbTrainConfig struct {
23-
NumBoostRound int `json:"num_boost_round,omitempty"`
24-
Maximize bool `json:"maximize,omitempty"`
28+
NumBoostRound int
29+
Maximize bool
30+
EarlyStoppingRounds int
2531
}
2632

2733
type xgbFiller struct {
28-
IsTrain bool
29-
TrainingDatasetSQL string
30-
ValidationDatasetSQL string
31-
TrainCfg *xgbTrainConfig
32-
Features []*featureMeta
33-
Label *featureMeta
34-
Save string
35-
ParamsCfgJSON string
36-
TrainCfgJSON string
37-
*connectionConfig
34+
Estimator
35+
xgbTrainConfig
36+
Save string
37+
ParamsCfgJSON string
38+
}
39+
40+
func resolveTrainCfg(attrs map[string]*attribute) *xgbTrainConfig {
41+
return &xgbTrainConfig{
42+
NumBoostRound: getIntAttr(attrs, "train.num_boost_round", 10),
43+
Maximize: getBoolAttr(attrs, "train.maximize", false, false),
44+
EarlyStoppingRounds: getIntAttr(attrs, "train.early_stopping_rounds", -1),
45+
}
3846
}
3947

40-
func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) (*xgbFiller, error) {
48+
func resolveParamsCfg(attrs map[string]*attribute) (map[string]interface{}, error) {
49+
// extract the attributes without any prefix as the XGBoost Parmaeters
50+
params := make(map[string]interface{})
4151
var err error
52+
for k, v := range attrs {
53+
if !strings.Contains(k, ".") {
54+
var vStr string
55+
var ok bool
56+
if vStr, ok = v.Value.(string); !ok {
57+
return nil, fmt.Errorf("convert params %s to string failed, %v", vStr, err)
58+
}
59+
if govalidator.IsFloat(vStr) {
60+
floatVal, err := strconv.ParseFloat(vStr, 16)
61+
if err != nil {
62+
return nil, fmt.Errorf("convert params %s to float32 failed, %v", vStr, err)
63+
}
64+
params[k] = floatVal
65+
} else if govalidator.IsInt(vStr) {
66+
if params[k], err = strconv.ParseInt(vStr, 0, 32); err != nil {
67+
return nil, fmt.Errorf("convert params %s to int32 failed, %v", vStr, err)
68+
}
69+
} else if govalidator.IsASCII(vStr) {
70+
params[k] = vStr
71+
} else {
72+
return nil, fmt.Errorf("unsupported params type: %s", vStr)
73+
}
74+
}
75+
}
76+
return params, nil
77+
}
78+
79+
func resolveObjective(pr *extendedSelect) (string, error) {
80+
estimatorParts := strings.Split(pr.estimator, ".")
81+
if len(estimatorParts) != 3 {
82+
return "", fmt.Errorf("XGBoost Estimator should be xgboost.first_part.second_part")
83+
}
84+
return strings.Join(estimatorParts[1:], ":"), nil
85+
}
86+
87+
func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, db *DB) (*xgbFiller, error) {
88+
attrs, err := resolveAttribute(&pr.trainAttrs)
89+
if err != nil {
90+
return nil, err
91+
}
4292
training, validation := trainingAndValidationDataset(pr, ds)
4393
r := &xgbFiller{
44-
IsTrain: pr.train,
45-
TrainingDatasetSQL: training,
46-
ValidationDatasetSQL: validation,
47-
Save: pr.save,
94+
Estimator: Estimator{
95+
IsTrain: pr.train,
96+
TrainingDatasetSQL: training,
97+
ValidationDatasetSQL: validation,
98+
},
99+
xgbTrainConfig: *resolveTrainCfg(attrs),
100+
Save: pr.save,
101+
}
102+
103+
// resolve the attribute keys without any prefix as the XGBoost Paremeters
104+
params, err := resolveParamsCfg(attrs)
105+
if err != nil {
106+
return nil, err
107+
}
108+
109+
// fill learning target
110+
objective, err := resolveObjective(pr)
111+
if err != nil {
112+
return nil, err
113+
}
114+
params["objective"] = objective
115+
116+
paramsJSON, err := json.Marshal(params)
117+
if err != nil {
118+
return nil, err
48119
}
49-
// TODO(Yancey1989): fill the train_args and parameters by WITH statment
50-
r.TrainCfgJSON = ""
51-
r.ParamsCfgJSON = ""
120+
r.ParamsCfgJSON = string(paramsJSON)
52121

53122
if r.connectionConfig, err = newConnectionConfig(db); err != nil {
54123
return nil, err
@@ -63,17 +132,17 @@ func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db
63132
return nil, fmt.Errorf("newXGBoostFiller doesn't support DENSE/SPARSE")
64133
}
65134
for _, col := range feaCols {
66-
fm := &featureMeta{
135+
fm := &FeatureMeta{
67136
FeatureName: col.GetKey(),
68137
Dtype: col.GetDtype(),
69138
Delimiter: col.GetDelimiter(),
70139
InputShape: col.GetInputShape(),
71140
IsSparse: false,
72141
}
73-
r.Features = append(r.Features, fm)
142+
r.X = append(r.X, fm)
74143
}
75144
}
76-
r.Label = &featureMeta{
145+
r.Y = &FeatureMeta{
77146
FeatureName: pr.label,
78147
Dtype: "int32",
79148
Delimiter: ",",
@@ -85,7 +154,7 @@ func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db
85154
}
86155

87156
func genXGBoost(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) error {
88-
r, e := newXGBFiller(pr, ds, fts, db)
157+
r, e := newXGBFiller(pr, ds, db)
89158
if e != nil {
90159
return e
91160
}

sql/codegen_xgboost_test.go

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,41 @@
1313

1414
package sql
1515

16+
import (
17+
"encoding/json"
18+
"testing"
19+
20+
"github.com/stretchr/testify/assert"
21+
)
22+
1623
const testXGBoostTrainSelectIris = `
1724
SELECT *
1825
FROM iris.train
1926
TRAIN xgb.multi.softprob
2027
WITH
21-
train.num_boost_round = 30
28+
train.num_boost_round = 30,
29+
eta = 3.1,
30+
num_class = 3
2231
COLUMN sepal_length, sepal_width, petal_length, petal_width
2332
LABEL class
2433
INTO sqlflow_models.my_xgboost_model;
2534
`
35+
36+
func TestXGBFiller(t *testing.T) {
37+
a := assert.New(t)
38+
parser := newParser()
39+
r, e := parser.Parse(testXGBoostTrainSelectIris)
40+
a.NoError(e)
41+
filler, e := newXGBFiller(r, nil, testDB)
42+
a.NoError(e)
43+
a.True(filler.IsTrain)
44+
a.Equal(filler.NumBoostRound, 30)
45+
expectedParams := map[string]interface{}{
46+
"eta": 3.1,
47+
"num_class": 3,
48+
"objective": "multi:softprob",
49+
}
50+
paramsJSON, err := json.Marshal(expectedParams)
51+
a.NoError(err)
52+
a.Equal(filler.ParamsCfgJSON, string(paramsJSON))
53+
}

sql/executor_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ func TestExecutorTrainXGBoost(t *testing.T) {
110110
a.NotPanics(func() {
111111
stream := runExtendedSQL(testXGBoostTrainSelectIris, testDB, modelDir, nil)
112112
a.True(goodStream(stream.ReadAll()))
113-
114113
})
115114
}
116115

0 commit comments

Comments
 (0)