Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 3bffe5b

Browse files
committed
General tests now include Presto, Trino & Vertica; Includes small fixes
1 parent aa74a22 commit 3bffe5b

File tree

6 files changed

+47
-26
lines changed

6 files changed

+47
-26
lines changed

data_diff/databases/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def apply_queries(self, callback: Callable[[str], Any]):
7979
q: Expr = next(self.gen)
8080
while True:
8181
sql = self.compiler.compile(q)
82+
logger.debug("Running SQL (%s-TL): %s", self.compiler.database.name, sql)
8283
try:
8384
try:
8485
res = callback(sql) if sql is not SKIP else SKIP
@@ -130,7 +131,8 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list):
130131
if sql_code is SKIP:
131132
return SKIP
132133

133-
logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code)
134+
logger.debug("Running SQL (%s): %s", self.name, sql_code)
135+
134136
if self._interactive and isinstance(sql_ast, Select):
135137
explained_sql = compiler.compile(Explain(sql_ast))
136138
explain = self._query(explained_sql)

data_diff/databases/presto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Text,
1212
FractionalType,
1313
DbPath,
14+
DbTime,
1415
Decimal,
1516
ColType,
1617
ColType_UUID,
@@ -159,3 +160,6 @@ def type_repr(self, t) -> str:
159160
return {float: "REAL"}[t]
160161
except KeyError:
161162
return super().type_repr(t)
163+
164+
def timestamp_value(self, t: DbTime) -> str:
165+
return f"timestamp '{t.isoformat(' ')}'"

data_diff/query_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .databases import Oracle
99
from .queries import table, commit, Expr
1010

11+
1112
def _drop_table_oracle(name: DbPath):
1213
t = table(name)
1314
# Experience shows double drop is necessary
@@ -50,6 +51,7 @@ def _append_to_table(path: DbPath, expr: Expr):
5051
yield t.insert_expr(expr)
5152
yield commit
5253

54+
5355
def append_to_table(db, path, expr):
5456
f = _append_to_table_oracle if isinstance(db, Oracle) else _append_to_table
5557
db.query(f(path, expr))

tests/common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def str_to_checksum(str: str):
120120
return int(md5[half_pos:], 16)
121121

122122

123-
124123
class TestPerDatabase(unittest.TestCase):
125124
db_cls = None
126125

tests/test_diff_tables.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,17 @@
1515
from .common import str_to_checksum, test_each_database_in_list, TestPerDatabase
1616

1717

18-
TEST_DATABASES = {db.MySQL, db.PostgreSQL, db.Oracle, db.Redshift, db.Snowflake, db.BigQuery}
18+
TEST_DATABASES = {
19+
db.MySQL,
20+
db.PostgreSQL,
21+
db.Oracle,
22+
db.Redshift,
23+
db.Snowflake,
24+
db.BigQuery,
25+
db.Presto,
26+
db.Trino,
27+
db.Vertica,
28+
}
1929

2030
test_each_database: Callable = test_each_database_in_list(TEST_DATABASES)
2131

@@ -383,9 +393,7 @@ def test_string_keys(self):
383393
diff = list(differ.diff_tables(self.a, self.b))
384394
self.assertEqual(diff, [("-", (str(self.new_uuid), "This one is different"))])
385395

386-
self.connection.query(
387-
self.src_table.insert_row('unexpected', '<-- this bad value should not break us')
388-
)
396+
self.connection.query(self.src_table.insert_row("unexpected", "<-- this bad value should not break us"))
389397

390398
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))
391399

@@ -421,7 +429,7 @@ def setUp(self):
421429
src_table.create(),
422430
src_table.insert_rows(values),
423431
table(self.table_dst_path).create(src_table),
424-
src_table.insert_row(self.new_alphanum, 'This one is different'),
432+
src_table.insert_row(self.new_alphanum, "This one is different"),
425433
commit,
426434
]
427435

@@ -491,7 +499,7 @@ def test_varying_alphanum_keys(self):
491499
self.assertEqual(diff, [("-", (str(self.new_alphanum), "This one is different"))])
492500

493501
self.connection.query(
494-
self.src_table.insert_row('@@@', '<-- this bad value should not break us'),
502+
self.src_table.insert_row("@@@", "<-- this bad value should not break us"),
495503
commit,
496504
)
497505

@@ -548,13 +556,15 @@ def setUp(self):
548556

549557
self.null_uuid = uuid.uuid1(32132131)
550558

551-
self.connection.query([
552-
src_table.create(),
553-
src_table.insert_rows(values),
554-
table(self.table_dst_path).create(src_table),
555-
src_table.insert_row(self.null_uuid, None),
556-
commit,
557-
])
559+
self.connection.query(
560+
[
561+
src_table.create(),
562+
src_table.insert_rows(values),
563+
table(self.table_dst_path).create(src_table),
564+
src_table.insert_row(self.null_uuid, None),
565+
commit,
566+
]
567+
)
558568

559569
self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False)
560570
self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False)
@@ -576,9 +586,9 @@ def setUp(self):
576586
self.connection.query(
577587
[
578588
src_table.create(),
579-
src_table.insert_row(uuid.uuid1(1), '1'),
589+
src_table.insert_row(uuid.uuid1(1), "1"),
580590
table(self.table_dst_path).create(src_table),
581-
src_table.insert_row(self.null_uuid, None), # Add a row where a column has NULL value
591+
src_table.insert_row(self.null_uuid, None), # Add a row where a column has NULL value
582592
commit,
583593
]
584594
)
@@ -685,24 +695,28 @@ class TestTableTableEmpty(TestPerDatabase):
685695
def setUp(self):
686696
super().setUp()
687697

688-
self.src_table = src_table = table(self.table_src_path, schema={"id": str, "text_comment": str})
689-
self.dst_table = dst_table = table(self.table_dst_path, schema={"id": str, "text_comment": str})
698+
self.src_table = table(self.table_src_path, schema={"id": str, "text_comment": str})
699+
self.dst_table = table(self.table_dst_path, schema={"id": str, "text_comment": str})
690700

691701
self.null_uuid = uuid.uuid1(1)
692702

693703
self.diffs = [(uuid.uuid1(i), str(i)) for i in range(100)]
694704

695-
self.connection.query([src_table.create(), dst_table.create(), src_table.insert_rows(self.diffs), commit])
696-
697705
self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False)
698706
self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False)
699707

700708
def test_right_table_empty(self):
709+
self.connection.query(
710+
[self.src_table.create(), self.dst_table.create(), self.src_table.insert_rows(self.diffs), commit]
711+
)
712+
701713
differ = HashDiffer()
702714
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))
703715

704716
def test_left_table_empty(self):
705-
self.connection.query([self.dst_table.insert_expr(self.src_table), self.src_table.truncate(), commit])
717+
self.connection.query(
718+
[self.src_table.create(), self.dst_table.create(), self.dst_table.insert_rows(self.diffs), commit]
719+
)
706720

707721
differ = HashDiffer()
708722
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))

tests/test_joindiff.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717

1818
TEST_DATABASES = {
1919
db.PostgreSQL,
20-
db.Snowflake,
2120
db.MySQL,
21+
db.Snowflake,
2222
db.BigQuery,
23-
db.Presto,
24-
db.Vertica,
25-
db.Trino,
2623
db.Oracle,
2724
db.Redshift,
25+
db.Presto,
26+
db.Trino,
27+
db.Vertica,
2828
}
2929

3030
test_each_database = test_each_database_in_list(TEST_DATABASES)

0 commit comments

Comments
 (0)