Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 5cd424d

Browse files
committed
Joindiff: Added support to materialize results as tables (-m)
1 parent 9f404a0 commit 5cd424d

File tree

6 files changed

+104
-38
lines changed

6 files changed

+104
-38
lines changed

data_diff/__main__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import rich
1010
import click
1111

12-
from .utils import remove_password_from_url, safezip, match_like
12+
from data_diff.databases.base import parse_table_name
13+
14+
from .utils import eval_name_template, remove_password_from_url, safezip, match_like
1315
from .diff_tables import Algorithm
1416
from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
1517
from .joindiff_tables import JoinDiffer
@@ -104,6 +106,7 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
104106
help=f"Minimal bisection threshold. Below it, data-diff will download the data and compare it locally. Default={DEFAULT_BISECTION_THRESHOLD}.",
105107
metavar="NUM",
106108
)
109+
@click.option("-m", "--materialize", default=None, metavar="TABLE_NAME", help="Materialize the diff results into a new table in the database.")
107110
@click.option(
108111
"--min-age",
109112
default=None,
@@ -126,6 +129,11 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
126129
is_flag=True,
127130
help="Column names are treated as case-sensitive. Otherwise, data-diff corrects their case according to schema.",
128131
)
132+
@click.option(
133+
"--assume-unique-key",
134+
is_flag=True,
135+
help="Skip validating the uniqueness of the key column during joindiff, which is costly in non-cloud dbs.",
136+
)
129137
@click.option(
130138
"-j",
131139
"--threads",
@@ -192,6 +200,8 @@ def _main(
192200
case_sensitive,
193201
json_output,
194202
where,
203+
assume_unique_key,
204+
materialize,
195205
threads1=None,
196206
threads2=None,
197207
__conf__=None,
@@ -256,6 +266,8 @@ def _main(
256266
differ = JoinDiffer(
257267
threaded=threaded,
258268
max_threadpool_size=threads and threads * 2,
269+
validate_unique_key = not assume_unique_key,
270+
materialize_to_table = materialize and parse_table_name(eval_name_template(materialize)),
259271
)
260272
else:
261273
assert algorithm == Algorithm.HASHDIFF

data_diff/databases/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class Database(AbstractDatabase):
107107
def name(self):
108108
return type(self).__name__
109109

110-
def query(self, sql_ast: Expr, res_type: type):
110+
def query(self, sql_ast: Expr, res_type: type = None):
111111
"Query the given SQL code/AST, and attempt to convert the result to type 'res_type'"
112112

113113
compiler = Compiler(self)

data_diff/diff_tables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _threaded_call_as_completed(self, func, iterable):
6868
@contextmanager
6969
def _run_in_background(self, *funcs):
7070
with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool:
71-
futures = [task_pool.submit(f) for f in funcs]
71+
futures = [task_pool.submit(f) for f in funcs if f is not None]
7272
yield futures
7373
for f in futures:
7474
f.result()

data_diff/joindiff_tables.py

Lines changed: 59 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
from decimal import Decimal
66
from functools import partial
77
import logging
8-
from typing import Dict, List
8+
from typing import Dict, List, Optional
99

1010
from runtype import dataclass
1111

12+
from data_diff.databases.database_types import DbPath, Schema
13+
1214

1315
from .utils import safezip
1416
from .databases.base import Database
@@ -17,15 +19,16 @@
1719
from .diff_tables import TableDiffer, DiffResult
1820
from .thread_utils import ThreadedYielder
1921

20-
from .queries import table, sum_, min_, max_, avg
22+
from .queries import table, sum_, min_, max_, avg, SKIP
2123
from .queries.api import and_, if_, or_, outerjoin, leftjoin, rightjoin, this, ITable
22-
from .queries.ast_classes import Concat, Count, Expr, Random
24+
from .queries.ast_classes import Concat, Count, Expr, Random, TablePath
2325
from .queries.compiler import Compiler
2426
from .queries.extras import NormalizeAsString
2527

26-
2728
logger = logging.getLogger("joindiff_tables")
2829

30+
WRITE_LIMIT = 1000
31+
2932

3033
def merge_dicts(dicts):
3134
i = iter(dicts)
@@ -60,6 +63,18 @@ def create_temp_table(c: Compiler, name: str, expr: Expr):
6063
return f"create temporary table {c.quote(name)} as {c.compile(expr)}"
6164

6265

66+
def drop_table(db, name: DbPath):
67+
t = TablePath(name)
68+
db.query(t.drop(if_exists=True))
69+
70+
def append_to_table(name: DbPath, expr: Expr):
71+
t = TablePath(name, expr.schema)
72+
yield t.create(if_not_exists=True) # uses expr.schema
73+
yield 'commit'
74+
yield t.insert_expr(expr)
75+
yield 'commit'
76+
77+
6378
def bool_to_int(x):
6479
return if_(x, 1, 0)
6580

@@ -117,6 +132,8 @@ class JoinDiffer(TableDiffer):
117132
stats: dict = {}
118133
validate_unique_key: bool = True
119134
sample_exclusive_rows: bool = True
135+
materialize_to_table: DbPath = None
136+
write_limit: int = WRITE_LIMIT
120137

121138
def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
122139
db = table1.database
@@ -128,8 +145,12 @@ def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult
128145

129146

130147
bg_funcs = [partial(self._test_duplicate_keys, table1, table2)] if self.validate_unique_key else []
148+
if self.materialize_to_table:
149+
drop_table(db, self.materialize_to_table)
150+
db.query('COMMIT')
131151

132152
with self._run_in_background(*bg_funcs):
153+
133154
if isinstance(db, (Snowflake, BigQuery)):
134155
# Don't segment the table; let the database handling parallelization
135156
yield from self._diff_segments(None, table1, table2, None)
@@ -147,12 +168,29 @@ def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: Tabl
147168
f"size <= {max_rows}"
148169
)
149170

171+
db = table1.database
172+
diff_rows, a_cols, b_cols, is_diff_cols = self._create_outer_join(table1, table2)
173+
150174
with self._run_in_background(
151175
partial(self._collect_stats, 1, table1),
152176
partial(self._collect_stats, 2, table2),
153177
partial(self._test_null_keys, table1, table2),
178+
partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols),
179+
partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols),
180+
partial(self._materialize_diff, db, diff_rows, segment_index=segment_index) if self.materialize_to_table else None,
154181
):
155-
yield from self._outer_join(table1, table2)
182+
183+
logger.debug("Querying for different rows")
184+
for is_xa, is_xb, *x in db.query(diff_rows, list):
185+
if is_xa and is_xb:
186+
# Can't both be exclusive, meaning a pk is NULL
187+
# This can happen if the explicit null test didn't finish running yet
188+
raise ValueError(f"NULL values in one or more primary keys")
189+
is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols))
190+
if not is_xb:
191+
yield "-", tuple(a_row)
192+
if not is_xa:
193+
yield "+", tuple(b_row)
156194

