diff --git a/sql/codegen.go b/sql/codegen.go index a1daa81d69..17e500dcd5 100644 --- a/sql/codegen.go +++ b/sql/codegen.go @@ -88,16 +88,6 @@ func trainingAndValidationDataset(pr *extendedSelect, ds *trainAndValDataset) (s func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) (*filler, error) { isKerasModel, modelClassString := parseModelURI(pr.estimator) - auth := "" - var sc map[string]string - if db.driverName == "hive" { - cfg, err := gohive.ParseDSN(db.dataSourceName) - if err != nil { - return nil, err - } - auth = cfg.Auth - sc = cfg.SessionCfg - } training, validation := trainingAndValidationDataset(pr, ds) r := &filler{ IsTrain: pr.train, @@ -111,10 +101,6 @@ func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *D Save: pr.save, IsKerasModel: isKerasModel, }, - connectionConfig: connectionConfig{ - Auth: auth, - Session: sc, - }, } trainResolved, err := resolveTrainClause(&pr.trainClause) @@ -246,6 +232,8 @@ func fillDatabaseInfo(r *filler, db *DB) (*filler, error) { if err != nil { return nil, err } + r.Auth = cfg.Auth + r.Session = cfg.SessionCfg sa := strings.Split(cfg.Addr, ":") r.Host, r.Port, r.Database = sa[0], sa[1], cfg.DBName r.User, r.Password = cfg.User, cfg.Passwd diff --git a/sql/codegen_xgboost.go b/sql/codegen_xgboost.go index 7118a76899..39d000d524 100644 --- a/sql/codegen_xgboost.go +++ b/sql/codegen_xgboost.go @@ -104,6 +104,12 @@ type xgDataBaseField struct { Port string `json:"port"` Database string `json:"database"` Driver string `json:"driver"` + xgDataBaseHiveField +} + +type xgDataBaseHiveField struct { + HiveAuth string `json:"auth,omitempty"` + HiveSession map[string]string `json:"session,omitempty"` } type xgFeatureMeta struct { @@ -659,6 +665,10 @@ func xgFillDatabaseInfo(r *xgDataSourceFields, db *DB) error { if err != nil { return err } + r.HiveAuth = cfg.Auth + if len(cfg.SessionCfg) > 0 { + r.HiveSession = cfg.SessionCfg + } sa := strings.Split(cfg.Addr, ":") r.Host, r.Port, r.Database = sa[0], sa[1], cfg.DBName r.User, r.Password = cfg.User, cfg.Passwd diff --git a/sql/python/sqlflow_submitter/xgboost/sqlflow_data_source.py b/sql/python/sqlflow_submitter/xgboost/sqlflow_data_source.py index 61e3723bc9..145681caaf 100644 --- a/sql/python/sqlflow_submitter/xgboost/sqlflow_data_source.py +++ b/sql/python/sqlflow_submitter/xgboost/sqlflow_data_source.py @@ -11,15 +11,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import json import typing from typing import Iterator +from launcher import DataSource, config_fields, XGBoostResult, XGBoostRecord from launcher.data_units import RecordBuilder from .common import XGBoostError from ..db import connect, db_generator, buffered_db_writer -from launcher import DataSource, config_fields, XGBoostResult, XGBoostRecord class FeatureMeta(typing.NamedTuple): @@ -115,7 +116,10 @@ def __init__(self, rank: int, num_worker: int, self._result_schema = {'append_columns': column_conf.append_columns or []} self._result_schema.update(column_conf.result_columns._asdict()) - conn = connect(**source_conf.db_config) + # assert sqlflow_submitter.db.connect contains no varargs + conn_fields = set(inspect.getfullargspec(connect).args) + conn_conf = {k: v for k, v in source_conf.db_config.items() if k in conn_fields} + conn = connect(**conn_conf) def writer_maker(table_schema): return buffered_db_writer( @@ -131,8 +135,8 @@ def writer_maker(table_schema): self._reader = db_generator( driver=source_conf.db_config['driver'], conn=conn, - # TODO(weiguo): support auth(connect)/session_cfg for hive - session_cfg={}, + # specialized session_cfg for hive + session_cfg=source_conf.db_config.get('session', {}), statement=source_conf.standard_select, feature_column_names=col_names, label_column_name=None,