Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit d1b9985

Browse files
committed
Many fixes; Added materialize tests;
Now works for : postgresql, mysql, bigquery, presto, trino, snowflake, oracle, redshift
1 parent 00ee415 commit d1b9985

18 files changed

+232
-123
lines changed

data_diff/__main__.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
import rich
1010
import click
1111

12-
from data_diff.databases.base import parse_table_name
13-
1412
from .utils import eval_name_template, remove_password_from_url, safezip, match_like
1513
from .diff_tables import Algorithm
1614
from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
@@ -269,22 +267,6 @@ def _main(
269267
logging.error(f"Error while parsing age expression: {e}")
270268
return
271269

272-
if algorithm == Algorithm.JOINDIFF:
273-
differ = JoinDiffer(
274-
threaded=threaded,
275-
max_threadpool_size=threads and threads * 2,
276-
validate_unique_key=not assume_unique_key,
277-
materialize_to_table=materialize and parse_table_name(eval_name_template(materialize)),
278-
)
279-
else:
280-
assert algorithm == Algorithm.HASHDIFF
281-
differ = HashDiffer(
282-
bisection_factor=bisection_factor,
283-
bisection_threshold=bisection_threshold,
284-
threaded=threaded,
285-
max_threadpool_size=threads and threads * 2,
286-
)
287-
288270
if database1 is None or database2 is None:
289271
logging.error(
290272
f"Error: Databases not specified. Got {database1} and {database2}. Use --help for more information."
@@ -307,6 +289,22 @@ def _main(
307289
for db in dbs:
308290
db.enable_interactive()
309291

292+
if algorithm == Algorithm.JOINDIFF:
293+
differ = JoinDiffer(
294+
threaded=threaded,
295+
max_threadpool_size=threads and threads * 2,
296+
validate_unique_key=not assume_unique_key,
297+
materialize_to_table=materialize and db1.parse_table_name(eval_name_template(materialize)),
298+
)
299+
else:
300+
assert algorithm == Algorithm.HASHDIFF
301+
differ = HashDiffer(
302+
bisection_factor=bisection_factor,
303+
bisection_threshold=bisection_threshold,
304+
threaded=threaded,
305+
max_threadpool_size=threads and threads * 2,
306+
)
307+
310308
table_names = table1, table2
311309
table_paths = [db.parse_table_name(t) for db, t in safezip(dbs, table_names)]
312310

data_diff/databases/base.py

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import math
22
import sys
33
import logging
4-
from typing import Dict, Generator, Tuple, Optional, Sequence, Type, List, Union
5-
from functools import wraps
4+
from typing import Any, Callable, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union
5+
from functools import partial, wraps
66
from concurrent.futures import ThreadPoolExecutor
77
import threading
88
from abc import abstractmethod
@@ -27,7 +27,7 @@
2727
DbPath,
2828
)
2929

30-
from data_diff.queries import Expr, Compiler, table, Select, SKIP, ThreadLocalInterpreter
30+
from data_diff.queries import Expr, Compiler, table, Select, SKIP
3131

3232
logger = logging.getLogger("database")
3333

@@ -66,30 +66,39 @@ def _one(seq):
6666
return x
6767

6868

69-
def _query_cursor(c, sql_code):
70-
try:
71-
c.execute(sql_code)
72-
if sql_code.lower().startswith("select"):
73-
return c.fetchall()
74-
except Exception as e:
75-
logger.exception(e)
76-
raise
69+
class ThreadLocalInterpreter:
70+
"""An interpeter used to execute a sequence of queries within the same thread.
7771
72+
Useful for cursor-sensitive operations, such as creating a temporary table.
73+
"""
7874

79-
def _query_conn(conn, sql_code: Union[str, ThreadLocalInterpreter]) -> list:
80-
c = conn.cursor()
75+
def __init__(self, compiler: Compiler, gen: Generator):
76+
self.gen = gen
77+
self.compiler = compiler
8178

82-
if isinstance(sql_code, ThreadLocalInterpreter):
83-
g = sql_code.interpret()
84-
q = next(g)
79+
def apply_queries(self, callback: Callable[[str], Any]):
80+
q: Expr = next(self.gen)
8581
while True:
86-
res = _query_cursor(c, q)
82+
sql = self.compiler.compile(q)
8783
try:
88-
q = g.send(res)
84+
try:
85+
res = callback(sql) if sql is not SKIP else SKIP
86+
except Exception as e:
87+
q = self.gen.throw(type(e), e)
88+
else:
89+
q = self.gen.send(res)
8990
except StopIteration:
9091
break
92+
93+
94+
def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocalInterpreter]) -> list:
95+
if isinstance(sql_code, ThreadLocalInterpreter):
96+
return sql_code.apply_queries(callback)
9197
else:
92-
return _query_cursor(c, sql_code)
98+
return callback(sql_code)
99+
100+
101+
93102

