Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Fix a few things here & there #740

Merged
merged 8 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def _data_diff(

else:
for op, values in diff_iter:
color = COLOR_SCHEME[op]
color = COLOR_SCHEME.get(op, "grey62")

if json_output:
jsonl = json.dumps([op, list(values)])
Expand Down
19 changes: 17 additions & 2 deletions data_diff/abcs/database_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import decimal
from abc import ABC, abstractmethod
from typing import Tuple, Union
from typing import List, Optional, Tuple, Type, TypeVar, Union
from datetime import datetime

import attrs
Expand All @@ -12,9 +12,24 @@
DbKey = Union[int, str, bytes, ArithUUID, ArithAlphanumeric]
DbTime = datetime

N = TypeVar("N")

@attrs.define(frozen=True)

@attrs.define(frozen=True, kw_only=True)
class ColType:
# Arbitrary metadata added and fetched at runtime.
_notes: List[N] = attrs.field(factory=list, init=False, hash=False, eq=False)

def add_note(self, note: N) -> None:
self._notes.append(note)

def get_note(self, cls: Type[N]) -> Optional[N]:
"""Get the latest added note of type ``cls`` or its descendants."""
for note in reversed(self._notes):
if isinstance(note, cls):
return note
return None

@property
def supported(self) -> bool:
return True
Expand Down
9 changes: 7 additions & 2 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def apply_queries(self, callback: Callable[[str], Any]):
q: Expr = next(self.gen)
while True:
sql = self.compiler.database.dialect.compile(self.compiler, q)
logger.debug("Running SQL (%s-TL): %s", self.compiler.database.name, sql)
logger.debug("Running SQL (%s-TL):\n%s", self.compiler.database.name, sql)
try:
try:
res = callback(sql) if sql is not SKIP else SKIP
Expand Down Expand Up @@ -267,6 +267,8 @@ def _compile(self, compiler: Compiler, elem) -> str:
return "NULL"
elif isinstance(elem, Compilable):
return self.render_compilable(attrs.evolve(compiler, root=False), elem)
elif isinstance(elem, ColType):
return self.render_coltype(attrs.evolve(compiler, root=False), elem)
elif isinstance(elem, str):
return f"'{elem}'"
elif isinstance(elem, (int, float)):
Expand Down Expand Up @@ -359,6 +361,9 @@ def render_compilable(self, c: Compiler, elem: Compilable) -> str:
raise RuntimeError(f"Cannot render AST of type {elem.__class__}")
# return elem.compile(compiler.replace(root=False))

def render_coltype(self, c: Compiler, elem: ColType) -> str:
return self.type_repr(elem)

def render_column(self, c: Compiler, elem: Column) -> str:
if c._table_context:
if len(c._table_context) > 1:
Expand Down Expand Up @@ -876,7 +881,7 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = None):
if sql_code is SKIP:
return SKIP

logger.debug("Running SQL (%s): %s", self.name, sql_code)
logger.debug("Running SQL (%s):\n%s", self.name, sql_code)

