diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index e1ac2208..c33f7ead 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -79,6 +79,7 @@ def apply_queries(self, callback: Callable[[str], Any]): q: Expr = next(self.gen) while True: sql = self.compiler.compile(q) + logger.debug("Running SQL (%s-TL): %s", self.compiler.database.name, sql) try: try: res = callback(sql) if sql is not SKIP else SKIP @@ -130,7 +131,8 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list): if sql_code is SKIP: return SKIP - logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code) + logger.debug("Running SQL (%s): %s", self.name, sql_code) + if self._interactive and isinstance(sql_ast, Select): explained_sql = compiler.compile(Explain(sql_ast)) explain = self._query(explained_sql) diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index 2d69efc8..56a32d48 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -11,6 +11,7 @@ Text, FractionalType, DbPath, + DbTime, Decimal, ColType, ColType_UUID, @@ -159,3 +160,6 @@ def type_repr(self, t) -> str: return {float: "REAL"}[t] except KeyError: return super().type_repr(t) + + def timestamp_value(self, t: DbTime) -> str: + return f"timestamp '{t.isoformat(' ')}'" diff --git a/data_diff/query_utils.py b/data_diff/query_utils.py index dac22fe1..825dbdc3 100644 --- a/data_diff/query_utils.py +++ b/data_diff/query_utils.py @@ -8,6 +8,7 @@ from .databases import Oracle from .queries import table, commit, Expr + def _drop_table_oracle(name: DbPath): t = table(name) # Experience shows double drop is necessary @@ -50,6 +51,7 @@ def _append_to_table(path: DbPath, expr: Expr): yield t.insert_expr(expr) yield commit + def append_to_table(db, path, expr): f = _append_to_table_oracle if isinstance(db, Oracle) else _append_to_table db.query(f(path, expr)) diff --git a/tests/common.py b/tests/common.py index a652e1c4..aad75074 100644 --- a/tests/common.py +++ b/tests/common.py @@ -120,7 +120,6 @@ def str_to_checksum(str: str): return int(md5[half_pos:], 16) - class TestPerDatabase(unittest.TestCase): db_cls = None diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index e615acc9..c87b23bf 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -15,7 +15,17 @@ from .common import str_to_checksum, test_each_database_in_list, TestPerDatabase -TEST_DATABASES = {db.MySQL, db.PostgreSQL, db.Oracle, db.Redshift, db.Snowflake, db.BigQuery} +TEST_DATABASES = { + db.MySQL, + db.PostgreSQL, + db.Oracle, + db.Redshift, + db.Snowflake, + db.BigQuery, + db.Presto, + db.Trino, + db.Vertica, +} test_each_database: Callable = test_each_database_in_list(TEST_DATABASES) @@ -383,9 +393,7 @@ def test_string_keys(self): diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, [("-", (str(self.new_uuid), "This one is different"))]) - self.connection.query( - self.src_table.insert_row('unexpected', '<-- this bad value should not break us') - ) + self.connection.query(self.src_table.insert_row("unexpected", "<-- this bad value should not break us")) self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) @@ -421,7 +429,7 @@ def setUp(self): src_table.create(), src_table.insert_rows(values), table(self.table_dst_path).create(src_table), - src_table.insert_row(self.new_alphanum, 'This one is different'), + src_table.insert_row(self.new_alphanum, "This one is different"), commit, ] @@ -491,7 +499,7 @@ def test_varying_alphanum_keys(self): self.assertEqual(diff, [("-", (str(self.new_alphanum), "This one is different"))]) self.connection.query( - self.src_table.insert_row('@@@', '<-- this bad value should not break us'), + self.src_table.insert_row("@@@", "<-- this bad value should not break us"), commit, ) @@ -548,13 +556,15 @@ def setUp(self): self.null_uuid = uuid.uuid1(32132131) - self.connection.query([ - src_table.create(), - src_table.insert_rows(values), - table(self.table_dst_path).create(src_table), - src_table.insert_row(self.null_uuid, None), - commit, - ]) + self.connection.query( + [ + src_table.create(), + src_table.insert_rows(values), + table(self.table_dst_path).create(src_table), + src_table.insert_row(self.null_uuid, None), + commit, + ] + ) self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) @@ -576,9 +586,9 @@ def setUp(self): self.connection.query( [ src_table.create(), - src_table.insert_row(uuid.uuid1(1), '1'), + src_table.insert_row(uuid.uuid1(1), "1"), table(self.table_dst_path).create(src_table), - src_table.insert_row(self.null_uuid, None), # Add a row where a column has NULL value + src_table.insert_row(self.null_uuid, None), # Add a row where a column has NULL value commit, ] ) @@ -685,24 +695,28 @@ class TestTableTableEmpty(TestPerDatabase): def setUp(self): super().setUp() - self.src_table = src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) - self.dst_table = dst_table = table(self.table_dst_path, schema={"id": str, "text_comment": str}) + self.src_table = table(self.table_src_path, schema={"id": str, "text_comment": str}) + self.dst_table = table(self.table_dst_path, schema={"id": str, "text_comment": str}) self.null_uuid = uuid.uuid1(1) self.diffs = [(uuid.uuid1(i), str(i)) for i in range(100)] - self.connection.query([src_table.create(), dst_table.create(), src_table.insert_rows(self.diffs), commit]) - self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_right_table_empty(self): + self.connection.query( + [self.src_table.create(), self.dst_table.create(), self.src_table.insert_rows(self.diffs), commit] + ) + differ = HashDiffer() self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) def test_left_table_empty(self): - self.connection.query([self.dst_table.insert_expr(self.src_table), self.src_table.truncate(), commit]) + self.connection.query( + [self.src_table.create(), self.dst_table.create(), self.dst_table.insert_rows(self.diffs), commit] + ) differ = HashDiffer() self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index d3db82e0..3bf2246d 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -17,14 +17,14 @@ TEST_DATABASES = { db.PostgreSQL, - db.Snowflake, db.MySQL, + db.Snowflake, db.BigQuery, - db.Presto, - db.Vertica, - db.Trino, db.Oracle, db.Redshift, + db.Presto, + db.Trino, + db.Vertica, } test_each_database = test_each_database_in_list(TEST_DATABASES)