94103

95104
class Database(AbstractDatabase):
@@ -108,11 +117,17 @@ class Database(AbstractDatabase):
108117
def name(self):
109118
return type(self).__name__
110119

111-
def query(self, sql_ast: Expr, res_type: type = None):
120+
def query(self, sql_ast: Union[Expr, Generator], res_type: type = None):
112121
"Query the given SQL code/AST, and attempt to convert the result to type 'res_type'"
113122

114123
compiler = Compiler(self)
115-
sql_code = compiler.compile(sql_ast)
124+
if isinstance(sql_ast, Generator):
125+
sql_code = ThreadLocalInterpreter(compiler, sql_ast)
126+
else:
127+
sql_code = compiler.compile(sql_ast)
128+
if sql_code is SKIP:
129+
return SKIP
130+
116131
logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code)
117132
if getattr(self, "_interactive", False) and isinstance(sql_ast, Select):
118133
explained_sql = compiler.compile(Explain(sql_ast))
@@ -311,6 +326,34 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
311326
def random(self) -> str:
312327
return "RANDOM()"
313328

329+
def type_repr(self, t) -> str:
330+
if isinstance(t, str):
331+
return t
332+
return {
333+
int: "INT",
334+
str: "VARCHAR",
335+
bool: "BOOLEAN",
336+
float: "FLOAT",
337+
}[t]
338+
339+
def _query_cursor(self, c, sql_code: str):
340+
assert isinstance(sql_code, str), sql_code
341+
try:
342+
c.execute(sql_code)
343+
if sql_code.lower().startswith("select"):
344+
return c.fetchall()
345+
except Exception as e:
346+
# logger.exception(e)
347+
# logger.error(f'Caused by SQL: {sql_code}')
348+
raise
349+
350+
def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> list:
351+
c = conn.cursor()
352+
callback = partial(self._query_cursor, c)
353+
return apply_query(callback, sql_code)
354+
355+
356+
314357

