5
5
from decimal import Decimal
6
6
from functools import partial
7
7
import logging
8
- from typing import Dict , List
8
+ from typing import Dict , List , Optional
9
9
10
10
from runtype import dataclass
11
11
12
+ from data_diff .databases .database_types import DbPath , Schema
13
+
12
14
13
15
from .utils import safezip
14
16
from .databases .base import Database
17
19
from .diff_tables import TableDiffer , DiffResult
18
20
from .thread_utils import ThreadedYielder
19
21
20
- from .queries import table , sum_ , min_ , max_ , avg
22
+ from .queries import table , sum_ , min_ , max_ , avg , SKIP
21
23
from .queries .api import and_ , if_ , or_ , outerjoin , leftjoin , rightjoin , this , ITable
22
- from .queries .ast_classes import Concat , Count , Expr , Random
24
+ from .queries .ast_classes import Concat , Count , Expr , Random , TablePath
23
25
from .queries .compiler import Compiler
24
26
from .queries .extras import NormalizeAsString
25
27
26
-
27
28
logger = logging .getLogger ("joindiff_tables" )
28
29
30
+ WRITE_LIMIT = 1000
31
+
29
32
30
33
def merge_dicts (dicts ):
31
34
i = iter (dicts )
@@ -60,6 +63,18 @@ def create_temp_table(c: Compiler, name: str, expr: Expr):
60
63
return f"create temporary table { c .quote (name )} as { c .compile (expr )} "
61
64
62
65
66
+ def drop_table (db , name : DbPath ):
67
+ t = TablePath (name )
68
+ db .query (t .drop (if_exists = True ))
69
+
70
+ def append_to_table (name : DbPath , expr : Expr ):
71
+ t = TablePath (name , expr .schema )
72
+ yield t .create (if_not_exists = True ) # uses expr.schema
73
+ yield 'commit'
74
+ yield t .insert_expr (expr )
75
+ yield 'commit'
76
+
77
+
63
78
def bool_to_int (x ):
64
79
return if_ (x , 1 , 0 )
65
80
@@ -117,6 +132,8 @@ class JoinDiffer(TableDiffer):
117
132
stats : dict = {}
118
133
validate_unique_key : bool = True
119
134
sample_exclusive_rows : bool = True
135
+ materialize_to_table : DbPath = None
136
+ write_limit : int = WRITE_LIMIT
120
137
121
138
def _diff_tables (self , table1 : TableSegment , table2 : TableSegment ) -> DiffResult :
122
139
db = table1 .database
@@ -128,8 +145,12 @@ def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult
128
145
129
146
130
147
bg_funcs = [partial (self ._test_duplicate_keys , table1 , table2 )] if self .validate_unique_key else []
148
+ if self .materialize_to_table :
149
+ drop_table (db , self .materialize_to_table )
150
+ db .query ('COMMIT' )
131
151
132
152
with self ._run_in_background (* bg_funcs ):
153
+
133
154
if isinstance (db , (Snowflake , BigQuery )):
134
155
# Don't segment the table; let the database handling parallelization
135
156
yield from self ._diff_segments (None , table1 , table2 , None )
@@ -147,12 +168,29 @@ def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: Tabl
147
168
f"size <= { max_rows } "
148
169
)
149
170
171
+ db = table1 .database
172
+ diff_rows , a_cols , b_cols , is_diff_cols = self ._create_outer_join (table1 , table2 )
173
+
150
174
with self ._run_in_background (
151
175
partial (self ._collect_stats , 1 , table1 ),
152
176
partial (self ._collect_stats , 2 , table2 ),
153
177
partial (self ._test_null_keys , table1 , table2 ),
178
+ partial (self ._sample_and_count_exclusive , db , diff_rows , a_cols , b_cols ),
179
+ partial (self ._count_diff_per_column , db , diff_rows , list (a_cols ), is_diff_cols ),
180
+ partial (self ._materialize_diff , db , diff_rows , segment_index = segment_index ) if self .materialize_to_table else None ,
154
181
):
155
- yield from self ._outer_join (table1 , table2 )
182
+
183
+ logger .debug ("Querying for different rows" )
184
+ for is_xa , is_xb , * x in db .query (diff_rows , list ):
185
+ if is_xa and is_xb :
186
+ # Can't both be exclusive, meaning a pk is NULL
187
+ # This can happen if the explicit null test didn't finish running yet
188
+ raise ValueError (f"NULL values in one or more primary keys" )
189
+ is_diff , a_row , b_row = _slice_tuple (x , len (is_diff_cols ), len (a_cols ), len (b_cols ))
190
+ if not is_xb :
191
+ yield "-" , tuple (a_row )
192
+ if not is_xa :
193
+ yield "+" , tuple (b_row )
156
194
157
195
def _test_duplicate_keys (self , table1 , table2 ):
158
196
logger .debug ("Testing for duplicate keys" )
@@ -162,7 +200,7 @@ def _test_duplicate_keys(self, table1, table2):
162
200
t = ts ._make_select ()
163
201
key_columns = [ts .key_column ] # XXX
164
202
165
- q = t .select (total = Count (), total_distinct = Count (Concat (key_columns ), distinct = True ))
203
+ q = t .select (total = Count (), total_distinct = Count (Concat (this [ key_columns ] ), distinct = True ))
166
204
total , total_distinct = ts .database .query (q , tuple )
167
205
if total != total_distinct :
168
206
raise ValueError ("Duplicate primary keys" )
@@ -175,7 +213,7 @@ def _test_null_keys(self, table1, table2):
175
213
t = ts ._make_select ()
176
214
key_columns = [ts .key_column ] # XXX
177
215
178
- q = t .select (* key_columns ).where (or_ (this [k ] == None for k in key_columns ))
216
+ q = t .select (* this [ key_columns ] ).where (or_ (this [k ] == None for k in key_columns ))
179
217
nulls = ts .database .query (q , list )
180
218
if nulls :
181
219
raise ValueError (f"NULL values in one or more primary keys" )
@@ -188,10 +226,10 @@ def _collect_stats(self, i, table):
188
226
# Metrics
189
227
col_exprs = merge_dicts (
190
228
{
191
- f"sum_{ c } " : sum_ (c ),
192
- f"avg_{ c } " : avg (c ),
193
- f"min_{ c } " : min_ (c ),
194
- f"max_{ c } " : max_ (c ),
229
+ f"sum_{ c } " : sum_ (this [ c ] ),
230
+ f"avg_{ c } " : avg (this [ c ] ),
231
+ f"min_{ c } " : min_ (this [ c ] ),
232
+ f"max_{ c } " : max_ (this [ c ] ),
195
233
}
196
234
for c in table ._relevant_columns
197
235
if c == "id" # TODO just if the right type
@@ -209,8 +247,7 @@ def _collect_stats(self, i, table):
209
247
# stats.diff_ratio_by_column = diff_stats
210
248
# stats.diff_ratio_total = diff_stats['total_diff']
211
249
212
-
213
- def _outer_join (self , table1 , table2 ):
250
+ def _create_outer_join (self , table1 , table2 ):
214
251
db = table1 .database
215
252
if db is not table2 .database :
216
253
raise ValueError ("Joindiff only applies to tables within the same database" )
@@ -239,23 +276,8 @@ def _outer_join(self, table1, table2):
239
276
_outerjoin (db , a , b , keys1 , keys2 , {** is_diff_cols , ** a_cols , ** b_cols })
240
277
.where (or_ (this [c ] == 1 for c in is_diff_cols ))
241
278
)
279
+ return diff_rows , a_cols , b_cols , is_diff_cols
242
280
243
- with self ._run_in_background (
244
- partial (self ._sample_and_count_exclusive , db , diff_rows , a_cols , b_cols ),
245
- partial (self ._count_diff_per_column , db , diff_rows , cols1 , is_diff_cols )
246
- ):
247
-
248
- logger .debug ("Querying for different rows" )
249
- for is_xa , is_xb , * x in db .query (diff_rows , list ):
250
- if is_xa and is_xb :
251
- # Can't both be exclusive, meaning a pk is NULL
252
- # This can happen if the explicit null test didn't finish running yet
253
- raise ValueError (f"NULL values in one or more primary keys" )
254
- is_diff , a_row , b_row = _slice_tuple (x , len (is_diff_cols ), len (a_cols ), len (b_cols ))
255
- if not is_xb :
256
- yield "-" , tuple (a_row )
257
- if not is_xa :
258
- yield "+" , tuple (b_row )
259
281
260
282
def _count_diff_per_column (self , db , diff_rows , cols , is_diff_cols ):
261
283
logger .info ("Counting differences per column" )
@@ -280,7 +302,7 @@ def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols):
280
302
def exclusive_rows (expr ):
281
303
c = Compiler (db )
282
304
name = c .new_unique_table_name ("temp_table" )
283
- yield create_temp_table (c , name , expr )
305
+ yield create_temp_table (c , name , expr . limit ( self . write_limit ) )
284
306
exclusive_rows = table (name , schema = expr .source_table .schema )
285
307
286
308
count = yield exclusive_rows .count ()
@@ -293,3 +315,10 @@ def exclusive_rows(expr):
293
315
294
316
# Run as a sequence of thread-local queries (compiled into a ThreadLocalInterpreter)
295
317
db .query (exclusive_rows (exclusive_rows_query ), None )
318
+
319
+ def _materialize_diff (self , db , diff_rows , segment_index = None ):
320
+ assert self .materialize_to_table
321
+
322
+ db .query (append_to_table (self .materialize_to_table , diff_rows .limit (self .write_limit )))
323
+ logger .info (f"Materialized diff to table '{ '.' .join (self .materialize_to_table )} '." )
324
+
0 commit comments