157195
def _test_duplicate_keys(self, table1, table2):
158196
logger.debug("Testing for duplicate keys")
@@ -162,7 +200,7 @@ def _test_duplicate_keys(self, table1, table2):
162200
t = ts._make_select()
163201
key_columns = [ts.key_column] # XXX
164202

165-
q = t.select(total=Count(), total_distinct=Count(Concat(key_columns), distinct=True))
203+
q = t.select(total=Count(), total_distinct=Count(Concat(this[key_columns]), distinct=True))
166204
total, total_distinct = ts.database.query(q, tuple)
167205
if total != total_distinct:
168206
raise ValueError("Duplicate primary keys")
@@ -175,7 +213,7 @@ def _test_null_keys(self, table1, table2):
175213
t = ts._make_select()
176214
key_columns = [ts.key_column] # XXX
177215

178-
q = t.select(*key_columns).where(or_(this[k] == None for k in key_columns))
216+
q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns))
179217
nulls = ts.database.query(q, list)
180218
if nulls:
181219
raise ValueError(f"NULL values in one or more primary keys")
@@ -188,10 +226,10 @@ def _collect_stats(self, i, table):
188226
# Metrics
189227
col_exprs = merge_dicts(
190228
{
191-
f"sum_{c}": sum_(c),
192-
f"avg_{c}": avg(c),
193-
f"min_{c}": min_(c),
194-
f"max_{c}": max_(c),
229+
f"sum_{c}": sum_(this[c]),
230+
f"avg_{c}": avg(this[c]),
231+
f"min_{c}": min_(this[c]),
232+
f"max_{c}": max_(this[c]),
195233
}
196234
for c in table._relevant_columns
197235
if c == "id" # TODO just if the right type
@@ -209,8 +247,7 @@ def _collect_stats(self, i, table):
209247
# stats.diff_ratio_by_column = diff_stats
210248
# stats.diff_ratio_total = diff_stats['total_diff']
211249

212-
213-
def _outer_join(self, table1, table2):
250+
def _create_outer_join(self, table1, table2):
214251
db = table1.database
215252
if db is not table2.database:
216253
raise ValueError("Joindiff only applies to tables within the same database")
@@ -239,23 +276,8 @@ def _outer_join(self, table1, table2):
239276
_outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols})
240277
.where(or_(this[c] == 1 for c in is_diff_cols))
241278
)
279+
return diff_rows, a_cols, b_cols, is_diff_cols
242280

243-
with self._run_in_background(
244-
partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols),
245-
partial(self._count_diff_per_column, db, diff_rows, cols1, is_diff_cols)
246-
):
247-
248-
logger.debug("Querying for different rows")
249-
for is_xa, is_xb, *x in db.query(diff_rows, list):
250-
if is_xa and is_xb:
251-
# Can't both be exclusive, meaning a pk is NULL
252-
# This can happen if the explicit null test didn't finish running yet
253-
raise ValueError(f"NULL values in one or more primary keys")
254-
is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols))
255-
if not is_xb:
256-
yield "-", tuple(a_row)
257-
if not is_xa:
258-
yield "+", tuple(b_row)
259281

