Skip to content

Xgboost predict #789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions sql/codegen_xgboost.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func resolveParamsCfg(attrs map[string]*attribute) (map[string]interface{}, erro
func resolveObjective(pr *extendedSelect) (string, error) {
estimatorParts := strings.Split(pr.estimator, ".")
if len(estimatorParts) != 3 {
return "", fmt.Errorf("XGBoost Estimator should be xgboost.first_part.second_part")
return "", fmt.Errorf("XGBoost Estimator should be xgboost.first_part.second_part, current: %s", pr.estimator)
}
return strings.Join(estimatorParts[1:], ":"), nil
}
Expand All @@ -90,6 +90,7 @@ func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, db *DB) (*xgbFille
return nil, err
}
training, validation := trainingAndValidationDataset(pr, ds)
isTrain := pr.train
r := &xgbFiller{
Estimator: Estimator{
IsTrain: pr.train,
Expand All @@ -99,25 +100,34 @@ func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, db *DB) (*xgbFille
xgbTrainConfig: *resolveTrainCfg(attrs),
Save: pr.save,
}

// resolve the attribute keys without any prefix as the XGBoost Paremeters
params, err := resolveParamsCfg(attrs)
if err != nil {
return nil, err
if !isTrain && !pr.analyze {
r.PredictionDatasetSQL = pr.standardSelect.String()
if r.TableName, _, err = parseTableColumn(pr.into); err != nil {
return nil, err
}
r.Save = pr.model
}

// fill learning target
objective, err := resolveObjective(pr)
if err != nil {
return nil, err
}
params["objective"] = objective
if isTrain {
// resolve the attribute keys without any prefix as the XGBoost Paremeters
params, err := resolveParamsCfg(attrs)
if err != nil {
return nil, err
}

paramsJSON, err := json.Marshal(params)
if err != nil {
return nil, err
// fill learning target
objective, err := resolveObjective(pr)
if err != nil {
return nil, err
}
params["objective"] = objective

paramsJSON, err := json.Marshal(params)
if err != nil {
return nil, err
}
r.ParamsCfgJSON = string(paramsJSON)
}
r.ParamsCfgJSON = string(paramsJSON)

if r.connectionConfig, err = newConnectionConfig(db); err != nil {
return nil, err
Expand Down Expand Up @@ -161,7 +171,11 @@ func genXGBoost(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fie
if pr.train {
return xgbTrainTemplate.Execute(w, r)
}
return fmt.Errorf("xgboost prediction codegen has not been implemented")
if e := createPredictionTable(pr, db); e != nil {
return fmt.Errorf("failed to create prediction table: %v", e)
}
return xgbPredictTemplate.Execute(w, r)
}

var xgbTrainTemplate = template.Must(template.New("codegenXGBTrain").Parse(xgbTrainTemplateText))
var xgbPredictTemplate = template.Must(template.New("codegenXGBPredict").Parse(xgbPredictTemplateText))
27 changes: 24 additions & 3 deletions sql/codegen_xgboost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ SELECT *
FROM iris.train
TRAIN xgb.multi.softprob
WITH
train.num_boost_round = 30,
eta = 3.1,
num_class = 3
train.num_boost_round = 30,
eta = 3.1,
num_class = 3
COLUMN sepal_length, sepal_width, petal_length, petal_width
LABEL class
INTO sqlflow_models.my_xgboost_model;
Expand All @@ -38,6 +38,13 @@ ANALYZE sqlflow_models.my_xgboost_model
USING TreeExplainer;
`

const testXGBoostPredictIris = `
SELECT *
FROM iris.test
PREDICT iris.predict.class
USING sqlflow_models.my_xgboost_model;
`

func TestXGBFiller(t *testing.T) {
a := assert.New(t)
parser := newParser()
Expand All @@ -56,3 +63,17 @@ func TestXGBFiller(t *testing.T) {
a.NoError(err)
a.Equal(filler.ParamsCfgJSON, string(paramsJSON))
}

func TestXGBFillerPredict(t *testing.T) {
a := assert.New(t)
parser := newParser()
r, e := parser.Parse(testXGBoostPredictIris)
a.NoError(e)
filler, e := newXGBFiller(r, nil, testDB)
a.NoError(e)
a.False(filler.IsTrain)
a.Equal(filler.TableName, "iris.predict")
a.Equal(filler.Save, "sqlflow_models.my_xgboost_model")
a.Equal(filler.PredictionDatasetSQL, `SELECT *
FROM iris.test`)
}
4 changes: 4 additions & 0 deletions sql/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,10 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin
if e := genAntXGBoost(&buf, pr, nil, fts, db); e != nil {
return fmt.Errorf("genAntXGBoost %v", e)
}
} else if strings.HasPrefix(strings.ToUpper(pr.estimator), `XGB.`) {
if e := genXGBoost(&buf, pr, nil, fts, db); e != nil {
return fmt.Errorf("genXGBoost %v", e)
}
} else {
if e := genTF(&buf, pr, nil, fts, db); e != nil {
return fmt.Errorf("genTF %v", e)
Expand Down
2 changes: 2 additions & 0 deletions sql/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ func TestExecutorXGBoost(t *testing.T) {
a.True(goodStream(stream.ReadAll()))
stream = runExtendedSQL(testAnalyzeTreeModelSelectIris, testDB, modelDir, nil)
a.True(goodStream(stream.ReadAll()))
stream = runExtendedSQL(testXGBoostPredictIris, testDB, modelDir, nil)
a.True(goodStream(stream.ReadAll()))
})
}

Expand Down
97 changes: 83 additions & 14 deletions sql/template_xgboost.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ num_boost_round = {{.NumBoostRound}}
maximize = True if "{{.Maximize}}" == "true" else False
early_stopping_rounds = {{.EarlyStoppingRounds}}
if early_stopping_rounds == -1:
early_stopping_rounds = None
early_stopping_rounds = None

{{if ne .ParamsCfgJSON ""}}
params = {{.ParamsCfgJSON}}
Expand All @@ -58,22 +58,20 @@ feature_specs["{{$value.FeatureName}}"] = {
}
{{end}}



conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}")

def xgb_dataset(fn, dataset_sql):
gen = db_generator(driver, conn, session_cfg, dataset_sql, feature_column_names, "{{.Y.FeatureName}}", feature_specs)
with open(fn, 'w') as f:
for item in gen():
features, label = item
row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)]
f.write("\t".join(row_data) + "\n")
# TODO(yancey1989): genearte group and weight text file if necessary
return xgb.DMatrix(fn)

dtrain = xgb_dataset('train.txt', "{{.TrainingDatasetSQL}}")
dtest = xgb_dataset('test.txt', "{{.ValidationDatasetSQL}}")
gen = db_generator(driver, conn, session_cfg, dataset_sql, feature_column_names, "{{.Y.FeatureName}}", feature_specs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing the indent.

with open(fn, 'w') as f:
for item in gen():
features, label = item
row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)]
f.write("\t".join(row_data) + "\n")
# TODO(yancey1989): genearte group and weight text file if necessary
return xgb.DMatrix(fn)

dtrain = xgb_dataset('train.txt', """{{.TrainingDatasetSQL}}""")
dtest = xgb_dataset('test.txt', """{{.ValidationDatasetSQL}}""")

train_args = {}
train_args["num_boost_round"] = num_boost_round
Expand All @@ -84,3 +82,74 @@ train_args["evals"] = [(dtrain, "train"), (dtest, "validation")]
bst = xgb.train(params, dtrain, **train_args)
bst.save_model("{{.Save}}")
`

const xgbPredictTemplateText = `
import xgboost as xgb
import numpy as np
from sqlflow_submitter.db import connect, db_generator, buffered_db_writer

driver="{{.Driver}}"

{{if ne .Database ""}}
database="{{.Database}}"
{{else}}
database=""
{{end}}

session_cfg = {}
{{ range $k, $v := .Session }}
session_cfg["{{$k}}"] = "{{$v}}"
{{end}}

feature_column_names = [{{range .X}}
"{{.FeatureName}}",
{{end}}]

{{/* Convert go side featureSpec to python dict for input_fn */}}
feature_specs = dict()
{{ range $value := .X }}
feature_specs["{{$value.FeatureName}}"] = {
"feature_name": "{{$value.FeatureName}}",
"dtype": "{{$value.Dtype}}",
"delimiter": "{{$value.Delimiter}}",
"shape": {{$value.InputShape}},
"is_sparse": "{{$value.IsSparse}}" == "true"
}
{{end}}

conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}")

def xgb_dataset(fn, dataset_sql):
gen = db_generator(driver, conn, session_cfg, dataset_sql, feature_column_names, "", feature_specs)
with open(fn, 'w') as f:
for item in gen():
features, label = item
row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)]
f.write("\t".join(row_data) + "\n")
# TODO(yancey1989): genearte group and weight text file if necessary
return xgb.DMatrix(fn)

dpred = xgb_dataset('predict.txt', """{{.PredictionDatasetSQL}}""")

bst = xgb.Booster({'nthread': 4}) # init model
bst.load_model("{{.Save}}") # load data
preds = bst.predict(dpred)
# TODO(typhoonzero): regression models may have different behavior
pred_classes = np.argmax(np.array(preds), axis=1)

feature_file_read = open("predict.txt", "r")

result_column_names = feature_column_names
result_column_names.append("{{.Y.FeatureName}}")

line_no = 0
with buffered_db_writer(driver, conn, "{{.TableName}}", result_column_names, 100) as w:
while True:
line = feature_file_read.readline()
if not line:
break
row = [i.split(":")[1] for i in line.replace("\n", "").split("\t")[1:]]
row.append(pred_classes[line_no])
w.write(row)
line_no += 1
`