From ecf0e126258439047ec4e85727b66519acb55862 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Wed, 28 Aug 2019 14:21:58 +0800 Subject: [PATCH 1/6] add additional hive config for xgboost codegen --- sql/codegen.go | 16 ++-------------- sql/codegen_xgboost.go | 8 ++++++++ .../xgboost/sqlflow_data_source.py | 10 +++++++--- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/sql/codegen.go b/sql/codegen.go index a1daa81d69..17e500dcd5 100644 --- a/sql/codegen.go +++ b/sql/codegen.go @@ -88,16 +88,6 @@ func trainingAndValidationDataset(pr *extendedSelect, ds *trainAndValDataset) (s func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) (*filler, error) { isKerasModel, modelClassString := parseModelURI(pr.estimator) - auth := "" - var sc map[string]string - if db.driverName == "hive" { - cfg, err := gohive.ParseDSN(db.dataSourceName) - if err != nil { - return nil, err - } - auth = cfg.Auth - sc = cfg.SessionCfg - } training, validation := trainingAndValidationDataset(pr, ds) r := &filler{ IsTrain: pr.train, @@ -111,10 +101,6 @@ func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *D Save: pr.save, IsKerasModel: isKerasModel, }, - connectionConfig: connectionConfig{ - Auth: auth, - Session: sc, - }, } trainResolved, err := resolveTrainClause(&pr.trainClause) @@ -246,6 +232,8 @@ func fillDatabaseInfo(r *filler, db *DB) (*filler, error) { if err != nil { return nil, err } + r.Auth = cfg.Auth + r.Session = cfg.SessionCfg sa := strings.Split(cfg.Addr, ":") r.Host, r.Port, r.Database = sa[0], sa[1], cfg.DBName r.User, r.Password = cfg.User, cfg.Passwd diff --git a/sql/codegen_xgboost.go b/sql/codegen_xgboost.go index 7118a76899..b4a1578c7c 100644 --- a/sql/codegen_xgboost.go +++ b/sql/codegen_xgboost.go @@ -104,6 +104,12 @@ type xgDataBaseField struct { Port string `json:"port"` Database string `json:"database"` Driver string `json:"driver"` + xgDataBaseHiveField +} + +type xgDataBaseHiveField struct { + HiveAuth string `json:"auth,omitempty"` + HiveSession map[string]string `json:"session,omitempty"` } type xgFeatureMeta struct { @@ -659,6 +665,8 @@ func xgFillDatabaseInfo(r *xgDataSourceFields, db *DB) error { if err != nil { return err } + r.HiveAuth = cfg.Auth + r.HiveSession = cfg.SessionCfg sa := strings.Split(cfg.Addr, ":") r.Host, r.Port, r.Database = sa[0], sa[1], cfg.DBName r.User, r.Password = cfg.User, cfg.Passwd diff --git a/sql/python/sqlflow_submitter/xgboost/sqlflow_data_source.py b/sql/python/sqlflow_submitter/xgboost/sqlflow_data_source.py index 61e3723bc9..906f46e82a 100644 --- a/sql/python/sqlflow_submitter/xgboost/sqlflow_data_source.py +++ b/sql/python/sqlflow_submitter/xgboost/sqlflow_data_source.py @@ -115,7 +115,11 @@ def __init__(self, rank: int, num_worker: int, self._result_schema = {'append_columns': column_conf.append_columns or []} self._result_schema.update(column_conf.result_columns._asdict()) - conn = connect(**source_conf.db_config) + # assert sqlflow_submitter.db.connect contains no varargs + import inspect + conn_fields = set(inspect.getfullargspec(connect).args) + conn_conf = {k: v for k, v in source_conf.db_config.items() if k in conn_fields} + conn = connect(**conn_conf) def writer_maker(table_schema): return buffered_db_writer( @@ -131,8 +135,8 @@ def writer_maker(table_schema): self._reader = db_generator( driver=source_conf.db_config['driver'], conn=conn, - # TODO(weiguo): support auth(connect)/session_cfg for hive - session_cfg={}, + # specialized session_cfg for hive + session_cfg=source_conf.db_config.get('session', {}), statement=source_conf.standard_select, feature_column_names=col_names, label_column_name=None, From 32663853bfb4bd4ab7a0bf228fd8cb7f53aa4e51 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Wed, 28 Aug 2019 15:07:38 +0800 Subject: [PATCH 2/6] fix ut --- sql/codegen_xgboost_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/codegen_xgboost_test.go b/sql/codegen_xgboost_test.go index 26db77c6a2..55705e2503 100644 --- a/sql/codegen_xgboost_test.go +++ b/sql/codegen_xgboost_test.go @@ -364,6 +364,9 @@ LABEL e INTO model_table; dsFields := &xgDataSourceFields{} e = json.Unmarshal([]byte(filler.DataSourceJSON), dsFields) a.NoError(e) + if dsFields.HiveSession == nil { + dsFields.HiveSession = make(map[string]string) + } a.EqualValues(filler.xgDataSourceFields, *dsFields) xgbFields := &xgLearningFields{} e = json.Unmarshal([]byte(filler.LearningJSON), xgbFields) From 1b62dcdb617138252fb7d3f2013edec457d4c6f3 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Wed, 28 Aug 2019 15:29:20 +0800 Subject: [PATCH 3/6] fix ut --- sql/codegen_xgboost_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/codegen_xgboost_test.go b/sql/codegen_xgboost_test.go index 55705e2503..2d9c3af8d1 100644 --- a/sql/codegen_xgboost_test.go +++ b/sql/codegen_xgboost_test.go @@ -364,6 +364,9 @@ LABEL e INTO model_table; dsFields := &xgDataSourceFields{} e = json.Unmarshal([]byte(filler.DataSourceJSON), dsFields) a.NoError(e) + if filler.HiveSession == nil { + filler.HiveSession = make(map[string]string) + } if dsFields.HiveSession == nil { dsFields.HiveSession = make(map[string]string) } From 500e0a913ccdec1642108d13935ac1bf7ab27ef5 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Wed, 28 Aug 2019 16:05:57 +0800 Subject: [PATCH 4/6] fix ut --- sql/codegen_xgboost_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sql/codegen_xgboost_test.go b/sql/codegen_xgboost_test.go index 2d9c3af8d1..9004a8ec64 100644 --- a/sql/codegen_xgboost_test.go +++ b/sql/codegen_xgboost_test.go @@ -388,6 +388,12 @@ LABEL e INTO model_table; vdsFields := &xgDataSourceFields{} e = json.Unmarshal([]byte(filler.ValidDataSourceJSON), vdsFields) a.NoError(e) + if filler.validDataSource.HiveSession == nil { + filler.validDataSource.HiveSession = make(map[string]string) + } + if vdsFields.HiveSession == nil { + vdsFields.HiveSession = make(map[string]string) + } a.EqualValues(filler.validDataSource, *vdsFields) filler.StandardSelect, filler.validDataSource.StandardSelect = "", "" From 008cfecf7be2c030773fb7ad0651d2751c1955db Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Wed, 28 Aug 2019 17:05:57 +0800 Subject: [PATCH 5/6] fix --- sql/codegen_xgboost.go | 4 +++- sql/codegen_xgboost_test.go | 12 ------------ 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/sql/codegen_xgboost.go b/sql/codegen_xgboost.go index b4a1578c7c..39d000d524 100644 --- a/sql/codegen_xgboost.go +++ b/sql/codegen_xgboost.go @@ -666,7 +666,9 @@ func xgFillDatabaseInfo(r *xgDataSourceFields, db *DB) error { return err } r.HiveAuth = cfg.Auth - r.HiveSession = cfg.SessionCfg + if len(cfg.SessionCfg) > 0 { + r.HiveSession = cfg.SessionCfg + } sa := strings.Split(cfg.Addr, ":") r.Host, r.Port, r.Database = sa[0], sa[1], cfg.DBName r.User, r.Password = cfg.User, cfg.Passwd diff --git a/sql/codegen_xgboost_test.go b/sql/codegen_xgboost_test.go index 9004a8ec64..26db77c6a2 100644 --- a/sql/codegen_xgboost_test.go +++ b/sql/codegen_xgboost_test.go @@ -364,12 +364,6 @@ LABEL e INTO model_table; dsFields := &xgDataSourceFields{} e = json.Unmarshal([]byte(filler.DataSourceJSON), dsFields) a.NoError(e) - if filler.HiveSession == nil { - filler.HiveSession = make(map[string]string) - } - if dsFields.HiveSession == nil { - dsFields.HiveSession = make(map[string]string) - } a.EqualValues(filler.xgDataSourceFields, *dsFields) xgbFields := &xgLearningFields{} e = json.Unmarshal([]byte(filler.LearningJSON), xgbFields) @@ -388,12 +382,6 @@ LABEL e INTO model_table; vdsFields := &xgDataSourceFields{} e = json.Unmarshal([]byte(filler.ValidDataSourceJSON), vdsFields) a.NoError(e) - if filler.validDataSource.HiveSession == nil { - filler.validDataSource.HiveSession = make(map[string]string) - } - if vdsFields.HiveSession == nil { - vdsFields.HiveSession = make(map[string]string) - } a.EqualValues(filler.validDataSource, *vdsFields) filler.StandardSelect, filler.validDataSource.StandardSelect = "", "" From f76b2d0b445aa0cf576be4a5c119e6a3857b6e37 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Wed, 28 Aug 2019 18:41:17 +0800 Subject: [PATCH 6/6] fix python code style --- sql/python/sqlflow_submitter/xgboost/sqlflow_data_source.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/python/sqlflow_submitter/xgboost/sqlflow_data_source.py b/sql/python/sqlflow_submitter/xgboost/sqlflow_data_source.py index 906f46e82a..145681caaf 100644 --- a/sql/python/sqlflow_submitter/xgboost/sqlflow_data_source.py +++ b/sql/python/sqlflow_submitter/xgboost/sqlflow_data_source.py @@ -11,15 +11,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import json import typing from typing import Iterator +from launcher import DataSource, config_fields, XGBoostResult, XGBoostRecord from launcher.data_units import RecordBuilder from .common import XGBoostError from ..db import connect, db_generator, buffered_db_writer -from launcher import DataSource, config_fields, XGBoostResult, XGBoostRecord class FeatureMeta(typing.NamedTuple): @@ -116,7 +117,6 @@ def __init__(self, rank: int, num_worker: int, self._result_schema.update(column_conf.result_columns._asdict()) # assert sqlflow_submitter.db.connect contains no varargs - import inspect conn_fields = set(inspect.getfullargspec(connect).args) conn_conf = {k: v for k, v in source_conf.db_config.items() if k in conn_fields} conn = connect(**conn_conf)