Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit bc28997

Browse files
authored
Merge pull request #740 from datafold/negotiate
Fix a few things here & there
2 parents f080ce7 + d268ff7 commit bc28997

File tree

7 files changed

+47
-23
lines changed

7 files changed

+47
-23
lines changed

data_diff/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def _data_diff(
519519

520520
else:
521521
for op, values in diff_iter:
522-
color = COLOR_SCHEME[op]
522+
color = COLOR_SCHEME.get(op, "grey62")
523523

524524
if json_output:
525525
jsonl = json.dumps([op, list(values)])

data_diff/abcs/database_types.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import decimal
22
from abc import ABC, abstractmethod
3-
from typing import Tuple, Union
3+
from typing import List, Optional, Tuple, Type, TypeVar, Union
44
from datetime import datetime
55

66
import attrs
@@ -12,9 +12,24 @@
1212
DbKey = Union[int, str, bytes, ArithUUID, ArithAlphanumeric]
1313
DbTime = datetime
1414

15+
N = TypeVar("N")
1516

16-
@attrs.define(frozen=True)
17+
18+
@attrs.define(frozen=True, kw_only=True)
1719
class ColType:
20+
# Arbitrary metadata added and fetched at runtime.
21+
_notes: List[N] = attrs.field(factory=list, init=False, hash=False, eq=False)
22+
23+
def add_note(self, note: N) -> None:
24+
self._notes.append(note)
25+
26+
def get_note(self, cls: Type[N]) -> Optional[N]:
27+
"""Get the latest added note of type ``cls`` or its descendants."""
28+
for note in reversed(self._notes):
29+
if isinstance(note, cls):
30+
return note
31+
return None
32+
1833
@property
1934
def supported(self) -> bool:
2035
return True

data_diff/databases/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def apply_queries(self, callback: Callable[[str], Any]):
182182
q: Expr = next(self.gen)
183183
while True:
184184
sql = self.compiler.database.dialect.compile(self.compiler, q)
185-
logger.debug("Running SQL (%s-TL): %s", self.compiler.database.name, sql)
185+
logger.debug("Running SQL (%s-TL):\n%s", self.compiler.database.name, sql)
186186
try:
187187
try:
188188
res = callback(sql) if sql is not SKIP else SKIP
@@ -267,6 +267,8 @@ def _compile(self, compiler: Compiler, elem) -> str:
267267
return "NULL"
268268
elif isinstance(elem, Compilable):
269269
return self.render_compilable(attrs.evolve(compiler, root=False), elem)
270+
elif isinstance(elem, ColType):
271+
return self.render_coltype(attrs.evolve(compiler, root=False), elem)
270272
elif isinstance(elem, str):
271273
return f"'{elem}'"
272274
elif isinstance(elem, (int, float)):
@@ -359,6 +361,9 @@ def render_compilable(self, c: Compiler, elem: Compilable) -> str:
359361
raise RuntimeError(f"Cannot render AST of type {elem.__class__}")
360362
# return elem.compile(compiler.replace(root=False))
361363

364+
def render_coltype(self, c: Compiler, elem: ColType) -> str:
365+
return self.type_repr(elem)
366+
362367
def render_column(self, c: Compiler, elem: Column) -> str:
363368
if c._table_context:
364369
if len(c._table_context) > 1:
@@ -876,7 +881,7 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = None):
876881
if sql_code is SKIP:
877882
return SKIP
878883

879-
logger.debug("Running SQL (%s): %s", self.name, sql_code)
884+
logger.debug("Running SQL (%s):\n%s", self.name, sql_code)
880885

881886
if self._interactive and isinstance(sql_ast, Select):
882887
explained_sql = self.compile(Explain(sql_ast))

