Skip to content

Commit 95a6d90

Browse files
authored
add additional hive config for xgboost codegen (#729)
fix issue #728
1 parent d766138 commit 95a6d90

File tree

3 files changed

+20
-18
lines changed

3 files changed

+20
-18
lines changed

sql/codegen.go

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,6 @@ func trainingAndValidationDataset(pr *extendedSelect, ds *trainAndValDataset) (s
8888

8989
func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) (*filler, error) {
9090
isKerasModel, modelClassString := parseModelURI(pr.estimator)
91-
auth := ""
92-
var sc map[string]string
93-
if db.driverName == "hive" {
94-
cfg, err := gohive.ParseDSN(db.dataSourceName)
95-
if err != nil {
96-
return nil, err
97-
}
98-
auth = cfg.Auth
99-
sc = cfg.SessionCfg
100-
}
10191
training, validation := trainingAndValidationDataset(pr, ds)
10292
r := &filler{
10393
IsTrain: pr.train,
@@ -111,10 +101,6 @@ func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *D
111101
Save: pr.save,
112102
IsKerasModel: isKerasModel,
113103
},
114-
connectionConfig: connectionConfig{
115-
Auth: auth,
116-
Session: sc,
117-
},
118104
}
119105

120106
trainResolved, err := resolveTrainClause(&pr.trainClause)
@@ -246,6 +232,8 @@ func fillDatabaseInfo(r *filler, db *DB) (*filler, error) {
246232
if err != nil {
247233
return nil, err
248234
}
235+
r.Auth = cfg.Auth
236+
r.Session = cfg.SessionCfg
249237
sa := strings.Split(cfg.Addr, ":")
250238
r.Host, r.Port, r.Database = sa[0], sa[1], cfg.DBName
251239
r.User, r.Password = cfg.User, cfg.Passwd

sql/codegen_xgboost.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ type xgDataBaseField struct {
104104
Port string `json:"port"`
105105
Database string `json:"database"`
106106
Driver string `json:"driver"`
107+
xgDataBaseHiveField
108+
}
109+
110+
type xgDataBaseHiveField struct {
111+
HiveAuth string `json:"auth,omitempty"`
112+
HiveSession map[string]string `json:"session,omitempty"`
107113
}
108114

109115
type xgFeatureMeta struct {
@@ -659,6 +665,10 @@ func xgFillDatabaseInfo(r *xgDataSourceFields, db *DB) error {
659665
if err != nil {
660666
return err
661667
}
668+
r.HiveAuth = cfg.Auth
669+
if len(cfg.SessionCfg) > 0 {
670+
r.HiveSession = cfg.SessionCfg
671+
}
662672
sa := strings.Split(cfg.Addr, ":")
663673
r.Host, r.Port, r.Database = sa[0], sa[1], cfg.DBName
664674
r.User, r.Password = cfg.User, cfg.Passwd

sql/python/sqlflow_submitter/xgboost/sqlflow_data_source.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,16 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
import inspect
1415
import json
1516
import typing
1617
from typing import Iterator
1718

19+
from launcher import DataSource, config_fields, XGBoostResult, XGBoostRecord
1820
from launcher.data_units import RecordBuilder
1921

2022
from .common import XGBoostError
2123
from ..db import connect, db_generator, buffered_db_writer
22-
from launcher import DataSource, config_fields, XGBoostResult, XGBoostRecord
2324

2425

2526
class FeatureMeta(typing.NamedTuple):
@@ -115,7 +116,10 @@ def __init__(self, rank: int, num_worker: int,
115116
self._result_schema = {'append_columns': column_conf.append_columns or []}
116117
self._result_schema.update(column_conf.result_columns._asdict())
117118

118-
conn = connect(**source_conf.db_config)
119+
# assert sqlflow_submitter.db.connect contains no varargs
120+
conn_fields = set(inspect.getfullargspec(connect).args)
121+
conn_conf = {k: v for k, v in source_conf.db_config.items() if k in conn_fields}
122+
conn = connect(**conn_conf)
119123

120124
def writer_maker(table_schema):
121125
return buffered_db_writer(
@@ -131,8 +135,8 @@ def writer_maker(table_schema):
131135
self._reader = db_generator(
132136
driver=source_conf.db_config['driver'],
133137
conn=conn,
134-
# TODO(weiguo): support auth(connect)/session_cfg for hive
135-
session_cfg={},
138+
# specialized session_cfg for hive
139+
session_cfg=source_conf.db_config.get('session', {}),
136140
statement=source_conf.standard_select,
137141
feature_column_names=col_names,
138142
label_column_name=None,

0 commit comments

Comments
 (0)