Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 00ee415

Browse files
committed
Joindiff: Ran black
1 parent 733972a commit 00ee415

File tree

8 files changed

+107
-58
lines changed

8 files changed

+107
-58
lines changed

data_diff/__main__.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ def _get_schema(pair):
4848

4949

5050
def diff_schemas(schema1, schema2, columns):
51-
logging.info('Diffing schemas...')
52-
attrs = 'name', 'type', 'datetime_precision', 'numeric_precision', 'numeric_scale'
51+
logging.info("Diffing schemas...")
52+
attrs = "name", "type", "datetime_precision", "numeric_precision", "numeric_scale"
5353
for c in columns:
54-
if c is None: # Skip for convenience
54+
if c is None: # Skip for convenience
5555
continue
5656
diffs = []
5757
for attr, v1, v2 in safezip(attrs, schema1[c], schema2[c]):
@@ -60,6 +60,7 @@ def diff_schemas(schema1, schema2, columns):
6060
if diffs:
6161
logging.warning(f"Schema mismatch in column '{c}': {', '.join(diffs)}")
6262

63+
6364
class MyHelpFormatter(click.HelpFormatter):
6465
def __init__(self, **kwargs):
6566
super().__init__(self, **kwargs)
@@ -106,7 +107,13 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
106107
help=f"Minimal bisection threshold. Below it, data-diff will download the data and compare it locally. Default={DEFAULT_BISECTION_THRESHOLD}.",
107108
metavar="NUM",
108109
)
109-
@click.option("-m", "--materialize", default=None, metavar="TABLE_NAME", help="Materialize the diff results into a new table in the database.")
110+
@click.option(
111+
"-m",
112+
"--materialize",
113+
default=None,
114+
metavar="TABLE_NAME",
115+
help="Materialize the diff results into a new table in the database. (joindiff only)",
116+
)
110117
@click.option(
111118
"--min-age",
112119
default=None,
@@ -266,8 +273,8 @@ def _main(
266273
differ = JoinDiffer(
267274
threaded=threaded,
268275
max_threadpool_size=threads and threads * 2,
269-
validate_unique_key = not assume_unique_key,
270-
materialize_to_table = materialize and parse_table_name(eval_name_template(materialize)),
276+
validate_unique_key=not assume_unique_key,
277+
materialize_to_table=materialize and parse_table_name(eval_name_template(materialize)),
271278
)
272279
else:
273280
assert algorithm == Algorithm.HASHDIFF
@@ -326,8 +333,15 @@ def _main(
326333
columns = tuple(expanded_columns - {key_column, update_column})
327334

328335
if db1 is db2:
329-
diff_schemas(schema1, schema2, (key_column, update_column,) + columns)
330-
336+
diff_schemas(
337+
schema1,
338+
schema2,
339+
(
340+
key_column,
341+
update_column,
342+
)
343+
+ columns,
344+
)
331345

332346
logging.info(f"Diffing using columns: key={key_column} update={update_column} extra={columns}")
333347

data_diff/databases/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
DbPath,
2828
)
2929

30-
from data_diff.queries import Expr, Compiler, table, Select, SKIP, ThreadLocalInterpreter
30+
from data_diff.queries import Expr, Compiler, table, Select, SKIP, ThreadLocalInterpreter
3131

3232
logger = logging.getLogger("database")
3333

@@ -75,6 +75,7 @@ def _query_cursor(c, sql_code):
7575
logger.exception(e)
7676
raise
7777

78+
7879
def _query_conn(conn, sql_code: Union[str, ThreadLocalInterpreter]) -> list:
7980
c = conn.cursor()
8081

data_diff/databases/presto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
TIMESTAMP_PRECISION_POS,
1212
)
1313

14+
1415
def query_cursor(c, sql_code):
1516
c.execute(sql_code)
1617
if sql_code.lower().startswith("select"):
@@ -87,7 +88,6 @@ def _query(self, sql_code: str) -> list:
8788

8889
return query_cursor(c, sql_code)
8990

90-
9191
def close(self):
9292
self._conn.close()
9393

data_diff/diff_tables.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
logger = getLogger(__name__)
2222

23+
2324
class Algorithm(Enum):
2425
AUTO = "auto"
2526
JOINDIFF = "joindiff"
@@ -28,8 +29,9 @@ class Algorithm(Enum):
2829

2930
DiffResult = Iterator[Tuple[str, tuple]] # Iterator[Tuple[Literal["+", "-"], tuple]]
3031

32+
3133
def truncate_error(error: str):
32-
first_line = error.split('\n', 1)[0]
34+
first_line = error.split("\n", 1)[0]
3335
return re.sub("'(.*?)'", "'***'", first_line)
3436

3537

@@ -137,12 +139,19 @@ def _validate_and_adjust_columns(self, table1: TableSegment, table2: TableSegmen
137139
def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
138140
return self._bisect_and_diff_tables(table1, table2)
139141

140-
141142
@abstractmethod
142-
def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None):
143+
def _diff_segments(
144+
self,
145+
ti: ThreadedYielder,
146+
table1: TableSegment,
147+
table2: TableSegment,
148+
max_rows: int,
149+
level=0,
150+
segment_index=None,
151+
segment_count=None,
152+
):
143153
...
144154

145-
146155
def _bisect_and_diff_tables(self, table1, table2):
147156
key_type = table1._schema[table1.key_column]
148157
key_type2 = table2._schema[table2.key_column]
@@ -183,7 +192,6 @@ def _bisect_and_diff_tables(self, table1, table2):
183192

184193
return ti
185194

186-
187195
def _parse_key_range_result(self, key_type, key_range):
188196
mn, mx = key_range
189197
cls = key_type.make_value
@@ -193,8 +201,9 @@ def _parse_key_range_result(self, key_type, key_range):
193201
except (TypeError, ValueError) as e:
194202
raise type(e)(f"Cannot apply {key_type} to '{mn}', '{mx}'.") from e
195203

196-
197-
def _bisect_and_diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None):
204+
def _bisect_and_diff_segments(
205+
self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None
206+
):
198207
assert table1.is_bounded and table2.is_bounded
199208

