|
17 | 17 | import tempfile
|
18 | 18 | import subprocess
|
19 | 19 |
|
| 20 | +CSV_DELIMITER = '\001' |
| 21 | + |
20 | 22 | class HiveDBWriter(BufferedDBWriter):
|
21 | 23 | def __init__(self, conn, table_name, table_schema, buff_size=10000):
|
22 | 24 | super().__init__(conn, table_name, table_schema, buff_size)
|
23 | 25 | self.tmp_f = tempfile.NamedTemporaryFile(dir="./")
|
24 | 26 | self.f = open(self.tmp_f.name, "w")
|
| 27 | + self.schema_idx = self._indexing_table_schema(table_schema) |
| 28 | + |
| 29 | + def _indexing_table_schema(self, table_schema): |
| 30 | + cursor = self.conn.cursor() |
| 31 | + cursor.execute("describe %s" % self.table_name) |
| 32 | + column_list = cursor.fetchall() |
| 33 | + schema_idx = [] |
| 34 | + idx_map = {} |
| 35 | + # column list: [(col1, type, desc), (col2, type, desc)...] |
| 36 | + for i, e in enumerate(column_list): |
| 37 | + idx_map[e[0]] = i |
| 38 | + |
| 39 | + for s in table_schema: |
| 40 | + if s not in idx_map: |
| 41 | + raise ValueError("column: %s should be in table columns:%s" % (s, idx_map)) |
| 42 | + schema_idx.append(idx_map[s]) |
| 43 | + |
| 44 | + return schema_idx |
| 45 | + |
| 46 | + def _ordered_row_data(self, row): |
| 47 | + # Use NULL as the default value for hive columns |
| 48 | + row_data = ["NULL" for i in range(len(self.table_schema))] |
| 49 | + for idx, element in enumerate(row): |
| 50 | + row_data[self.schema_idx[idx]] = str(element) |
| 51 | + return CSV_DELIMITER.join(row_data) |
25 | 52 |
|
26 | 53 | def flush(self):
|
27 | 54 | for row in self.rows:
|
28 |
| - line = "%s\n" % '\001'.join([str(v) for v in row]) |
29 |
| - self.f.write(line) |
| 55 | + data = self._ordered_row_data(row) |
| 56 | + self.f.write(data+'\n') |
30 | 57 | self.rows = []
|
31 | 58 |
|
32 | 59 | def write_hive_table(self):
|
|
0 commit comments