Skip to content

Commit 59a4113

Browse files
authored
fix write hive table with unordered row data (#757)
1 parent f961e37 commit 59a4113

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

sql/python/sqlflow_submitter/db_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def test_hive(self):
8585

8686
def _do_test(self, driver, conn):
8787
table_name = "test_db"
88-
table_schema = ["features", "label"]
89-
values = [('5,6,1,2', 1)] * 10
88+
table_schema = ["label", "features"]
89+
values = [(1, '5,6,1,2')] * 10
9090

9191
execute(driver, conn, self.drop_statement)
9292

sql/python/sqlflow_submitter/db_writer/hive.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,43 @@
1717
import tempfile
1818
import subprocess
1919

20+
CSV_DELIMITER = '\001'
21+
2022
class HiveDBWriter(BufferedDBWriter):
2123
def __init__(self, conn, table_name, table_schema, buff_size=10000):
2224
super().__init__(conn, table_name, table_schema, buff_size)
2325
self.tmp_f = tempfile.NamedTemporaryFile(dir="./")
2426
self.f = open(self.tmp_f.name, "w")
27+
self.schema_idx = self._indexing_table_schema(table_schema)
28+
29+
def _indexing_table_schema(self, table_schema):
30+
cursor = self.conn.cursor()
31+
cursor.execute("describe %s" % self.table_name)
32+
column_list = cursor.fetchall()
33+
schema_idx = []
34+
idx_map = {}
35+
# column list: [(col1, type, desc), (col2, type, desc)...]
36+
for i, e in enumerate(column_list):
37+
idx_map[e[0]] = i
38+
39+
for s in table_schema:
40+
if s not in idx_map:
41+
raise ValueError("column: %s should be in table columns:%s" % (s, idx_map))
42+
schema_idx.append(idx_map[s])
43+
44+
return schema_idx
45+
46+
def _ordered_row_data(self, row):
47+
# Use NULL as the default value for hive columns
48+
row_data = ["NULL" for i in range(len(self.table_schema))]
49+
for idx, element in enumerate(row):
50+
row_data[self.schema_idx[idx]] = str(element)
51+
return CSV_DELIMITER.join(row_data)
2552

2653
def flush(self):
2754
for row in self.rows:
28-
line = "%s\n" % '\001'.join([str(v) for v in row])
29-
self.f.write(line)
55+
data = self._ordered_row_data(row)
56+
self.f.write(data+'\n')
3057
self.rows = []
3158

3259
def write_hive_table(self):

0 commit comments

Comments
 (0)