200209
# Choose evenly spaced checkpoints (according to min_key and max_key)

data_diff/hashdiff_tables.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ def __post_init__(self):
6666
if self.bisection_factor < 2:
6767
raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)")
6868

69-
70-
7169
def _validate_and_adjust_columns(self, table1, table2):
7270
for c1, c2 in safezip(table1._relevant_columns, table2._relevant_columns):
7371
if c1 not in table1._schema:
@@ -115,8 +113,16 @@ def _validate_and_adjust_columns(self, table1, table2):
115113
"If encoding/formatting differs between databases, it may result in false positives."
116114
)
117115

118-
119-
def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None):
116+
def _diff_segments(
117+
self,
118+
ti: ThreadedYielder,
119+
table1: TableSegment,
120+
table2: TableSegment,
121+
max_rows: int,
122+
level=0,
123+
segment_index=None,
124+
segment_count=None,
125+
):
120126
logger.info(
121127
". " * level + f"Diffing segment {segment_index}/{segment_count}, "
122128
f"key-range: {table1.min_key}..{table2.max_key}, "
@@ -148,7 +154,9 @@ def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: Tabl
148154
if checksum1 != checksum2:
149155
return self._bisect_and_diff_segments(ti, table1, table2, level=level, max_rows=max(count1, count2))
150156

151-
def _bisect_and_diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None):
157+
def _bisect_and_diff_segments(
158+
self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None
159+
):
152160
assert table1.is_bounded and table2.is_bounded
153161

154162
if max_rows is None:

data_diff/joindiff_tables.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class Stats:
5050
def sample(table):
5151
return table.order_by(Random()).limit(10)
5252

53+
5354
def create_temp_table(c: Compiler, name: str, expr: Expr):
5455
db = c.database
5556
if isinstance(db, BigQuery):
@@ -67,12 +68,13 @@ def drop_table(db, name: DbPath):
6768
t = TablePath(name)
6869
db.query(t.drop(if_exists=True))
6970