315358
class ThreadedDatabase(Database):
316359
"""Access the database through singleton threads.
@@ -339,7 +382,7 @@ def _query_in_worker(self, sql_code: Union[str, ThreadLocalInterpreter]):
339382
"This method runs in a worker thread"
340383
if self._init_error:
341384
raise self._init_error
342-
return _query_conn(self.thread_local.conn, sql_code)
385+
return self._query_conn(self.thread_local.conn, sql_code)
343386

344387
@abstractmethod
345388
def create_connection(self):
@@ -348,6 +391,10 @@ def create_connection(self):
348391
def close(self):
349392
self._queue.shutdown()
350393

394+
@property
395+
def is_autocommit(self) -> bool:
396+
return False
397+
351398

352399
CHECKSUM_HEXDIGITS = 15 # Must be 15 or lower
353400
MD5_HEXDIGITS = 32

data_diff/databases/bigquery.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .database_types import *
2-
from .base import Database, import_helper, parse_table_name, ConnectError
3-
from .base import TIMESTAMP_PRECISION_POS
2+
from .base import Database, import_helper, parse_table_name, ConnectError, apply_query
3+
from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter
44

55

66
@import_helper(text="Please install BigQuery and configure your google-cloud access.")
@@ -47,7 +47,7 @@ def _normalize_returned_value(self, value):
4747
return value.decode()
4848
return value
4949

50-
def _query(self, sql_code: str):
50+
def _query_atom(self, sql_code: str):
5151
from google.cloud import bigquery
5252

5353
try:
@@ -60,6 +60,9 @@ def _query(self, sql_code: str):
6060
res = [tuple(self._normalize_returned_value(v) for v in row.values()) for row in res]
6161
return res
6262

63+
def _query(self, sql_code: Union[str, ThreadLocalInterpreter]):
64+
return apply_query(self._query_atom, sql_code)
65+
6366
def to_string(self, s: str):
6467
return f"cast({s} as string)"
6568

@@ -98,3 +101,15 @@ def parse_table_name(self, name: str) -> DbPath:
98101

99102
def random(self) -> str:
100103
return "RAND()"
104+
105+
@property
106+
def is_autocommit(self) -> bool:
107+
return True
108+
109+
def type_repr(self, t) -> str:
110+
try:
111+
return {
112+
str: "STRING",
113+
}[t]
114+
except KeyError:
115+
return super().type_repr(t)

data_diff/databases/database_types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import decimal
3-
from abc import ABC, abstractmethod
3+
from abc import ABC, abstractmethod, abstractproperty
44
from typing import Sequence, Optional, Tuple, Union, Dict, List
55
from datetime import datetime
66

@@ -293,6 +293,10 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
293293
def _normalize_table_path(self, path: DbPath) -> DbPath:
294294
...
295295

296+
@abstractproperty
297+
def is_autocommit(self) -> bool:
298+
...
299+
296300

297301
Schema = CaseAwareMapping
298302

data_diff/databases/databricks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22

33
from .database_types import *
4-
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Database, import_helper, _query_conn, parse_table_name
4+
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Database, import_helper, parse_table_name
55

66

77
@import_helper(text="You can install it using 'pip install databricks-sql-connector'")
@@ -52,7 +52,7 @@ def __init__(
5252

5353
def _query(self, sql_code: str) -> list:
5454
"Uses the standard SQL cursor interface"
55-
return _query_conn(self._conn, sql_code)
55+
return self._query_conn(self._conn, sql_code)
5656

5757
def quote(self, s: str):
5858
return f"`{s}`"

data_diff/databases/mysql.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,11 @@ def is_distinct_from(self, a: str, b: str) -> str:
7676

7777
def random(self) -> str:
7878
return "RAND()"
79+
80+
def type_repr(self, t) -> str:
81+
try:
82+
return {
83+
str: "VARCHAR(1024)",
84+
}[t]
85+
except KeyError:
86+
return super().type_repr(t)

data_diff/databases/oracle.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def create_connection(self):
4343
except Exception as e:
4444
raise ConnectError(*e.args) from e
4545

46-
def _query(self, sql_code: str):
46+
def _query_cursor(self, c, sql_code: str):
4747
try:
48-
return super()._query(sql_code)
48+
return super()._query_cursor(c, sql_code)
4949
except self._oracle.DatabaseError as e:
5050
raise QueryError(e)
5151

@@ -130,3 +130,11 @@ def random(self) -> str:
130130

131131
def is_distinct_from(self, a: str, b: str) -> str:
132132
return f"DECODE({a}, {b}, 1, 0) = 0"
133+
134+
def type_repr(self, t) -> str:
135+
try:
136+
return {
137+
str: "VARCHAR(1024)",
138+
}[t]
139+
except KeyError:
140+
return super().type_repr(t)

data_diff/databases/presto.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
from functools import partial
12
import re
23

34
from data_diff.utils import match_regexps
4-
from data_diff.queries import ThreadLocalInterpreter
55

66
from .database_types import *
7-
from .base import Database, import_helper
7+
from .base import Database, import_helper, ThreadLocalInterpreter
88
from .base import (
99
MD5_HEXDIGITS,
1010
CHECKSUM_HEXDIGITS,
@@ -15,7 +15,7 @@
1515
def query_cursor(c, sql_code):
1616
c.execute(sql_code)
1717
if sql_code.lower().startswith("select"):
18-
return c.fetchall()
18+
return [tuple(x) for x in c.fetchall()]
1919
# Required for the query to actually run 🤯
2020
if re.match(r"(insert|create|truncate|drop)", sql_code, re.IGNORECASE):
2121
return c.fetchone()
@@ -75,16 +75,7 @@ def _query(self, sql_code: str) -> list:
7575
c = self._conn.cursor()
7676

7777
if isinstance(sql_code, ThreadLocalInterpreter):
78-
# TODO reuse code from base.py
79-
g = sql_code.interpret()
80-
q = next(g)
81-
while True:
82-
res = query_cursor(c, q)
83-
try:
84-
q = g.send(res)
85-
except StopIteration:
86-
break
87-
return
78+
return sql_code.apply_queries(partial(query_cursor, c))
8879

8980
return query_cursor(c, sql_code)
9081

@@ -142,3 +133,6 @@ def _parse_type(
142133
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
143134
# Trim doesn't work on CHAR type
144135
return f"TRIM(CAST({value} AS VARCHAR))"
136+
137+
def is_autocommit(self) -> bool:
138+
return False

data_diff/databases/snowflake.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22

33
from .database_types import *
4-
from .base import ConnectError, Database, import_helper, _query_conn, CHECKSUM_MASK
4+
from .base import ConnectError, Database, import_helper, CHECKSUM_MASK, ThreadLocalInterpreter
55

66

77
@import_helper("snowflake")
@@ -60,9 +60,9 @@ def __init__(self, *, schema: str, **kw):
6060
def close(self):
6161
self._conn.close()
6262

63-
def _query(self, sql_code: str) -> list:
63+
def _query(self, sql_code: Union[str, ThreadLocalInterpreter]):
6464
"Uses the standard SQL cursor interface"
65-
return _query_conn(self._conn, sql_code)
65+
return self._query_conn(self._conn, sql_code)
6666

6767
def quote(self, s: str):
6868
return f'"{s}"'
@@ -87,3 +87,6 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
8787

8888
def normalize_number(self, value: str, coltype: FractionalType) -> str:
8989
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")
90+
91+
def is_autocommit(self) -> bool:
92+
return True

0 commit comments

Comments
 (0)