data_diff/databases/mssql.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from data_diff.databases.base import Mixin_Schema
1818
from data_diff.abcs.database_types import (
1919
JSON,
20+
NumericType,
2021
Timestamp,
2122
TimestampTZ,
2223
DbPath,
@@ -51,7 +52,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
5152

5253
return formatted_value
5354

54-
def normalize_number(self, value: str, coltype: FractionalType) -> str:
55+
def normalize_number(self, value: str, coltype: NumericType) -> str:
5556
if coltype.precision == 0:
5657
return f"CAST(FLOOR({value}) AS VARCHAR)"
5758

data_diff/hashdiff_tables.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
from collections import defaultdict
55
from typing import Iterator
6-
from operator import attrgetter
76

87
import attrs
98

@@ -71,7 +70,8 @@ class HashDiffer(TableDiffer):
7170
"""
7271

7372
bisection_factor: int = DEFAULT_BISECTION_FACTOR
74-
bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests
73+
bisection_threshold: int = DEFAULT_BISECTION_THRESHOLD
74+
bisection_disabled: bool = False # i.e. always download the rows (used in tests)
7575

7676
stats: dict = attrs.field(factory=dict)
7777

@@ -82,7 +82,7 @@ def __attrs_post_init__(self):
8282
if self.bisection_factor < 2:
8383
raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)")
8484

85-
def _validate_and_adjust_columns(self, table1, table2):
85+
def _validate_and_adjust_columns(self, table1, table2, *, strict: bool = True):
8686
for c1, c2 in safezip(table1.relevant_columns, table2.relevant_columns):
8787
if c1 not in table1._schema:
8888
raise ValueError(f"Column '{c1}' not found in schema for table {table1}")
@@ -92,23 +92,23 @@ def _validate_and_adjust_columns(self, table1, table2):
9292
# Update schemas to minimal mutual precision
9393
col1 = table1._schema[c1]
9494
col2 = table2._schema[c2]
95-
if isinstance(col1, PrecisionType):
96-
if not isinstance(col2, PrecisionType):
95+
if isinstance(col1, PrecisionType) and isinstance(col2, PrecisionType):
96+
if strict and not isinstance(col2, PrecisionType):
9797
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")
9898

99-
lowest = min(col1, col2, key=attrgetter("precision"))
99+
lowest = min(col1, col2, key=lambda col: col.precision)
100100

101101
if col1.precision != col2.precision:
102102
logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}")
103103

104104
table1._schema[c1] = attrs.evolve(col1, precision=lowest.precision, rounds=lowest.rounds)
105105
table2._schema[c2] = attrs.evolve(col2, precision=lowest.precision, rounds=lowest.rounds)
106106

107-
elif isinstance(col1, (NumericType, Boolean)):
108-
if not isinstance(col2, (NumericType, Boolean)):
107+
elif isinstance(col1, (NumericType, Boolean)) and isinstance(col2, (NumericType, Boolean)):
108+
if strict and not isinstance(col2, (NumericType, Boolean)):
109109
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")
110110

111-
lowest = min(col1, col2, key=attrgetter("precision"))
111+
lowest = min(col1, col2, key=lambda col: col.precision)
112112

113113
if col1.precision != col2.precision:
114114
logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}")
@@ -119,11 +119,11 @@ def _validate_and_adjust_columns(self, table1, table2):
119119
table2._schema[c2] = attrs.evolve(col2, precision=lowest.precision)
120120

121121
elif isinstance(col1, ColType_UUID):
122-
if not isinstance(col2, ColType_UUID):
122+
if strict and not isinstance(col2, ColType_UUID):
123123
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")
124124

125125
elif isinstance(col1, StringType):
126-
if not isinstance(col2, StringType):
126+
if strict and not isinstance(col2, StringType):
127127
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")
128128

129129
for t in [table1, table2]:
@@ -157,7 +157,7 @@ def _diff_segments(
157157
# default, data-diff will checksum the section first (when it's below
158158
# the threshold) and _then_ download it.
159159
if BENCHMARK:
160-
if max_rows < self.bisection_threshold:
160+
if self.bisection_disabled or max_rows < self.bisection_threshold:
161161
return self._bisect_and_diff_segments(ti, table1, table2, info_tree, level=level, max_rows=max_rows)
162162

163163
(count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2])
@@ -202,7 +202,7 @@ def _bisect_and_diff_segments(
202202

203203
# If count is below the threshold, just download and compare the columns locally
204204
# This saves time, as bisection speed is limited by ping and query performance.
205-
if max_rows < self.bisection_threshold or max_space_size < self.bisection_factor * 2:
205+
if self.bisection_disabled or max_rows < self.bisection_threshold or max_space_size < self.bisection_factor * 2:
206206
rows1, rows2 = self._threaded_call("get_values", [table1, table2])
207207
json_cols = {
208208
i: colname

data_diff/queries/ast_classes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ class CaseWhen(ExprNode):
302302
def type(self):
303303
then_types = {_expr_type(case.then) for case in self.cases}
304304
if self.else_expr:
305-
then_types |= _expr_type(self.else_expr)
305+
then_types |= {_expr_type(self.else_expr)}
306306
if len(then_types) > 1:
307307
raise QB_TypeError(f"Non-matching types in when: {then_types}")
308308
(t,) = then_types

tests/test_database_types.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
import sys
12
import unittest
23
import time
34
import json
45
import re
5-
import math
66
import uuid
77
from datetime import datetime, timedelta, timezone
88
import logging
@@ -765,10 +765,13 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
765765
# reasonable amount of rows each. These will then be downloaded in
766766
# parallel, using the existing implementation.
767767
dl_factor = max(int(N_SAMPLES / 100_000), 2) if BENCHMARK else 2
768-
dl_threshold = int(N_SAMPLES / dl_factor) + 1 if BENCHMARK else math.inf
768+
dl_threshold = int(N_SAMPLES / dl_factor) + 1 if BENCHMARK else sys.maxsize
769769
dl_threads = N_THREADS
770770
differ = HashDiffer(
771-
bisection_threshold=dl_threshold, bisection_factor=dl_factor, max_threadpool_size=dl_threads
771+
bisection_factor=dl_factor,
772+
bisection_threshold=dl_threshold,
773+
bisection_disabled=True,
774+
max_threadpool_size=dl_threads,
772775
)
773776
start = time.monotonic()
774777
diff = list(differ.diff_tables(self.table, self.table2))

0 commit comments

Comments
 (0)