Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit c9d1db0

Browse files
committed
redshift: normalize schema output
There are currently 3 possible mechanisms to get schema info for redshift. Make sure their output conforms to the same pattern
1 parent 5a4b879 commit c9d1db0

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

data_diff/databases/redshift.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -121,27 +121,36 @@ def query_external_table_schema(self, path: DbPath) -> Dict[str, tuple]:
121121
if not rows:
122122
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
123123

124-
d = {r[0]: r for r in rows}
125-
assert len(d) == len(rows)
126-
return d
124+
schema_dict = self._normalize_schema_info(rows)
125+
126+
return schema_dict
127127

128128
def select_view_columns(self, path: DbPath) -> str:
129129
_, schema, table = self._normalize_table_path(path)
130130

131131
return """select * from pg_get_cols('{}.{}')
132-
cols(view_schema name, view_name name, col_name name, col_type varchar, col_num int)
133-
""".format(schema, table)
132+
cols(col_name name, col_type varchar)
133+
""".format(
134+
schema, table
135+
)
134136

135137
def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]:
136138
rows = self.query(self.select_view_columns(path), list)
137139

138140
if not rows:
139141
raise RuntimeError(f"{self.name}: View '{'.'.join(path)}' does not exist, or has no columns")
140142

141-
output = {}
143+
schema_dict = self._normalize_schema_info(rows)
144+
145+
return schema_dict
146+
147+
# when using a non-information_schema source, strip (N) from type(N) etc. to match
148+
# typical information_schema output
149+
def _normalize_schema_info(self, rows) -> Dict[str, tuple]:
150+
schema_dict = {}
142151
for r in rows:
143-
col_name = r[2]
144-
type_info = r[3].split("(")
152+
col_name = r[0]
153+
type_info = r[1].split("(")
145154
base_type = type_info[0]
146155
precision = None
147156
scale = None
@@ -153,9 +162,8 @@ def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]:
153162
scale = int(scale)
154163

155164
out = [col_name, base_type, None, precision, scale]
156-
output[col_name] = tuple(out)
157-
158-
return output
165+
schema_dict[col_name] = tuple(out)
166+
return schema_dict
159167

160168
def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
161169
try:

0 commit comments

Comments
 (0)