Skip to content

fix write hive table with unordered row data #757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sql/python/sqlflow_submitter/db_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def test_hive(self):

def _do_test(self, driver, conn):
table_name = "test_db"
table_schema = ["features", "label"]
values = [('5,6,1,2', 1)] * 10
table_schema = ["label", "features"]
values = [(1, '5,6,1,2')] * 10

execute(driver, conn, self.drop_statement)

Expand Down
31 changes: 29 additions & 2 deletions sql/python/sqlflow_submitter/db_writer/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,43 @@
import tempfile
import subprocess

CSV_DELIMITER = '\001'

class HiveDBWriter(BufferedDBWriter):
def __init__(self, conn, table_name, table_schema, buff_size=10000):
super().__init__(conn, table_name, table_schema, buff_size)
self.tmp_f = tempfile.NamedTemporaryFile(dir="./")
self.f = open(self.tmp_f.name, "w")
self.schema_idx = self._indexing_table_schema(table_schema)

def _indexing_table_schema(self, table_schema):
cursor = self.conn.cursor()
cursor.execute("describe %s" % self.table_name)
column_list = cursor.fetchall()
schema_idx = []
idx_map = {}
# column list: [(col1, type, desc), (col2, type, desc)...]
for i, e in enumerate(column_list):
idx_map[e[0]] = i

for s in table_schema:
if s not in idx_map:
raise ValueError("column: %s should be in table columns:%s" % (s, idx_map))
schema_idx.append(idx_map[s])

return schema_idx

def _ordered_row_data(self, row):
# Use NULL as the default value for hive columns
row_data = ["NULL" for i in range(len(self.table_schema))]
for idx, element in enumerate(row):
row_data[self.schema_idx[idx]] = str(element)
return CSV_DELIMITER.join(row_data)

def flush(self):
for row in self.rows:
line = "%s\n" % '\001'.join([str(v) for v in row])
self.f.write(line)
data = self._ordered_row_data(row)
self.f.write(data+'\n')
self.rows = []

def write_hive_table(self):
Expand Down