260282
def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols):
261283
logger.info("Counting differences per column")
@@ -280,7 +302,7 @@ def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols):
280302
def exclusive_rows(expr):
281303
c = Compiler(db)
282304
name = c.new_unique_table_name("temp_table")
283-
yield create_temp_table(c, name, expr)
305+
yield create_temp_table(c, name, expr.limit(self.write_limit))
284306
exclusive_rows = table(name, schema=expr.source_table.schema)
285307

286308
count = yield exclusive_rows.count()
@@ -293,3 +315,10 @@ def exclusive_rows(expr):
293315

294316
# Run as a sequence of thread-local queries (compiled into a ThreadLocalInterpreter)
295317
db.query(exclusive_rows(exclusive_rows_query), None)
318+
319+
def _materialize_diff(self, db, diff_rows, segment_index=None):
320+
assert self.materialize_to_table
321+
322+
db.query(append_to_table(self.materialize_to_table, diff_rows.limit(self.write_limit)))
323+
logger.info(f"Materialized diff to table '{'.'.join(self.materialize_to_table)}'.")
324+

data_diff/queries/ast_classes.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class Concat(ExprNode):
140140

141141
def compile(self, c: Compiler) -> str:
142142
# We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL
143-
items = [f"coalesce({c.compile(c.database.to_string(expr))}, '<null>')" for expr in self.exprs]
143+
items = [f"coalesce({c.compile(c.database.to_string(c.compile(expr)))}, '<null>')" for expr in self.exprs]
144144
assert items
145145
if len(items) == 1:
146146
return items[0]
@@ -294,6 +294,9 @@ def create(self, if_not_exists=False):
294294
raise ValueError("Schema must have a value to create table")
295295
return CreateTable(self, if_not_exists=if_not_exists)
296296

297+
def drop(self, if_exists=False):
298+
return DropTable(self, if_exists=if_exists)
299+
297300
def insert_values(self, rows):
298301
raise NotImplementedError()
299302

@@ -513,13 +516,13 @@ def resolve_names(source_table, exprs):
513516
if isinstance(expr, ExprNode):
514517
for v in expr._dfs_values():
515518
if isinstance(v, _ResolveColumn):
516-
v.resolve(source_table._get_column(v.name))
519+
v.resolve(source_table._get_column(v.resolve_name))
517520
i += 1
518521

519522

520523
@dataclass(frozen=False, eq=False, order=False)
521524
class _ResolveColumn(ExprNode, LazyOps):
522-
name: str
525+
resolve_name: str
523526
resolved: Expr = None
524527

525528
def resolve(self, expr):
@@ -528,15 +531,22 @@ def resolve(self, expr):
528531

529532
def compile(self, c: Compiler) -> str:
530533
if self.resolved is None:
531-
raise RuntimeError(f"Column not resolved: {self.name}")
534+
raise RuntimeError(f"Column not resolved: {self.resolve_name}")
532535
return self.resolved.compile(c)
533536

534537
@property
535538
def type(self):
536539
if self.resolved is None:
537-
raise RuntimeError(f"Column not resolved: {self.name}")
540+
raise RuntimeError(f"Column not resolved: {self.resolve_name}")
538541
return self.resolved.type
539542

543+
@property
544+
def name(self):
545+
if self.resolved is None:
546+
raise RuntimeError(f"Column not resolved: {self.name}")
547+
return self.resolved.name
548+
549+
540550

541551
class This:
542552
def __getattr__(self, name):
@@ -606,6 +616,15 @@ def compile(self, c: Compiler) -> str:
606616
ne = 'IF NOT EXISTS ' if self.if_not_exists else ''
607617
return f'CREATE TABLE {ne}{c.compile(self.path)}({schema})'
608618

619+
@dataclass
620+
class DropTable(Statement):
621+
path: TablePath
622+
if_exists: bool = False
623+
624+
def compile(self, c: Compiler) -> str:
625+
ie = 'IF EXISTS ' if self.if_exists else ''
626+
return f'DROP TABLE {ie}{c.compile(self.path)}'
627+
609628
@dataclass
610629
class InsertToTable(Statement):
611630
# TODO Support insert for only some columns

data_diff/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import operator
1010
import string
1111
import threading
12+
from datetime import datetime
1213

1314
alphanums = " -" + string.digits + string.ascii_uppercase + "_" + string.ascii_lowercase
1415

@@ -295,3 +296,8 @@ def run_as_daemon(threadfunc, *args):
295296

296297
def getLogger(name):
297298
return logging.getLogger(name.rsplit('.', 1)[-1])
299+
300+
def eval_name_template(name):
301+
def get_timestamp(m):
302+
return datetime.now().isoformat('_', 'seconds').replace(':', '_')
303+
return re.sub('%t', get_timestamp, name)

0 commit comments

Comments
 (0)