if self._interactive and isinstance(sql_ast, Select):
explained_sql = self.compile(Explain(sql_ast))
Expand Down
3 changes: 2 additions & 1 deletion data_diff/databases/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from data_diff.databases.base import Mixin_Schema
from data_diff.abcs.database_types import (
JSON,
NumericType,
Timestamp,
TimestampTZ,
DbPath,
Expand Down Expand Up @@ -50,7 +51,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:

return formatted_value

def normalize_number(self, value: str, coltype: FractionalType) -> str:
def normalize_number(self, value: str, coltype: NumericType) -> str:
if coltype.precision == 0:
return f"CAST(FLOOR({value}) AS VARCHAR)"

Expand Down
26 changes: 13 additions & 13 deletions data_diff/hashdiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
from collections import defaultdict
from typing import Iterator
from operator import attrgetter

import attrs

Expand Down Expand Up @@ -71,7 +70,8 @@ class HashDiffer(TableDiffer):
"""

bisection_factor: int = DEFAULT_BISECTION_FACTOR
bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests
bisection_threshold: int = DEFAULT_BISECTION_THRESHOLD
bisection_disabled: bool = False # i.e. always download the rows (used in tests)

stats: dict = attrs.field(factory=dict)

Expand All @@ -82,7 +82,7 @@ def __attrs_post_init__(self):
if self.bisection_factor < 2:
raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)")

def _validate_and_adjust_columns(self, table1, table2):
def _validate_and_adjust_columns(self, table1, table2, *, strict: bool = True):
for c1, c2 in safezip(table1.relevant_columns, table2.relevant_columns):
if c1 not in table1._schema:
raise ValueError(f"Column '{c1}' not found in schema for table {table1}")
Expand All @@ -92,23 +92,23 @@ def _validate_and_adjust_columns(self, table1, table2):
# Update schemas to minimal mutual precision
col1 = table1._schema[c1]
col2 = table2._schema[c2]
if isinstance(col1, PrecisionType):
if not isinstance(col2, PrecisionType):
if isinstance(col1, PrecisionType) and isinstance(col2, PrecisionType):
if strict and not isinstance(col2, PrecisionType):
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")

lowest = min(col1, col2, key=attrgetter("precision"))
lowest = min(col1, col2, key=lambda col: col.precision)

if col1.precision != col2.precision:
logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}")

table1._schema[c1] = attrs.evolve(col1, precision=lowest.precision, rounds=lowest.rounds)
table2._schema[c2] = attrs.evolve(col2, precision=lowest.precision, rounds=lowest.rounds)

elif isinstance(col1, (NumericType, Boolean)):
if not isinstance(col2, (NumericType, Boolean)):
elif isinstance(col1, (NumericType, Boolean)) and isinstance(col2, (NumericType, Boolean)):
if strict and not isinstance(col2, (NumericType, Boolean)):
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")

lowest = min(col1, col2, key=attrgetter("precision"))
lowest = min(col1, col2, key=lambda col: col.precision)

if col1.precision != col2.precision:
logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}")
Expand All @@ -119,11 +119,11 @@ def _validate_and_adjust_columns(self, table1, table2):
table2._schema[c2] = attrs.evolve(col2, precision=lowest.precision)

elif isinstance(col1, ColType_UUID):
if not isinstance(col2, ColType_UUID):
if strict and not isinstance(col2, ColType_UUID):
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")

elif isinstance(col1, StringType):
if not isinstance(col2, StringType):
if strict and not isinstance(col2, StringType):
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")

for t in [table1, table2]:
Expand Down Expand Up @@ -157,7 +157,7 @@ def _diff_segments(
# default, data-diff will checksum the section first (when it's below
# the threshold) and _then_ download it.
if BENCHMARK:
if max_rows < self.bisection_threshold:
if self.bisection_disabled or max_rows < self.bisection_threshold:
return self._bisect_and_diff_segments(ti, table1, table2, info_tree, level=level, max_rows=max_rows)

(count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2])
Expand Down Expand Up @@ -202,7 +202,7 @@ def _bisect_and_diff_segments(

# If count is below the threshold, just download and compare the columns locally
# This saves time, as bisection speed is limited by ping and query performance.
if max_rows < self.bisection_threshold or max_space_size < self.bisection_factor * 2:
if self.bisection_disabled or max_rows < self.bisection_threshold or max_space_size < self.bisection_factor * 2:
rows1, rows2 = self._threaded_call("get_values", [table1, table2])
json_cols = {
i: colname
Expand Down
2 changes: 1 addition & 1 deletion data_diff/queries/ast_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ class CaseWhen(ExprNode):
def type(self):
then_types = {_expr_type(case.then) for case in self.cases}
if self.else_expr:
then_types |= _expr_type(self.else_expr)
then_types |= {_expr_type(self.else_expr)}
if len(then_types) > 1:
raise QB_TypeError(f"Non-matching types in when: {then_types}")
(t,) = then_types
Expand Down
9 changes: 6 additions & 3 deletions tests/test_database_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import sys
import unittest
import time
import json
import re
import math
import uuid
from datetime import datetime, timedelta, timezone
import logging
Expand Down Expand Up @@ -765,10 +765,13 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
# reasonable amount of rows each. These will then be downloaded in
# parallel, using the existing implementation.
dl_factor = max(int(N_SAMPLES / 100_000), 2) if BENCHMARK else 2
dl_threshold = int(N_SAMPLES / dl_factor) + 1 if BENCHMARK else math.inf
dl_threshold = int(N_SAMPLES / dl_factor) + 1 if BENCHMARK else sys.maxsize
dl_threads = N_THREADS
differ = HashDiffer(
bisection_threshold=dl_threshold, bisection_factor=dl_factor, max_threadpool_size=dl_threads
bisection_factor=dl_factor,
bisection_threshold=dl_threshold,
bisection_disabled=True,
max_threadpool_size=dl_threads,
)
start = time.monotonic()
diff = list(differ.diff_tables(self.table, self.table2))
Expand Down