71+
7072
def append_to_table(name: DbPath, expr: Expr):
7173
t = TablePath(name, expr.schema)
7274
yield t.create(if_not_exists=True) # uses expr.schema
73-
yield 'commit'
75+
yield "commit"
7476
yield t.insert_expr(expr)
75-
yield 'commit'
77+
yield "commit"
7678

7779

7880
def bool_to_int(x):
@@ -95,10 +97,7 @@ def _outerjoin(db: Database, a: ITable, b: ITable, keys1: List[str], keys2: List
9597
r = rightjoin(a, b).on(*on).select(is_exclusive_a=False, is_exclusive_b=is_exclusive_b, **select_fields)
9698
return l.union(r)
9799

98-
return (
99-
outerjoin(a, b).on(*on)
100-
.select(is_exclusive_a=is_exclusive_a, is_exclusive_b=is_exclusive_b, **select_fields)
101-
)
100+
return outerjoin(a, b).on(*on).select(is_exclusive_a=is_exclusive_a, is_exclusive_b=is_exclusive_b, **select_fields)
102101

103102

104103
def _slice_tuple(t, *sizes):
@@ -115,7 +114,6 @@ def json_friendly_value(v):
115114
return v
116115

117116

118-
119117
@dataclass
120118
class JoinDiffer(TableDiffer):
121119
"""Finds the diff between two SQL tables in the same database, using JOINs.
@@ -143,11 +141,10 @@ def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult
143141

144142
table1, table2 = self._threaded_call("with_schema", [table1, table2])
145143

146-
147144
bg_funcs = [partial(self._test_duplicate_keys, table1, table2)] if self.validate_unique_key else []
148145
if self.materialize_to_table:
149146
drop_table(db, self.materialize_to_table)
150-
db.query('COMMIT')
147+
db.query("COMMIT")
151148

152149
with self._run_in_background(*bg_funcs):
153150

@@ -158,7 +155,16 @@ def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult
158155
yield from self._bisect_and_diff_tables(table1, table2)
159156
logger.info("Diffing complete")
160157

161-
def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None):
158+
def _diff_segments(
159+
self,
160+
ti: ThreadedYielder,
161+
table1: TableSegment,
162+
table2: TableSegment,
163+
max_rows: int,
164+
level=0,
165+
segment_index=None,
166+
segment_count=None,
167+
):
162168
assert table1.database is table2.database
163169

164170
if segment_index or table1.min_key or max_rows:
@@ -172,13 +178,15 @@ def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: Tabl
172178
diff_rows, a_cols, b_cols, is_diff_cols = self._create_outer_join(table1, table2)
173179

174180
with self._run_in_background(
175-
partial(self._collect_stats, 1, table1),
176-
partial(self._collect_stats, 2, table2),
177-
partial(self._test_null_keys, table1, table2),
178-
partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols),
179-
partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols),
180-
partial(self._materialize_diff, db, diff_rows, segment_index=segment_index) if self.materialize_to_table else None,
181-
):
181+
partial(self._collect_stats, 1, table1),
182+
partial(self._collect_stats, 2, table2),
183+
partial(self._test_null_keys, table1, table2),
184+
partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols),
185+
partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols),
186+
partial(self._materialize_diff, db, diff_rows, segment_index=segment_index)
187+
if self.materialize_to_table
188+
else None,
189+
):
182190

183191
logger.debug("Querying for different rows")
184192
for is_xa, is_xb, *x in db.query(diff_rows, list):
@@ -218,7 +226,6 @@ def _test_null_keys(self, table1, table2):
218226
if nulls:
219227
raise ValueError(f"NULL values in one or more primary keys")
220228

221-
222229
def _collect_stats(self, i, table):
223230
logger.info(f"Collecting stats for table #{i}")
224231
db = table.database
@@ -265,31 +272,27 @@ def _create_outer_join(self, table1, table2):
265272
a = table1._make_select()
266273
b = table2._make_select()
267274

268-
is_diff_cols = {
269-
f"is_diff_{c1}": bool_to_int(a[c1].is_distinct_from(b[c2])) for c1, c2 in safezip(cols1, cols2)
270-
}
275+
is_diff_cols = {f"is_diff_{c1}": bool_to_int(a[c1].is_distinct_from(b[c2])) for c1, c2 in safezip(cols1, cols2)}
271276

272277
a_cols = {f"table1_{c}": NormalizeAsString(a[c]) for c in cols1}
273278
b_cols = {f"table2_{c}": NormalizeAsString(b[c]) for c in cols2}
274279

275-
diff_rows = (
276-
_outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols})
277-
.where(or_(this[c] == 1 for c in is_diff_cols))
280+
diff_rows = _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols}).where(
281+
or_(this[c] == 1 for c in is_diff_cols)
278282
)
279283
return diff_rows, a_cols, b_cols, is_diff_cols
280284

281-
282285
def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols):
283286
logger.info("Counting differences per column")
284287
is_diff_cols_counts = db.query(diff_rows.select(sum_(this[c]) for c in is_diff_cols), tuple)
285288
diff_counts = {}
286289
for name, count in safezip(cols, is_diff_cols_counts):
287290
diff_counts[name] = diff_counts.get(name, 0) + (count or 0)
288-
self.stats['diff_counts'] = diff_counts
291+
self.stats["diff_counts"] = diff_counts
289292

290293
def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols):
291294
if isinstance(db, Oracle):
292-
exclusive_rows_query = diff_rows.where((this.is_exclusive_a==1) | (this.is_exclusive_b==1))
295+
exclusive_rows_query = diff_rows.where((this.is_exclusive_a == 1) | (this.is_exclusive_b == 1))
293296
else:
294297
exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b)
295298

@@ -299,16 +302,17 @@ def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols):
299302
return
300303

301304
logger.info("Counting and sampling exclusive rows")
305+
302306
def exclusive_rows(expr):
303307
c = Compiler(db)
304308
name = c.new_unique_table_name("temp_table")
305309
yield create_temp_table(c, name, expr.limit(self.write_limit))
306310
exclusive_rows = table(name, schema=expr.source_table.schema)
307311

308312
count = yield exclusive_rows.count()
309-
self.stats["exclusive_count"] = self.stats.get('exclusive_count', 0) + count[0][0]
313+
self.stats["exclusive_count"] = self.stats.get("exclusive_count", 0) + count[0][0]
310314
sample_rows = yield sample(exclusive_rows.select(*this[list(a_cols)], *this[list(b_cols)]))
311-
self.stats["exclusive_sample"] = self.stats.get('exclusive_sample', []) + sample_rows
315+
self.stats["exclusive_sample"] = self.stats.get("exclusive_sample", []) + sample_rows
312316

313317
# Only drops if create table succeeded (meaning, the table didn't already exist)
314318
yield f"drop table {c.quote(name)}"
@@ -321,4 +325,3 @@ def _materialize_diff(self, db, diff_rows, segment_index=None):
321325

322326
db.query(append_to_table(self.materialize_to_table, diff_rows.limit(self.write_limit)))
323327
logger.info(f"Materialized diff to table '{'.'.join(self.materialize_to_table)}'.")
324-

data_diff/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,11 @@ def run_as_daemon(threadfunc, *args):
295295

296296

297297
def getLogger(name):
298-
return logging.getLogger(name.rsplit('.', 1)[-1])
298+
return logging.getLogger(name.rsplit(".", 1)[-1])
299+
299300

300301
def eval_name_template(name):
301302
def get_timestamp(m):
302-
return datetime.now().isoformat('_', 'seconds').replace(':', '_')
303-
return re.sub('%t', get_timestamp, name)
303+
return datetime.now().isoformat("_", "seconds").replace(":", "_")
304+
305+
return re.sub("%t", get_timestamp, name)

0 commit comments

Comments
 (0)