Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

General tests now include Presto, Trino & Vertica; Includes small fixes #256

Merged
merged 1 commit into from
Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def apply_queries(self, callback: Callable[[str], Any]):
q: Expr = next(self.gen)
while True:
sql = self.compiler.compile(q)
logger.debug("Running SQL (%s-TL): %s", self.compiler.database.name, sql)
try:
try:
res = callback(sql) if sql is not SKIP else SKIP
Expand Down Expand Up @@ -130,7 +131,8 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list):
if sql_code is SKIP:
return SKIP

logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code)
logger.debug("Running SQL (%s): %s", self.name, sql_code)

if self._interactive and isinstance(sql_ast, Select):
explained_sql = compiler.compile(Explain(sql_ast))
explain = self._query(explained_sql)
Expand Down
4 changes: 4 additions & 0 deletions data_diff/databases/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Text,
FractionalType,
DbPath,
DbTime,
Decimal,
ColType,
ColType_UUID,
Expand Down Expand Up @@ -159,3 +160,6 @@ def type_repr(self, t) -> str:
return {float: "REAL"}[t]
except KeyError:
return super().type_repr(t)

def timestamp_value(self, t: DbTime) -> str:
return f"timestamp '{t.isoformat(' ')}'"
2 changes: 2 additions & 0 deletions data_diff/query_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .databases import Oracle
from .queries import table, commit, Expr


def _drop_table_oracle(name: DbPath):
t = table(name)
# Experience shows double drop is necessary
Expand Down Expand Up @@ -50,6 +51,7 @@ def _append_to_table(path: DbPath, expr: Expr):
yield t.insert_expr(expr)
yield commit


def append_to_table(db, path, expr):
f = _append_to_table_oracle if isinstance(db, Oracle) else _append_to_table
db.query(f(path, expr))
1 change: 0 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def str_to_checksum(str: str):
return int(md5[half_pos:], 16)



class TestPerDatabase(unittest.TestCase):
db_cls = None

Expand Down
54 changes: 34 additions & 20 deletions tests/test_diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@
from .common import str_to_checksum, test_each_database_in_list, TestPerDatabase


TEST_DATABASES = {db.MySQL, db.PostgreSQL, db.Oracle, db.Redshift, db.Snowflake, db.BigQuery}
TEST_DATABASES = {
db.MySQL,
db.PostgreSQL,
db.Oracle,
db.Redshift,
db.Snowflake,
db.BigQuery,
db.Presto,
db.Trino,
db.Vertica,
}

test_each_database: Callable = test_each_database_in_list(TEST_DATABASES)

Expand Down Expand Up @@ -383,9 +393,7 @@ def test_string_keys(self):
diff = list(differ.diff_tables(self.a, self.b))
self.assertEqual(diff, [("-", (str(self.new_uuid), "This one is different"))])

self.connection.query(
self.src_table.insert_row('unexpected', '<-- this bad value should not break us')
)
self.connection.query(self.src_table.insert_row("unexpected", "<-- this bad value should not break us"))

self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))

Expand Down Expand Up @@ -421,7 +429,7 @@ def setUp(self):
src_table.create(),
src_table.insert_rows(values),
table(self.table_dst_path).create(src_table),
src_table.insert_row(self.new_alphanum, 'This one is different'),
src_table.insert_row(self.new_alphanum, "This one is different"),
commit,
]

Expand Down Expand Up @@ -491,7 +499,7 @@ def test_varying_alphanum_keys(self):
self.assertEqual(diff, [("-", (str(self.new_alphanum), "This one is different"))])

self.connection.query(
self.src_table.insert_row('@@@', '<-- this bad value should not break us'),
self.src_table.insert_row("@@@", "<-- this bad value should not break us"),
commit,
)

Expand Down Expand Up @@ -548,13 +556,15 @@ def setUp(self):

self.null_uuid = uuid.uuid1(32132131)

self.connection.query([
src_table.create(),
src_table.insert_rows(values),
table(self.table_dst_path).create(src_table),
src_table.insert_row(self.null_uuid, None),
commit,
])
self.connection.query(
[
src_table.create(),
src_table.insert_rows(values),
table(self.table_dst_path).create(src_table),
src_table.insert_row(self.null_uuid, None),
commit,
]
)

self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False)
self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False)
Expand All @@ -576,9 +586,9 @@ def setUp(self):
self.connection.query(
[
src_table.create(),
src_table.insert_row(uuid.uuid1(1), '1'),
src_table.insert_row(uuid.uuid1(1), "1"),
table(self.table_dst_path).create(src_table),
src_table.insert_row(self.null_uuid, None), # Add a row where a column has NULL value
src_table.insert_row(self.null_uuid, None), # Add a row where a column has NULL value
commit,
]
)
Expand Down Expand Up @@ -685,24 +695,28 @@ class TestTableTableEmpty(TestPerDatabase):
def setUp(self):
super().setUp()

self.src_table = src_table = table(self.table_src_path, schema={"id": str, "text_comment": str})
self.dst_table = dst_table = table(self.table_dst_path, schema={"id": str, "text_comment": str})
self.src_table = table(self.table_src_path, schema={"id": str, "text_comment": str})
self.dst_table = table(self.table_dst_path, schema={"id": str, "text_comment": str})

self.null_uuid = uuid.uuid1(1)

self.diffs = [(uuid.uuid1(i), str(i)) for i in range(100)]

self.connection.query([src_table.create(), dst_table.create(), src_table.insert_rows(self.diffs), commit])

self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False)
self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False)

def test_right_table_empty(self):
self.connection.query(
[self.src_table.create(), self.dst_table.create(), self.src_table.insert_rows(self.diffs), commit]
)

differ = HashDiffer()
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))

def test_left_table_empty(self):
self.connection.query([self.dst_table.insert_expr(self.src_table), self.src_table.truncate(), commit])
self.connection.query(
[self.src_table.create(), self.dst_table.create(), self.dst_table.insert_rows(self.diffs), commit]
)

differ = HashDiffer()
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))
8 changes: 4 additions & 4 deletions tests/test_joindiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

TEST_DATABASES = {
db.PostgreSQL,
db.Snowflake,
db.MySQL,
db.Snowflake,
db.BigQuery,
db.Presto,
db.Vertica,
db.Trino,
db.Oracle,
db.Redshift,
db.Presto,
db.Trino,
db.Vertica,
}

test_each_database = test_each_database_in_list(TEST_DATABASES)
Expand Down