Skip to content

add additional hive config for xgboost codegen #729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions sql/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,6 @@ func trainingAndValidationDataset(pr *extendedSelect, ds *trainAndValDataset) (s

func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) (*filler, error) {
isKerasModel, modelClassString := parseModelURI(pr.estimator)
auth := ""
var sc map[string]string
if db.driverName == "hive" {
cfg, err := gohive.ParseDSN(db.dataSourceName)
if err != nil {
return nil, err
}
auth = cfg.Auth
sc = cfg.SessionCfg
}
training, validation := trainingAndValidationDataset(pr, ds)
r := &filler{
IsTrain: pr.train,
Expand All @@ -111,10 +101,6 @@ func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *D
Save: pr.save,
IsKerasModel: isKerasModel,
},
connectionConfig: connectionConfig{
Auth: auth,
Session: sc,
},
}

trainResolved, err := resolveTrainClause(&pr.trainClause)
Expand Down Expand Up @@ -246,6 +232,8 @@ func fillDatabaseInfo(r *filler, db *DB) (*filler, error) {
if err != nil {
return nil, err
}
r.Auth = cfg.Auth
r.Session = cfg.SessionCfg
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice practice!

sa := strings.Split(cfg.Addr, ":")
r.Host, r.Port, r.Database = sa[0], sa[1], cfg.DBName
r.User, r.Password = cfg.User, cfg.Passwd
Expand Down
10 changes: 10 additions & 0 deletions sql/codegen_xgboost.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ type xgDataBaseField struct {
Port string `json:"port"`
Database string `json:"database"`
Driver string `json:"driver"`
xgDataBaseHiveField
}

type xgDataBaseHiveField struct {
HiveAuth string `json:"auth,omitempty"`
HiveSession map[string]string `json:"session,omitempty"`
}

type xgFeatureMeta struct {
Expand Down Expand Up @@ -659,6 +665,10 @@ func xgFillDatabaseInfo(r *xgDataSourceFields, db *DB) error {
if err != nil {
return err
}
r.HiveAuth = cfg.Auth
if len(cfg.SessionCfg) > 0 {
r.HiveSession = cfg.SessionCfg
}
sa := strings.Split(cfg.Addr, ":")
r.Host, r.Port, r.Database = sa[0], sa[1], cfg.DBName
r.User, r.Password = cfg.User, cfg.Passwd
Expand Down
12 changes: 8 additions & 4 deletions sql/python/sqlflow_submitter/xgboost/sqlflow_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import json
import typing
from typing import Iterator

from launcher import DataSource, config_fields, XGBoostResult, XGBoostRecord
from launcher.data_units import RecordBuilder

from .common import XGBoostError
from ..db import connect, db_generator, buffered_db_writer
from launcher import DataSource, config_fields, XGBoostResult, XGBoostRecord


class FeatureMeta(typing.NamedTuple):
Expand Down Expand Up @@ -115,7 +116,10 @@ def __init__(self, rank: int, num_worker: int,
self._result_schema = {'append_columns': column_conf.append_columns or []}
self._result_schema.update(column_conf.result_columns._asdict())

conn = connect(**source_conf.db_config)
# assert sqlflow_submitter.db.connect contains no varargs
conn_fields = set(inspect.getfullargspec(connect).args)
conn_conf = {k: v for k, v in source_conf.db_config.items() if k in conn_fields}
conn = connect(**conn_conf)

def writer_maker(table_schema):
return buffered_db_writer(
Expand All @@ -131,8 +135,8 @@ def writer_maker(table_schema):
self._reader = db_generator(
driver=source_conf.db_config['driver'],
conn=conn,
# TODO(weiguo): support auth(connect)/session_cfg for hive
session_cfg={},
# specialized session_cfg for hive
session_cfg=source_conf.db_config.get('session', {}),
statement=source_conf.standard_select,
feature_column_names=col_names,
label_column_name=None,
Expand Down