From 75ed605bb2125dec97f04aae0e68718e07c0d7d2 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 31 Aug 2022 17:29:37 +0200 Subject: [PATCH 01/33] Query builder package (still incomplete) --- data_diff/databases/base.py | 3 + data_diff/databases/database_types.py | 26 +- data_diff/queries/__init__.py | 4 + data_diff/queries/api.py | 58 +++ data_diff/queries/ast_classes.py | 493 ++++++++++++++++++++++++++ data_diff/queries/base.py | 18 + data_diff/queries/compiler.py | 60 ++++ data_diff/queries/extras.py | 61 ++++ data_diff/sql.py | 2 - tests/test_query.py | 130 +++++++ tests/test_sql.py | 1 - 11 files changed, 843 insertions(+), 13 deletions(-) create mode 100644 data_diff/queries/__init__.py create mode 100644 data_diff/queries/api.py create mode 100644 data_diff/queries/ast_classes.py create mode 100644 data_diff/queries/base.py create mode 100644 data_diff/queries/compiler.py create mode 100644 data_diff/queries/extras.py create mode 100644 tests/test_query.py diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index b114937a..fd6ec2c0 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -287,6 +287,9 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: return f"TRIM({value})" return self.to_string(value) + def random(self) -> str: + return "RANDOM()" + class ThreadedDatabase(Database): """Access the database through singleton threads. diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index e93e380e..ca2734fc 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -140,7 +140,7 @@ class UnknownColType(ColType): supported = False -class AbstractDatabase(ABC): +class AbstractDialect(ABC): name: str @abstractmethod @@ -148,11 +148,6 @@ def quote(self, s: str): "Quote SQL name (implementation specific)" ... - @abstractmethod - def to_string(self, s: str) -> str: - "Provide SQL for casting a column to string" - ... - @abstractmethod def concat(self, l: List[str]) -> str: "Provide SQL for concatenating a bunch of column into a string" @@ -163,6 +158,21 @@ def is_distinct_from(self, a: str, b: str) -> str: "Provide SQL for a comparison where NULL = NULL is true" ... + @abstractmethod + def to_string(self, s: str) -> str: + "Provide SQL for casting a column to string" + ... + + @abstractmethod + def random(self) -> str: + "Provide SQL for generating a random number" + + @abstractmethod + def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): + "Provide SQL fragment for limit and offset inside a select" + ... + +class AbstractDatabase(AbstractDialect): @abstractmethod def timestamp_value(self, t: DbTime) -> str: "Provide SQL for the given timestamp value" @@ -173,10 +183,6 @@ def md5_to_int(self, s: str) -> str: "Provide SQL for computing md5 and returning an int" ... - @abstractmethod - def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): - "Provide SQL fragment for limit and offset inside a select" - ... @abstractmethod def _query(self, sql_code: str) -> list: diff --git a/data_diff/queries/__init__.py b/data_diff/queries/__init__.py new file mode 100644 index 00000000..93299b26 --- /dev/null +++ b/data_diff/queries/__init__.py @@ -0,0 +1,4 @@ +from .compiler import Compiler +from .api import this, join, outerjoin, table, SKIP, sum_, avg, min_, max_, cte +from .ast_classes import Expr, ExprNode, Select, Count, BinOp, Explain, In +from .extras import Checksum, NormalizeAsString, ApplyFuncAndNormalizeAsString diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py new file mode 100644 index 00000000..7c617af4 --- /dev/null +++ b/data_diff/queries/api.py @@ -0,0 +1,58 @@ +from typing import Optional +from .ast_classes import * +from .base import args_as_tuple + + +this = This() + + +def join(*tables: ITable): + "Joins each table into a 'struct'" + return Join(tables) + + +def outerjoin(*tables: ITable): + "Outerjoins each table into a 'struct'" + return Join(tables, "FULL OUTER") + + +def cte(expr: Expr, *, name: Optional[str] = None, params: Sequence[str] = None): + return Cte(expr, name, params) + + +def table(*path: str, schema: Schema = None) -> ITable: + assert all(isinstance(i, str) for i in path), path + return TablePath(path, schema) + + +def or_(*exprs: Expr): + exprs = args_as_tuple(exprs) + if len(exprs) == 1: + return exprs[0] + return BinOp("OR", exprs) + +def and_(*exprs: Expr): + exprs = args_as_tuple(exprs) + if len(exprs) == 1: + return exprs[0] + return BinOp("AND", exprs) + + +def sum_(expr: Expr): + return Func("sum", [expr]) + + +def avg(expr: Expr): + return Func("avg", [expr]) + + +def min_(expr: Expr): + return Func("min", [expr]) + + +def max_(expr: Expr): + return Func("max", [expr]) + + +def if_(cond: Expr, then: Expr, else_: Optional[Expr] = None): + return CaseWhen([(cond, then)], else_=else_) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py new file mode 100644 index 00000000..a5c1008c --- /dev/null +++ b/data_diff/queries/ast_classes.py @@ -0,0 +1,493 @@ +from datetime import datetime +from typing import Any, Generator, Sequence, Tuple, Union + +from runtype import dataclass + +from data_diff.utils import ArithString, join_iter + +from .compiler import Compilable, Compiler +from .base import SKIP, CompileError, DbPath, Schema, args_as_tuple + + +class ExprNode(Compilable): + type: Any = None + + def _dfs_values(self): + yield self + for k, vs in dict(self).items(): # __dict__ provided by runtype.dataclass + if k == "source_table": + # Skip data-sources, we're only interested in data-parameters + continue + if not isinstance(vs, (list, tuple)): + vs = [vs] + for v in vs: + if isinstance(v, ExprNode): + yield from v._dfs_values() + + def cast_to(self, to): + return Cast(self, to) + + +Expr = Union[ExprNode, str, bool, int, datetime, ArithString, None] + + +@dataclass +class Alias(ExprNode): + expr: Expr + name: str + + def compile(self, c: Compiler) -> str: + return f"{c.compile(self.expr)} AS {c.quote(self.name)}" + + +def _drop_skips(exprs): + return [e for e in exprs if e is not SKIP] + + +def _drop_skips_dict(exprs_dict): + return {k: v for k, v in exprs_dict.items() if v is not SKIP} + + +class ITable: + source_table: Any + schema: Schema = None + + def select(self, *exprs, **named_exprs): + exprs = args_as_tuple(exprs) + exprs = _drop_skips(exprs) + named_exprs = _drop_skips_dict(named_exprs) + exprs += _named_exprs_as_aliases(named_exprs) + resolve_names(self.source_table, exprs) + return Select.make(self, columns=exprs) + + def where(self, *exprs): + exprs = args_as_tuple(exprs) + exprs = _drop_skips(exprs) + if not exprs: + return self + + resolve_names(self.source_table, exprs) + return Select.make(self, where_exprs=exprs, _concat=True) + + def at(self, *exprs): + # TODO + exprs = _drop_skips(exprs) + if not exprs: + return self + + raise NotImplementedError() + + def join(self, target): + return Join(self, target) + + def group_by(self, *, keys=None, values=None): + # TODO + assert keys or values + raise NotImplementedError() + + def with_schema(self): + # TODO + raise NotImplementedError() + + def _get_column(self, name: str): + if self.schema: + name = self.schema.get_key(name) # Get the actual name. Might be case-insensitive. + return Column(self, name) + + # def __getattr__(self, column): + # return self._get_column(column) + + def __getitem__(self, column): + if not isinstance(column, str): + raise TypeError() + return self._get_column(column) + + def count(self): + return Select(self, [Count()]) + + +@dataclass +class Concat(ExprNode): + args: list + sep: str = None + + def compile(self, c: Compiler) -> str: + # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL + items = [f"coalesce({c.compile(c.database.to_string(expr))}, '')" for expr in self.exprs] + assert items + if len(items) == 1: + return items[0] + + if self.sep: + items = list(join_iter(f"'{self.sep}'", items)) + return c.database.concat(items) + +@dataclass +class Count(ExprNode): + expr: Expr = '*' + distinct: bool = False + + def compile(self, c: Compiler) -> str: + expr = c.compile(self.expr) + if self.distinct: + return f"count(distinct {expr})" + + return f"count({expr})" + + +@dataclass +class Func(ExprNode): + name: str + args: Sequence[Expr] + + def compile(self, c: Compiler) -> str: + args = ", ".join(c.compile(e) for e in self.args) + return f"{self.name}({args})" + + +@dataclass +class CaseWhen(ExprNode): + cases: Sequence[Tuple[Expr, Expr]] + else_: Expr = None + + def compile(self, c: Compiler) -> str: + assert self.cases + when_thens = " ".join(f"WHEN {c.compile(when)} THEN {c.compile(then)}" for when, then in self.cases) + else_ = (" " + c.compile(self.else_)) if self.else_ else "" + return f"CASE {when_thens}{else_} END" + + +class LazyOps: + def __add__(self, other): + return BinOp("+", [self, other]) + + def __gt__(self, other): + return BinOp(">", [self, other]) + + def __ge__(self, other): + return BinOp(">=", [self, other]) + + def __eq__(self, other): + if other is None: + return BinOp("IS", [self, None]) + return BinOp("=", [self, other]) + + def __lt__(self, other): + return BinOp("<", [self, other]) + + def __le__(self, other): + return BinOp("<=", [self, other]) + + def __or__(self, other): + return BinOp("OR", [self, other]) + + def is_distinct_from(self, other): + return IsDistinctFrom(self, other) + + def sum(self): + return Func("SUM", [self]) + + +@dataclass(eq=False, order=False) +class IsDistinctFrom(ExprNode, LazyOps): + a: Expr + b: Expr + + def compile(self, c: Compiler) -> str: + return c.database.is_distinct_from(c.compile(self.a), c.compile(self.b)) + + +@dataclass(eq=False, order=False) +class BinOp(ExprNode, LazyOps): + op: str + args: Sequence[Expr] + + def __post_init__(self): + assert len(self.args) == 2, self.args + + def compile(self, c: Compiler) -> str: + a, b = self.args + return f"({c.compile(a)} {self.op} {c.compile(b)})" + + +@dataclass(eq=False, order=False) +class Column(ExprNode, LazyOps): + source_table: ITable + name: str + + @property + def type(self): + if self.source_table.schema is None: + raise RuntimeError(f"Schema required for table {self.source_table}") + return self.source_table.schema[self.name] + + def compile(self, c: Compiler) -> str: + if c._table_context: + if len(c._table_context) > 1: + aliases = [ + t for t in c._table_context if isinstance(t, TableAlias) and t.source_table is self.source_table + ] + if not aliases: + raise CompileError(f"No aliased table found for column {self.name}") # TODO better error + elif len(aliases) > 1: + raise CompileError(f"Too many aliases for column {self.name}") + (alias,) = aliases + + return f"{c.quote(alias.name)}.{c.quote(self.name)}" + + return c.quote(self.name) + + +@dataclass +class TablePath(ExprNode, ITable): + path: DbPath + schema: Schema = None + + def insert_values(self, rows): + pass + + def insert_query(self, query): + pass + + @property + def source_table(self): + return self + + def compile(self, c: Compiler) -> str: + path = self.path # c.database._normalize_table_path(self.name) + return ".".join(map(c.quote, path)) + + +@dataclass +class TableAlias(ExprNode, ITable): + source_table: ITable + name: str + + def compile(self, c: Compiler) -> str: + return f"{c.compile(self.source_table)} {c.quote(self.name)}" + + +@dataclass +class Join(ExprNode, ITable): + source_tables: Sequence[ITable] + op: str = None + on_exprs: Sequence[Expr] = None + columns: Sequence[Expr] = None + + @property + def source_table(self): + return self # TODO is this right? + + @property + def schema(self): + # TODO combine both tables + return None + + def on(self, *exprs): + if len(exprs) == 1: + (e,) = exprs + if isinstance(e, Generator): + exprs = tuple(e) + + exprs = _drop_skips(exprs) + if not exprs: + return self + + return self.replace(on_exprs=(self.on_exprs or []) + exprs) + + def select(self, *exprs, **named_exprs): + if self.columns is not None: + # join-select already applied + return super().select(*exprs, **named_exprs) + + exprs = _drop_skips(exprs) + named_exprs = _drop_skips_dict(named_exprs) + exprs += _named_exprs_as_aliases(named_exprs) + # resolve_names(self.source_table, exprs) + # TODO Ensure exprs <= self.columns ? + return self.replace(columns=exprs) + + def compile(self, parent_c: Compiler) -> str: + tables = [ + t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in self.source_tables + ] + c = parent_c.add_table_context(*tables) + op = " JOIN " if self.op is None else f" {self.op} JOIN " + joined = op.join(c.compile(t) for t in tables) + + if self.on_exprs: + on = " AND ".join(c.compile(e) for e in self.on_exprs) + res = f"{joined} ON {on}" + else: + res = joined + + columns = "*" if self.columns is None else ", ".join(map(c.compile, self.columns)) + select = f"SELECT {columns} FROM {res}" + + if parent_c.in_select: + select = f"({select}) {c.new_unique_name()}" + return select + + +class GroupBy(ITable): + def having(self): + pass + + +@dataclass +class Select(ExprNode, ITable): + table: Expr = None + columns: Sequence[Expr] = None + where_exprs: Sequence[Expr] = None + order_by_exprs: Sequence[Expr] = None + group_by_exprs: Sequence[Expr] = None + limit_expr: int = None + + @property + def source_table(self): + return self + + @property + def schema(self): + return self.table.schema + + def compile(self, parent_c: Compiler) -> str: + c = parent_c.replace(in_select=True).add_table_context(self.table) + + columns = ", ".join(map(c.compile, self.columns)) if self.columns else "*" + select = f"SELECT {columns}" + + if self.table: + select += " FROM " + c.compile(self.table) + + if self.where_exprs: + select += " WHERE " + " AND ".join(map(c.compile, self.where_exprs)) + + if self.group_by_exprs: + select += " GROUP BY " + ", ".join(map(c.compile, self.group_by_exprs)) + + if self.order_by_exprs: + select += " ORDER BY " + ", ".join(map(c.compile, self.order_by_exprs)) + + if self.limit_expr is not None: + select += " " + c.database.offset_limit(0, self.limit_expr) + + if parent_c.in_select: + select = f"({select})" + return select + + @classmethod + def make(cls, table: ITable, _concat: bool = False, **kwargs): + if not isinstance(table, cls): + return cls(table, **kwargs) + + # Fill in missing attributes, instead of creating a new instance. + for k, v in kwargs.items(): + if getattr(table, k) is not None: + if _concat: + kwargs[k] = getattr(table, k) + v + else: + raise ValueError("...") + + return table.replace(**kwargs) + + +@dataclass +class Cte(ExprNode, ITable): + source_table: Expr + name: str = None + params: Sequence[str] = None + + def compile(self, parent_c: Compiler) -> str: + c = parent_c.replace(_table_context=[], in_select=False) + compiled = c.compile(self.source_table) + + name = self.name or parent_c.new_unique_name() + name_params = f"{name}({', '.join(self.params)})" if self.params else name + parent_c._subqueries[name_params] = compiled + + return name + + @property + def schema(self): + # TODO add cte to schema + return self.source_table.schema + + +def _named_exprs_as_aliases(named_exprs): + return [Alias(expr, name) for name, expr in named_exprs.items()] + + +def resolve_names(source_table, exprs): + i = 0 + for expr in exprs: + # Iterate recursively and update _ResolveColumn with the right expression + if isinstance(expr, ExprNode): + for v in expr._dfs_values(): + if isinstance(v, _ResolveColumn): + v.resolve(source_table._get_column(v.name)) + i += 1 + + +@dataclass(frozen=False, eq=False, order=False) +class _ResolveColumn(ExprNode, LazyOps): + name: str + resolved: Expr = None + + def resolve(self, expr): + assert self.resolved is None + self.resolved = expr + + def compile(self, c: Compiler) -> str: + if self.resolved is None: + raise RuntimeError(f"Column not resolved: {self.name}") + return self.resolved.compile(c) + + @property + def type(self): + if self.resolved is None: + raise RuntimeError(f"Column not resolved: {self.name}") + return self.resolved.type + + +class This: + def __getattr__(self, name): + return _ResolveColumn(name) + + def __getitem__(self, name): + if isinstance(name, list): + return [_ResolveColumn(n) for n in name] + return _ResolveColumn(name) + + +@dataclass +class Explain(ExprNode): + sql: Select + + def compile(self, c: Compiler) -> str: + return f"EXPLAIN {c.compile(self.sql)}" + + +@dataclass +class In(ExprNode): + expr: Expr + list: Sequence[Expr] + + def compile(self, c: Compiler): + elems = ", ".join(map(c.compile, self.list)) + return f"({c.compile(self.expr)} IN ({elems}))" + + +@dataclass +class Cast(ExprNode): + expr: Expr + target_type: Expr + + def compile(self, c: Compiler) -> str: + return f"cast({c.compile(self.expr)} as {c.compile(self.target_type)})" + + +@dataclass +class Random(ExprNode): + def compile(self, c: Compiler) -> str: + return c.database.random() diff --git a/data_diff/queries/base.py b/data_diff/queries/base.py new file mode 100644 index 00000000..50a57e2f --- /dev/null +++ b/data_diff/queries/base.py @@ -0,0 +1,18 @@ +from typing import Generator + +from data_diff.databases.database_types import DbPath, DbKey, Schema + + +SKIP = object() + + +class CompileError(Exception): + pass + + +def args_as_tuple(exprs): + if len(exprs) == 1: + (e,) = exprs + if isinstance(e, Generator): + return tuple(e) + return exprs diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py new file mode 100644 index 00000000..2a37d09f --- /dev/null +++ b/data_diff/queries/compiler.py @@ -0,0 +1,60 @@ +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, Sequence, List + +from runtype import dataclass + +from data_diff.databases.database_types import AbstractDialect + + +@dataclass +class Compiler: + database: AbstractDialect + in_select: bool = False # Compilation + + _table_context: List = [] # List[ITable] + _subqueries: Dict[str, Any] = {} # XXX not thread-safe + root: bool = True + + _counter: List = [0] + + def quote(self, s: str): + return self.database.quote(s) + + def compile(self, elem) -> str: + res = self._compile(elem) + if self.root and self._subqueries: + subq = ", ".join(f"\n {k} AS ({v})" for k, v in self._subqueries.items()) + self._subqueries.clear() + return f"WITH {subq}\n{res}" + return res + + def _compile(self, elem) -> str: + if elem is None: + return "NULL" + elif isinstance(elem, Compilable): + return elem.compile(self.replace(root=False)) + elif isinstance(elem, str): + return elem + elif isinstance(elem, int): + return str(elem) + elif isinstance(elem, datetime): + return self.database.timestamp_value(elem) + elif isinstance(elem, bytes): + return f"b'{elem.decode()}'" + elif isinstance(elem, ArithString): + return f"'{elem}'" + assert False, elem + + def new_unique_name(self, prefix="tmp"): + self._counter[0] += 1 + return f"{prefix}{self._counter[0]}" + + def add_table_context(self, *tables: Sequence): + return self.replace(_table_context=self._table_context + list(tables)) + + +class Compilable(ABC): + @abstractmethod + def compile(self, c: Compiler) -> str: + ... diff --git a/data_diff/queries/extras.py b/data_diff/queries/extras.py new file mode 100644 index 00000000..9b5189e1 --- /dev/null +++ b/data_diff/queries/extras.py @@ -0,0 +1,61 @@ +"Useful AST classes that don't quite fall within the scope of regular SQL" + +from typing import Callable, Sequence +from runtype import dataclass + +from data_diff.databases.database_types import ColType, Native_UUID + +from .compiler import Compiler +from .ast_classes import Expr, ExprNode, Concat + + +@dataclass +class NormalizeAsString(ExprNode): + expr: ExprNode + type: ColType = None + + def compile(self, c: Compiler) -> str: + expr = c.compile(self.expr) + return c.database.normalize_value_by_type(expr, self.type or self.expr.type) + + +@dataclass +class ApplyFuncAndNormalizeAsString(ExprNode): + expr: ExprNode + apply_func: Callable = None + + def compile(self, c: Compiler) -> str: + expr = self.expr + expr_type = expr.type + + if isinstance(expr_type, Native_UUID): + # Normalize first, apply template after (for uuids) + # Needed because min/max(uuid) fails in postgresql + expr = NormalizeAsString(expr, expr_type) + if self.apply_func is not None: + expr = self.apply_func(expr) # Apply template using Python's string formatting + + else: + # Apply template before normalizing (for ints) + if self.apply_func is not None: + expr = self.apply_func(expr) # Apply template using Python's string formatting + expr = NormalizeAsString(expr, expr_type) + + return c.compile(expr) + + +@dataclass +class Checksum(ExprNode): + exprs: Sequence[Expr] + + def compile(self, c: Compiler): + if len(self.exprs) > 1: + exprs = [f"coalesce({c.compile(expr)}, '')" for expr in self.exprs] + # exprs = [c.compile(e) for e in exprs] + expr = Concat(exprs, "|") + else: + # No need to coalesce - safe to assume that key cannot be null + (expr,) = self.exprs + expr = c.compile(expr) + md5 = c.database.md5_to_int(expr) + return f"sum({md5})" diff --git a/data_diff/sql.py b/data_diff/sql.py index 46332797..6240ca9b 100644 --- a/data_diff/sql.py +++ b/data_diff/sql.py @@ -17,8 +17,6 @@ class Sql: SqlOrStr = Union[Sql, str] -CONCAT_SEP = "|" - @dataclass class Compiler: diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 00000000..b6b90394 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,130 @@ +from cmath import exp +from typing import List, Optional +import unittest +from data_diff.databases.database_types import AbstractDialect, CaseInsensitiveDict, CaseSensitiveDict + +from data_diff.queries import this, table, Compiler, outerjoin, cte +from data_diff.queries.ast_classes import Random + + +def normalize_spaces(s: str): + return " ".join(s.split()) + + +class MockDialect(AbstractDialect): + def quote(self, s: str): + return s + + def concat(self, l: List[str]) -> str: + s = ", ".join(l) + return f"concat({s})" + + def to_string(self, s: str) -> str: + return f"cast({s} as varchar)" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"{a} is distinct from {b}" + + def random(self) -> str: + return "random()" + + def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): + x = offset and f"offset {offset}", limit and f"limit {limit}" + return " ".join(filter(None, x)) + + +class TestQuery(unittest.TestCase): + def setUp(self): + pass + + def test_basic(self): + c = Compiler(MockDialect()) + + t = table("point") + t2 = t.select(x=this.x + 1, y=t["y"] + this.x) + assert c.compile(t2) == "SELECT (x + 1) AS x, (y + x) AS y FROM point" + + t = table("point").where(this.x == 1, this.y == 2) + assert c.compile(t) == "SELECT * FROM point WHERE (x = 1) AND (y = 2)" + + t = table("point").select("x", "y") + assert c.compile(t) == "SELECT x, y FROM point" + + def test_outerjoin(self): + c = Compiler(MockDialect()) + + a = table("a") + b = table("b") + keys = ["x", "y"] + cols = ["u", "v"] + + j = outerjoin(a, b).on(a[k] == b[k] for k in keys) + + self.assertEqual( + c.compile(j), "SELECT * FROM a tmp1 FULL OUTER JOIN b tmp2 ON (tmp1.x = tmp2.x) AND (tmp1.y = tmp2.y)" + ) + + # diffed = j.select("*", **{f"is_diff_col_{c}": a[c].is_distinct_from(b[c]) for c in cols}) + + # t = diffed.select( + # **{f"total_diff_col_{c}": diffed[f"is_diff_col_{c}"].sum() for c in cols}, + # total_diff=or_(diffed[f"is_diff_col_{c}"] for c in cols).sum(), + # ) + + # print(c.compile(t)) + + # t.group_by(keys=[this.x], values=[this.py]) + + def test_schema(self): + c = Compiler(MockDialect()) + schema = dict(id="int", comment="varchar") + + t = table("a", schema=CaseInsensitiveDict(schema)) + q = t.select(this.Id, t["COMMENT"]) + assert c.compile(q) == "SELECT id, comment FROM a" + + t = table("a", schema=CaseSensitiveDict(schema)) + self.assertRaises(KeyError, t.__getitem__, "Id") + self.assertRaises(KeyError, t.select, this.Id) + + def test_commutable_select(self): + # c = Compiler(MockDialect()) + + t = table("a") + q1 = t.select("a").where("b") + q2 = t.where("b").select("a") + assert q1 == q2, (q1, q2) + + def test_cte(self): + c = Compiler(MockDialect()) + + t = table("a") + + # single cte + t2 = cte(t.select(this.x)) + t3 = t2.select(this.x) + + expected = "WITH tmp1 AS (SELECT x FROM a) SELECT x FROM tmp1" + assert normalize_spaces(c.compile(t3)) == expected + + # nested cte + c = Compiler(MockDialect()) + t4 = cte(t3).select(this.x) + + expected = "WITH tmp1 AS (SELECT x FROM a), tmp2 AS (SELECT x FROM tmp1) SELECT x FROM tmp2" + assert normalize_spaces(c.compile(t4)) == expected + + # parameterized cte + c = Compiler(MockDialect()) + t2 = cte(t.select(this.x), params=["y"]) + t3 = t2.select(this.y) + + expected = "WITH tmp1(y) AS (SELECT x FROM a) SELECT y FROM tmp1" + assert normalize_spaces(c.compile(t3)) == expected + + def test_funcs(self): + c = Compiler(MockDialect()) + t = table("a") + + q = c.compile(t.order_by(Random()).limit(10)) + assert q == "SELECT * FROM a ORDER BY random() limit 10" diff --git a/tests/test_sql.py b/tests/test_sql.py index bc4828c0..67c5637d 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -5,7 +5,6 @@ from .common import TEST_MYSQL_CONN_STRING - class TestSQL(unittest.TestCase): def setUp(self): self.mysql = connect_to_uri(TEST_MYSQL_CONN_STRING) From 70c595210ab14346eda09f6c2627639b3678bb9b Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 2 Sep 2022 17:42:38 +0200 Subject: [PATCH 02/33] data-diff now uses new 'data_diff.queries' modules instead of 'data_diff.sql' --- data_diff/databases/base.py | 8 +- data_diff/queries/api.py | 1 + data_diff/queries/ast_classes.py | 19 ++- data_diff/queries/compiler.py | 1 + data_diff/sql.py | 196 ------------------------------- data_diff/table_segment.py | 95 ++++----------- tests/test_sql.py | 45 ++++--- 7 files changed, 71 insertions(+), 294 deletions(-) delete mode 100644 data_diff/sql.py diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index fd6ec2c0..181a80e5 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -24,8 +24,10 @@ UnknownColType, Text, DbTime, + DbPath, ) -from data_diff.sql import DbPath, SqlOrStr, Compiler, Explain, Select, TableName + +from data_diff.queries import Expr, Compiler, table, Select, SKIP logger = logging.getLogger("database") @@ -87,7 +89,7 @@ class Database(AbstractDatabase): def name(self): return type(self).__name__ - def query(self, sql_ast: SqlOrStr, res_type: type): + def query(self, sql_ast: Expr, res_type: type): "Query the given SQL code/AST, and attempt to convert the result to type 'res_type'" compiler = Compiler(self) @@ -213,7 +215,7 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe return fields = [self.normalize_uuid(c, String_UUID()) for c in text_columns] - samples_by_row = self.query(Select(fields, TableName(table_path), limit=16, where=where and [where]), list) + samples_by_row = self.query(table(*table_path).select(*fields).where(where or SKIP).limit(16), list) if not samples_by_row: raise ValueError(f"Table {table_path} is empty.") diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index 7c617af4..76aaf5d2 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -31,6 +31,7 @@ def or_(*exprs: Expr): return exprs[0] return BinOp("OR", exprs) + def and_(*exprs: Expr): exprs = args_as_tuple(exprs) if len(exprs) == 1: diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index a5c1008c..019227f7 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -69,6 +69,20 @@ def where(self, *exprs): resolve_names(self.source_table, exprs) return Select.make(self, where_exprs=exprs, _concat=True) + def order_by(self, *exprs): + exprs = _drop_skips(exprs) + if not exprs: + return self + + resolve_names(self.source_table, exprs) + return Select.make(self, order_by_exprs=exprs) + + def limit(self, limit: int): + if limit is SKIP: + return self + + return Select.make(self, limit_expr=limit) + def at(self, *exprs): # TODO exprs = _drop_skips(exprs) @@ -108,7 +122,7 @@ def count(self): @dataclass class Concat(ExprNode): - args: list + exprs: list sep: str = None def compile(self, c: Compiler) -> str: @@ -122,9 +136,10 @@ def compile(self, c: Compiler) -> str: items = list(join_iter(f"'{self.sep}'", items)) return c.database.concat(items) + @dataclass class Count(ExprNode): - expr: Expr = '*' + expr: Expr = "*" distinct: bool = False def compile(self, c: Compiler) -> str: diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index 2a37d09f..8ea0e7a5 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -4,6 +4,7 @@ from runtype import dataclass +from data_diff.utils import ArithString from data_diff.databases.database_types import AbstractDialect diff --git a/data_diff/sql.py b/data_diff/sql.py deleted file mode 100644 index 6240ca9b..00000000 --- a/data_diff/sql.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Provides classes for a pseudo-SQL AST that compiles to SQL code -""" - -from typing import Sequence, Union, Optional -from datetime import datetime - -from runtype import dataclass - -from .utils import join_iter, ArithString - -from .databases.database_types import AbstractDatabase, DbPath - - -class Sql: - pass - - -SqlOrStr = Union[Sql, str] - - -@dataclass -class Compiler: - """Provides a set of utility methods for compiling SQL - - For internal use. - """ - - database: AbstractDatabase - in_select: bool = False # Compilation - - def quote(self, s: str): - return self.database.quote(s) - - def compile(self, elem): - if isinstance(elem, Sql): - return elem.compile(self) - elif isinstance(elem, str): - return elem - elif isinstance(elem, int): - return str(elem) - assert False - - -@dataclass -class TableName(Sql): - name: DbPath - - def compile(self, c: Compiler): - path = c.database._normalize_table_path(self.name) - return ".".join(map(c.quote, path)) - - -@dataclass -class ColumnName(Sql): - name: str - - def compile(self, c: Compiler): - return c.quote(self.name) - - -@dataclass -class Value(Sql): - value: object # Primitive - - def compile(self, c: Compiler): - if isinstance(self.value, bytes): - return f"b'{self.value.decode()}'" - elif isinstance(self.value, str): - return f"'{self.value}'" % self.value - elif isinstance(self.value, ArithString): - return f"'{self.value}'" - return str(self.value) - - -@dataclass -class Select(Sql): - columns: Sequence[SqlOrStr] - table: SqlOrStr = None - where: Sequence[SqlOrStr] = None - order_by: Sequence[SqlOrStr] = None - group_by: Sequence[SqlOrStr] = None - limit: int = None - - def compile(self, parent_c: Compiler): - c = parent_c.replace(in_select=True) - columns = ", ".join(map(c.compile, self.columns)) - select = f"SELECT {columns}" - - if self.table: - select += " FROM " + c.compile(self.table) - - if self.where: - select += " WHERE " + " AND ".join(map(c.compile, self.where)) - - if self.group_by: - select += " GROUP BY " + ", ".join(map(c.compile, self.group_by)) - - if self.order_by: - select += " ORDER BY " + ", ".join(map(c.compile, self.order_by)) - - if self.limit is not None: - select += " " + c.database.offset_limit(0, self.limit) - - if parent_c.in_select: - select = "(%s)" % select - return select - - -@dataclass -class Enum(Sql): - table: DbPath - order_by: SqlOrStr - - def compile(self, c: Compiler): - table = ".".join(map(c.quote, self.table)) - order = c.compile(self.order_by) - return f"(SELECT *, (row_number() over (ORDER BY {order})) as idx FROM {table} ORDER BY {order}) tmp" - - -@dataclass -class Checksum(Sql): - exprs: Sequence[SqlOrStr] - - def compile(self, c: Compiler): - if len(self.exprs) > 1: - compiled_exprs = [f"coalesce({c.compile(expr)}, '')" for expr in self.exprs] - separated = list(join_iter(f"'|'", compiled_exprs)) - expr = c.database.concat(separated) - else: - # No need to coalesce - safe to assume that key cannot be null - (expr,) = self.exprs - expr = c.compile(expr) - md5 = c.database.md5_to_int(expr) - return f"sum({md5})" - - -@dataclass -class Compare(Sql): - op: str - a: SqlOrStr - b: SqlOrStr - - def compile(self, c: Compiler): - return f"({c.compile(self.a)} {self.op} {c.compile(self.b)})" - - -@dataclass -class In(Sql): - expr: SqlOrStr - list: Sequence # List[SqlOrStr] - - def compile(self, c: Compiler): - elems = ", ".join(map(c.compile, self.list)) - return f"({c.compile(self.expr)} IN ({elems}))" - - -@dataclass -class Count(Sql): - column: Optional[SqlOrStr] = None - - def compile(self, c: Compiler): - if self.column: - return f"count({c.compile(self.column)})" - return "count(*)" - - -@dataclass -class Min(Sql): - column: SqlOrStr - - def compile(self, c: Compiler): - return f"min({c.compile(self.column)})" - - -@dataclass -class Max(Sql): - column: SqlOrStr - - def compile(self, c: Compiler): - return f"max({c.compile(self.column)})" - - -@dataclass -class Time(Sql): - time: datetime - - def compile(self, c: Compiler): - return c.database.timestamp_value(self.time) - - -@dataclass -class Explain(Sql): - sql: Select - - def compile(self, c: Compiler): - return f"EXPLAIN {c.compile(self.sql)}" diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 8b95458f..761b3a74 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -4,11 +4,11 @@ from runtype import dataclass -from .utils import ArithString, split_space, ArithAlphanumeric - +from .utils import ArithString, split_space from .databases.base import Database -from .databases.database_types import DbPath, DbKey, DbTime, Native_UUID, Schema, create_schema -from .sql import Select, Checksum, Compare, Count, TableName, Time, Value +from .databases.database_types import DbPath, DbKey, DbTime, Schema, create_schema +from .queries import Count, Checksum, SKIP, table, this, Expr, min_, max_ +from .queries.extras import ApplyFuncAndNormalizeAsString, NormalizeAsString logger = logging.getLogger("table_segment") @@ -66,38 +66,6 @@ def __post_init__(self): f"Error: min_update expected to be smaller than max_update! ({self.min_update} >= {self.max_update})" ) - @property - def _update_column(self): - return self._quote_column(self.update_column) - - def _quote_column(self, c: str) -> str: - if self._schema: - c = self._schema.get_key(c) # Get the actual name. Might be case-insensitive. - return self.database.quote(c) - - def _normalize_column(self, name: str, template: str = None) -> str: - if not self._schema: - raise RuntimeError( - "Cannot compile query when the schema is unknown. Please use TableSegment.with_schema()." - ) - - col_type = self._schema[name] - col = self._quote_column(name) - - if isinstance(col_type, Native_UUID): - # Normalize first, apply template after (for uuids) - # Needed because min/max(uuid) fails in postgresql - col = self.database.normalize_value_by_type(col, col_type) - if template is not None: - col = template % col # Apply template using Python's string formatting - return col - - # Apply template before normalizing (for ints) - if template is not None: - col = template % col # Apply template using Python's string formatting - - return self.database.normalize_value_by_type(col, col_type) - def _with_raw_schema(self, raw_schema: dict) -> "TableSegment": schema = self.database._process_table_schema(self.table_path, raw_schema, self._relevant_columns, self.where) return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive)) @@ -111,37 +79,26 @@ def with_schema(self) -> "TableSegment": def _make_key_range(self): if self.min_key is not None: - yield Compare("<=", Value(self.min_key), self._quote_column(self.key_column)) + yield self.min_key <= this[self.key_column] if self.max_key is not None: - yield Compare("<", self._quote_column(self.key_column), Value(self.max_key)) + yield this[self.key_column] < self.max_key def _make_update_range(self): if self.min_update is not None: - yield Compare("<=", Time(self.min_update), self._update_column) + yield self.min_update <= this[self.update_column] if self.max_update is not None: - yield Compare("<", self._update_column, Time(self.max_update)) - - def _make_select(self, *, table=None, columns=None, where=None, group_by=None, order_by=None): - if columns is None: - columns = [self._normalize_column(self.key_column)] - where = [ - *self._make_key_range(), - *self._make_update_range(), - *([] if where is None else [where]), - *([] if self.where is None else [self.where]), - ] - order_by = None if order_by is None else [order_by] - return Select( - table=table or TableName(self.table_path), - where=where, - columns=columns, - group_by=group_by, - order_by=order_by, - ) + yield this[self.update_column] < self.max_update + + @property + def source_table(self): + return table(*self.table_path, schema=self._schema) + + def _make_select(self): + return self.source_table.where(*self._make_key_range(), *self._make_update_range(), self.where or SKIP) def get_values(self) -> list: "Download all the relevant values of the segment from the database" - select = self._make_select(columns=self._relevant_columns_repr) + select = self._make_select().select(*self._relevant_columns_repr) return self.database.query(select, List[Tuple]) def choose_checkpoints(self, count: int) -> List[DbKey]: @@ -185,19 +142,18 @@ def _relevant_columns(self) -> List[str]: return [self.key_column] + extras @property - def _relevant_columns_repr(self) -> List[str]: - return [self._normalize_column(c) for c in self._relevant_columns] + def _relevant_columns_repr(self) -> List[Expr]: + return [NormalizeAsString(this[c]) for c in self._relevant_columns] def count(self) -> Tuple[int, int]: """Count how many rows are in the segment, in one pass.""" - return self.database.query(self._make_select(columns=[Count()]), int) + return self.database.query(self._make_select().select(Count()), int) def count_and_checksum(self) -> Tuple[int, int]: """Count and checksum the rows in the segment, in one pass.""" start = time.monotonic() - count, checksum = self.database.query( - self._make_select(columns=[Count(), Checksum(self._relevant_columns_repr)]), tuple - ) + q = self._make_select().select(Count(), Checksum(self._relevant_columns_repr)) + count, checksum = self.database.query(q, tuple) duration = time.monotonic() - start if duration > RECOMMENDED_CHECKSUM_DURATION: logger.warning( @@ -212,11 +168,10 @@ def count_and_checksum(self) -> Tuple[int, int]: def query_key_range(self) -> Tuple[int, int]: """Query database for minimum and maximum key. This is used for setting the initial bounds.""" # Normalizes the result (needed for UUIDs) after the min/max computation - select = self._make_select( - columns=[ - self._normalize_column(self.key_column, "min(%s)"), - self._normalize_column(self.key_column, "max(%s)"), - ] + # TODO better error if there is no schema + select = self._make_select().select( + ApplyFuncAndNormalizeAsString(this[self.key_column], min_), + ApplyFuncAndNormalizeAsString(this[self.key_column], max_), ) min_key, max_key = self.database.query(select, tuple) diff --git a/tests/test_sql.py b/tests/test_sql.py index 67c5637d..fe17940b 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -1,10 +1,11 @@ import unittest from data_diff.databases import connect_to_uri -from data_diff.sql import Checksum, Compare, Compiler, Count, Enum, Explain, In, Select, TableName - from .common import TEST_MYSQL_CONN_STRING +from data_diff.queries import Compiler, Count, Explain, Select, table, In, BinOp + + class TestSQL(unittest.TestCase): def setUp(self): self.mysql = connect_to_uri(TEST_MYSQL_CONN_STRING) @@ -17,7 +18,7 @@ def test_compile_int(self): self.assertEqual("1", self.compiler.compile(1)) def test_compile_table_name(self): - self.assertEqual("`marine_mammals`.`walrus`", self.compiler.compile(TableName(("marine_mammals", "walrus")))) + self.assertEqual("`marine_mammals`.`walrus`", self.compiler.compile(table("marine_mammals", "walrus"))) def test_compile_select(self): expected_sql = "SELECT name FROM `marine_mammals`.`walrus`" @@ -25,23 +26,23 @@ def test_compile_select(self): expected_sql, self.compiler.compile( Select( + table("marine_mammals", "walrus"), ["name"], - TableName(("marine_mammals", "walrus")), ) ), ) - def test_enum(self): - expected_sql = "(SELECT *, (row_number() over (ORDER BY id)) as idx FROM `walrus` ORDER BY id) tmp" - self.assertEqual( - expected_sql, - self.compiler.compile( - Enum( - ("walrus",), - "id", - ) - ), - ) + # def test_enum(self): + # expected_sql = "(SELECT *, (row_number() over (ORDER BY id)) as idx FROM `walrus` ORDER BY id) tmp" + # self.assertEqual( + # expected_sql, + # self.compiler.compile( + # Enum( + # ("walrus",), + # "id", + # ) + # ), + # ) # def test_checksum(self): # expected_sql = "SELECT name, sum(cast(conv(substring(md5(concat(cast(id as char), cast(timestamp as char))), 18), 16, 10) as unsigned)) FROM `marine_mammals`.`walrus`" @@ -61,9 +62,9 @@ def test_compare(self): expected_sql, self.compiler.compile( Select( + table("marine_mammals", "walrus"), ["name"], - TableName(("marine_mammals", "walrus")), - [Compare("<=", "id", "1000"), Compare(">", "id", "1")], + [BinOp("<=", ["id", "1000"]), BinOp(">", ["id", "1"])], ) ), ) @@ -72,23 +73,21 @@ def test_in(self): expected_sql = "SELECT name FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" self.assertEqual( expected_sql, - self.compiler.compile(Select(["name"], TableName(("marine_mammals", "walrus")), [In("id", [1, 2, 3])])), + self.compiler.compile(Select(table("marine_mammals", "walrus"), ["name"], [In("id", [1, 2, 3])])), ) def test_count(self): expected_sql = "SELECT count(*) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" self.assertEqual( expected_sql, - self.compiler.compile(Select([Count()], TableName(("marine_mammals", "walrus")), [In("id", [1, 2, 3])])), + self.compiler.compile(Select(table("marine_mammals", "walrus"), [Count()], [In("id", [1, 2, 3])])), ) def test_count_with_column(self): expected_sql = "SELECT count(id) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" self.assertEqual( expected_sql, - self.compiler.compile( - Select([Count("id")], TableName(("marine_mammals", "walrus")), [In("id", [1, 2, 3])]) - ), + self.compiler.compile(Select(table("marine_mammals", "walrus"), [Count("id")], [In("id", [1, 2, 3])])), ) def test_explain(self): @@ -96,6 +95,6 @@ def test_explain(self): self.assertEqual( expected_sql, self.compiler.compile( - Explain(Select([Count("id")], TableName(("marine_mammals", "walrus")), [In("id", [1, 2, 3])])) + Explain(Select(table("marine_mammals", "walrus"), [Count("id")], [In("id", [1, 2, 3])])) ), ) From e5ace37250ca4e8ab341d49c5954bafb9df91c54 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 8 Sep 2022 14:09:56 +0200 Subject: [PATCH 03/33] Join-diff implementation --- data_diff/__init__.py | 3 - data_diff/__main__.py | 19 +++- data_diff/diff_tables.py | 79 ++++++++----- data_diff/joindiff_tables.py | 207 +++++++++++++++++++++++++++++++++++ tests/test_diff_tables.py | 8 +- tests/test_joindiff.py | 168 ++++++++++++++++++++++++++++ 6 files changed, 441 insertions(+), 43 deletions(-) create mode 100644 data_diff/joindiff_tables.py create mode 100644 tests/test_joindiff.py diff --git a/data_diff/__init__.py b/data_diff/__init__.py index bc5677c8..2af199db 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -55,8 +55,6 @@ def diff_tables( # Maximum size of each threadpool. None = auto. Only relevant when threaded is True. # There may be many pools, so number of actual threads can be a lot higher. max_threadpool_size: Optional[int] = 1, - # Enable/disable debug prints - debug: bool = False, ) -> Iterator: """Efficiently finds the diff between table1 and table2. @@ -86,7 +84,6 @@ def diff_tables( differ = TableDiffer( bisection_factor=bisection_factor, bisection_threshold=bisection_threshold, - debug=debug, threaded=threaded, max_threadpool_size=max_threadpool_size, ) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index bccd132f..a5715397 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -192,12 +192,19 @@ def _main( logging.error(f"Error while parsing age expression: {e}") return - differ = TableDiffer( - bisection_factor=bisection_factor, - bisection_threshold=bisection_threshold, - threaded=threaded, - max_threadpool_size=threads and threads * 2, - ) + if algorithm == Algorithm.JOINDIFF: + differ = JoinDiffer( + threaded=threaded, + max_threadpool_size=threads and threads * 2, + ) + else: + assert algorithm == Algorithm.HASHDIFF + differ = TableDiffer( + bisection_factor=bisection_factor, + bisection_threshold=bisection_threshold, + threaded=threaded, + max_threadpool_size=threads and threads * 2, + ) if database1 is None or database2 is None: logging.error( diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 96c6c624..7f67cfc2 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -1,7 +1,9 @@ """Provides classes for performing a table diff """ +from contextlib import contextmanager import time +import threading import os from numbers import Number from operator import attrgetter, methodcaller @@ -44,7 +46,53 @@ def diff_sets(a: set, b: set) -> Iterator: @dataclass -class TableDiffer: +class ThreadBase: + "Provides utility methods for optional threading" + + threaded: bool = True + max_threadpool_size: Optional[int] = 1 + + def _thread_map(self, func, iterable): + if not self.threaded: + return map(func, iterable) + + with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool: + return task_pool.map(func, iterable) + + def _threaded_call(self, func, iterable): + "Calls a method for each object in iterable." + return list(self._thread_map(methodcaller(func), iterable)) + + def _thread_as_completed(self, func, iterable): + if not self.threaded: + yield from map(func, iterable) + return + + with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool: + futures = [task_pool.submit(func, item) for item in iterable] + for future in as_completed(futures): + yield future.result() + + def _threaded_call_as_completed(self, func, iterable): + "Calls a method for each object in iterable. Returned in order of completion." + return self._thread_as_completed(methodcaller(func), iterable) + + def _run_thread(self, threadfunc, *args, daemon=False) -> threading.Thread: + th = threading.Thread(target=threadfunc, args=args) + if daemon: + th.daemon = True + th.start() + return th + + @contextmanager + def _run_in_background(self, threadfunc, *args, daemon=False): + t = self._run_thread(threadfunc, *args, daemon=daemon) + yield t + t.join() + + +@dataclass +class TableDiffer(ThreadBase): """Finds the diff between two SQL tables The algorithm uses hashing to quickly check if the tables are different, and then applies a @@ -62,11 +110,6 @@ class TableDiffer: bisection_factor: int = DEFAULT_BISECTION_FACTOR bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests - threaded: bool = True - max_threadpool_size: Optional[int] = 1 - - # Enable/disable debug prints - debug: bool = False stats: dict = {} @@ -291,27 +334,3 @@ def _diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableS if checksum1 != checksum2: return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max(count1, count2)) - def _thread_map(self, func, iterable): - if not self.threaded: - return map(func, iterable) - - with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool: - return task_pool.map(func, iterable) - - def _threaded_call(self, func, iterable): - "Calls a method for each object in iterable." - return list(self._thread_map(methodcaller(func), iterable)) - - def _thread_as_completed(self, func, iterable): - if not self.threaded: - yield from map(func, iterable) - return - - with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool: - futures = [task_pool.submit(func, item) for item in iterable] - for future in as_completed(futures): - yield future.result() - - def _threaded_call_as_completed(self, func, iterable): - "Calls a method for each object in iterable. Returned in order of completion." - return self._thread_as_completed(methodcaller(func), iterable) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py new file mode 100644 index 00000000..7f684ebb --- /dev/null +++ b/data_diff/joindiff_tables.py @@ -0,0 +1,207 @@ +"""Provides classes for performing a table diff using JOIN + +""" + +from decimal import Decimal +import logging +from contextlib import contextmanager +from typing import Dict, List + +from runtype import dataclass + +from .utils import safezip +from .databases.base import Database +from .table_segment import TableSegment +from .diff_tables import ThreadBase, DiffResult + +from .queries import table, sum_, min_, max_, avg +from .queries.api import and_, if_, or_, outerjoin, this +from .queries.ast_classes import Concat, Count, Expr, Random +from .queries.compiler import Compiler +from .queries.extras import NormalizeAsString + + +logger = logging.getLogger("joindiff_tables") + + +def merge_dicts(dicts): + i = iter(dicts) + res = next(i) + for d in i: + res.update(d) + return res + + +@dataclass(frozen=False) +class Stats: + exclusive_count: int + exclusive_sample: List[tuple] + diff_ratio_by_column: Dict[str, float] + diff_ratio_total: float + metrics: Dict[str, float] + + +def sample(table): + # TODO + return table.order_by(Random()).limit(10) + + +@contextmanager +def temp_table(db: Database, expr: Expr): + c = Compiler(db) + name = c.new_unique_name("tmp_table") + db.query(f"create temporary table {c.quote(name)} as {c.compile(expr)}", None) + try: + yield table(name, schema=expr.source_table.schema) + finally: + db.query(f"drop table {c.quote(name)}", None) + + +def _slice_tuple(t, *sizes): + i = 0 + for size in sizes: + yield t[i : i + size] + i += size + assert i == len(t) + + +def json_friendly_value(v): + if isinstance(v, Decimal): + return float(v) + return v + + +@dataclass +class JoinDifferBase(ThreadBase): + """Finds the diff between two SQL tables using JOINs""" + + stats: dict = {} + + def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: + table1, table2 = self._threaded_call("with_schema", [table1, table2]) + + if table1.database is not table2.database: + raise ValueError("Join-diff only works when both tables are in the same database") + + with self._run_in_background(self._test_null_or_duplicate_keys, table1, table2): + with self._run_in_background(self._collect_stats, 1, table1): + with self._run_in_background(self._collect_stats, 2, table2): + yield from self._outer_join(table1, table2) + + logger.info("Diffing complete") + + def _test_null_or_duplicate_keys(self, table1, table2): + logger.info("Testing for null or duplicate keys") + + # Test null or duplicate keys + for ts in [table1, table2]: + t = table(*ts.table_path, schema=ts._schema) + key_columns = [ts.key_column] # XXX + + q = t.select(total=Count(), total_distinct=Count(Concat(key_columns), distinct=True)) + total, total_distinct = ts.database.query(q, tuple) + if total != total_distinct: + raise ValueError("Duplicate primary keys") + + q = t.select(*key_columns).where(or_(this[k] == None for k in key_columns)) + nulls = ts.database.query(q, list) + if nulls: + raise ValueError(f"NULL values in one or more primary keys: {nulls}") + + logger.debug("Done testing for null or duplicate keys") + + def _collect_stats(self, i, table): + logger.info(f"Collecting stats for table #{i}") + db = table.database + + # Metrics + col_exprs = merge_dicts( + { + f"sum_{c}": sum_(c), + f"avg_{c}": avg(c), + f"min_{c}": min_(c), + f"max_{c}": max_(c), + } + for c in table._relevant_columns + if c == "id" # TODO just if the right type + ) + col_exprs["count"] = Count() + + res = db.query(table._make_select().select(**col_exprs), tuple) + res = dict(zip([f"table{i}_{n}" for n in col_exprs], map(json_friendly_value, res))) + self.stats.update(res) + + logger.debug(f"Done collecting stats for table #{i}") + + # stats.diff_ratio_by_column = diff_stats + # stats.diff_ratio_total = diff_stats['total_diff'] + + +def bool_to_int(x): + return if_(x, 1, 0) + + +class JoinDiffer(JoinDifferBase): + def _outer_join(self, table1, table2): + db = table1.database + if db is not table2.database: + raise ValueError("Joindiff only applies to tables within the same database") + + keys1 = [table1.key_column] # XXX + keys2 = [table2.key_column] # XXX + if len(keys1) != len(keys2): + raise ValueError("The provided key columns are of a different count") + + cols1 = table1._relevant_columns + cols2 = table2._relevant_columns + if len(cols1) != len(cols2): + raise ValueError("The provided columns are of a different count") + + a = table1._make_select() + b = table2._make_select() + + is_diff_cols = { + f"is_diff_col_{c1}": bool_to_int(a[c1].is_distinct_from(b[c2])) for c1, c2 in safezip(cols1, cols2) + } + + a_cols = {f"table1_{c}": NormalizeAsString(a[c]) for c in cols1} + b_cols = {f"table2_{c}": NormalizeAsString(b[c]) for c in cols2} + + diff_rows = ( + outerjoin(a, b) + .on(a[k1] == b[k2] for k1, k2 in safezip(keys1, keys2)) + .select( + is_exclusive_a=and_(b[k] == None for k in keys2), + is_exclusive_b=and_(a[k] == None for k in keys1), + **is_diff_cols, + **a_cols, + **b_cols, + ) + .where(or_(this[c] == 1 for c in is_diff_cols)) + ) + + with self._run_in_background(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols): + with self._run_in_background(self._count_diff_per_column, db, diff_rows, is_diff_cols): + + logger.info("Querying for different rows") + for is_xa, is_xb, *x in db.query(diff_rows, list): + assert not (is_xa and is_xb) # Can't both be exclusive + is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols)) + if not is_xb: + yield "-", tuple(a_row) + if not is_xa: + yield "+", tuple(b_row) + + def _count_diff_per_column(self, db, diff_rows, is_diff_cols): + logger.info("Counting differences per column") + is_diff_cols_counts = db.query(diff_rows.select(sum_(this[c]) for c in is_diff_cols), tuple) + for name, count in safezip(is_diff_cols, is_diff_cols_counts): + self.stats[f"count_{name}"] = count + + def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): + logger.info("Counting and sampling exclusive rows") + exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b) + with temp_table(db, exclusive_rows_query) as exclusive_rows: + self.stats["exclusive_count"] = db.query(exclusive_rows.count(), int) + sample_rows = db.query(sample(exclusive_rows.select(*this[list(a_cols)], *this[list(b_cols)])), list) + self.stats["exclusive_sample"] = sample_rows diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index de0cde5d..3ac37bd0 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -176,7 +176,7 @@ def test_init(self): ) def test_basic(self): - differ = TableDiffer(10, 100) + differ = TableDiffer(bisection_factor=10, bisection_threshold=100) a = TableSegment(self.connection, self.table_src_path, "id", "datetime", case_sensitive=False) b = TableSegment(self.connection, self.table_dst_path, "id", "datetime", case_sensitive=False) assert a.count() == 6 @@ -186,7 +186,7 @@ def test_basic(self): self.assertEqual(len(list(differ.diff_tables(a, b))), 1) def test_offset(self): - differ = TableDiffer(2, 10) + differ = TableDiffer(bisection_factor=2, bisection_threshold=10) sec1 = self.now.shift(seconds=-1).datetime a = TableSegment(self.connection, self.table_src_path, "id", "datetime", max_update=sec1, case_sensitive=False) b = TableSegment(self.connection, self.table_dst_path, "id", "datetime", max_update=sec1, case_sensitive=False) @@ -250,7 +250,7 @@ def setUp(self): self.table = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) self.table2 = TableSegment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) - self.differ = TableDiffer(3, 4) + self.differ = TableDiffer(bisection_factor=3, bisection_threshold=4) def test_properties_on_empty_table(self): table = self.table.with_schema() @@ -287,7 +287,7 @@ def test_diff_small_tables(self): self.assertEqual(1, self.differ.stats["table2_count"]) def test_non_threaded(self): - differ = TableDiffer(3, 4, threaded=False) + differ = TableDiffer(bisection_factor=3, bisection_threshold=4, threaded=False) time = "2022-01-01 00:00:00" time_str = f"timestamp '{time}'" diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py new file mode 100644 index 00000000..72d604cd --- /dev/null +++ b/tests/test_joindiff.py @@ -0,0 +1,168 @@ +from parameterized import parameterized_class + +from data_diff.databases.connect import connect +from data_diff.table_segment import TableSegment, split_space +from data_diff import databases as db +from data_diff.utils import ArithAlphanumeric +from data_diff.joindiff_tables import JoinDiffer + +from .test_diff_tables import TestPerDatabase, _get_float_type, _get_text_type, _commit, _insert_row, _insert_rows + +from .common import ( + str_to_checksum, + CONN_STRINGS, + N_THREADS, +) + +DATABASE_INSTANCES = None +DATABASE_URIS = {k.__name__: v for k, v in CONN_STRINGS.items()} + + +def init_instances(): + global DATABASE_INSTANCES + if DATABASE_INSTANCES is not None: + return + + DATABASE_INSTANCES = {k.__name__: connect(v, N_THREADS) for k, v in CONN_STRINGS.items()} + + +TEST_DATABASES = {x.__name__ for x in (db.PostgreSQL,)} + +_class_per_db_dec = parameterized_class( + ("name", "db_name"), [(name, name) for name in DATABASE_URIS if name in TEST_DATABASES] +) + + +def test_per_database(cls): + return _class_per_db_dec(cls) + + +@test_per_database +class TestJoindiff(TestPerDatabase): + def setUp(self): + super().setUp() + + float_type = _get_float_type(self.connection) + + self.connection.query( + f"create table {self.table_src}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", + None, + ) + self.connection.query( + f"create table {self.table_dst}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", + None, + ) + _commit(self.connection) + + self.table = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) + self.table2 = TableSegment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) + + self.differ = JoinDiffer() + + def test_diff_small_tables(self): + time = "2022-01-01 00:00:00" + time_str = f"timestamp '{time}'" + + cols = "id userid movieid rating timestamp".split() + _insert_rows(self.connection, self.table_src, cols, [[1, 1, 1, 9, time_str], [2, 2, 2, 9, time_str]]) + _insert_rows(self.connection, self.table_dst, cols, [[1, 1, 1, 9, time_str]]) + _commit(self.connection) + diff = list(self.differ.diff_tables(self.table, self.table2)) + expected = [("-", ("2", time + ".000000"))] + self.assertEqual(expected, diff) + self.assertEqual(2, self.differ.stats["table1_count"]) + self.assertEqual(1, self.differ.stats["table2_count"]) + + def test_diff_table_above_bisection_threshold(self): + time = "2022-01-01 00:00:00" + time_str = f"timestamp '{time}'" + + cols = "id userid movieid rating timestamp".split() + _insert_rows( + self.connection, + self.table_src, + cols, + [ + [1, 1, 1, 9, time_str], + [2, 2, 2, 9, time_str], + [3, 3, 3, 9, time_str], + [4, 4, 4, 9, time_str], + [5, 5, 5, 9, time_str], + ], + ) + + _insert_rows( + self.connection, + self.table_dst, + cols, + [ + [1, 1, 1, 9, time_str], + [2, 2, 2, 9, time_str], + [3, 3, 3, 9, time_str], + [4, 4, 4, 9, time_str], + ], + ) + _commit(self.connection) + + diff = list(self.differ.diff_tables(self.table, self.table2)) + expected = [("-", ("5", time + ".000000"))] + self.assertEqual(expected, diff) + self.assertEqual(5, self.differ.stats["table1_count"]) + self.assertEqual(4, self.differ.stats["table2_count"]) + + def test_return_empty_array_when_same(self): + time = "2022-01-01 00:00:00" + time_str = f"timestamp '{time}'" + + cols = "id userid movieid rating timestamp".split() + + _insert_row(self.connection, self.table_src, cols, [1, 1, 1, 9, time_str]) + _insert_row(self.connection, self.table_dst, cols, [1, 1, 1, 9, time_str]) + + diff = list(self.differ.diff_tables(self.table, self.table2)) + self.assertEqual([], diff) + + def test_diff_sorted_by_key(self): + time = "2022-01-01 00:00:00" + time2 = "2021-01-01 00:00:00" + + time_str = f"timestamp '{time}'" + time_str2 = f"timestamp '{time2}'" + + cols = "id userid movieid rating timestamp".split() + + _insert_rows( + self.connection, + self.table_src, + cols, + [ + [1, 1, 1, 9, time_str], + [2, 2, 2, 9, time_str2], + [3, 3, 3, 9, time_str], + [4, 4, 4, 9, time_str2], + [5, 5, 5, 9, time_str], + ], + ) + + _insert_rows( + self.connection, + self.table_dst, + cols, + [ + [1, 1, 1, 9, time_str], + [2, 2, 2, 9, time_str], + [3, 3, 3, 9, time_str], + [4, 4, 4, 9, time_str], + [5, 5, 5, 9, time_str], + ], + ) + _commit(self.connection) + + diff = list(self.differ.diff_tables(self.table, self.table2)) + expected = [ + ("-", ("2", time2 + ".000000")), + ("+", ("2", time + ".000000")), + ("-", ("4", time2 + ".000000")), + ("+", ("4", time + ".000000")), + ] + self.assertEqual(expected, diff) From b6170bf8c64916126173cc25a4c1c67691f1e314 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Sat, 10 Sep 2022 11:51:07 +0300 Subject: [PATCH 04/33] Integrate joindiff into main --- data_diff/__main__.py | 76 ++++++++++++++++++++++++--- data_diff/databases/database_types.py | 2 +- data_diff/diff_tables.py | 10 ++-- 3 files changed, 74 insertions(+), 14 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index a5715397..7ad156ea 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -1,13 +1,17 @@ from copy import deepcopy +from enum import Enum import sys import time import json import logging from itertools import islice +from typing import Optional import rich import click +from data_diff.joindiff_tables import JoinDiffer + from .utils import remove_password_from_url, safezip, match_like from .diff_tables import TableDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR @@ -28,6 +32,12 @@ } +class Algorithm(Enum): + AUTO = "auto" + JOINDIFF = "joindiff" + HASHDIFF = "hashdiff" + + def _remove_passwords_in_dict(d: dict): for k, v in d.items(): if k == "password": @@ -43,13 +53,30 @@ def _get_schema(pair): return db.query_table_schema(table_path) -@click.command() +class MyHelpFormatter(click.HelpFormatter): + def __init__(self, **kwargs): + super().__init__(self, **kwargs) + self.indent_increment = 6 + + def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -> None: + self.write(f"data-diff - efficiently diff rows across database tables.\n\n") + self.write(f"Usage:\n") + self.write(f" * In-db diff: {prog} [OPTIONS]\n") + self.write(f" * Cross-db diff: {prog} [OPTIONS]\n") + self.write(f" * Using config: {prog} --conf PATH [--run NAME] [OPTIONS]\n") + # s = super().write_usage(prog, args, prefix) + + +click.Context.formatter_class = MyHelpFormatter + + +@click.command(no_args_is_help=True) @click.argument("database1", required=False) @click.argument("table1", required=False) @click.argument("database2", required=False) @click.argument("table2", required=False) -@click.option("-k", "--key-column", default=None, help="Name of primary key column. Default='id'.") -@click.option("-t", "--update-column", default=None, help="Name of updated_at/last_updated column") +@click.option("-k", "--key-column", default=None, help="Name of primary key column. Default='id'.", metavar="NAME") +@click.option("-t", "--update-column", default=None, help="Name of updated_at/last_updated column", metavar="NAME") @click.option( "-c", "--columns", @@ -58,13 +85,20 @@ def _get_schema(pair): help="Names of extra columns to compare." "Can be used more than once in the same command. " "Accepts a name or a pattern like in SQL. Example: -c col% -c another_col", + metavar="NAME", +) +@click.option("-l", "--limit", default=None, help="Maximum number of differences to find", metavar="NUM") +@click.option( + "--bisection-factor", + default=None, + help=f"Segments per iteration. Default={DEFAULT_BISECTION_FACTOR}.", + metavar="NUM", ) -@click.option("-l", "--limit", default=None, help="Maximum number of differences to find") -@click.option("--bisection-factor", default=None, help=f"Segments per iteration. Default={DEFAULT_BISECTION_FACTOR}.") @click.option( "--bisection-threshold", default=None, help=f"Minimal bisection threshold. Below it, data-diff will download the data and compare it locally. Default={DEFAULT_BISECTION_THRESHOLD}.", + metavar="NUM", ) @click.option( "--min-age", @@ -72,8 +106,11 @@ def _get_schema(pair): help="Considers only rows older than specified. Useful for specifying replication lag." "Example: --min-age=5min ignores rows from the last 5 minutes. " f"\nValid units: {UNITS_STR}", + metavar="AGE", +) +@click.option( + "--max-age", default=None, help="Considers only rows younger than specified. See --min-age.", metavar="AGE" ) -@click.option("--max-age", default=None, help="Considers only rows younger than specified. See --min-age.") @click.option("-s", "--stats", is_flag=True, help="Print stats instead of a detailed diff") @click.option("-d", "--debug", is_flag=True, help="Print debug info") @click.option("--json", "json_output", is_flag=True, help="Print JSONL output for machine readability") @@ -92,21 +129,39 @@ def _get_schema(pair): help="Number of worker threads to use per database. Default=1. " "A higher number will increase performance, but take more capacity from your database. " "'serial' guarantees a single-threaded execution of the algorithm (useful for debugging).", + metavar="COUNT", +) +@click.option( + "-w", "--where", default=None, help="An additional 'where' expression to restrict the search space.", metavar="EXPR" ) -@click.option("-w", "--where", default=None, help="An additional 'where' expression to restrict the search space.") +@click.option("-a", "--algorithm", default=Algorithm.AUTO.value, type=click.Choice([i.value for i in Algorithm])) @click.option( "--conf", default=None, help="Path to a configuration.toml file, to provide a default configuration, and a list of possible runs.", + metavar="PATH", ) @click.option( "--run", default=None, help="Name of run-configuration to run. If used, CLI arguments for database and table must be omitted.", + metavar="NAME", ) def main(conf, run, **kw): + indb_syntax = False + if kw["table2"] is None and kw["database2"]: + # Use the "database table table" form + kw["table2"] = kw["database2"] + kw["database2"] = kw["database1"] + indb_syntax = True + if conf: kw = apply_config_from_file(conf, run, kw) + + kw["algorithm"] = Algorithm(kw["algorithm"]) + if kw["algorithm"] == Algorithm.AUTO: + kw["algorithm"] = Algorithm.JOINDIFF if indb_syntax else Algorithm.HASHDIFF + return _main(**kw) @@ -119,6 +174,7 @@ def _main( update_column, columns, limit, + algorithm, bisection_factor, bisection_threshold, min_age, @@ -214,7 +270,10 @@ def _main( try: db1 = connect(database1, threads1 or threads) - db2 = connect(database2, threads2 or threads) + if database1 == database2: + db2 = db1 + else: + db2 = connect(database2, threads2 or threads) except Exception as e: logging.error(e) return @@ -277,6 +336,7 @@ def _main( "different_+": plus, "different_-": minus, "total": max_table_count, + "stats": differ.stats, } print(json.dumps(json_output)) else: diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index ca2734fc..1e9c973e 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -172,6 +172,7 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None "Provide SQL fragment for limit and offset inside a select" ... + class AbstractDatabase(AbstractDialect): @abstractmethod def timestamp_value(self, t: DbTime) -> str: @@ -183,7 +184,6 @@ def md5_to_int(self, s: str) -> str: "Provide SQL for computing md5 and returning an int" ... - @abstractmethod def _query(self, sql_code: str) -> list: "Send query to database and return result" diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 7f67cfc2..a98a6508 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -78,11 +78,11 @@ def _threaded_call_as_completed(self, func, iterable): return self._thread_as_completed(methodcaller(func), iterable) def _run_thread(self, threadfunc, *args, daemon=False) -> threading.Thread: - th = threading.Thread(target=threadfunc, args=args) - if daemon: - th.daemon = True - th.start() - return th + th = threading.Thread(target=threadfunc, args=args) + if daemon: + th.daemon = True + th.start() + return th @contextmanager def _run_in_background(self, threadfunc, *args, daemon=False): From becf36c6ddda0fbfc8f7a51c0077b4388657d428 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 14 Sep 2022 11:20:36 +0300 Subject: [PATCH 05/33] Refactor diff_tables.TableDiffer -> hashdiff_tables.HashDiffer --- data_diff/__init__.py | 33 ++-- data_diff/__main__.py | 17 +-- data_diff/diff_tables.py | 282 +--------------------------------- data_diff/hashdiff_tables.py | 283 +++++++++++++++++++++++++++++++++++ tests/common.py | 3 +- tests/test_database_types.py | 7 +- tests/test_diff_tables.py | 33 ++-- tests/test_postgresql.py | 5 +- 8 files changed, 341 insertions(+), 322 deletions(-) create mode 100644 data_diff/hashdiff_tables.py diff --git a/data_diff/__init__.py b/data_diff/__init__.py index 2af199db..f22ab039 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -3,7 +3,10 @@ from .tracking import disable_tracking from .databases.connect import connect from .databases.database_types import DbKey, DbTime, DbPath -from .diff_tables import TableSegment, TableDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR +from .diff_tables import Algorithm +from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR +from .joindiff_tables import JoinDiffer +from .table_segment import TableSegment def connect_to_table( @@ -46,9 +49,11 @@ def diff_tables( # Start/end update_column values, used to restrict the segment min_update: DbTime = None, max_update: DbTime = None, - # Into how many segments to bisect per iteration + # Algorithm + algorithm: Algorithm = Algorithm.HASHDIFF, + # Into how many segments to bisect per iteration (hashdiff only) bisection_factor: int = DEFAULT_BISECTION_FACTOR, - # When should we stop bisecting and compare locally (in row count) + # When should we stop bisecting and compare locally (in row count; hashdiff only) bisection_threshold: int = DEFAULT_BISECTION_THRESHOLD, # Enable/disable threaded diffing. Needed to take advantage of database threads. threaded: bool = True, @@ -81,10 +86,20 @@ def diff_tables( segments = [t.new(**override_attrs) for t in tables] if override_attrs else tables - differ = TableDiffer( - bisection_factor=bisection_factor, - bisection_threshold=bisection_threshold, - threaded=threaded, - max_threadpool_size=max_threadpool_size, - ) + algorithm = Algorithm(algorithm) + if algorithm == Algorithm.HASHDIFF: + differ = HashDiffer( + bisection_factor=bisection_factor, + bisection_threshold=bisection_threshold, + threaded=threaded, + max_threadpool_size=max_threadpool_size, + ) + elif algorithm == Algorithm.JOINDIFF: + differ = JoinDiffer( + threaded=threaded, + max_threadpool_size=max_threadpool_size, + ) + else: + raise ValueError(f"Unknown algorithm: {algorithm}") + return differ.diff_tables(*segments) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 7ad156ea..adb5bee9 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -1,5 +1,4 @@ from copy import deepcopy -from enum import Enum import sys import time import json @@ -10,11 +9,10 @@ import rich import click -from data_diff.joindiff_tables import JoinDiffer - - from .utils import remove_password_from_url, safezip, match_like -from .diff_tables import TableDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR +from .diff_tables import Algorithm +from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR +from .joindiff_tables import JoinDiffer from .table_segment import TableSegment from .databases.database_types import create_schema from .databases.connect import connect @@ -32,12 +30,6 @@ } -class Algorithm(Enum): - AUTO = "auto" - JOINDIFF = "joindiff" - HASHDIFF = "hashdiff" - - def _remove_passwords_in_dict(d: dict): for k, v in d.items(): if k == "password": @@ -64,7 +56,6 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - self.write(f" * In-db diff: {prog} [OPTIONS]\n") self.write(f" * Cross-db diff: {prog} [OPTIONS]\n") self.write(f" * Using config: {prog} --conf PATH [--run NAME] [OPTIONS]\n") - # s = super().write_usage(prog, args, prefix) click.Context.formatter_class = MyHelpFormatter @@ -255,7 +246,7 @@ def _main( ) else: assert algorithm == Algorithm.HASHDIFF - differ = TableDiffer( + differ = HashDiffer( bisection_factor=bisection_factor, bisection_threshold=bisection_threshold, threaded=threaded, diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index a98a6508..04a95fe7 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -1,45 +1,20 @@ """Provides classes for performing a table diff """ +from enum import Enum from contextlib import contextmanager -import time import threading -import os -from numbers import Number -from operator import attrgetter, methodcaller -from collections import defaultdict +from operator import methodcaller from typing import Tuple, Iterator, Optional -import logging from concurrent.futures import ThreadPoolExecutor, as_completed from runtype import dataclass -from .utils import safezip, run_as_daemon -from .thread_utils import ThreadedYielder -from .databases.database_types import IKey, NumericType, PrecisionType, StringType -from .table_segment import TableSegment -from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled -logger = logging.getLogger("diff_tables") - -BENCHMARK = os.environ.get("BENCHMARK", False) -DEFAULT_BISECTION_THRESHOLD = 1024 * 16 -DEFAULT_BISECTION_FACTOR = 32 - - -def diff_sets(a: set, b: set) -> Iterator: - s1 = set(a) - s2 = set(b) - d = defaultdict(list) - - # The first item is always the key (see TableDiffer._relevant_columns) - for i in s1 - s2: - d[i[0]].append(("-", i)) - for i in s2 - s1: - d[i[0]].append(("+", i)) - - for _k, v in sorted(d.items(), key=lambda i: i[0]): - yield from v +class Algorithm(Enum): + AUTO = "auto" + JOINDIFF = "joindiff" + HASHDIFF = "hashdiff" DiffResult = Iterator[Tuple[str, tuple]] # Iterator[Tuple[Literal["+", "-"], tuple]] @@ -89,248 +64,3 @@ def _run_in_background(self, threadfunc, *args, daemon=False): t = self._run_thread(threadfunc, *args, daemon=daemon) yield t t.join() - - -@dataclass -class TableDiffer(ThreadBase): - """Finds the diff between two SQL tables - - The algorithm uses hashing to quickly check if the tables are different, and then applies a - bisection search recursively to find the differences efficiently. - - Works best for comparing tables that are mostly the same, with minor discrepencies. - - Parameters: - bisection_factor (int): Into how many segments to bisect per iteration. - bisection_threshold (Number): When should we stop bisecting and compare locally (in row count). - threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. - max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``. - There may be many pools, so number of actual threads can be a lot higher. - """ - - bisection_factor: int = DEFAULT_BISECTION_FACTOR - bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests - - stats: dict = {} - - def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: - """Diff the given tables. - - Parameters: - table1 (TableSegment): The "before" table to compare. Or: source table - table2 (TableSegment): The "after" table to compare. Or: target table - - Returns: - An iterator that yield pair-tuples, representing the diff. Items can be either - ('-', columns) for items in table1 but not in table2 - ('+', columns) for items in table2 but not in table1 - Where `columns` is a tuple of values for the involved columns, i.e. (id, ...extra) - """ - # Validate options - if self.bisection_factor >= self.bisection_threshold: - raise ValueError("Incorrect param values (bisection factor must be lower than threshold)") - if self.bisection_factor < 2: - raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)") - - if is_tracking_enabled(): - options = dict(self) - event_json = create_start_event_json(options) - run_as_daemon(send_event_json, event_json) - - self.stats["diff_count"] = 0 - start = time.monotonic() - error = None - try: - - # Query and validate schema - table1, table2 = self._threaded_call("with_schema", [table1, table2]) - self._validate_and_adjust_columns(table1, table2) - - key_type = table1._schema[table1.key_column] - key_type2 = table2._schema[table2.key_column] - if not isinstance(key_type, IKey): - raise NotImplementedError(f"Cannot use column of type {key_type} as a key") - if not isinstance(key_type2, IKey): - raise NotImplementedError(f"Cannot use column of type {key_type2} as a key") - assert key_type.python_type is key_type2.python_type - - # Query min/max values - key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2]) - - # Start with the first completed value, so we don't waste time waiting - min_key1, max_key1 = self._parse_key_range_result(key_type, next(key_ranges)) - - table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)] - - logger.info( - f"Diffing tables | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}. " - f"key-range: {table1.min_key}..{table2.max_key}, " - f"size: table1 <= {table1.approximate_size()}, table2 <= {table2.approximate_size()}" - ) - - ti = ThreadedYielder(self.max_threadpool_size) - # Bisect (split) the table into segments, and diff them recursively. - ti.submit(self._bisect_and_diff_tables, ti, table1, table2) - - # Now we check for the second min-max, to diff the portions we "missed". - min_key2, max_key2 = self._parse_key_range_result(key_type, next(key_ranges)) - - if min_key2 < min_key1: - pre_tables = [t.new(min_key=min_key2, max_key=min_key1) for t in (table1, table2)] - ti.submit(self._bisect_and_diff_tables, ti, *pre_tables) - - if max_key2 > max_key1: - post_tables = [t.new(min_key=max_key1, max_key=max_key2) for t in (table1, table2)] - ti.submit(self._bisect_and_diff_tables, ti, *post_tables) - - yield from ti - - except BaseException as e: # Catch KeyboardInterrupt too - error = e - finally: - if is_tracking_enabled(): - runtime = time.monotonic() - start - table1_count = self.stats.get("table1_count") - table2_count = self.stats.get("table2_count") - diff_count = self.stats.get("diff_count") - err_message = str(error)[:20] # Truncate possibly sensitive information. - event_json = create_end_event_json( - error is None, - runtime, - table1.database.name, - table2.database.name, - table1_count, - table2_count, - diff_count, - err_message, - ) - send_event_json(event_json) - - if error: - raise error - - def _parse_key_range_result(self, key_type, key_range): - mn, mx = key_range - cls = key_type.make_value - # We add 1 because our ranges are exclusive of the end (like in Python) - try: - return cls(mn), cls(mx) + 1 - except (TypeError, ValueError) as e: - raise type(e)(f"Cannot apply {key_type} to {mn}, {mx}.") from e - - def _validate_and_adjust_columns(self, table1, table2): - for c1, c2 in safezip(table1._relevant_columns, table2._relevant_columns): - if c1 not in table1._schema: - raise ValueError(f"Column '{c1}' not found in schema for table {table1}") - if c2 not in table2._schema: - raise ValueError(f"Column '{c2}' not found in schema for table {table2}") - - # Update schemas to minimal mutual precision - col1 = table1._schema[c1] - col2 = table2._schema[c2] - if isinstance(col1, PrecisionType): - if not isinstance(col2, PrecisionType): - raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") - - lowest = min(col1, col2, key=attrgetter("precision")) - - if col1.precision != col2.precision: - logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") - - table1._schema[c1] = col1.replace(precision=lowest.precision, rounds=lowest.rounds) - table2._schema[c2] = col2.replace(precision=lowest.precision, rounds=lowest.rounds) - - elif isinstance(col1, NumericType): - if not isinstance(col2, NumericType): - raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") - - lowest = min(col1, col2, key=attrgetter("precision")) - - if col1.precision != col2.precision: - logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") - - table1._schema[c1] = col1.replace(precision=lowest.precision) - table2._schema[c2] = col2.replace(precision=lowest.precision) - - elif isinstance(col1, StringType): - if not isinstance(col2, StringType): - raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") - - for t in [table1, table2]: - for c in t._relevant_columns: - ctype = t._schema[c] - if not ctype.supported: - logger.warning( - f"[{t.database.name}] Column '{c}' of type '{ctype}' has no compatibility handling. " - "If encoding/formatting differs between databases, it may result in false positives." - ) - - def _bisect_and_diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None): - assert table1.is_bounded and table2.is_bounded - - if max_rows is None: - # We can be sure that row_count <= max_rows - max_rows = max(table1.approximate_size(), table2.approximate_size()) - - # If count is below the threshold, just download and compare the columns locally - # This saves time, as bisection speed is limited by ping and query performance. - if max_rows < self.bisection_threshold: - rows1, rows2 = self._threaded_call("get_values", [table1, table2]) - diff = list(diff_sets(rows1, rows2)) - - # Initial bisection_threshold larger than count. Normally we always - # checksum and count segments, even if we get the values. At the - # first level, however, that won't be true. - if level == 0: - self.stats["table1_count"] = len(rows1) - self.stats["table2_count"] = len(rows2) - - self.stats["diff_count"] += len(diff) - - logger.info(". " * level + f"Diff found {len(diff)} different rows.") - self.stats["rows_downloaded"] = self.stats.get("rows_downloaded", 0) + max(len(rows1), len(rows2)) - return diff - - # Choose evenly spaced checkpoints (according to min_key and max_key) - checkpoints = table1.choose_checkpoints(self.bisection_factor - 1) - - # Create new instances of TableSegment between each checkpoint - segmented1 = table1.segment_by_checkpoints(checkpoints) - segmented2 = table2.segment_by_checkpoints(checkpoints) - - # Recursively compare each pair of corresponding segments between table1 and table2 - for i, (t1, t2) in enumerate(safezip(segmented1, segmented2)): - ti.submit(self._diff_tables, ti, t1, t2, max_rows, level + 1, i + 1, len(segmented1), priority=level) - - def _diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): - logger.info( - ". " * level + f"Diffing segment {segment_index}/{segment_count}, " - f"key-range: {table1.min_key}..{table2.max_key}, " - f"size <= {max_rows}" - ) - - # When benchmarking, we want the ability to skip checksumming. This - # allows us to download all rows for comparison in performance. By - # default, data-diff will checksum the section first (when it's below - # the threshold) and _then_ download it. - if BENCHMARK: - if max_rows < self.bisection_threshold: - return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max_rows) - - (count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2]) - - if count1 == 0 and count2 == 0: - # logger.warning( - # f"Uneven distribution of keys detected in segment {table1.min_key}..{table2.max_key}. (big gaps in the key column). " - # "For better performance, we recommend to increase the bisection-threshold." - # ) - assert checksum1 is None and checksum2 is None - return - - if level == 1: - self.stats["table1_count"] = self.stats.get("table1_count", 0) + count1 - self.stats["table2_count"] = self.stats.get("table2_count", 0) + count2 - - if checksum1 != checksum2: - return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max(count1, count2)) - diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py new file mode 100644 index 00000000..f4867a74 --- /dev/null +++ b/data_diff/hashdiff_tables.py @@ -0,0 +1,283 @@ +import os +import time +from numbers import Number +import logging +from collections import defaultdict +from typing import Iterator +from operator import attrgetter + +from runtype import dataclass + +from .utils import safezip, run_as_daemon +from .thread_utils import ThreadedYielder +from .databases.database_types import IKey, NumericType, PrecisionType, StringType +from .table_segment import TableSegment +from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled + +from .diff_tables import ThreadBase, DiffResult + +BENCHMARK = os.environ.get("BENCHMARK", False) + +DEFAULT_BISECTION_THRESHOLD = 1024 * 16 +DEFAULT_BISECTION_FACTOR = 32 + +logger = logging.getLogger("hashdiff_tables") + + +def diff_sets(a: set, b: set) -> Iterator: + s1 = set(a) + s2 = set(b) + d = defaultdict(list) + + # The first item is always the key (see TableDiffer._relevant_columns) + for i in s1 - s2: + d[i[0]].append(("-", i)) + for i in s2 - s1: + d[i[0]].append(("+", i)) + + for _k, v in sorted(d.items(), key=lambda i: i[0]): + yield from v + + +@dataclass +class HashDiffer(ThreadBase): + """Finds the diff between two SQL tables + + The algorithm uses hashing to quickly check if the tables are different, and then applies a + bisection search recursively to find the differences efficiently. + + Works best for comparing tables that are mostly the same, with minor discrepencies. + + Parameters: + bisection_factor (int): Into how many segments to bisect per iteration. + bisection_threshold (Number): When should we stop bisecting and compare locally (in row count). + threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. + max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``. + There may be many pools, so number of actual threads can be a lot higher. + """ + + bisection_factor: int = DEFAULT_BISECTION_FACTOR + bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests + + stats: dict = {} + + def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: + """Diff the given tables. + + Parameters: + table1 (TableSegment): The "before" table to compare. Or: source table + table2 (TableSegment): The "after" table to compare. Or: target table + + Returns: + An iterator that yield pair-tuples, representing the diff. Items can be either + ('-', columns) for items in table1 but not in table2 + ('+', columns) for items in table2 but not in table1 + Where `columns` is a tuple of values for the involved columns, i.e. (id, ...extra) + """ + # Validate options + if self.bisection_factor >= self.bisection_threshold: + raise ValueError("Incorrect param values (bisection factor must be lower than threshold)") + if self.bisection_factor < 2: + raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)") + + if is_tracking_enabled(): + options = dict(self) + event_json = create_start_event_json(options) + run_as_daemon(send_event_json, event_json) + + self.stats["diff_count"] = 0 + start = time.monotonic() + error = None + try: + + # Query and validate schema + table1, table2 = self._threaded_call("with_schema", [table1, table2]) + self._validate_and_adjust_columns(table1, table2) + + key_type = table1._schema[table1.key_column] + key_type2 = table2._schema[table2.key_column] + if not isinstance(key_type, IKey): + raise NotImplementedError(f"Cannot use column of type {key_type} as a key") + if not isinstance(key_type2, IKey): + raise NotImplementedError(f"Cannot use column of type {key_type2} as a key") + assert key_type.python_type is key_type2.python_type + + # Query min/max values + key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2]) + + # Start with the first completed value, so we don't waste time waiting + min_key1, max_key1 = self._parse_key_range_result(key_type, next(key_ranges)) + + table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)] + + logger.info( + f"Diffing tables | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}. " + f"key-range: {table1.min_key}..{table2.max_key}, " + f"size: table1 <= {table1.approximate_size()}, table2 <= {table2.approximate_size()}" + ) + + ti = ThreadedYielder(self.max_threadpool_size) + # Bisect (split) the table into segments, and diff them recursively. + ti.submit(self._bisect_and_diff_tables, ti, table1, table2) + + # Now we check for the second min-max, to diff the portions we "missed". + min_key2, max_key2 = self._parse_key_range_result(key_type, next(key_ranges)) + + if min_key2 < min_key1: + pre_tables = [t.new(min_key=min_key2, max_key=min_key1) for t in (table1, table2)] + ti.submit(self._bisect_and_diff_tables, ti, *pre_tables) + + if max_key2 > max_key1: + post_tables = [t.new(min_key=max_key1, max_key=max_key2) for t in (table1, table2)] + ti.submit(self._bisect_and_diff_tables, ti, *post_tables) + + yield from ti + + except BaseException as e: # Catch KeyboardInterrupt too + error = e + finally: + if is_tracking_enabled(): + runtime = time.monotonic() - start + table1_count = self.stats.get("table1_count") + table2_count = self.stats.get("table2_count") + diff_count = self.stats.get("diff_count") + err_message = str(error)[:20] # Truncate possibly sensitive information. + event_json = create_end_event_json( + error is None, + runtime, + table1.database.name, + table2.database.name, + table1_count, + table2_count, + diff_count, + err_message, + ) + send_event_json(event_json) + + if error: + raise error + + def _parse_key_range_result(self, key_type, key_range): + mn, mx = key_range + cls = key_type.make_value + # We add 1 because our ranges are exclusive of the end (like in Python) + try: + return cls(mn), cls(mx) + 1 + except (TypeError, ValueError) as e: + raise type(e)(f"Cannot apply {key_type} to {mn}, {mx}.") from e + + def _validate_and_adjust_columns(self, table1, table2): + for c1, c2 in safezip(table1._relevant_columns, table2._relevant_columns): + if c1 not in table1._schema: + raise ValueError(f"Column '{c1}' not found in schema for table {table1}") + if c2 not in table2._schema: + raise ValueError(f"Column '{c2}' not found in schema for table {table2}") + + # Update schemas to minimal mutual precision + col1 = table1._schema[c1] + col2 = table2._schema[c2] + if isinstance(col1, PrecisionType): + if not isinstance(col2, PrecisionType): + raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") + + lowest = min(col1, col2, key=attrgetter("precision")) + + if col1.precision != col2.precision: + logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") + + table1._schema[c1] = col1.replace(precision=lowest.precision, rounds=lowest.rounds) + table2._schema[c2] = col2.replace(precision=lowest.precision, rounds=lowest.rounds) + + elif isinstance(col1, NumericType): + if not isinstance(col2, NumericType): + raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") + + lowest = min(col1, col2, key=attrgetter("precision")) + + if col1.precision != col2.precision: + logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") + + table1._schema[c1] = col1.replace(precision=lowest.precision) + table2._schema[c2] = col2.replace(precision=lowest.precision) + + elif isinstance(col1, StringType): + if not isinstance(col2, StringType): + raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") + + for t in [table1, table2]: + for c in t._relevant_columns: + ctype = t._schema[c] + if not ctype.supported: + logger.warning( + f"[{t.database.name}] Column '{c}' of type '{ctype}' has no compatibility handling. " + "If encoding/formatting differs between databases, it may result in false positives." + ) + + def _bisect_and_diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None): + assert table1.is_bounded and table2.is_bounded + + if max_rows is None: + # We can be sure that row_count <= max_rows + max_rows = max(table1.approximate_size(), table2.approximate_size()) + + # If count is below the threshold, just download and compare the columns locally + # This saves time, as bisection speed is limited by ping and query performance. + if max_rows < self.bisection_threshold: + rows1, rows2 = self._threaded_call("get_values", [table1, table2]) + diff = list(diff_sets(rows1, rows2)) + + # Initial bisection_threshold larger than count. Normally we always + # checksum and count segments, even if we get the values. At the + # first level, however, that won't be true. + if level == 0: + self.stats["table1_count"] = len(rows1) + self.stats["table2_count"] = len(rows2) + + self.stats["diff_count"] += len(diff) + + logger.info(". " * level + f"Diff found {len(diff)} different rows.") + self.stats["rows_downloaded"] = self.stats.get("rows_downloaded", 0) + max(len(rows1), len(rows2)) + return diff + + # Choose evenly spaced checkpoints (according to min_key and max_key) + checkpoints = table1.choose_checkpoints(self.bisection_factor - 1) + + # Create new instances of TableSegment between each checkpoint + segmented1 = table1.segment_by_checkpoints(checkpoints) + segmented2 = table2.segment_by_checkpoints(checkpoints) + + # Recursively compare each pair of corresponding segments between table1 and table2 + for i, (t1, t2) in enumerate(safezip(segmented1, segmented2)): + ti.submit(self._diff_tables, ti, t1, t2, max_rows, level + 1, i + 1, len(segmented1), priority=level) + + def _diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): + logger.info( + ". " * level + f"Diffing segment {segment_index}/{segment_count}, " + f"key-range: {table1.min_key}..{table2.max_key}, " + f"size <= {max_rows}" + ) + + # When benchmarking, we want the ability to skip checksumming. This + # allows us to download all rows for comparison in performance. By + # default, data-diff will checksum the section first (when it's below + # the threshold) and _then_ download it. + if BENCHMARK: + if max_rows < self.bisection_threshold: + return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max_rows) + + (count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2]) + + if count1 == 0 and count2 == 0: + # logger.warning( + # f"Uneven distribution of keys detected in segment {table1.min_key}..{table2.max_key}. (big gaps in the key column). " + # "For better performance, we recommend to increase the bisection-threshold." + # ) + assert checksum1 is None and checksum2 is None + return + + if level == 1: + self.stats["table1_count"] = self.stats.get("table1_count", 0) + count1 + self.stats["table2_count"] = self.stats.get("table2_count", 0) + count2 + + if checksum1 != checksum2: + return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max(count1, count2)) diff --git a/tests/common.py b/tests/common.py index 2aad4be1..44a15cf2 100644 --- a/tests/common.py +++ b/tests/common.py @@ -43,7 +43,8 @@ def get_git_revision_short_hash() -> str: level = getattr(logging, os.environ["LOG_LEVEL"].upper()) logging.basicConfig(level=level) -logging.getLogger("diff_tables").setLevel(level) +logging.getLogger("hashdiff_tables").setLevel(level) +logging.getLogger("joindiff_tables").setLevel(level) logging.getLogger("table_segment").setLevel(level) logging.getLogger("database").setLevel(level) diff --git a/tests/test_database_types.py b/tests/test_database_types.py index ce273182..4ac8d5f4 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -15,7 +15,8 @@ from data_diff import databases as db from data_diff.databases import postgresql, oracle from data_diff.utils import number_to_human, accumulate -from data_diff.diff_tables import TableDiffer, TableSegment, DEFAULT_BISECTION_THRESHOLD +from data_diff.hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD +from data_diff.table_segment import TableSegment from .common import ( CONN_STRINGS, N_SAMPLES, @@ -667,7 +668,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego ch_factor = min(max(int(N_SAMPLES / 250_000), 2), 128) if BENCHMARK else 2 ch_threshold = min(DEFAULT_BISECTION_THRESHOLD, int(N_SAMPLES / ch_factor)) if BENCHMARK else 3 ch_threads = N_THREADS - differ = TableDiffer( + differ = HashDiffer( bisection_threshold=ch_threshold, bisection_factor=ch_factor, max_threadpool_size=ch_threads, @@ -688,7 +689,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego dl_factor = max(int(N_SAMPLES / 100_000), 2) if BENCHMARK else 2 dl_threshold = int(N_SAMPLES / dl_factor) + 1 if BENCHMARK else math.inf dl_threads = N_THREADS - differ = TableDiffer( + differ = HashDiffer( bisection_threshold=dl_threshold, bisection_factor=dl_factor, max_threadpool_size=dl_threads ) start = time.monotonic() diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 3ac37bd0..63195efb 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -7,7 +7,7 @@ import arrow # comes with preql from data_diff.databases.connect import connect -from data_diff.diff_tables import TableDiffer +from data_diff.hashdiff_tables import HashDiffer from data_diff.table_segment import TableSegment, split_space from data_diff import databases as db from data_diff.utils import ArithAlphanumeric, numberToAlphanum @@ -176,7 +176,7 @@ def test_init(self): ) def test_basic(self): - differ = TableDiffer(bisection_factor=10, bisection_threshold=100) + differ = HashDiffer(bisection_factor=10, bisection_threshold=100) a = TableSegment(self.connection, self.table_src_path, "id", "datetime", case_sensitive=False) b = TableSegment(self.connection, self.table_dst_path, "id", "datetime", case_sensitive=False) assert a.count() == 6 @@ -186,7 +186,7 @@ def test_basic(self): self.assertEqual(len(list(differ.diff_tables(a, b))), 1) def test_offset(self): - differ = TableDiffer(bisection_factor=2, bisection_threshold=10) + differ = HashDiffer(bisection_factor=2, bisection_threshold=10) sec1 = self.now.shift(seconds=-1).datetime a = TableSegment(self.connection, self.table_src_path, "id", "datetime", max_update=sec1, case_sensitive=False) b = TableSegment(self.connection, self.table_dst_path, "id", "datetime", max_update=sec1, case_sensitive=False) @@ -250,7 +250,7 @@ def setUp(self): self.table = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) self.table2 = TableSegment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) - self.differ = TableDiffer(bisection_factor=3, bisection_threshold=4) + self.differ = HashDiffer(bisection_factor=3, bisection_threshold=4) def test_properties_on_empty_table(self): table = self.table.with_schema() @@ -287,7 +287,7 @@ def test_diff_small_tables(self): self.assertEqual(1, self.differ.stats["table2_count"]) def test_non_threaded(self): - differ = TableDiffer(bisection_factor=3, bisection_threshold=4, threaded=False) + differ = HashDiffer(bisection_factor=3, bisection_threshold=4, threaded=False) time = "2022-01-01 00:00:00" time_str = f"timestamp '{time}'" @@ -384,7 +384,7 @@ def test_diff_sorted_by_key(self): ) _commit(self.connection) - differ = TableDiffer() + differ = HashDiffer() diff = list(differ.diff_tables(self.table, self.table2)) expected = [ ("-", ("2", time2 + ".000000")), @@ -444,7 +444,7 @@ def test_diff_column_names(self): table1 = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) table2 = TableSegment(self.connection, self.table_dst_path, "id2", "timestamp2", case_sensitive=False) - differ = TableDiffer() + differ = HashDiffer() diff = list(differ.diff_tables(table1, table2)) assert diff == [] @@ -480,7 +480,7 @@ def setUp(self): self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_string_keys(self): - differ = TableDiffer() + differ = HashDiffer() diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, [("-", (str(self.new_uuid), "This one is different"))]) @@ -493,7 +493,7 @@ def test_string_keys(self): def test_where_sampling(self): a = self.a.replace(where="1=1") - differ = TableDiffer() + differ = HashDiffer() diff = list(differ.diff_tables(a, self.b)) self.assertEqual(diff, [("-", (str(self.new_uuid), "This one is different"))]) @@ -534,7 +534,7 @@ def setUp(self): def test_alphanum_keys(self): - differ = TableDiffer(bisection_factor=2, bisection_threshold=3) + differ = HashDiffer(bisection_factor=2, bisection_threshold=3) diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, [("-", (str(self.new_alphanum), "This one is different"))]) @@ -590,8 +590,7 @@ def test_varying_alphanum_keys(self): for a in alphanums: assert a - a == 0 - # Test with the differ - differ = TableDiffer(threaded=False) + differ = HashDiffer() diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, [("-", (str(self.new_alphanum), "This one is different"))]) @@ -669,7 +668,7 @@ def setUp(self): self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_uuid_column_with_nulls(self): - differ = TableDiffer() + differ = HashDiffer() diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, [("-", (str(self.null_uuid), None))]) @@ -719,7 +718,7 @@ def test_uuid_columns_with_nulls(self): diff results, but it's not. This test helps to detect such cases. """ - differ = TableDiffer(bisection_factor=2, bisection_threshold=3) + differ = HashDiffer(bisection_factor=2, bisection_threshold=3) diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, [("-", (str(self.null_uuid), None))]) @@ -783,7 +782,7 @@ def test_tables_are_different(self): value, it may lead that concat(pk_i, i, NULL) == concat(pk_i, i-diff, NULL). This test handle such cases. """ - differ = TableDiffer(bisection_factor=2, bisection_threshold=4) + differ = HashDiffer(bisection_factor=2, bisection_threshold=4) diff = list(differ.diff_tables(self.a, self.b)) self.assertEqual(diff, self.diffs) @@ -814,7 +813,7 @@ def setUp(self): self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_right_table_empty(self): - differ = TableDiffer() + differ = HashDiffer() self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) def test_left_table_empty(self): @@ -827,5 +826,5 @@ def test_left_table_empty(self): _commit(self.connection) - differ = TableDiffer() + differ = HashDiffer() self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 2feecb02..529de055 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -1,7 +1,6 @@ import unittest -from data_diff.databases.connect import connect -from data_diff import TableSegment, TableDiffer +from data_diff import TableSegment, HashDiffer, connect from .common import TEST_POSTGRESQL_CONN_STRING, random_table_suffix @@ -40,7 +39,7 @@ def test_uuid(self): a = TableSegment(self.connection, (self.table_src,), "id", "comment") b = TableSegment(self.connection, (self.table_dst,), "id", "comment") - differ = TableDiffer() + differ = HashDiffer() diff = list(differ.diff_tables(a, b)) uuid = diff[0][1][0] self.assertEqual(diff, [("-", (uuid, "This one is different"))]) From 74f31e8e2387f374df2afa24dc5347239bb0eae7 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 14 Sep 2022 17:11:04 +0300 Subject: [PATCH 06/33] Adjustments to joindiff implementation --- data_diff/diff_tables.py | 17 +++++------- data_diff/joindiff_tables.py | 52 +++++++++++++++++++++--------------- tests/test_joindiff.py | 26 ++++++++++++++++++ 3 files changed, 63 insertions(+), 32 deletions(-) diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 04a95fe7..430f027e 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -52,15 +52,10 @@ def _threaded_call_as_completed(self, func, iterable): "Calls a method for each object in iterable. Returned in order of completion." return self._thread_as_completed(methodcaller(func), iterable) - def _run_thread(self, threadfunc, *args, daemon=False) -> threading.Thread: - th = threading.Thread(target=threadfunc, args=args) - if daemon: - th.daemon = True - th.start() - return th - @contextmanager - def _run_in_background(self, threadfunc, *args, daemon=False): - t = self._run_thread(threadfunc, *args, daemon=daemon) - yield t - t.join() + def _run_in_background(self, *funcs): + with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool: + futures = [task_pool.submit(f) for f in funcs] + yield futures + for f in futures: + f.result() diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 7f684ebb..52d055d3 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -3,6 +3,7 @@ """ from decimal import Decimal +from functools import partial import logging from contextlib import contextmanager from typing import Dict, List @@ -83,10 +84,12 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: if table1.database is not table2.database: raise ValueError("Join-diff only works when both tables are in the same database") - with self._run_in_background(self._test_null_or_duplicate_keys, table1, table2): - with self._run_in_background(self._collect_stats, 1, table1): - with self._run_in_background(self._collect_stats, 2, table2): - yield from self._outer_join(table1, table2) + with self._run_in_background( + partial(self._test_null_or_duplicate_keys, table1, table2), + partial(self._collect_stats, 1, table1), + partial(self._collect_stats, 2, table2) + ): + yield from self._outer_join(table1, table2) logger.info("Diffing complete") @@ -106,7 +109,7 @@ def _test_null_or_duplicate_keys(self, table1, table2): q = t.select(*key_columns).where(or_(this[k] == None for k in key_columns)) nulls = ts.database.query(q, list) if nulls: - raise ValueError(f"NULL values in one or more primary keys: {nulls}") + raise ValueError(f"NULL values in one or more primary keys") logger.debug("Done testing for null or duplicate keys") @@ -161,7 +164,7 @@ def _outer_join(self, table1, table2): b = table2._make_select() is_diff_cols = { - f"is_diff_col_{c1}": bool_to_int(a[c1].is_distinct_from(b[c2])) for c1, c2 in safezip(cols1, cols2) + f"is_diff_{c1}": bool_to_int(a[c1].is_distinct_from(b[c2])) for c1, c2 in safezip(cols1, cols2) } a_cols = {f"table1_{c}": NormalizeAsString(a[c]) for c in cols1} @@ -180,23 +183,30 @@ def _outer_join(self, table1, table2): .where(or_(this[c] == 1 for c in is_diff_cols)) ) - with self._run_in_background(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols): - with self._run_in_background(self._count_diff_per_column, db, diff_rows, is_diff_cols): - - logger.info("Querying for different rows") - for is_xa, is_xb, *x in db.query(diff_rows, list): - assert not (is_xa and is_xb) # Can't both be exclusive - is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols)) - if not is_xb: - yield "-", tuple(a_row) - if not is_xa: - yield "+", tuple(b_row) - - def _count_diff_per_column(self, db, diff_rows, is_diff_cols): + with self._run_in_background( + partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols), + partial(self._count_diff_per_column, db, diff_rows, cols1, is_diff_cols) + ): + + logger.info("Querying for different rows") + for is_xa, is_xb, *x in db.query(diff_rows, list): + if is_xa and is_xb: + # Can't both be exclusive, meaning a pk is NULL + # This can happen if the explicit null test didn't finish running yet + raise ValueError(f"NULL values in one or more primary keys") + is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols)) + if not is_xb: + yield "-", tuple(a_row) + if not is_xa: + yield "+", tuple(b_row) + + def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): logger.info("Counting differences per column") is_diff_cols_counts = db.query(diff_rows.select(sum_(this[c]) for c in is_diff_cols), tuple) - for name, count in safezip(is_diff_cols, is_diff_cols_counts): - self.stats[f"count_{name}"] = count + diff_counts = {} + for name, count in safezip(cols, is_diff_cols_counts): + diff_counts[name] = count + self.stats['diff_counts'] = diff_counts def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): logger.info("Counting and sampling exclusive rows") diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 72d604cd..d37cea58 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -166,3 +166,29 @@ def test_diff_sorted_by_key(self): ("+", ("4", time + ".000000")), ] self.assertEqual(expected, diff) + + def test_dup_pks(self): + time = "2022-01-01 00:00:00" + time_str = f"timestamp '{time}'" + + cols = "id rating timestamp".split() + + _insert_row(self.connection, self.table_src, cols, [1, 9, time_str]) + _insert_row(self.connection, self.table_src, cols, [1, 10, time_str]) + _insert_row(self.connection, self.table_dst, cols, [1, 9, time_str]) + + x = self.differ.diff_tables(self.table, self.table2) + self.assertRaises(ValueError, list, x) + + + def test_null_pks(self): + time = "2022-01-01 00:00:00" + time_str = f"timestamp '{time}'" + + cols = "id rating timestamp".split() + + _insert_row(self.connection, self.table_src, cols, ['null', 9, time_str]) + _insert_row(self.connection, self.table_dst, cols, [1, 9, time_str]) + + x = self.differ.diff_tables(self.table, self.table2) + self.assertRaises(ValueError, list, x) From 686b1f730e332cb97b054ab1dd22779614185ed5 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 21 Sep 2022 14:48:21 +0300 Subject: [PATCH 07/33] refactor tablediffer --- data_diff/diff_tables.py | 21 ++++++++++++++++++++- data_diff/hashdiff_tables.py | 16 ++-------------- data_diff/joindiff_tables.py | 10 ++++++++-- docs/python-api.rst | 7 +++++-- tests/test_query.py | 2 +- 5 files changed, 36 insertions(+), 20 deletions(-) diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 430f027e..3a92d708 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -1,13 +1,15 @@ """Provides classes for performing a table diff """ +from abc import ABC, abstractmethod from enum import Enum from contextlib import contextmanager -import threading from operator import methodcaller from typing import Tuple, Iterator, Optional from concurrent.futures import ThreadPoolExecutor, as_completed +from .table_segment import TableSegment + from runtype import dataclass @@ -59,3 +61,20 @@ def _run_in_background(self, *funcs): yield futures for f in futures: f.result() + + +class TableDiffer(ThreadBase, ABC): + @abstractmethod + def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: + """Diff the given tables. + + Parameters: + table1 (TableSegment): The "before" table to compare. Or: source table + table2 (TableSegment): The "after" table to compare. Or: target table + + Returns: + An iterator that yield pair-tuples, representing the diff. Items can be either - + ('-', row) for items in table1 but not in table2. + ('+', row) for items in table2 but not in table1. + Where `row` is a tuple of values, corresponding to the diffed columns. + """ diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index f4867a74..0f2e8cb7 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -14,7 +14,7 @@ from .table_segment import TableSegment from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled -from .diff_tables import ThreadBase, DiffResult +from .diff_tables import TableDiffer, DiffResult BENCHMARK = os.environ.get("BENCHMARK", False) @@ -40,7 +40,7 @@ def diff_sets(a: set, b: set) -> Iterator: @dataclass -class HashDiffer(ThreadBase): +class HashDiffer(TableDiffer): """Finds the diff between two SQL tables The algorithm uses hashing to quickly check if the tables are different, and then applies a @@ -62,18 +62,6 @@ class HashDiffer(ThreadBase): stats: dict = {} def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: - """Diff the given tables. - - Parameters: - table1 (TableSegment): The "before" table to compare. Or: source table - table2 (TableSegment): The "after" table to compare. Or: target table - - Returns: - An iterator that yield pair-tuples, representing the diff. Items can be either - ('-', columns) for items in table1 but not in table2 - ('+', columns) for items in table2 but not in table1 - Where `columns` is a tuple of values for the involved columns, i.e. (id, ...extra) - """ # Validate options if self.bisection_factor >= self.bisection_threshold: raise ValueError("Incorrect param values (bisection factor must be lower than threshold)") diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 52d055d3..0099de6e 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -13,7 +13,7 @@ from .utils import safezip from .databases.base import Database from .table_segment import TableSegment -from .diff_tables import ThreadBase, DiffResult +from .diff_tables import TableDiffer, DiffResult from .queries import table, sum_, min_, max_, avg from .queries.api import and_, if_, or_, outerjoin, this @@ -73,7 +73,7 @@ def json_friendly_value(v): @dataclass -class JoinDifferBase(ThreadBase): +class JoinDifferBase(TableDiffer): """Finds the diff between two SQL tables using JOINs""" stats: dict = {} @@ -145,6 +145,12 @@ def bool_to_int(x): class JoinDiffer(JoinDifferBase): + """Finds the diff between two SQL tables in the same database. + + The algorithm uses an OUTER JOIN (or equivalent) with extra checks and statistics. + + """ + def _outer_join(self, table1, table2): db = table1.database if db is not table2.database: diff --git a/docs/python-api.rst b/docs/python-api.rst index d2b18636..f28b18d1 100644 --- a/docs/python-api.rst +++ b/docs/python-api.rst @@ -5,11 +5,14 @@ Python API Reference .. autofunction:: connect -.. autoclass:: TableDiffer +.. autoclass:: HashDiffer + :members: __init__, diff_tables + +.. autoclass:: JoinDiffer :members: __init__, diff_tables .. autoclass:: TableSegment - :members: __init__, get_values, choose_checkpoints, segment_by_checkpoints, count, count_and_checksum, is_bounded, new + :members: __init__, get_values, choose_checkpoints, segment_by_checkpoints, count, count_and_checksum, is_bounded, new, with_schema .. autoclass:: data_diff.databases.database_types.AbstractDatabase :members: diff --git a/tests/test_query.py b/tests/test_query.py index b6b90394..f31f5417 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -12,7 +12,7 @@ def normalize_spaces(s: str): class MockDialect(AbstractDialect): - def quote(self, s: str): + def quote(self, s: str) -> str: return s def concat(self, l: List[str]) -> str: From b830afc8fceba1a0d47dcd237cd2ca1121bf8f3e Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 21 Sep 2022 10:24:59 +0300 Subject: [PATCH 08/33] joindiff now working for all major databases: new: - mysql - bigquery - presto - verica - trino - oracle - redshift --- data_diff/databases/bigquery.py | 3 ++ data_diff/databases/mysql.py | 3 ++ data_diff/databases/oracle.py | 6 ++++ data_diff/databases/presto.py | 32 ++++++++++++++---- data_diff/databases/redshift.py | 3 ++ data_diff/databases/vertica.py | 3 ++ data_diff/joindiff_tables.py | 58 ++++++++++++++++++++++++-------- data_diff/queries/api.py | 10 +++++- data_diff/queries/ast_classes.py | 21 +++++++++++- data_diff/queries/compiler.py | 5 +++ tests/test_diff_tables.py | 2 ++ tests/test_joindiff.py | 3 +- tests/test_query.py | 8 +++++ 13 files changed, 132 insertions(+), 25 deletions(-) diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 411ae795..218c9cb4 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -95,3 +95,6 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: def parse_table_name(self, name: str) -> DbPath: path = parse_table_name(name) return self._normalize_table_path(path) + + def random(self) -> str: + return "RAND()" diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 7e89b184..07c34aaf 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -73,3 +73,6 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: def is_distinct_from(self, a: str, b: str) -> str: return f"not ({a} <=> {b})" + + def random(self) -> str: + return "RAND()" diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 76387010..79f7bf31 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -124,3 +124,9 @@ def timestamp_value(self, t: DbTime) -> str: def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Cast is necessary for correct MD5 (trimming not enough) return f"CAST(TRIM({value}) AS VARCHAR(36))" + + def random(self) -> str: + return "dbms_random.value" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"DECODE({a}, {b}, 1, 0) = 0" diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index 5ee98770..c990e06e 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -1,6 +1,7 @@ import re -from ..utils import match_regexps +from data_diff.utils import match_regexps +from data_diff.queries import ThreadLocalInterpreter from .database_types import * from .base import Database, import_helper @@ -10,6 +11,14 @@ TIMESTAMP_PRECISION_POS, ) +def query_cursor(c, sql_code): + c.execute(sql_code) + if sql_code.lower().startswith("select"): + return c.fetchall() + # Required for the query to actually run 🤯 + if re.match(r"(insert|create|truncate|drop)", sql_code, re.IGNORECASE): + return c.fetchone() + @import_helper("presto") def import_presto(): @@ -63,12 +72,21 @@ def to_string(self, s: str): def _query(self, sql_code: str) -> list: "Uses the standard SQL cursor interface" c = self._conn.cursor() - c.execute(sql_code) - if sql_code.lower().startswith("select"): - return c.fetchall() - # Required for the query to actually run 🤯 - if re.match(r"(insert|create|truncate|drop)", sql_code, re.IGNORECASE): - return c.fetchone() + + if isinstance(sql_code, ThreadLocalInterpreter): + # TODO reuse code from base.py + g = sql_code.interpret() + q = next(g) + while True: + res = query_cursor(c, q) + try: + q = g.send(res) + except StopIteration: + break + return + + return query_cursor(c, sql_code) + def close(self): self._conn.close() diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index a512c123..f11b950c 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -46,3 +46,6 @@ def select_table_schema(self, path: DbPath) -> str: "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM information_schema.columns " f"WHERE table_name = '{table.lower()}' AND table_schema = '{schema.lower()}'" ) + + def is_distinct_from(self, a: str, b: str) -> str: + return f"{a} IS NULL AND NOT {b} IS NULL OR {b} IS NULL OR {a}!={b}" diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index 78a52363..cc606511 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -123,3 +123,6 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Trim doesn't work on CHAR type return f"TRIM(CAST({value} AS VARCHAR))" + + def is_distinct_from(self, a: str, b: str) -> str: + return f"not ({a} <=> {b})" diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 0099de6e..53acd954 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -10,13 +10,15 @@ from runtype import dataclass + from .utils import safezip from .databases.base import Database +from .databases import MySQL, BigQuery, Presto, Oracle from .table_segment import TableSegment from .diff_tables import TableDiffer, DiffResult from .queries import table, sum_, min_, max_, avg -from .queries.api import and_, if_, or_, outerjoin, this +from .queries.api import and_, if_, or_, outerjoin, leftjoin, rightjoin, this, ITable from .queries.ast_classes import Concat, Count, Expr, Random from .queries.compiler import Compiler from .queries.extras import NormalizeAsString @@ -43,18 +45,29 @@ class Stats: def sample(table): - # TODO return table.order_by(Random()).limit(10) @contextmanager def temp_table(db: Database, expr: Expr): c = Compiler(db) - name = c.new_unique_name("tmp_table") - db.query(f"create temporary table {c.quote(name)} as {c.compile(expr)}", None) + + name = c.new_unique_table_name("temp_table") + + if isinstance(db, BigQuery): + name = f"{db.default_schema}.{name}" + db.query(f"create table {c.quote(name)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}", None) + elif isinstance(db, Presto): + db.query(f"create table {c.quote(name)} as {c.compile(expr)}", None) + elif isinstance(db, Oracle): + db.query(f"create global temporary table {c.quote(name)} as {c.compile(expr)}", None) + else: + db.query(f"create temporary table {c.quote(name)} as {c.compile(expr)}", None) + try: yield table(name, schema=expr.source_table.schema) finally: + # Only drops if create table succeeded (meaning, the table didn't already exist) db.query(f"drop table {c.quote(name)}", None) @@ -144,6 +157,28 @@ def bool_to_int(x): return if_(x, 1, 0) +def _outerjoin(db: Database, a: ITable, b: ITable, keys1: List[str], keys2: List[str], select_fields: dict) -> ITable: + on = [a[k1] == b[k2] for k1, k2 in safezip(keys1, keys2)] + + if isinstance(db, Oracle): + is_exclusive_a = and_(bool_to_int(b[k] == None) for k in keys2) + is_exclusive_b = and_(bool_to_int(a[k] == None) for k in keys1) + else: + is_exclusive_a = and_(b[k] == None for k in keys2) + is_exclusive_b = and_(a[k] == None for k in keys1) + + if isinstance(db, MySQL): + # No outer join + l = leftjoin(a, b).on(*on).select(is_exclusive_a=is_exclusive_a, is_exclusive_b=False, **select_fields) + r = rightjoin(a, b).on(*on).select(is_exclusive_a=False, is_exclusive_b=is_exclusive_b, **select_fields) + return l.union(r) + + return ( + outerjoin(a, b).on(*on) + .select(is_exclusive_a=is_exclusive_a, is_exclusive_b=is_exclusive_b, **select_fields) + ) + + class JoinDiffer(JoinDifferBase): """Finds the diff between two SQL tables in the same database. @@ -177,15 +212,7 @@ def _outer_join(self, table1, table2): b_cols = {f"table2_{c}": NormalizeAsString(b[c]) for c in cols2} diff_rows = ( - outerjoin(a, b) - .on(a[k1] == b[k2] for k1, k2 in safezip(keys1, keys2)) - .select( - is_exclusive_a=and_(b[k] == None for k in keys2), - is_exclusive_b=and_(a[k] == None for k in keys1), - **is_diff_cols, - **a_cols, - **b_cols, - ) + _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols}) .where(or_(this[c] == 1 for c in is_diff_cols)) ) @@ -216,7 +243,10 @@ def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): logger.info("Counting and sampling exclusive rows") - exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b) + if isinstance(db, Oracle): + exclusive_rows_query = diff_rows.where((this.is_exclusive_a==1) | (this.is_exclusive_b==1)) + else: + exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b) with temp_table(db, exclusive_rows_query) as exclusive_rows: self.stats["exclusive_count"] = db.query(exclusive_rows.count(), int) sample_rows = db.query(sample(exclusive_rows.select(*this[list(a_cols)], *this[list(b_cols)])), list) diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index 76aaf5d2..136807eb 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -11,8 +11,16 @@ def join(*tables: ITable): return Join(tables) +def leftjoin(*tables: ITable): + "Left-joins each table into a 'struct'" + return Join(tables, "LEFT") + +def rightjoin(*tables: ITable): + "Right-joins each table into a 'struct'" + return Join(tables, "RIGHT") + def outerjoin(*tables: ITable): - "Outerjoins each table into a 'struct'" + "Outer-joins each table into a 'struct'" return Join(tables, "FULL OUTER") diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 019227f7..a3383ad2 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Generator, Sequence, Tuple, Union +from typing import Any, Generator, ItemsView, Sequence, Tuple, Union from runtype import dataclass @@ -119,6 +119,9 @@ def __getitem__(self, column): def count(self): return Select(self, [Count()]) + def union(self, other: 'ITable'): + return Union(self, other) + @dataclass class Concat(ExprNode): @@ -348,6 +351,22 @@ class GroupBy(ITable): def having(self): pass +@dataclass +class Union(ExprNode, ITable): + table1: ITable + table2: ITable + + @property + def source_table(self): + return self # TODO is this right? + + def compile(self, parent_c: Compiler) -> str: + c = parent_c.replace(in_select=False) + union_all = f"{c.compile(self.table1)} UNION {c.compile(self.table2)}" + if parent_c.in_select: + union_all = f"({union_all}) {c.new_unique_name()}" + return union_all + @dataclass class Select(ExprNode, ITable): diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index 8ea0e7a5..64e24650 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -1,3 +1,4 @@ +import random from abc import ABC, abstractmethod from datetime import datetime from typing import Any, Dict, Sequence, List @@ -51,6 +52,10 @@ def new_unique_name(self, prefix="tmp"): self._counter[0] += 1 return f"{prefix}{self._counter[0]}" + def new_unique_table_name(self, prefix="tmp"): + self._counter[0] += 1 + return f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}" + def add_table_context(self, *tables: Sequence): return self.replace(_table_context=self._table_context + list(tables)) diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 63195efb..7668ec61 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -81,6 +81,8 @@ def _get_text_type(conn): def _get_float_type(conn): if isinstance(conn, db.BigQuery): return "FLOAT64" + elif isinstance(conn, db.Presto): + return "REAL" return "float" diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index d37cea58..e8db3167 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -3,7 +3,6 @@ from data_diff.databases.connect import connect from data_diff.table_segment import TableSegment, split_space from data_diff import databases as db -from data_diff.utils import ArithAlphanumeric from data_diff.joindiff_tables import JoinDiffer from .test_diff_tables import TestPerDatabase, _get_float_type, _get_text_type, _commit, _insert_row, _insert_rows @@ -26,7 +25,7 @@ def init_instances(): DATABASE_INSTANCES = {k.__name__: connect(v, N_THREADS) for k, v in CONN_STRINGS.items()} -TEST_DATABASES = {x.__name__ for x in (db.PostgreSQL,)} +TEST_DATABASES = {x.__name__ for x in (db.PostgreSQL, db.Snowflake, db.MySQL, db.BigQuery, db.Presto, db.Vertica, db.Trino, db.Oracle, db.Redshift)} _class_per_db_dec = parameterized_class( ("name", "db_name"), [(name, name) for name in DATABASE_URIS if name in TEST_DATABASES] diff --git a/tests/test_query.py b/tests/test_query.py index f31f5417..4ae4b82e 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -128,3 +128,11 @@ def test_funcs(self): q = c.compile(t.order_by(Random()).limit(10)) assert q == "SELECT * FROM a ORDER BY random() limit 10" + + def test_union_all(self): + c = Compiler(MockDialect()) + a = table("a").select('x') + b = table("b").select('y') + + q = c.compile(a.union(b)) + assert q == "SELECT x FROM a UNION SELECT y FROM b" From 4f441f02caf4040149bff60d16cccedda68492b4 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 22 Sep 2022 17:11:30 +0300 Subject: [PATCH 09/33] Fix in queries --- data_diff/queries/ast_classes.py | 29 ++++++++++++++++------------- data_diff/queries/compiler.py | 3 ++- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index a3383ad2..56213dd3 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -1,5 +1,6 @@ +from dataclasses import field from datetime import datetime -from typing import Any, Generator, ItemsView, Sequence, Tuple, Union +from typing import Any, Generator, ItemsView, Optional, Sequence, Tuple, Union from runtype import dataclass @@ -246,7 +247,7 @@ def compile(self, c: Compiler) -> str: t for t in c._table_context if isinstance(t, TableAlias) and t.source_table is self.source_table ] if not aliases: - raise CompileError(f"No aliased table found for column {self.name}") # TODO better error + return c.quote(self.name) elif len(aliases) > 1: raise CompileError(f"Too many aliases for column {self.name}") (alias,) = aliases @@ -259,7 +260,7 @@ def compile(self, c: Compiler) -> str: @dataclass class TablePath(ExprNode, ITable): path: DbPath - schema: Schema = None + schema: Optional[Schema] = field(default=None, repr=False) def insert_values(self, rows): pass @@ -329,7 +330,7 @@ def compile(self, parent_c: Compiler) -> str: tables = [ t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in self.source_tables ] - c = parent_c.add_table_context(*tables) + c = parent_c.add_table_context(*tables).replace(in_join=True, in_select=False) op = " JOIN " if self.op is None else f" {self.op} JOIN " joined = op.join(c.compile(t) for t in tables) @@ -344,6 +345,8 @@ def compile(self, parent_c: Compiler) -> str: if parent_c.in_select: select = f"({select}) {c.new_unique_name()}" + elif parent_c.in_join: + select = f"({select})" return select @@ -365,34 +368,32 @@ def compile(self, parent_c: Compiler) -> str: union_all = f"{c.compile(self.table1)} UNION {c.compile(self.table2)}" if parent_c.in_select: union_all = f"({union_all}) {c.new_unique_name()}" + elif parent_c.in_join: + union_all = f"({union_all})" return union_all @dataclass class Select(ExprNode, ITable): - table: Expr = None + source_table: Expr = None columns: Sequence[Expr] = None where_exprs: Sequence[Expr] = None order_by_exprs: Sequence[Expr] = None group_by_exprs: Sequence[Expr] = None limit_expr: int = None - @property - def source_table(self): - return self - @property def schema(self): - return self.table.schema + return self.source_table.schema def compile(self, parent_c: Compiler) -> str: - c = parent_c.replace(in_select=True).add_table_context(self.table) + c = parent_c.replace(in_select=True) #.add_table_context(self.table) columns = ", ".join(map(c.compile, self.columns)) if self.columns else "*" select = f"SELECT {columns}" - if self.table: - select += " FROM " + c.compile(self.table) + if self.source_table: + select += " FROM " + c.compile(self.source_table) if self.where_exprs: select += " WHERE " + " AND ".join(map(c.compile, self.where_exprs)) @@ -407,6 +408,8 @@ def compile(self, parent_c: Compiler) -> str: select += " " + c.database.offset_limit(0, self.limit_expr) if parent_c.in_select: + select = f"({select}) {c.new_unique_name()}" + elif parent_c.in_join: select = f"({select})" return select diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index 64e24650..5133301c 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -12,7 +12,8 @@ @dataclass class Compiler: database: AbstractDialect - in_select: bool = False # Compilation + in_select: bool = False # Compilation runtime flag + in_join: bool = False # Compilation runtime flag _table_context: List = [] # List[ITable] _subqueries: Dict[str, Any] = {} # XXX not thread-safe From de26e56dd86579cb903c492c3b9250efea2236fd Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 22 Sep 2022 17:12:59 +0300 Subject: [PATCH 10/33] Joindiff now support tracking and bisection --- data_diff/diff_tables.py | 130 ++++++++++++++++++++++++++- data_diff/hashdiff_tables.py | 164 ++++++++--------------------------- data_diff/joindiff_tables.py | 70 +++++++++++---- data_diff/table_segment.py | 4 + data_diff/utils.py | 8 ++ tests/common.py | 1 + 6 files changed, 228 insertions(+), 149 deletions(-) diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 3a92d708..5cd21302 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -1,6 +1,7 @@ """Provides classes for performing a table diff """ +import time from abc import ABC, abstractmethod from enum import Enum from contextlib import contextmanager @@ -8,10 +9,15 @@ from typing import Tuple, Iterator, Optional from concurrent.futures import ThreadPoolExecutor, as_completed -from .table_segment import TableSegment - from runtype import dataclass +from .utils import run_as_daemon, safezip, getLogger +from .thread_utils import ThreadedYielder +from .table_segment import TableSegment +from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled +from .databases.database_types import IKey + +logger = getLogger(__name__) class Algorithm(Enum): AUTO = "auto" @@ -64,7 +70,8 @@ def _run_in_background(self, *funcs): class TableDiffer(ThreadBase, ABC): - @abstractmethod + bisection_factor = 32 + def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: """Diff the given tables. @@ -78,3 +85,120 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: ('+', row) for items in table2 but not in table1. Where `row` is a tuple of values, corresponding to the diffed columns. """ + + if is_tracking_enabled(): + options = dict(self) + event_json = create_start_event_json(options) + run_as_daemon(send_event_json, event_json) + + self.stats["diff_count"] = 0 + start = time.monotonic() + error = None + try: + + # Query and validate schema + table1, table2 = self._threaded_call("with_schema", [table1, table2]) + self._validate_and_adjust_columns(table1, table2) + + yield from self._diff_tables(table1, table2) + + except BaseException as e: # Catch KeyboardInterrupt too + error = e + finally: + if is_tracking_enabled(): + runtime = time.monotonic() - start + table1_count = self.stats.get("table1_count") + table2_count = self.stats.get("table2_count") + diff_count = self.stats.get("diff_count") + err_message = str(error)[:20] # Truncate possibly sensitive information. + event_json = create_end_event_json( + error is None, + runtime, + table1.database.name, + table2.database.name, + table1_count, + table2_count, + diff_count, + err_message, + ) + send_event_json(event_json) + + if error: + raise error + + def _validate_and_adjust_columns(self, table1: TableSegment, table2: TableSegment) -> DiffResult: + pass + + def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: + return self._bisect_and_diff_tables(table1, table2) + + + @abstractmethod + def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): + ... + + + def _bisect_and_diff_tables(self, table1, table2): + key_type = table1._schema[table1.key_column] + key_type2 = table2._schema[table2.key_column] + if not isinstance(key_type, IKey): + raise NotImplementedError(f"Cannot use column of type {key_type} as a key") + if not isinstance(key_type2, IKey): + raise NotImplementedError(f"Cannot use column of type {key_type2} as a key") + assert key_type.python_type is key_type2.python_type + + # Query min/max values + key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2]) + + # Start with the first completed value, so we don't waste time waiting + min_key1, max_key1 = self._parse_key_range_result(key_type, next(key_ranges)) + + table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)] + + logger.info( + # f"Diffing tables | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}. " + f"Diffing segments at key-range: {table1.min_key}..{table2.max_key}. " + f"size: table1 <= {table1.approximate_size()}, table2 <= {table2.approximate_size()}" + ) + + ti = ThreadedYielder(self.max_threadpool_size) + # Bisect (split) the table into segments, and diff them recursively. + ti.submit(self._bisect_and_diff_segments, ti, table1, table2) + + # Now we check for the second min-max, to diff the portions we "missed". + min_key2, max_key2 = self._parse_key_range_result(key_type, next(key_ranges)) + + if min_key2 < min_key1: + pre_tables = [t.new(min_key=min_key2, max_key=min_key1) for t in (table1, table2)] + ti.submit(self._bisect_and_diff_segments, ti, *pre_tables) + + if max_key2 > max_key1: + post_tables = [t.new(min_key=max_key1, max_key=max_key2) for t in (table1, table2)] + ti.submit(self._bisect_and_diff_segments, ti, *post_tables) + + return ti + + + def _parse_key_range_result(self, key_type, key_range): + mn, mx = key_range + cls = key_type.make_value + # We add 1 because our ranges are exclusive of the end (like in Python) + try: + return cls(mn), cls(mx) + 1 + except (TypeError, ValueError) as e: + raise type(e)(f"Cannot apply {key_type} to {mn}, {mx}.") from e + + + def _bisect_and_diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None): + assert table1.is_bounded and table2.is_bounded + + # Choose evenly spaced checkpoints (according to min_key and max_key) + checkpoints = table1.choose_checkpoints(self.bisection_factor - 1) + + # Create new instances of TableSegment between each checkpoint + segmented1 = table1.segment_by_checkpoints(checkpoints) + segmented2 = table2.segment_by_checkpoints(checkpoints) + + # Recursively compare each pair of corresponding segments between table1 and table2 + for i, (t1, t2) in enumerate(safezip(segmented1, segmented2)): + ti.submit(self._diff_segments, ti, t1, t2, max_rows, level + 1, i + 1, len(segmented1), priority=level) diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index 0f2e8cb7..64b05b67 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -1,5 +1,4 @@ import os -import time from numbers import Number import logging from collections import defaultdict @@ -8,13 +7,12 @@ from runtype import dataclass -from .utils import safezip, run_as_daemon +from .utils import safezip from .thread_utils import ThreadedYielder from .databases.database_types import IKey, NumericType, PrecisionType, StringType from .table_segment import TableSegment -from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled -from .diff_tables import TableDiffer, DiffResult +from .diff_tables import TableDiffer BENCHMARK = os.environ.get("BENCHMARK", False) @@ -61,98 +59,14 @@ class HashDiffer(TableDiffer): stats: dict = {} - def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: + def __post_init__(self): # Validate options if self.bisection_factor >= self.bisection_threshold: raise ValueError("Incorrect param values (bisection factor must be lower than threshold)") if self.bisection_factor < 2: raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)") - if is_tracking_enabled(): - options = dict(self) - event_json = create_start_event_json(options) - run_as_daemon(send_event_json, event_json) - - self.stats["diff_count"] = 0 - start = time.monotonic() - error = None - try: - - # Query and validate schema - table1, table2 = self._threaded_call("with_schema", [table1, table2]) - self._validate_and_adjust_columns(table1, table2) - - key_type = table1._schema[table1.key_column] - key_type2 = table2._schema[table2.key_column] - if not isinstance(key_type, IKey): - raise NotImplementedError(f"Cannot use column of type {key_type} as a key") - if not isinstance(key_type2, IKey): - raise NotImplementedError(f"Cannot use column of type {key_type2} as a key") - assert key_type.python_type is key_type2.python_type - - # Query min/max values - key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2]) - - # Start with the first completed value, so we don't waste time waiting - min_key1, max_key1 = self._parse_key_range_result(key_type, next(key_ranges)) - - table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)] - - logger.info( - f"Diffing tables | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}. " - f"key-range: {table1.min_key}..{table2.max_key}, " - f"size: table1 <= {table1.approximate_size()}, table2 <= {table2.approximate_size()}" - ) - - ti = ThreadedYielder(self.max_threadpool_size) - # Bisect (split) the table into segments, and diff them recursively. - ti.submit(self._bisect_and_diff_tables, ti, table1, table2) - - # Now we check for the second min-max, to diff the portions we "missed". - min_key2, max_key2 = self._parse_key_range_result(key_type, next(key_ranges)) - - if min_key2 < min_key1: - pre_tables = [t.new(min_key=min_key2, max_key=min_key1) for t in (table1, table2)] - ti.submit(self._bisect_and_diff_tables, ti, *pre_tables) - - if max_key2 > max_key1: - post_tables = [t.new(min_key=max_key1, max_key=max_key2) for t in (table1, table2)] - ti.submit(self._bisect_and_diff_tables, ti, *post_tables) - - yield from ti - - except BaseException as e: # Catch KeyboardInterrupt too - error = e - finally: - if is_tracking_enabled(): - runtime = time.monotonic() - start - table1_count = self.stats.get("table1_count") - table2_count = self.stats.get("table2_count") - diff_count = self.stats.get("diff_count") - err_message = str(error)[:20] # Truncate possibly sensitive information. - event_json = create_end_event_json( - error is None, - runtime, - table1.database.name, - table2.database.name, - table1_count, - table2_count, - diff_count, - err_message, - ) - send_event_json(event_json) - - if error: - raise error - - def _parse_key_range_result(self, key_type, key_range): - mn, mx = key_range - cls = key_type.make_value - # We add 1 because our ranges are exclusive of the end (like in Python) - try: - return cls(mn), cls(mx) + 1 - except (TypeError, ValueError) as e: - raise type(e)(f"Cannot apply {key_type} to {mn}, {mx}.") from e + def _validate_and_adjust_columns(self, table1, table2): for c1, c2 in safezip(table1._relevant_columns, table2._relevant_columns): @@ -201,44 +115,8 @@ def _validate_and_adjust_columns(self, table1, table2): "If encoding/formatting differs between databases, it may result in false positives." ) - def _bisect_and_diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None): - assert table1.is_bounded and table2.is_bounded - - if max_rows is None: - # We can be sure that row_count <= max_rows - max_rows = max(table1.approximate_size(), table2.approximate_size()) - - # If count is below the threshold, just download and compare the columns locally - # This saves time, as bisection speed is limited by ping and query performance. - if max_rows < self.bisection_threshold: - rows1, rows2 = self._threaded_call("get_values", [table1, table2]) - diff = list(diff_sets(rows1, rows2)) - - # Initial bisection_threshold larger than count. Normally we always - # checksum and count segments, even if we get the values. At the - # first level, however, that won't be true. - if level == 0: - self.stats["table1_count"] = len(rows1) - self.stats["table2_count"] = len(rows2) - - self.stats["diff_count"] += len(diff) - - logger.info(". " * level + f"Diff found {len(diff)} different rows.") - self.stats["rows_downloaded"] = self.stats.get("rows_downloaded", 0) + max(len(rows1), len(rows2)) - return diff - - # Choose evenly spaced checkpoints (according to min_key and max_key) - checkpoints = table1.choose_checkpoints(self.bisection_factor - 1) - - # Create new instances of TableSegment between each checkpoint - segmented1 = table1.segment_by_checkpoints(checkpoints) - segmented2 = table2.segment_by_checkpoints(checkpoints) - # Recursively compare each pair of corresponding segments between table1 and table2 - for i, (t1, t2) in enumerate(safezip(segmented1, segmented2)): - ti.submit(self._diff_tables, ti, t1, t2, max_rows, level + 1, i + 1, len(segmented1), priority=level) - - def _diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): + def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): logger.info( ". " * level + f"Diffing segment {segment_index}/{segment_count}, " f"key-range: {table1.min_key}..{table2.max_key}, " @@ -251,7 +129,7 @@ def _diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableS # the threshold) and _then_ download it. if BENCHMARK: if max_rows < self.bisection_threshold: - return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max_rows) + return self._bisect_and_diff_segments(ti, table1, table2, level=level, max_rows=max_rows) (count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2]) @@ -268,4 +146,32 @@ def _diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableS self.stats["table2_count"] = self.stats.get("table2_count", 0) + count2 if checksum1 != checksum2: - return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max(count1, count2)) + return self._bisect_and_diff_segments(ti, table1, table2, level=level, max_rows=max(count1, count2)) + + def _bisect_and_diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None): + assert table1.is_bounded and table2.is_bounded + + if max_rows is None: + # We can be sure that row_count <= max_rows + max_rows = max(table1.approximate_size(), table2.approximate_size()) + + # If count is below the threshold, just download and compare the columns locally + # This saves time, as bisection speed is limited by ping and query performance. + if max_rows < self.bisection_threshold: + rows1, rows2 = self._threaded_call("get_values", [table1, table2]) + diff = list(diff_sets(rows1, rows2)) + + # Initial bisection_threshold larger than count. Normally we always + # checksum and count segments, even if we get the values. At the + # first level, however, that won't be true. + if level == 0: + self.stats["table1_count"] = len(rows1) + self.stats["table2_count"] = len(rows2) + + self.stats["diff_count"] += len(diff) + + logger.info(". " * level + f"Diff found {len(diff)} different rows.") + self.stats["rows_downloaded"] = self.stats.get("rows_downloaded", 0) + max(len(rows1), len(rows2)) + return diff + + return super()._bisect_and_diff_segments(ti, table1, table2, level, max_rows) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 53acd954..988cbaec 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -2,6 +2,7 @@ """ +from collections import defaultdict from decimal import Decimal from functools import partial import logging @@ -13,9 +14,10 @@ from .utils import safezip from .databases.base import Database -from .databases import MySQL, BigQuery, Presto, Oracle +from .databases import MySQL, BigQuery, Presto, Oracle, PostgreSQL, Snowflake from .table_segment import TableSegment from .diff_tables import TableDiffer, DiffResult +from .thread_utils import ThreadedYielder from .queries import table, sum_, min_, max_, avg from .queries.api import and_, if_, or_, outerjoin, leftjoin, rightjoin, this, ITable @@ -67,8 +69,10 @@ def temp_table(db: Database, expr: Expr): try: yield table(name, schema=expr.source_table.schema) finally: - # Only drops if create table succeeded (meaning, the table didn't already exist) - db.query(f"drop table {c.quote(name)}", None) + if isinstance(db, (BigQuery, Presto)): + # Only drops if create table succeeded (meaning, the table didn't already exist) + # And if the table won't delete itself + db.query(f"drop table {c.quote(name)}", None) def _slice_tuple(t, *sizes): @@ -90,28 +94,50 @@ class JoinDifferBase(TableDiffer): """Finds the diff between two SQL tables using JOINs""" stats: dict = {} + validate_unique_key: bool = True - def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: - table1, table2 = self._threaded_call("with_schema", [table1, table2]) + def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: + db = table1.database if table1.database is not table2.database: raise ValueError("Join-diff only works when both tables are in the same database") + table1, table2 = self._threaded_call("with_schema", [table1, table2]) + + + bg_funcs = [partial(self._test_duplicate_keys, table1, table2)] if self.validate_unique_key else [] + + with self._run_in_background(*bg_funcs): + if isinstance(db, (Snowflake, BigQuery)): + # Don't segment the table; let the database handling parallelization + yield from self._diff_segments(None, table1, table2, None) + else: + yield from self._bisect_and_diff_tables(table1, table2) + + def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): + assert table1.database is table2.database + + logger.info( + ". " * level + f"Diffing segment {segment_index}/{segment_count}, " + f"key-range: {table1.min_key}..{table2.max_key}, " + f"size <= {max_rows}" + ) + with self._run_in_background( - partial(self._test_null_or_duplicate_keys, table1, table2), - partial(self._collect_stats, 1, table1), - partial(self._collect_stats, 2, table2) - ): + partial(self._collect_stats, 1, table1), + partial(self._collect_stats, 2, table2), + partial(self._test_null_keys, table1, table2), + ): yield from self._outer_join(table1, table2) logger.info("Diffing complete") - def _test_null_or_duplicate_keys(self, table1, table2): - logger.info("Testing for null or duplicate keys") + def _test_duplicate_keys(self, table1, table2): + logger.debug("Testing for duplicate keys") - # Test null or duplicate keys + # Test duplicate keys for ts in [table1, table2]: - t = table(*ts.table_path, schema=ts._schema) + t = ts._make_select() key_columns = [ts.key_column] # XXX q = t.select(total=Count(), total_distinct=Count(Concat(key_columns), distinct=True)) @@ -119,12 +145,19 @@ def _test_null_or_duplicate_keys(self, table1, table2): if total != total_distinct: raise ValueError("Duplicate primary keys") + def _test_null_keys(self, table1, table2): + logger.debug("Testing for null keys") + + # Test null keys + for ts in [table1, table2]: + t = ts._make_select() + key_columns = [ts.key_column] # XXX + q = t.select(*key_columns).where(or_(this[k] == None for k in key_columns)) nulls = ts.database.query(q, list) if nulls: raise ValueError(f"NULL values in one or more primary keys") - logger.debug("Done testing for null or duplicate keys") def _collect_stats(self, i, table): logger.info(f"Collecting stats for table #{i}") @@ -145,7 +178,9 @@ def _collect_stats(self, i, table): res = db.query(table._make_select().select(**col_exprs), tuple) res = dict(zip([f"table{i}_{n}" for n in col_exprs], map(json_friendly_value, res))) - self.stats.update(res) + for k, v in res.items(): + self.stats[k] = self.stats.get(k, 0) + (v or 0) + # self.stats.update(res) logger.debug(f"Done collecting stats for table #{i}") @@ -221,7 +256,7 @@ def _outer_join(self, table1, table2): partial(self._count_diff_per_column, db, diff_rows, cols1, is_diff_cols) ): - logger.info("Querying for different rows") + logger.debug("Querying for different rows") for is_xa, is_xb, *x in db.query(diff_rows, list): if is_xa and is_xb: # Can't both be exclusive, meaning a pk is NULL @@ -238,7 +273,7 @@ def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): is_diff_cols_counts = db.query(diff_rows.select(sum_(this[c]) for c in is_diff_cols), tuple) diff_counts = {} for name, count in safezip(cols, is_diff_cols_counts): - diff_counts[name] = count + diff_counts[name] = diff_counts.get(name, 0) + (count or 0) self.stats['diff_counts'] = diff_counts def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): @@ -247,6 +282,7 @@ def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): exclusive_rows_query = diff_rows.where((this.is_exclusive_a==1) | (this.is_exclusive_b==1)) else: exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b) + with temp_table(db, exclusive_rows_query) as exclusive_rows: self.stats["exclusive_count"] = db.query(exclusive_rows.count(), int) sample_rows = db.query(sample(exclusive_rows.select(*this[list(a_cols)], *this[list(b_cols)])), list) diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 761b3a74..8b51e3f6 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -103,6 +103,10 @@ def get_values(self) -> list: def choose_checkpoints(self, count: int) -> List[DbKey]: "Suggests a bunch of evenly-spaced checkpoints to split by (not including start, end)" + + if self.max_key - self.min_key <= count: + count = 1 + assert self.is_bounded if isinstance(self.min_key, ArithString): assert type(self.min_key) is type(self.max_key) diff --git a/data_diff/utils.py b/data_diff/utils.py index 5911f8f8..2e346fa3 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -1,3 +1,4 @@ +import logging import re import math from typing import Iterable, Tuple, Union, Any, Sequence, Dict @@ -214,6 +215,9 @@ def __setitem__(self, key: str, value: V): def __contains__(self, key: str) -> bool: ... + def __repr__(self): + return repr(dict(self.items())) + class CaseInsensitiveDict(CaseAwareMapping): def __init__(self, initial): @@ -285,3 +289,7 @@ def run_as_daemon(threadfunc, *args): th.daemon = True th.start() return th + + +def getLogger(name): + return logging.getLogger(name.rsplit('.', 1)[-1]) diff --git a/tests/common.py b/tests/common.py index 44a15cf2..5cce3964 100644 --- a/tests/common.py +++ b/tests/common.py @@ -45,6 +45,7 @@ def get_git_revision_short_hash() -> str: logging.basicConfig(level=level) logging.getLogger("hashdiff_tables").setLevel(level) logging.getLogger("joindiff_tables").setLevel(level) +logging.getLogger("diff_tables").setLevel(level) logging.getLogger("table_segment").setLevel(level) logging.getLogger("database").setLevel(level) From 7c7e5bd963e8e94470aa91668dc48eaee2e06881 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 23 Sep 2022 11:48:20 +0300 Subject: [PATCH 11/33] Added diffing schemas (when same db, for mutual columns) --- data_diff/__main__.py | 19 ++++++++++++++++++- data_diff/joindiff_tables.py | 14 +++++++------- data_diff/utils.py | 4 +++- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index adb5bee9..39951b06 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -45,6 +45,19 @@ def _get_schema(pair): return db.query_table_schema(table_path) +def diff_schemas(schema1, schema2, columns): + logging.info('Diffing schemas...') + attrs = 'name', 'type', 'datetime_precision', 'numeric_precision', 'numeric_scale' + for c in columns: + if c is None: # Skip for convenience + continue + diffs = [] + for attr, v1, v2 in safezip(attrs, schema1[c], schema2[c]): + if v1 != v2: + diffs.append(f"{attr}:({v1} != {v2})") + if diffs: + logging.warning(f"Schema mismatch in column '{c}': {', '.join(diffs)}") + class MyHelpFormatter(click.HelpFormatter): def __init__(self, **kwargs): super().__init__(self, **kwargs) @@ -300,7 +313,11 @@ def _main( columns = tuple(expanded_columns - {key_column, update_column}) - logging.info(f"Diffing columns: key={key_column} update={update_column} extra={columns}") + if db1 is db2: + diff_schemas(schema1, schema2, (key_column, update_column,) + columns) + + + logging.info(f"Diffing using columns: key={key_column} update={update_column} extra={columns}") segments = [ TableSegment(db, table_path, key_column, update_column, columns, **options)._with_raw_schema(raw_schema) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 988cbaec..32b47a14 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -113,15 +113,17 @@ def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult yield from self._diff_segments(None, table1, table2, None) else: yield from self._bisect_and_diff_tables(table1, table2) + logger.info("Diffing complete") def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): assert table1.database is table2.database - logger.info( - ". " * level + f"Diffing segment {segment_index}/{segment_count}, " - f"key-range: {table1.min_key}..{table2.max_key}, " - f"size <= {max_rows}" - ) + if segment_index or table1.min_key or max_rows: + logger.info( + ". " * level + f"Diffing segment {segment_index}/{segment_count}, " + f"key-range: {table1.min_key}..{table2.max_key}, " + f"size <= {max_rows}" + ) with self._run_in_background( partial(self._collect_stats, 1, table1), @@ -130,8 +132,6 @@ def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: Tabl ): yield from self._outer_join(table1, table2) - logger.info("Diffing complete") - def _test_duplicate_keys(self, table1, table2): logger.debug("Testing for duplicate keys") diff --git a/data_diff/utils.py b/data_diff/utils.py index 2e346fa3..642a4b7b 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -15,7 +15,9 @@ def safezip(*args): "zip but makes sure all sequences are the same length" - assert len(set(map(len, args))) == 1 + lens = list(map(len, args)) + if len(set(lens)) != 1: + raise ValueError(f"Mismatching lengths in arguments to safezip: {lens}") return zip(*args) From bee5479d59f486efddb7a2571cf88cfdf24326be Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 23 Sep 2022 16:09:07 +0300 Subject: [PATCH 12/33] Joindiff: Added Interpreter; Fixed exclusive_rows to use temp_table in an interpreter. --- data_diff/databases/base.py | 34 +++++++++++++++++----- data_diff/joindiff_tables.py | 55 ++++++++++++++++++----------------- data_diff/queries/__init__.py | 2 +- data_diff/queries/compiler.py | 26 +++++++++++++++-- 4 files changed, 80 insertions(+), 37 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 181a80e5..288bf6fd 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -1,7 +1,7 @@ import math import sys import logging -from typing import Dict, Tuple, Optional, Sequence, Type, List +from typing import Dict, Generator, Tuple, Optional, Sequence, Type, List, Union from functools import wraps from concurrent.futures import ThreadPoolExecutor import threading @@ -27,7 +27,7 @@ DbPath, ) -from data_diff.queries import Expr, Compiler, table, Select, SKIP +from data_diff.queries import Expr, Compiler, table, Select, SKIP, ThreadLocalInterpreter logger = logging.getLogger("database") @@ -66,11 +66,29 @@ def _one(seq): return x -def _query_conn(conn, sql_code: str) -> list: +def _query_cursor(c, sql_code): + try: + c.execute(sql_code) + if sql_code.lower().startswith("select"): + return c.fetchall() + except Exception as e: + logger.exception(e) + raise + +def _query_conn(conn, sql_code: Union[str, ThreadLocalInterpreter]) -> list: c = conn.cursor() - c.execute(sql_code) - if sql_code.lower().startswith("select"): - return c.fetchall() + + if isinstance(sql_code, ThreadLocalInterpreter): + g = sql_code.interpret() + q = next(g) + while True: + res = _query_cursor(c, q) + try: + q = g.send(res) + except StopIteration: + break + else: + return _query_cursor(c, sql_code) class Database(AbstractDatabase): @@ -312,11 +330,11 @@ def set_conn(self): except ModuleNotFoundError as e: self._init_error = e - def _query(self, sql_code: str): + def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): r = self._queue.submit(self._query_in_worker, sql_code) return r.result() - def _query_in_worker(self, sql_code: str): + def _query_in_worker(self, sql_code: Union[str, ThreadLocalInterpreter]): "This method runs in a worker thread" if self._init_error: raise self._init_error diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 32b47a14..2b83201f 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -2,11 +2,9 @@ """ -from collections import defaultdict from decimal import Decimal from functools import partial import logging -from contextlib import contextmanager from typing import Dict, List from runtype import dataclass @@ -49,30 +47,17 @@ class Stats: def sample(table): return table.order_by(Random()).limit(10) - -@contextmanager -def temp_table(db: Database, expr: Expr): - c = Compiler(db) - - name = c.new_unique_table_name("temp_table") - +def create_temp_table(c: Compiler, name: str, expr: Expr): + db = c.database if isinstance(db, BigQuery): name = f"{db.default_schema}.{name}" - db.query(f"create table {c.quote(name)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}", None) + return f"create table {c.quote(name)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}" elif isinstance(db, Presto): - db.query(f"create table {c.quote(name)} as {c.compile(expr)}", None) + return f"create table {c.quote(name)} as {c.compile(expr)}" elif isinstance(db, Oracle): - db.query(f"create global temporary table {c.quote(name)} as {c.compile(expr)}", None) + return f"create global temporary table {c.quote(name)} as {c.compile(expr)}" else: - db.query(f"create temporary table {c.quote(name)} as {c.compile(expr)}", None) - - try: - yield table(name, schema=expr.source_table.schema) - finally: - if isinstance(db, (BigQuery, Presto)): - # Only drops if create table succeeded (meaning, the table didn't already exist) - # And if the table won't delete itself - db.query(f"drop table {c.quote(name)}", None) + return f"create temporary table {c.quote(name)} as {c.compile(expr)}" def _slice_tuple(t, *sizes): @@ -95,6 +80,7 @@ class JoinDifferBase(TableDiffer): stats: dict = {} validate_unique_key: bool = True + sample_exclusive_rows: bool = True def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: db = table1.database @@ -277,13 +263,30 @@ def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): self.stats['diff_counts'] = diff_counts def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): - logger.info("Counting and sampling exclusive rows") if isinstance(db, Oracle): exclusive_rows_query = diff_rows.where((this.is_exclusive_a==1) | (this.is_exclusive_b==1)) else: exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b) - with temp_table(db, exclusive_rows_query) as exclusive_rows: - self.stats["exclusive_count"] = db.query(exclusive_rows.count(), int) - sample_rows = db.query(sample(exclusive_rows.select(*this[list(a_cols)], *this[list(b_cols)])), list) - self.stats["exclusive_sample"] = sample_rows + if not self.sample_exclusive_rows: + logger.info("Counting exclusive rows") + self.stats["exclusive_count"] = db.query(exclusive_rows_query.count(), int) + return + + logger.info("Counting and sampling exclusive rows") + def exclusive_rows(expr): + c = Compiler(db) + name = c.new_unique_table_name("temp_table") + yield create_temp_table(c, name, expr) + exclusive_rows = table(name, schema=expr.source_table.schema) + + count = yield exclusive_rows.count() + self.stats["exclusive_count"] = self.stats.get('exclusive_count', 0) + count[0][0] + sample_rows = yield sample(exclusive_rows.select(*this[list(a_cols)], *this[list(b_cols)])) + self.stats["exclusive_sample"] = self.stats.get('exclusive_sample', []) + sample_rows + + # Only drops if create table succeeded (meaning, the table didn't already exist) + yield f"drop table {c.quote(name)}" + + # Run as a sequence of thread-local queries (compiled into a ThreadLocalInterpreter) + db.query(exclusive_rows(exclusive_rows_query), None) diff --git a/data_diff/queries/__init__.py b/data_diff/queries/__init__.py index 93299b26..64a6e60f 100644 --- a/data_diff/queries/__init__.py +++ b/data_diff/queries/__init__.py @@ -1,4 +1,4 @@ -from .compiler import Compiler +from .compiler import Compiler, ThreadLocalInterpreter from .api import this, join, outerjoin, table, SKIP, sum_, avg, min_, max_, cte from .ast_classes import Expr, ExprNode, Select, Count, BinOp, Explain, In from .extras import Checksum, NormalizeAsString, ApplyFuncAndNormalizeAsString diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index 5133301c..62430880 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -1,7 +1,7 @@ import random from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Dict, Sequence, List +from typing import Any, Dict, Generator, Sequence, List, Union from runtype import dataclass @@ -32,7 +32,7 @@ def compile(self, elem) -> str: return f"WITH {subq}\n{res}" return res - def _compile(self, elem) -> str: + def _compile(self, elem) -> Union[str, 'ThreadLocalInterpreter']: if elem is None: return "NULL" elif isinstance(elem, Compilable): @@ -47,6 +47,8 @@ def _compile(self, elem) -> str: return f"b'{elem.decode()}'" elif isinstance(elem, ArithString): return f"'{elem}'" + elif isinstance(elem, Generator): + return ThreadLocalInterpreter(self, elem) assert False, elem def new_unique_name(self, prefix="tmp"): @@ -65,3 +67,23 @@ class Compilable(ABC): @abstractmethod def compile(self, c: Compiler) -> str: ... + + +class ThreadLocalInterpreter: + """An interpeter used to execute a sequence of queries within the same thread. + + Useful for cursor-sensitive operations, such as creating a temporary table. + """ + + def __init__(self, compiler: Compiler, gen: Generator): + self.gen = gen + self.compiler = compiler + + def interpret(self): + q = next(self.gen) + while True: + try: + res = yield self.compiler.compile(q) + q = self.gen.send(res) + except StopIteration: + break From 4c80e5d48156650a675b95eeb37c4f52899dd6a3 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 30 Sep 2022 15:09:04 +0300 Subject: [PATCH 13/33] Tracking: Errors now provide more info, with truncated values --- data_diff/__main__.py | 2 +- data_diff/diff_tables.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 39951b06..e6d37253 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -307,7 +307,7 @@ def _main( m1 = None if any(match_like(c, schema1.keys())) else f"{db1}/{table1}" m2 = None if any(match_like(c, schema2.keys())) else f"{db2}/{table2}" not_matched = ", ".join(m for m in [m1, m2] if m) - raise ValueError(f"Column {c} not found in: {not_matched}") + raise ValueError(f"Column '{c}' not found in: {not_matched}") expanded_columns |= match diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 5cd21302..78440a95 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -1,6 +1,7 @@ """Provides classes for performing a table diff """ +import re import time from abc import ABC, abstractmethod from enum import Enum @@ -27,6 +28,10 @@ class Algorithm(Enum): DiffResult = Iterator[Tuple[str, tuple]] # Iterator[Tuple[Literal["+", "-"], tuple]] +def truncate_error(error: str): + first_line = error.split('\n', 1)[0] + return re.sub("'(.*?)'", "'***'", first_line) + @dataclass class ThreadBase: @@ -110,7 +115,7 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: table1_count = self.stats.get("table1_count") table2_count = self.stats.get("table2_count") diff_count = self.stats.get("diff_count") - err_message = str(error)[:20] # Truncate possibly sensitive information. + err_message = truncate_error(repr(error)) event_json = create_end_event_json( error is None, runtime, @@ -186,7 +191,7 @@ def _parse_key_range_result(self, key_type, key_range): try: return cls(mn), cls(mx) + 1 except (TypeError, ValueError) as e: - raise type(e)(f"Cannot apply {key_type} to {mn}, {mx}.") from e + raise type(e)(f"Cannot apply {key_type} to '{mn}', '{mx}'.") from e def _bisect_and_diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None): From 179ce547d3167f2e167a8f1e944fc22f52f5d38d Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 30 Sep 2022 17:38:30 +0300 Subject: [PATCH 14/33] Better docs and docstrings --- data_diff/__init__.py | 35 ++++++++++++++++++++++++++++++++--- data_diff/joindiff_tables.py | 10 +++++++++- data_diff/table_segment.py | 3 ++- docs/conf.py | 1 + docs/python-api.rst | 8 ++++++++ docs/requirements.txt | 2 +- 6 files changed, 53 insertions(+), 6 deletions(-) diff --git a/data_diff/__init__.py b/data_diff/__init__.py index f22ab039..3e8451ba 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -15,14 +15,17 @@ def connect_to_table( key_column: str = "id", thread_count: Optional[int] = 1, **kwargs, -): +) -> TableSegment: """Connects to the given database, and creates a TableSegment instance Parameters: db_info: Either a URI string, or a dict of connection options. table_name: Name of the table as a string, or a tuple that signifies the path. key_column: Name of the key column - thread_count: Number of threads for this connection (only if using a threadpooled implementation) + thread_count: Number of threads for this connection (only if using a threadpooled db implementation) + + See Also: + :meth:`connect` """ db = connect(db_info, thread_count=thread_count) @@ -61,13 +64,39 @@ def diff_tables( # There may be many pools, so number of actual threads can be a lot higher. max_threadpool_size: Optional[int] = 1, ) -> Iterator: - """Efficiently finds the diff between table1 and table2. + """Finds the diff between table1 and table2. + + Parameters: + key_column (str): Name of the key column, which uniquely identifies each row (usually id) + update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update). + Used by `min_update` and `max_update`. + extra_columns (Tuple[str, ...], optional): Extra columns to compare + min_key (:data:`DbKey`, optional): Lowest key_column value, used to restrict the segment + max_key (:data:`DbKey`, optional): Highest key_column value, used to restrict the segment + min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment + max_update (:data:`DbTime`, optional): Highest update_column value, used to restrict the segment + algorithm (:class:`Algorithm`): Which diffing algorithm to use (`HASHDIFF` or `JOINDIFF`) + bisection_factor (int): Into how many segments to bisect per iteration. (when algorithm is `HASHDIFF`) + bisection_threshold (Number): When should we stop bisecting and compare locally (when algorithm is `HASHDIFF`; in row count). + threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. + max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``. + There may be many pools, so number of actual threads can be a lot higher. + + Note: + The following parameters are used to override the corresponding attributes of the given :class:`TableSegment` instances: + `key_column`, `update_column`, `extra_columns`, `min_key`, `max_key`. If different values are needed per table, it's + possible to omit them here, and instead set them directly when creating each :class:`TableSegment`. Example: >>> table1 = connect_to_table('postgresql:///', 'Rating', 'id') >>> list(diff_tables(table1, table1)) [] + See Also: + :class:`TableSegment` + :class:`HashDiffer` + :class:`JoinDiffer` + """ tables = [table1, table2] override_attrs = { diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 2b83201f..622d002f 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -76,7 +76,15 @@ def json_friendly_value(v): @dataclass class JoinDifferBase(TableDiffer): - """Finds the diff between two SQL tables using JOINs""" + """Finds the diff between two SQL tables using JOINs + + The two tables must reside in the same database, and their primary keys must be unique and not null. + + Parameters: + threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. + max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``. + There may be many pools, so number of actual threads can be a lot higher. + """ stats: dict = {} validate_unique_key: bool = True diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 8b51e3f6..3a4ddbe4 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -23,7 +23,8 @@ class TableSegment: database (Database): Database instance. See :meth:`connect` table_path (:data:`DbPath`): Path to table in form of a tuple. e.g. `('my_dataset', 'table_name')` key_column (str): Name of the key column, which uniquely identifies each row (usually id) - update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update) + update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update). + Used by `min_update` and `max_update`. extra_columns (Tuple[str, ...], optional): Extra columns to compare min_key (:data:`DbKey`, optional): Lowest key_column value, used to restrict the segment max_key (:data:`DbKey`, optional): Highest key_column value, used to restrict the segment diff --git a/docs/conf.py b/docs/conf.py index ef75ecc0..dc58fb90 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -41,6 +41,7 @@ "recommonmark", "sphinx_markdown_tables", "sphinx_copybutton", + "enum_tools.autoenum", # 'sphinx_gallery.gen_gallery' ] diff --git a/docs/python-api.rst b/docs/python-api.rst index f28b18d1..ada633d1 100644 --- a/docs/python-api.rst +++ b/docs/python-api.rst @@ -5,6 +5,10 @@ Python API Reference .. autofunction:: connect +.. autofunction:: connect_to_table + +.. autofunction:: diff_tables + .. autoclass:: HashDiffer :members: __init__, diff_tables @@ -17,6 +21,10 @@ Python API Reference .. autoclass:: data_diff.databases.database_types.AbstractDatabase :members: +.. autoclass:: data_diff.databases.database_types.AbstractDialect + :members: + .. autodata:: DbKey .. autodata:: DbTime .. autodata:: DbPath +.. autoenum:: Algorithm diff --git a/docs/requirements.txt b/docs/requirements.txt index 0d1d793a..252c7acb 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,6 +4,6 @@ sphinx_markdown_tables sphinx-copybutton sphinx-rtd-theme recommonmark +enum-tools[sphinx] -# Requirements. TODO Use poetry instead of this redundant list data_diff From 073333ce624a8f88dd5cb5972395cf61ed449039 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 3 Oct 2022 10:20:26 +0300 Subject: [PATCH 15/33] Refactor joindiff --- data_diff/joindiff_tables.py | 65 +++++++++++++++++------------------- 1 file changed, 30 insertions(+), 35 deletions(-) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 622d002f..d997730f 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -60,6 +60,32 @@ def create_temp_table(c: Compiler, name: str, expr: Expr): return f"create temporary table {c.quote(name)} as {c.compile(expr)}" +def bool_to_int(x): + return if_(x, 1, 0) + + +def _outerjoin(db: Database, a: ITable, b: ITable, keys1: List[str], keys2: List[str], select_fields: dict) -> ITable: + on = [a[k1] == b[k2] for k1, k2 in safezip(keys1, keys2)] + + if isinstance(db, Oracle): + is_exclusive_a = and_(bool_to_int(b[k] == None) for k in keys2) + is_exclusive_b = and_(bool_to_int(a[k] == None) for k in keys1) + else: + is_exclusive_a = and_(b[k] == None for k in keys2) + is_exclusive_b = and_(a[k] == None for k in keys1) + + if isinstance(db, MySQL): + # No outer join + l = leftjoin(a, b).on(*on).select(is_exclusive_a=is_exclusive_a, is_exclusive_b=False, **select_fields) + r = rightjoin(a, b).on(*on).select(is_exclusive_a=False, is_exclusive_b=is_exclusive_b, **select_fields) + return l.union(r) + + return ( + outerjoin(a, b).on(*on) + .select(is_exclusive_a=is_exclusive_a, is_exclusive_b=is_exclusive_b, **select_fields) + ) + + def _slice_tuple(t, *sizes): i = 0 for size in sizes: @@ -74,10 +100,12 @@ def json_friendly_value(v): return v + @dataclass -class JoinDifferBase(TableDiffer): - """Finds the diff between two SQL tables using JOINs +class JoinDiffer(TableDiffer): + """Finds the diff between two SQL tables in the same database, using JOINs. + The algorithm uses an OUTER JOIN (or equivalent) with extra checks and statistics. The two tables must reside in the same database, and their primary keys must be unique and not null. Parameters: @@ -182,39 +210,6 @@ def _collect_stats(self, i, table): # stats.diff_ratio_total = diff_stats['total_diff'] -def bool_to_int(x): - return if_(x, 1, 0) - - -def _outerjoin(db: Database, a: ITable, b: ITable, keys1: List[str], keys2: List[str], select_fields: dict) -> ITable: - on = [a[k1] == b[k2] for k1, k2 in safezip(keys1, keys2)] - - if isinstance(db, Oracle): - is_exclusive_a = and_(bool_to_int(b[k] == None) for k in keys2) - is_exclusive_b = and_(bool_to_int(a[k] == None) for k in keys1) - else: - is_exclusive_a = and_(b[k] == None for k in keys2) - is_exclusive_b = and_(a[k] == None for k in keys1) - - if isinstance(db, MySQL): - # No outer join - l = leftjoin(a, b).on(*on).select(is_exclusive_a=is_exclusive_a, is_exclusive_b=False, **select_fields) - r = rightjoin(a, b).on(*on).select(is_exclusive_a=False, is_exclusive_b=is_exclusive_b, **select_fields) - return l.union(r) - - return ( - outerjoin(a, b).on(*on) - .select(is_exclusive_a=is_exclusive_a, is_exclusive_b=is_exclusive_b, **select_fields) - ) - - -class JoinDiffer(JoinDifferBase): - """Finds the diff between two SQL tables in the same database. - - The algorithm uses an OUTER JOIN (or equivalent) with extra checks and statistics. - - """ - def _outer_join(self, table1, table2): db = table1.database if db is not table2.database: From c1e171d74e3f1340289f4e2b172b894d93a00d51 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 3 Oct 2022 13:15:49 +0300 Subject: [PATCH 16/33] Queries: Derive schemas (WIP) --- data_diff/queries/ast_classes.py | 56 ++++++++++++++++++++++++-------- data_diff/queries/extras.py | 5 +-- tests/test_query.py | 14 ++++++++ 3 files changed, 60 insertions(+), 15 deletions(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 56213dd3..7c891a96 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -40,6 +40,10 @@ class Alias(ExprNode): def compile(self, c: Compiler) -> str: return f"{c.compile(self.expr)} AS {c.quote(self.name)}" + @property + def type(self): + return self.expr.type + def _drop_skips(exprs): return [e for e in exprs if e is not SKIP] @@ -163,6 +167,10 @@ def compile(self, c: Compiler) -> str: args = ", ".join(c.compile(e) for e in self.args) return f"{self.name}({args})" +def _expr_type(e: Expr): + if isinstance(e, ExprNode): + return e.type + return type(e) @dataclass class CaseWhen(ExprNode): @@ -175,30 +183,40 @@ def compile(self, c: Compiler) -> str: else_ = (" " + c.compile(self.else_)) if self.else_ else "" return f"CASE {when_thens}{else_} END" + @property + def type(self): + when_types = {_expr_type(w) for _c,w in self.cases } + if self.else_: + when_types |= _expr_type(self.else_) + if len(when_types) > 1: + raise RuntimeError(f"Non-matching types in when: {when_types}") + t ,= when_types + return t + class LazyOps: def __add__(self, other): return BinOp("+", [self, other]) def __gt__(self, other): - return BinOp(">", [self, other]) + return BinBoolOp(">", [self, other]) def __ge__(self, other): - return BinOp(">=", [self, other]) + return BinBoolOp(">=", [self, other]) def __eq__(self, other): if other is None: - return BinOp("IS", [self, None]) - return BinOp("=", [self, other]) + return BinBoolOp("IS", [self, None]) + return BinBoolOp("=", [self, other]) def __lt__(self, other): - return BinOp("<", [self, other]) + return BinBoolOp("<", [self, other]) def __le__(self, other): - return BinOp("<=", [self, other]) + return BinBoolOp("<=", [self, other]) def __or__(self, other): - return BinOp("OR", [self, other]) + return BinBoolOp("OR", [self, other]) def is_distinct_from(self, other): return IsDistinctFrom(self, other) @@ -211,6 +229,7 @@ def sum(self): class IsDistinctFrom(ExprNode, LazyOps): a: Expr b: Expr + type = bool def compile(self, c: Compiler) -> str: return c.database.is_distinct_from(c.compile(self.a), c.compile(self.b)) @@ -228,6 +247,9 @@ def compile(self, c: Compiler) -> str: a, b = self.args return f"({c.compile(a)} {self.op} {c.compile(b)})" +class BinBoolOp(BinOp): + type = bool + @dataclass(eq=False, order=False) class Column(ExprNode, LazyOps): @@ -299,8 +321,9 @@ def source_table(self): @property def schema(self): - # TODO combine both tables - return None + assert self.columns # TODO Implement SELECT * + s = self.source_tables[0].schema # XXX + return type(s)({c.name: c.type for c in self.columns}) def on(self, *exprs): if len(exprs) == 1: @@ -375,7 +398,7 @@ def compile(self, parent_c: Compiler) -> str: @dataclass class Select(ExprNode, ITable): - source_table: Expr = None + table: Expr = None columns: Sequence[Expr] = None where_exprs: Sequence[Expr] = None order_by_exprs: Sequence[Expr] = None @@ -384,7 +407,14 @@ class Select(ExprNode, ITable): @property def schema(self): - return self.source_table.schema + s = self.table.schema + if s is None or self.columns is None: + return s + return type(s)({c.name: c.type for c in self.columns}) + + @property + def source_table(self): + return self def compile(self, parent_c: Compiler) -> str: c = parent_c.replace(in_select=True) #.add_table_context(self.table) @@ -392,8 +422,8 @@ def compile(self, parent_c: Compiler) -> str: columns = ", ".join(map(c.compile, self.columns)) if self.columns else "*" select = f"SELECT {columns}" - if self.source_table: - select += " FROM " + c.compile(self.source_table) + if self.table: + select += " FROM " + c.compile(self.table) if self.where_exprs: select += " WHERE " + " AND ".join(map(c.compile, self.where_exprs)) diff --git a/data_diff/queries/extras.py b/data_diff/queries/extras.py index 9b5189e1..bcd426df 100644 --- a/data_diff/queries/extras.py +++ b/data_diff/queries/extras.py @@ -12,11 +12,12 @@ @dataclass class NormalizeAsString(ExprNode): expr: ExprNode - type: ColType = None + expr_type: ColType = None + type = str def compile(self, c: Compiler) -> str: expr = c.compile(self.expr) - return c.database.normalize_value_by_type(expr, self.type or self.expr.type) + return c.database.normalize_value_by_type(expr, self.expr_type or self.expr.type) @dataclass diff --git a/tests/test_query.py b/tests/test_query.py index 4ae4b82e..3e895bc5 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -79,6 +79,7 @@ def test_schema(self): c = Compiler(MockDialect()) schema = dict(id="int", comment="varchar") + # test table t = table("a", schema=CaseInsensitiveDict(schema)) q = t.select(this.Id, t["COMMENT"]) assert c.compile(q) == "SELECT id, comment FROM a" @@ -87,6 +88,19 @@ def test_schema(self): self.assertRaises(KeyError, t.__getitem__, "Id") self.assertRaises(KeyError, t.select, this.Id) + # test select + q = t.select(this.id) + self.assertRaises(KeyError, q.__getitem__, "comment") + + # test join + s = CaseInsensitiveDict({'x': int, 'y': int}) + a = table("a", schema=s) + b = table("b", schema=s) + keys = ["x", "y"] + j = outerjoin(a, b).on(a[k] == b[k] for k in keys).select(a['x'], b['y'], xsum=a['x'] + b['x']) + j['x'], j['y'], j['xsum'] + self.assertRaises(KeyError, j.__getitem__, "ysum") + def test_commutable_select(self): # c = Compiler(MockDialect()) From da6c2df0d7205af13b97a0b6bf7354e02f6bd1c0 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 4 Oct 2022 09:35:44 +0300 Subject: [PATCH 17/33] Queries: DDL initial (drop/create table, insert) --- data_diff/queries/ast_classes.py | 49 +++++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 7c891a96..cc2463f0 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -1,6 +1,6 @@ from dataclasses import field from datetime import datetime -from typing import Any, Generator, ItemsView, Optional, Sequence, Tuple, Union +from typing import Any, Generator, Optional, Sequence, Tuple, Union from runtype import dataclass @@ -10,6 +10,7 @@ from .base import SKIP, CompileError, DbPath, Schema, args_as_tuple + class ExprNode(Compilable): type: Any = None @@ -284,11 +285,16 @@ class TablePath(ExprNode, ITable): path: DbPath schema: Optional[Schema] = field(default=None, repr=False) + def create(self, if_not_exists=False): + if not self.schema: + raise ValueError("Schema must have a value to create table") + return CreateTable(self, if_not_exists=if_not_exists) + def insert_values(self, rows): - pass + raise NotImplementedError() - def insert_query(self, query): - pass + def insert_expr(self, expr: Expr): + return InsertToTable(self, expr) @property def source_table(self): @@ -558,3 +564,38 @@ def compile(self, c: Compiler) -> str: class Random(ExprNode): def compile(self, c: Compiler) -> str: return c.database.random() + + +# DDL + +class Statement(Compilable): + type = None + +def to_sql_type(t): + if isinstance(t, str): + return t + return { + int: "int", + str: "varchar", + bool: "boolean", + }[t] + + +@dataclass +class CreateTable(Statement): + path: TablePath + if_not_exists: bool = False + + def compile(self, c: Compiler) -> str: + schema = ', '.join(f'{k} {to_sql_type(v)}' for k, v in self.path.schema.items()) + ne = 'IF NOT EXISTS ' if self.if_not_exists else '' + return f'CREATE TABLE {ne}{c.compile(self.path)}({schema})' + +@dataclass +class InsertToTable(Statement): + # TODO Support insert for only some columns + path: TablePath + expr: Expr + + def compile(self, c: Compiler) -> str: + return f'INSERT INTO {c.compile(self.path)} {c.compile(self.expr)}' From 9f404a06a1522509f1f26b378401ee28ffec0539 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 5 Oct 2022 11:08:46 +0300 Subject: [PATCH 18/33] Queries: Fix in .type --- data_diff/queries/ast_classes.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index cc2463f0..8a07f4fb 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -32,6 +32,10 @@ def cast_to(self, to): Expr = Union[ExprNode, str, bool, int, datetime, ArithString, None] +def get_type(e: Expr) -> type: + if isinstance(e, ExprNode): + return e.type + return type(e) @dataclass class Alias(ExprNode): @@ -43,7 +47,7 @@ def compile(self, c: Compiler) -> str: @property def type(self): - return self.expr.type + return get_type(self.expr) def _drop_skips(exprs): @@ -392,6 +396,17 @@ class Union(ExprNode, ITable): def source_table(self): return self # TODO is this right? + @property + def type(self): + return self.table1.type + + @property + def schema(self): + s1 = self.table1.schema + s2 = self.table2.schema + assert len(s1) == len(s2) + return s1 + def compile(self, parent_c: Compiler) -> str: c = parent_c.replace(in_select=False) union_all = f"{c.compile(self.table1)} UNION {c.compile(self.table2)}" @@ -576,7 +591,7 @@ def to_sql_type(t): return t return { int: "int", - str: "varchar", + str: "varchar(1024)", bool: "boolean", }[t] From 5cd424dd49c0013cb509d7979d6b4fecda2f5e1b Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 3 Oct 2022 13:16:01 +0300 Subject: [PATCH 19/33] Joindiff: Added support to materialize results as tables (-m) --- data_diff/__main__.py | 14 ++++- data_diff/databases/base.py | 2 +- data_diff/diff_tables.py | 2 +- data_diff/joindiff_tables.py | 89 +++++++++++++++++++++----------- data_diff/queries/ast_classes.py | 29 +++++++++-- data_diff/utils.py | 6 +++ 6 files changed, 104 insertions(+), 38 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index e6d37253..437bc67a 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -9,7 +9,9 @@ import rich import click -from .utils import remove_password_from_url, safezip, match_like +from data_diff.databases.base import parse_table_name + +from .utils import eval_name_template, remove_password_from_url, safezip, match_like from .diff_tables import Algorithm from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR from .joindiff_tables import JoinDiffer @@ -104,6 +106,7 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - help=f"Minimal bisection threshold. Below it, data-diff will download the data and compare it locally. Default={DEFAULT_BISECTION_THRESHOLD}.", metavar="NUM", ) +@click.option("-m", "--materialize", default=None, metavar="TABLE_NAME", help="Materialize the diff results into a new table in the database.") @click.option( "--min-age", default=None, @@ -126,6 +129,11 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - is_flag=True, help="Column names are treated as case-sensitive. Otherwise, data-diff corrects their case according to schema.", ) +@click.option( + "--assume-unique-key", + is_flag=True, + help="Skip validating the uniqueness of the key column during joindiff, which is costly in non-cloud dbs.", +) @click.option( "-j", "--threads", @@ -192,6 +200,8 @@ def _main( case_sensitive, json_output, where, + assume_unique_key, + materialize, threads1=None, threads2=None, __conf__=None, @@ -256,6 +266,8 @@ def _main( differ = JoinDiffer( threaded=threaded, max_threadpool_size=threads and threads * 2, + validate_unique_key = not assume_unique_key, + materialize_to_table = materialize and parse_table_name(eval_name_template(materialize)), ) else: assert algorithm == Algorithm.HASHDIFF diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 288bf6fd..bd33165f 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -107,7 +107,7 @@ class Database(AbstractDatabase): def name(self): return type(self).__name__ - def query(self, sql_ast: Expr, res_type: type): + def query(self, sql_ast: Expr, res_type: type = None): "Query the given SQL code/AST, and attempt to convert the result to type 'res_type'" compiler = Compiler(self) diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 78440a95..5ecd1667 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -68,7 +68,7 @@ def _threaded_call_as_completed(self, func, iterable): @contextmanager def _run_in_background(self, *funcs): with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool: - futures = [task_pool.submit(f) for f in funcs] + futures = [task_pool.submit(f) for f in funcs if f is not None] yield futures for f in futures: f.result() diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index d997730f..1afb9467 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -5,10 +5,12 @@ from decimal import Decimal from functools import partial import logging -from typing import Dict, List +from typing import Dict, List, Optional from runtype import dataclass +from data_diff.databases.database_types import DbPath, Schema + from .utils import safezip from .databases.base import Database @@ -17,15 +19,16 @@ from .diff_tables import TableDiffer, DiffResult from .thread_utils import ThreadedYielder -from .queries import table, sum_, min_, max_, avg +from .queries import table, sum_, min_, max_, avg, SKIP from .queries.api import and_, if_, or_, outerjoin, leftjoin, rightjoin, this, ITable -from .queries.ast_classes import Concat, Count, Expr, Random +from .queries.ast_classes import Concat, Count, Expr, Random, TablePath from .queries.compiler import Compiler from .queries.extras import NormalizeAsString - logger = logging.getLogger("joindiff_tables") +WRITE_LIMIT = 1000 + def merge_dicts(dicts): i = iter(dicts) @@ -60,6 +63,18 @@ def create_temp_table(c: Compiler, name: str, expr: Expr): return f"create temporary table {c.quote(name)} as {c.compile(expr)}" +def drop_table(db, name: DbPath): + t = TablePath(name) + db.query(t.drop(if_exists=True)) + +def append_to_table(name: DbPath, expr: Expr): + t = TablePath(name, expr.schema) + yield t.create(if_not_exists=True) # uses expr.schema + yield 'commit' + yield t.insert_expr(expr) + yield 'commit' + + def bool_to_int(x): return if_(x, 1, 0) @@ -117,6 +132,8 @@ class JoinDiffer(TableDiffer): stats: dict = {} validate_unique_key: bool = True sample_exclusive_rows: bool = True + materialize_to_table: DbPath = None + write_limit: int = WRITE_LIMIT def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: db = table1.database @@ -128,8 +145,12 @@ def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult bg_funcs = [partial(self._test_duplicate_keys, table1, table2)] if self.validate_unique_key else [] + if self.materialize_to_table: + drop_table(db, self.materialize_to_table) + db.query('COMMIT') with self._run_in_background(*bg_funcs): + if isinstance(db, (Snowflake, BigQuery)): # Don't segment the table; let the database handling parallelization yield from self._diff_segments(None, table1, table2, None) @@ -147,12 +168,29 @@ def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: Tabl f"size <= {max_rows}" ) + db = table1.database + diff_rows, a_cols, b_cols, is_diff_cols = self._create_outer_join(table1, table2) + with self._run_in_background( partial(self._collect_stats, 1, table1), partial(self._collect_stats, 2, table2), partial(self._test_null_keys, table1, table2), + partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols), + partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols), + partial(self._materialize_diff, db, diff_rows, segment_index=segment_index) if self.materialize_to_table else None, ): - yield from self._outer_join(table1, table2) + + logger.debug("Querying for different rows") + for is_xa, is_xb, *x in db.query(diff_rows, list): + if is_xa and is_xb: + # Can't both be exclusive, meaning a pk is NULL + # This can happen if the explicit null test didn't finish running yet + raise ValueError(f"NULL values in one or more primary keys") + is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols)) + if not is_xb: + yield "-", tuple(a_row) + if not is_xa: + yield "+", tuple(b_row) def _test_duplicate_keys(self, table1, table2): logger.debug("Testing for duplicate keys") @@ -162,7 +200,7 @@ def _test_duplicate_keys(self, table1, table2): t = ts._make_select() key_columns = [ts.key_column] # XXX - q = t.select(total=Count(), total_distinct=Count(Concat(key_columns), distinct=True)) + q = t.select(total=Count(), total_distinct=Count(Concat(this[key_columns]), distinct=True)) total, total_distinct = ts.database.query(q, tuple) if total != total_distinct: raise ValueError("Duplicate primary keys") @@ -175,7 +213,7 @@ def _test_null_keys(self, table1, table2): t = ts._make_select() key_columns = [ts.key_column] # XXX - q = t.select(*key_columns).where(or_(this[k] == None for k in key_columns)) + q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns)) nulls = ts.database.query(q, list) if nulls: raise ValueError(f"NULL values in one or more primary keys") @@ -188,10 +226,10 @@ def _collect_stats(self, i, table): # Metrics col_exprs = merge_dicts( { - f"sum_{c}": sum_(c), - f"avg_{c}": avg(c), - f"min_{c}": min_(c), - f"max_{c}": max_(c), + f"sum_{c}": sum_(this[c]), + f"avg_{c}": avg(this[c]), + f"min_{c}": min_(this[c]), + f"max_{c}": max_(this[c]), } for c in table._relevant_columns if c == "id" # TODO just if the right type @@ -209,8 +247,7 @@ def _collect_stats(self, i, table): # stats.diff_ratio_by_column = diff_stats # stats.diff_ratio_total = diff_stats['total_diff'] - - def _outer_join(self, table1, table2): + def _create_outer_join(self, table1, table2): db = table1.database if db is not table2.database: raise ValueError("Joindiff only applies to tables within the same database") @@ -239,23 +276,8 @@ def _outer_join(self, table1, table2): _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols}) .where(or_(this[c] == 1 for c in is_diff_cols)) ) + return diff_rows, a_cols, b_cols, is_diff_cols - with self._run_in_background( - partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols), - partial(self._count_diff_per_column, db, diff_rows, cols1, is_diff_cols) - ): - - logger.debug("Querying for different rows") - for is_xa, is_xb, *x in db.query(diff_rows, list): - if is_xa and is_xb: - # Can't both be exclusive, meaning a pk is NULL - # This can happen if the explicit null test didn't finish running yet - raise ValueError(f"NULL values in one or more primary keys") - is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols)) - if not is_xb: - yield "-", tuple(a_row) - if not is_xa: - yield "+", tuple(b_row) def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): logger.info("Counting differences per column") @@ -280,7 +302,7 @@ def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): def exclusive_rows(expr): c = Compiler(db) name = c.new_unique_table_name("temp_table") - yield create_temp_table(c, name, expr) + yield create_temp_table(c, name, expr.limit(self.write_limit)) exclusive_rows = table(name, schema=expr.source_table.schema) count = yield exclusive_rows.count() @@ -293,3 +315,10 @@ def exclusive_rows(expr): # Run as a sequence of thread-local queries (compiled into a ThreadLocalInterpreter) db.query(exclusive_rows(exclusive_rows_query), None) + + def _materialize_diff(self, db, diff_rows, segment_index=None): + assert self.materialize_to_table + + db.query(append_to_table(self.materialize_to_table, diff_rows.limit(self.write_limit))) + logger.info(f"Materialized diff to table '{'.'.join(self.materialize_to_table)}'.") + diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 8a07f4fb..eec3a200 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -140,7 +140,7 @@ class Concat(ExprNode): def compile(self, c: Compiler) -> str: # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL - items = [f"coalesce({c.compile(c.database.to_string(expr))}, '')" for expr in self.exprs] + items = [f"coalesce({c.compile(c.database.to_string(c.compile(expr)))}, '')" for expr in self.exprs] assert items if len(items) == 1: return items[0] @@ -294,6 +294,9 @@ def create(self, if_not_exists=False): raise ValueError("Schema must have a value to create table") return CreateTable(self, if_not_exists=if_not_exists) + def drop(self, if_exists=False): + return DropTable(self, if_exists=if_exists) + def insert_values(self, rows): raise NotImplementedError() @@ -513,13 +516,13 @@ def resolve_names(source_table, exprs): if isinstance(expr, ExprNode): for v in expr._dfs_values(): if isinstance(v, _ResolveColumn): - v.resolve(source_table._get_column(v.name)) + v.resolve(source_table._get_column(v.resolve_name)) i += 1 @dataclass(frozen=False, eq=False, order=False) class _ResolveColumn(ExprNode, LazyOps): - name: str + resolve_name: str resolved: Expr = None def resolve(self, expr): @@ -528,15 +531,22 @@ def resolve(self, expr): def compile(self, c: Compiler) -> str: if self.resolved is None: - raise RuntimeError(f"Column not resolved: {self.name}") + raise RuntimeError(f"Column not resolved: {self.resolve_name}") return self.resolved.compile(c) @property def type(self): if self.resolved is None: - raise RuntimeError(f"Column not resolved: {self.name}") + raise RuntimeError(f"Column not resolved: {self.resolve_name}") return self.resolved.type + @property + def name(self): + if self.resolved is None: + raise RuntimeError(f"Column not resolved: {self.name}") + return self.resolved.name + + class This: def __getattr__(self, name): @@ -606,6 +616,15 @@ def compile(self, c: Compiler) -> str: ne = 'IF NOT EXISTS ' if self.if_not_exists else '' return f'CREATE TABLE {ne}{c.compile(self.path)}({schema})' +@dataclass +class DropTable(Statement): + path: TablePath + if_exists: bool = False + + def compile(self, c: Compiler) -> str: + ie = 'IF EXISTS ' if self.if_exists else '' + return f'DROP TABLE {ie}{c.compile(self.path)}' + @dataclass class InsertToTable(Statement): # TODO Support insert for only some columns diff --git a/data_diff/utils.py b/data_diff/utils.py index 642a4b7b..ca05e051 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -9,6 +9,7 @@ import operator import string import threading +from datetime import datetime alphanums = " -" + string.digits + string.ascii_uppercase + "_" + string.ascii_lowercase @@ -295,3 +296,8 @@ def run_as_daemon(threadfunc, *args): def getLogger(name): return logging.getLogger(name.rsplit('.', 1)[-1]) + +def eval_name_template(name): + def get_timestamp(m): + return datetime.now().isoformat('_', 'seconds').replace(':', '_') + return re.sub('%t', get_timestamp, name) From 733972a7c7428800ff1cbe2d6f5ce387b15aa764 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 4 Oct 2022 17:45:47 +0300 Subject: [PATCH 20/33] Queries: Ran black --- data_diff/queries/api.py | 2 ++ data_diff/queries/ast_classes.py | 34 +++++++++++++++++++------------- data_diff/queries/compiler.py | 4 ++-- tests/test_query.py | 10 +++++----- 4 files changed, 29 insertions(+), 21 deletions(-) diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index 136807eb..a07a9084 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -15,10 +15,12 @@ def leftjoin(*tables: ITable): "Left-joins each table into a 'struct'" return Join(tables, "LEFT") + def rightjoin(*tables: ITable): "Right-joins each table into a 'struct'" return Join(tables, "RIGHT") + def outerjoin(*tables: ITable): "Outer-joins each table into a 'struct'" return Join(tables, "FULL OUTER") diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index eec3a200..b3552620 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -10,7 +10,6 @@ from .base import SKIP, CompileError, DbPath, Schema, args_as_tuple - class ExprNode(Compilable): type: Any = None @@ -129,7 +128,7 @@ def __getitem__(self, column): def count(self): return Select(self, [Count()]) - def union(self, other: 'ITable'): + def union(self, other: "ITable"): return Union(self, other) @@ -172,11 +171,13 @@ def compile(self, c: Compiler) -> str: args = ", ".join(c.compile(e) for e in self.args) return f"{self.name}({args})" + def _expr_type(e: Expr): if isinstance(e, ExprNode): return e.type return type(e) + @dataclass class CaseWhen(ExprNode): cases: Sequence[Tuple[Expr, Expr]] @@ -190,12 +191,12 @@ def compile(self, c: Compiler) -> str: @property def type(self): - when_types = {_expr_type(w) for _c,w in self.cases } + when_types = {_expr_type(w) for _c, w in self.cases} if self.else_: when_types |= _expr_type(self.else_) if len(when_types) > 1: raise RuntimeError(f"Non-matching types in when: {when_types}") - t ,= when_types + (t,) = when_types return t @@ -252,6 +253,7 @@ def compile(self, c: Compiler) -> str: a, b = self.args return f"({c.compile(a)} {self.op} {c.compile(b)})" + class BinBoolOp(BinOp): type = bool @@ -334,8 +336,8 @@ def source_table(self): @property def schema(self): - assert self.columns # TODO Implement SELECT * - s = self.source_tables[0].schema # XXX + assert self.columns # TODO Implement SELECT * + s = self.source_tables[0].schema # XXX return type(s)({c.name: c.type for c in self.columns}) def on(self, *exprs): @@ -390,6 +392,7 @@ class GroupBy(ITable): def having(self): pass + @dataclass class Union(ExprNode, ITable): table1: ITable @@ -441,7 +444,7 @@ def source_table(self): return self def compile(self, parent_c: Compiler) -> str: - c = parent_c.replace(in_select=True) #.add_table_context(self.table) + c = parent_c.replace(in_select=True) # .add_table_context(self.table) columns = ", ".join(map(c.compile, self.columns)) if self.columns else "*" select = f"SELECT {columns}" @@ -547,7 +550,6 @@ def name(self): return self.resolved.name - class This: def __getattr__(self, name): return _ResolveColumn(name) @@ -593,9 +595,11 @@ def compile(self, c: Compiler) -> str: # DDL + class Statement(Compilable): type = None + def to_sql_type(t): if isinstance(t, str): return t @@ -612,9 +616,10 @@ class CreateTable(Statement): if_not_exists: bool = False def compile(self, c: Compiler) -> str: - schema = ', '.join(f'{k} {to_sql_type(v)}' for k, v in self.path.schema.items()) - ne = 'IF NOT EXISTS ' if self.if_not_exists else '' - return f'CREATE TABLE {ne}{c.compile(self.path)}({schema})' + schema = ", ".join(f"{k} {to_sql_type(v)}" for k, v in self.path.schema.items()) + ne = "IF NOT EXISTS " if self.if_not_exists else "" + return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})" + @dataclass class DropTable(Statement): @@ -622,8 +627,9 @@ class DropTable(Statement): if_exists: bool = False def compile(self, c: Compiler) -> str: - ie = 'IF EXISTS ' if self.if_exists else '' - return f'DROP TABLE {ie}{c.compile(self.path)}' + ie = "IF EXISTS " if self.if_exists else "" + return f"DROP TABLE {ie}{c.compile(self.path)}" + @dataclass class InsertToTable(Statement): @@ -632,4 +638,4 @@ class InsertToTable(Statement): expr: Expr def compile(self, c: Compiler) -> str: - return f'INSERT INTO {c.compile(self.path)} {c.compile(self.expr)}' + return f"INSERT INTO {c.compile(self.path)} {c.compile(self.expr)}" diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index 62430880..2c48cb86 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -13,7 +13,7 @@ class Compiler: database: AbstractDialect in_select: bool = False # Compilation runtime flag - in_join: bool = False # Compilation runtime flag + in_join: bool = False # Compilation runtime flag _table_context: List = [] # List[ITable] _subqueries: Dict[str, Any] = {} # XXX not thread-safe @@ -32,7 +32,7 @@ def compile(self, elem) -> str: return f"WITH {subq}\n{res}" return res - def _compile(self, elem) -> Union[str, 'ThreadLocalInterpreter']: + def _compile(self, elem) -> Union[str, "ThreadLocalInterpreter"]: if elem is None: return "NULL" elif isinstance(elem, Compilable): diff --git a/tests/test_query.py b/tests/test_query.py index 3e895bc5..5091843e 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -93,12 +93,12 @@ def test_schema(self): self.assertRaises(KeyError, q.__getitem__, "comment") # test join - s = CaseInsensitiveDict({'x': int, 'y': int}) + s = CaseInsensitiveDict({"x": int, "y": int}) a = table("a", schema=s) b = table("b", schema=s) keys = ["x", "y"] - j = outerjoin(a, b).on(a[k] == b[k] for k in keys).select(a['x'], b['y'], xsum=a['x'] + b['x']) - j['x'], j['y'], j['xsum'] + j = outerjoin(a, b).on(a[k] == b[k] for k in keys).select(a["x"], b["y"], xsum=a["x"] + b["x"]) + j["x"], j["y"], j["xsum"] self.assertRaises(KeyError, j.__getitem__, "ysum") def test_commutable_select(self): @@ -145,8 +145,8 @@ def test_funcs(self): def test_union_all(self): c = Compiler(MockDialect()) - a = table("a").select('x') - b = table("b").select('y') + a = table("a").select("x") + b = table("b").select("y") q = c.compile(a.union(b)) assert q == "SELECT x FROM a UNION SELECT y FROM b" From 00ee4158fc07d2e454145e0c401d607d410e3052 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 4 Oct 2022 17:47:11 +0300 Subject: [PATCH 21/33] Joindiff: Ran black --- data_diff/__main__.py | 30 ++++++++++++----- data_diff/databases/base.py | 3 +- data_diff/databases/presto.py | 2 +- data_diff/diff_tables.py | 23 +++++++++---- data_diff/hashdiff_tables.py | 18 +++++++--- data_diff/joindiff_tables.py | 63 ++++++++++++++++++----------------- data_diff/utils.py | 8 +++-- tests/test_joindiff.py | 18 ++++++++-- 8 files changed, 107 insertions(+), 58 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 437bc67a..c481fdee 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -48,10 +48,10 @@ def _get_schema(pair): def diff_schemas(schema1, schema2, columns): - logging.info('Diffing schemas...') - attrs = 'name', 'type', 'datetime_precision', 'numeric_precision', 'numeric_scale' + logging.info("Diffing schemas...") + attrs = "name", "type", "datetime_precision", "numeric_precision", "numeric_scale" for c in columns: - if c is None: # Skip for convenience + if c is None: # Skip for convenience continue diffs = [] for attr, v1, v2 in safezip(attrs, schema1[c], schema2[c]): @@ -60,6 +60,7 @@ def diff_schemas(schema1, schema2, columns): if diffs: logging.warning(f"Schema mismatch in column '{c}': {', '.join(diffs)}") + class MyHelpFormatter(click.HelpFormatter): def __init__(self, **kwargs): super().__init__(self, **kwargs) @@ -106,7 +107,13 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - help=f"Minimal bisection threshold. Below it, data-diff will download the data and compare it locally. Default={DEFAULT_BISECTION_THRESHOLD}.", metavar="NUM", ) -@click.option("-m", "--materialize", default=None, metavar="TABLE_NAME", help="Materialize the diff results into a new table in the database.") +@click.option( + "-m", + "--materialize", + default=None, + metavar="TABLE_NAME", + help="Materialize the diff results into a new table in the database. (joindiff only)", +) @click.option( "--min-age", default=None, @@ -266,8 +273,8 @@ def _main( differ = JoinDiffer( threaded=threaded, max_threadpool_size=threads and threads * 2, - validate_unique_key = not assume_unique_key, - materialize_to_table = materialize and parse_table_name(eval_name_template(materialize)), + validate_unique_key=not assume_unique_key, + materialize_to_table=materialize and parse_table_name(eval_name_template(materialize)), ) else: assert algorithm == Algorithm.HASHDIFF @@ -326,8 +333,15 @@ def _main( columns = tuple(expanded_columns - {key_column, update_column}) if db1 is db2: - diff_schemas(schema1, schema2, (key_column, update_column,) + columns) - + diff_schemas( + schema1, + schema2, + ( + key_column, + update_column, + ) + + columns, + ) logging.info(f"Diffing using columns: key={key_column} update={update_column} extra={columns}") diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index bd33165f..2956ab63 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -27,7 +27,7 @@ DbPath, ) -from data_diff.queries import Expr, Compiler, table, Select, SKIP, ThreadLocalInterpreter +from data_diff.queries import Expr, Compiler, table, Select, SKIP, ThreadLocalInterpreter logger = logging.getLogger("database") @@ -75,6 +75,7 @@ def _query_cursor(c, sql_code): logger.exception(e) raise + def _query_conn(conn, sql_code: Union[str, ThreadLocalInterpreter]) -> list: c = conn.cursor() diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index c990e06e..85ec4c7c 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -11,6 +11,7 @@ TIMESTAMP_PRECISION_POS, ) + def query_cursor(c, sql_code): c.execute(sql_code) if sql_code.lower().startswith("select"): @@ -87,7 +88,6 @@ def _query(self, sql_code: str) -> list: return query_cursor(c, sql_code) - def close(self): self._conn.close() diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 5ecd1667..7ca0646a 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -20,6 +20,7 @@ logger = getLogger(__name__) + class Algorithm(Enum): AUTO = "auto" JOINDIFF = "joindiff" @@ -28,8 +29,9 @@ class Algorithm(Enum): DiffResult = Iterator[Tuple[str, tuple]] # Iterator[Tuple[Literal["+", "-"], tuple]] + def truncate_error(error: str): - first_line = error.split('\n', 1)[0] + first_line = error.split("\n", 1)[0] return re.sub("'(.*?)'", "'***'", first_line) @@ -137,12 +139,19 @@ def _validate_and_adjust_columns(self, table1: TableSegment, table2: TableSegmen def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: return self._bisect_and_diff_tables(table1, table2) - @abstractmethod - def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): + def _diff_segments( + self, + ti: ThreadedYielder, + table1: TableSegment, + table2: TableSegment, + max_rows: int, + level=0, + segment_index=None, + segment_count=None, + ): ... - def _bisect_and_diff_tables(self, table1, table2): key_type = table1._schema[table1.key_column] key_type2 = table2._schema[table2.key_column] @@ -183,7 +192,6 @@ def _bisect_and_diff_tables(self, table1, table2): return ti - def _parse_key_range_result(self, key_type, key_range): mn, mx = key_range cls = key_type.make_value @@ -193,8 +201,9 @@ def _parse_key_range_result(self, key_type, key_range): except (TypeError, ValueError) as e: raise type(e)(f"Cannot apply {key_type} to '{mn}', '{mx}'.") from e - - def _bisect_and_diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None): + def _bisect_and_diff_segments( + self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None + ): assert table1.is_bounded and table2.is_bounded # Choose evenly spaced checkpoints (according to min_key and max_key) diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index 64b05b67..78b33e17 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -66,8 +66,6 @@ def __post_init__(self): if self.bisection_factor < 2: raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)") - - def _validate_and_adjust_columns(self, table1, table2): for c1, c2 in safezip(table1._relevant_columns, table2._relevant_columns): if c1 not in table1._schema: @@ -115,8 +113,16 @@ def _validate_and_adjust_columns(self, table1, table2): "If encoding/formatting differs between databases, it may result in false positives." ) - - def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): + def _diff_segments( + self, + ti: ThreadedYielder, + table1: TableSegment, + table2: TableSegment, + max_rows: int, + level=0, + segment_index=None, + segment_count=None, + ): logger.info( ". " * level + f"Diffing segment {segment_index}/{segment_count}, " f"key-range: {table1.min_key}..{table2.max_key}, " @@ -148,7 +154,9 @@ def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: Tabl if checksum1 != checksum2: return self._bisect_and_diff_segments(ti, table1, table2, level=level, max_rows=max(count1, count2)) - def _bisect_and_diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None): + def _bisect_and_diff_segments( + self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None + ): assert table1.is_bounded and table2.is_bounded if max_rows is None: diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 1afb9467..f488e945 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -50,6 +50,7 @@ class Stats: def sample(table): return table.order_by(Random()).limit(10) + def create_temp_table(c: Compiler, name: str, expr: Expr): db = c.database if isinstance(db, BigQuery): @@ -67,12 +68,13 @@ def drop_table(db, name: DbPath): t = TablePath(name) db.query(t.drop(if_exists=True)) + def append_to_table(name: DbPath, expr: Expr): t = TablePath(name, expr.schema) yield t.create(if_not_exists=True) # uses expr.schema - yield 'commit' + yield "commit" yield t.insert_expr(expr) - yield 'commit' + yield "commit" def bool_to_int(x): @@ -95,10 +97,7 @@ def _outerjoin(db: Database, a: ITable, b: ITable, keys1: List[str], keys2: List r = rightjoin(a, b).on(*on).select(is_exclusive_a=False, is_exclusive_b=is_exclusive_b, **select_fields) return l.union(r) - return ( - outerjoin(a, b).on(*on) - .select(is_exclusive_a=is_exclusive_a, is_exclusive_b=is_exclusive_b, **select_fields) - ) + return outerjoin(a, b).on(*on).select(is_exclusive_a=is_exclusive_a, is_exclusive_b=is_exclusive_b, **select_fields) def _slice_tuple(t, *sizes): @@ -115,7 +114,6 @@ def json_friendly_value(v): return v - @dataclass class JoinDiffer(TableDiffer): """Finds the diff between two SQL tables in the same database, using JOINs. @@ -143,11 +141,10 @@ def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult table1, table2 = self._threaded_call("with_schema", [table1, table2]) - bg_funcs = [partial(self._test_duplicate_keys, table1, table2)] if self.validate_unique_key else [] if self.materialize_to_table: drop_table(db, self.materialize_to_table) - db.query('COMMIT') + db.query("COMMIT") with self._run_in_background(*bg_funcs): @@ -158,7 +155,16 @@ def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult yield from self._bisect_and_diff_tables(table1, table2) logger.info("Diffing complete") - def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None): + def _diff_segments( + self, + ti: ThreadedYielder, + table1: TableSegment, + table2: TableSegment, + max_rows: int, + level=0, + segment_index=None, + segment_count=None, + ): assert table1.database is table2.database if segment_index or table1.min_key or max_rows: @@ -172,13 +178,15 @@ def _diff_segments(self, ti: ThreadedYielder, table1: TableSegment, table2: Tabl diff_rows, a_cols, b_cols, is_diff_cols = self._create_outer_join(table1, table2) with self._run_in_background( - partial(self._collect_stats, 1, table1), - partial(self._collect_stats, 2, table2), - partial(self._test_null_keys, table1, table2), - partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols), - partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols), - partial(self._materialize_diff, db, diff_rows, segment_index=segment_index) if self.materialize_to_table else None, - ): + partial(self._collect_stats, 1, table1), + partial(self._collect_stats, 2, table2), + partial(self._test_null_keys, table1, table2), + partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols), + partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols), + partial(self._materialize_diff, db, diff_rows, segment_index=segment_index) + if self.materialize_to_table + else None, + ): logger.debug("Querying for different rows") for is_xa, is_xb, *x in db.query(diff_rows, list): @@ -218,7 +226,6 @@ def _test_null_keys(self, table1, table2): if nulls: raise ValueError(f"NULL values in one or more primary keys") - def _collect_stats(self, i, table): logger.info(f"Collecting stats for table #{i}") db = table.database @@ -265,31 +272,27 @@ def _create_outer_join(self, table1, table2): a = table1._make_select() b = table2._make_select() - is_diff_cols = { - f"is_diff_{c1}": bool_to_int(a[c1].is_distinct_from(b[c2])) for c1, c2 in safezip(cols1, cols2) - } + is_diff_cols = {f"is_diff_{c1}": bool_to_int(a[c1].is_distinct_from(b[c2])) for c1, c2 in safezip(cols1, cols2)} a_cols = {f"table1_{c}": NormalizeAsString(a[c]) for c in cols1} b_cols = {f"table2_{c}": NormalizeAsString(b[c]) for c in cols2} - diff_rows = ( - _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols}) - .where(or_(this[c] == 1 for c in is_diff_cols)) + diff_rows = _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols}).where( + or_(this[c] == 1 for c in is_diff_cols) ) return diff_rows, a_cols, b_cols, is_diff_cols - def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): logger.info("Counting differences per column") is_diff_cols_counts = db.query(diff_rows.select(sum_(this[c]) for c in is_diff_cols), tuple) diff_counts = {} for name, count in safezip(cols, is_diff_cols_counts): diff_counts[name] = diff_counts.get(name, 0) + (count or 0) - self.stats['diff_counts'] = diff_counts + self.stats["diff_counts"] = diff_counts def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): if isinstance(db, Oracle): - exclusive_rows_query = diff_rows.where((this.is_exclusive_a==1) | (this.is_exclusive_b==1)) + exclusive_rows_query = diff_rows.where((this.is_exclusive_a == 1) | (this.is_exclusive_b == 1)) else: exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b) @@ -299,6 +302,7 @@ def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): return logger.info("Counting and sampling exclusive rows") + def exclusive_rows(expr): c = Compiler(db) name = c.new_unique_table_name("temp_table") @@ -306,9 +310,9 @@ def exclusive_rows(expr): exclusive_rows = table(name, schema=expr.source_table.schema) count = yield exclusive_rows.count() - self.stats["exclusive_count"] = self.stats.get('exclusive_count', 0) + count[0][0] + self.stats["exclusive_count"] = self.stats.get("exclusive_count", 0) + count[0][0] sample_rows = yield sample(exclusive_rows.select(*this[list(a_cols)], *this[list(b_cols)])) - self.stats["exclusive_sample"] = self.stats.get('exclusive_sample', []) + sample_rows + self.stats["exclusive_sample"] = self.stats.get("exclusive_sample", []) + sample_rows # Only drops if create table succeeded (meaning, the table didn't already exist) yield f"drop table {c.quote(name)}" @@ -321,4 +325,3 @@ def _materialize_diff(self, db, diff_rows, segment_index=None): db.query(append_to_table(self.materialize_to_table, diff_rows.limit(self.write_limit))) logger.info(f"Materialized diff to table '{'.'.join(self.materialize_to_table)}'.") - diff --git a/data_diff/utils.py b/data_diff/utils.py index ca05e051..8224d270 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -295,9 +295,11 @@ def run_as_daemon(threadfunc, *args): def getLogger(name): - return logging.getLogger(name.rsplit('.', 1)[-1]) + return logging.getLogger(name.rsplit(".", 1)[-1]) + def eval_name_template(name): def get_timestamp(m): - return datetime.now().isoformat('_', 'seconds').replace(':', '_') - return re.sub('%t', get_timestamp, name) + return datetime.now().isoformat("_", "seconds").replace(":", "_") + + return re.sub("%t", get_timestamp, name) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index e8db3167..d9726c85 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -25,7 +25,20 @@ def init_instances(): DATABASE_INSTANCES = {k.__name__: connect(v, N_THREADS) for k, v in CONN_STRINGS.items()} -TEST_DATABASES = {x.__name__ for x in (db.PostgreSQL, db.Snowflake, db.MySQL, db.BigQuery, db.Presto, db.Vertica, db.Trino, db.Oracle, db.Redshift)} +TEST_DATABASES = { + x.__name__ + for x in ( + db.PostgreSQL, + db.Snowflake, + db.MySQL, + db.BigQuery, + db.Presto, + db.Vertica, + db.Trino, + db.Oracle, + db.Redshift, + ) +} _class_per_db_dec = parameterized_class( ("name", "db_name"), [(name, name) for name in DATABASE_URIS if name in TEST_DATABASES] @@ -179,14 +192,13 @@ def test_dup_pks(self): x = self.differ.diff_tables(self.table, self.table2) self.assertRaises(ValueError, list, x) - def test_null_pks(self): time = "2022-01-01 00:00:00" time_str = f"timestamp '{time}'" cols = "id rating timestamp".split() - _insert_row(self.connection, self.table_src, cols, ['null', 9, time_str]) + _insert_row(self.connection, self.table_src, cols, ["null", 9, time_str]) _insert_row(self.connection, self.table_dst, cols, [1, 9, time_str]) x = self.differ.diff_tables(self.table, self.table2) From 78e4c84068da493cec3e99b21ed59120e79db16b Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 5 Oct 2022 11:08:56 +0300 Subject: [PATCH 22/33] Many fixes; Added materialize tests; Now works for : postgresql, mysql, bigquery, presto, trino, snowflake, oracle, redshift --- data_diff/__main__.py | 34 +++++----- data_diff/databases/base.py | 93 ++++++++++++++++++++------- data_diff/databases/bigquery.py | 21 +++++- data_diff/databases/database_types.py | 6 +- data_diff/databases/databricks.py | 4 +- data_diff/databases/mysql.py | 8 +++ data_diff/databases/oracle.py | 12 +++- data_diff/databases/presto.py | 18 ++---- data_diff/databases/snowflake.py | 9 ++- data_diff/joindiff_tables.py | 57 +++++++++++----- data_diff/queries/__init__.py | 4 +- data_diff/queries/api.py | 2 + data_diff/queries/ast_classes.py | 21 +++--- data_diff/queries/base.py | 6 +- data_diff/queries/compiler.py | 31 ++------- data_diff/utils.py | 3 + tests/test_diff_tables.py | 8 +++ tests/test_joindiff.py | 19 +++++- 18 files changed, 233 insertions(+), 123 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index c481fdee..c6c8fefe 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -9,8 +9,6 @@ import rich import click -from data_diff.databases.base import parse_table_name - from .utils import eval_name_template, remove_password_from_url, safezip, match_like from .diff_tables import Algorithm from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR @@ -269,22 +267,6 @@ def _main( logging.error(f"Error while parsing age expression: {e}") return - if algorithm == Algorithm.JOINDIFF: - differ = JoinDiffer( - threaded=threaded, - max_threadpool_size=threads and threads * 2, - validate_unique_key=not assume_unique_key, - materialize_to_table=materialize and parse_table_name(eval_name_template(materialize)), - ) - else: - assert algorithm == Algorithm.HASHDIFF - differ = HashDiffer( - bisection_factor=bisection_factor, - bisection_threshold=bisection_threshold, - threaded=threaded, - max_threadpool_size=threads and threads * 2, - ) - if database1 is None or database2 is None: logging.error( f"Error: Databases not specified. Got {database1} and {database2}. Use --help for more information." @@ -307,6 +289,22 @@ def _main( for db in dbs: db.enable_interactive() + if algorithm == Algorithm.JOINDIFF: + differ = JoinDiffer( + threaded=threaded, + max_threadpool_size=threads and threads * 2, + validate_unique_key=not assume_unique_key, + materialize_to_table=materialize and db1.parse_table_name(eval_name_template(materialize)), + ) + else: + assert algorithm == Algorithm.HASHDIFF + differ = HashDiffer( + bisection_factor=bisection_factor, + bisection_threshold=bisection_threshold, + threaded=threaded, + max_threadpool_size=threads and threads * 2, + ) + table_names = table1, table2 table_paths = [db.parse_table_name(t) for db, t in safezip(dbs, table_names)] diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 2956ab63..79a17578 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -1,8 +1,8 @@ import math import sys import logging -from typing import Dict, Generator, Tuple, Optional, Sequence, Type, List, Union -from functools import wraps +from typing import Any, Callable, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union +from functools import partial, wraps from concurrent.futures import ThreadPoolExecutor import threading from abc import abstractmethod @@ -27,7 +27,7 @@ DbPath, ) -from data_diff.queries import Expr, Compiler, table, Select, SKIP, ThreadLocalInterpreter +from data_diff.queries import Expr, Compiler, table, Select, SKIP logger = logging.getLogger("database") @@ -66,30 +66,39 @@ def _one(seq): return x -def _query_cursor(c, sql_code): - try: - c.execute(sql_code) - if sql_code.lower().startswith("select"): - return c.fetchall() - except Exception as e: - logger.exception(e) - raise +class ThreadLocalInterpreter: + """An interpeter used to execute a sequence of queries within the same thread. + Useful for cursor-sensitive operations, such as creating a temporary table. + """ -def _query_conn(conn, sql_code: Union[str, ThreadLocalInterpreter]) -> list: - c = conn.cursor() + def __init__(self, compiler: Compiler, gen: Generator): + self.gen = gen + self.compiler = compiler - if isinstance(sql_code, ThreadLocalInterpreter): - g = sql_code.interpret() - q = next(g) + def apply_queries(self, callback: Callable[[str], Any]): + q: Expr = next(self.gen) while True: - res = _query_cursor(c, q) + sql = self.compiler.compile(q) try: - q = g.send(res) + try: + res = callback(sql) if sql is not SKIP else SKIP + except Exception as e: + q = self.gen.throw(type(e), e) + else: + q = self.gen.send(res) except StopIteration: break + + +def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocalInterpreter]) -> list: + if isinstance(sql_code, ThreadLocalInterpreter): + return sql_code.apply_queries(callback) else: - return _query_cursor(c, sql_code) + return callback(sql_code) + + + class Database(AbstractDatabase): @@ -108,11 +117,17 @@ class Database(AbstractDatabase): def name(self): return type(self).__name__ - def query(self, sql_ast: Expr, res_type: type = None): + def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): "Query the given SQL code/AST, and attempt to convert the result to type 'res_type'" compiler = Compiler(self) - sql_code = compiler.compile(sql_ast) + if isinstance(sql_ast, Generator): + sql_code = ThreadLocalInterpreter(compiler, sql_ast) + else: + sql_code = compiler.compile(sql_ast) + if sql_code is SKIP: + return SKIP + logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code) if getattr(self, "_interactive", False) and isinstance(sql_ast, Select): explained_sql = compiler.compile(Explain(sql_ast)) @@ -134,7 +149,7 @@ def query(self, sql_ast: Expr, res_type: type = None): elif getattr(res_type, "__origin__", None) is list and len(res_type.__args__) == 1: if res_type.__args__ in ((int,), (str,)): return [_one(row) for row in res] - elif res_type.__args__ == (Tuple,): + elif res_type.__args__ in [(Tuple,), (tuple,)]: return [tuple(row) for row in res] else: raise ValueError(res_type) @@ -311,6 +326,34 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: def random(self) -> str: return "RANDOM()" + def type_repr(self, t) -> str: + if isinstance(t, str): + return t + return { + int: "INT", + str: "VARCHAR", + bool: "BOOLEAN", + float: "FLOAT", + }[t] + + def _query_cursor(self, c, sql_code: str): + assert isinstance(sql_code, str), sql_code + try: + c.execute(sql_code) + if sql_code.lower().startswith("select"): + return c.fetchall() + except Exception as e: + # logger.exception(e) + # logger.error(f'Caused by SQL: {sql_code}') + raise + + def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> list: + c = conn.cursor() + callback = partial(self._query_cursor, c) + return apply_query(callback, sql_code) + + + class ThreadedDatabase(Database): """Access the database through singleton threads. @@ -339,7 +382,7 @@ def _query_in_worker(self, sql_code: Union[str, ThreadLocalInterpreter]): "This method runs in a worker thread" if self._init_error: raise self._init_error - return _query_conn(self.thread_local.conn, sql_code) + return self._query_conn(self.thread_local.conn, sql_code) @abstractmethod def create_connection(self): @@ -348,6 +391,10 @@ def create_connection(self): def close(self): self._queue.shutdown() + @property + def is_autocommit(self) -> bool: + return False + CHECKSUM_HEXDIGITS = 15 # Must be 15 or lower MD5_HEXDIGITS = 32 diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 218c9cb4..7044c084 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -1,6 +1,6 @@ from .database_types import * -from .base import Database, import_helper, parse_table_name, ConnectError -from .base import TIMESTAMP_PRECISION_POS +from .base import Database, import_helper, parse_table_name, ConnectError, apply_query +from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter @import_helper(text="Please install BigQuery and configure your google-cloud access.") @@ -47,7 +47,7 @@ def _normalize_returned_value(self, value): return value.decode() return value - def _query(self, sql_code: str): + def _query_atom(self, sql_code: str): from google.cloud import bigquery try: @@ -60,6 +60,9 @@ def _query(self, sql_code: str): res = [tuple(self._normalize_returned_value(v) for v in row.values()) for row in res] return res + def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): + return apply_query(self._query_atom, sql_code) + def to_string(self, s: str): return f"cast({s} as string)" @@ -98,3 +101,15 @@ def parse_table_name(self, name: str) -> DbPath: def random(self) -> str: return "RAND()" + + @property + def is_autocommit(self) -> bool: + return True + + def type_repr(self, t) -> str: + try: + return { + str: "STRING", + }[t] + except KeyError: + return super().type_repr(t) diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index 1e9c973e..7fe436ae 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -1,6 +1,6 @@ import logging import decimal -from abc import ABC, abstractmethod +from abc import ABC, abstractmethod, abstractproperty from typing import Sequence, Optional, Tuple, Union, Dict, List from datetime import datetime @@ -293,6 +293,10 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: def _normalize_table_path(self, path: DbPath) -> DbPath: ... + @abstractproperty + def is_autocommit(self) -> bool: + ... + Schema = CaseAwareMapping diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index b0ee9fa5..5d381b66 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -1,7 +1,7 @@ import logging from .database_types import * -from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Database, import_helper, _query_conn, parse_table_name +from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Database, import_helper, parse_table_name @import_helper(text="You can install it using 'pip install databricks-sql-connector'") @@ -52,7 +52,7 @@ def __init__( def _query(self, sql_code: str) -> list: "Uses the standard SQL cursor interface" - return _query_conn(self._conn, sql_code) + return self._query_conn(self._conn, sql_code) def quote(self, s: str): return f"`{s}`" diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 07c34aaf..b34afb36 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -76,3 +76,11 @@ def is_distinct_from(self, a: str, b: str) -> str: def random(self) -> str: return "RAND()" + + def type_repr(self, t) -> str: + try: + return { + str: "VARCHAR(1024)", + }[t] + except KeyError: + return super().type_repr(t) diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 79f7bf31..59004412 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -43,9 +43,9 @@ def create_connection(self): except Exception as e: raise ConnectError(*e.args) from e - def _query(self, sql_code: str): + def _query_cursor(self, c, sql_code: str): try: - return super()._query(sql_code) + return super()._query_cursor(c, sql_code) except self._oracle.DatabaseError as e: raise QueryError(e) @@ -130,3 +130,11 @@ def random(self) -> str: def is_distinct_from(self, a: str, b: str) -> str: return f"DECODE({a}, {b}, 1, 0) = 0" + + def type_repr(self, t) -> str: + try: + return { + str: "VARCHAR(1024)", + }[t] + except KeyError: + return super().type_repr(t) diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index 85ec4c7c..2fb041fc 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -1,10 +1,10 @@ +from functools import partial import re from data_diff.utils import match_regexps -from data_diff.queries import ThreadLocalInterpreter from .database_types import * -from .base import Database, import_helper +from .base import Database, import_helper, ThreadLocalInterpreter from .base import ( MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, @@ -75,16 +75,7 @@ def _query(self, sql_code: str) -> list: c = self._conn.cursor() if isinstance(sql_code, ThreadLocalInterpreter): - # TODO reuse code from base.py - g = sql_code.interpret() - q = next(g) - while True: - res = query_cursor(c, q) - try: - q = g.send(res) - except StopIteration: - break - return + return sql_code.apply_queries(partial(query_cursor, c)) return query_cursor(c, sql_code) @@ -142,3 +133,6 @@ def _parse_type( def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Trim doesn't work on CHAR type return f"TRIM(CAST({value} AS VARCHAR))" + + def is_autocommit(self) -> bool: + return False diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 9b03d833..bbd0958c 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,7 +1,7 @@ import logging from .database_types import * -from .base import ConnectError, Database, import_helper, _query_conn, CHECKSUM_MASK +from .base import ConnectError, Database, import_helper, CHECKSUM_MASK, ThreadLocalInterpreter @import_helper("snowflake") @@ -60,9 +60,9 @@ def __init__(self, *, schema: str, **kw): def close(self): self._conn.close() - def _query(self, sql_code: str) -> list: + def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): "Uses the standard SQL cursor interface" - return _query_conn(self._conn, sql_code) + return self._query_conn(self._conn, sql_code) def quote(self, s: str): return f'"{s}"' @@ -87,3 +87,6 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: def normalize_number(self, value: str, coltype: FractionalType) -> str: return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") + + def is_autocommit(self) -> bool: + return True diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index f488e945..c42826c6 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -2,6 +2,7 @@ """ +from contextlib import suppress from decimal import Decimal from functools import partial import logging @@ -10,6 +11,7 @@ from runtype import dataclass from data_diff.databases.database_types import DbPath, Schema +from data_diff.databases.base import QueryError from .utils import safezip @@ -19,7 +21,7 @@ from .diff_tables import TableDiffer, DiffResult from .thread_utils import ThreadedYielder -from .queries import table, sum_, min_, max_, avg, SKIP +from .queries import table, sum_, min_, max_, avg, SKIP, commit from .queries.api import and_, if_, or_, outerjoin, leftjoin, rightjoin, this, ITable from .queries.ast_classes import Concat, Count, Expr, Random, TablePath from .queries.compiler import Compiler @@ -51,30 +53,48 @@ def sample(table): return table.order_by(Random()).limit(10) -def create_temp_table(c: Compiler, name: str, expr: Expr): +def create_temp_table(c: Compiler, table: TablePath, expr: Expr): db = c.database if isinstance(db, BigQuery): - name = f"{db.default_schema}.{name}" - return f"create table {c.quote(name)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}" + return f"create table {c.compile(table)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}" elif isinstance(db, Presto): - return f"create table {c.quote(name)} as {c.compile(expr)}" + return f"create table {c.compile(table)} as {c.compile(expr)}" elif isinstance(db, Oracle): - return f"create global temporary table {c.quote(name)} as {c.compile(expr)}" + return f"create global temporary table {c.compile(table)} as {c.compile(expr)}" else: - return f"create temporary table {c.quote(name)} as {c.compile(expr)}" + return f"create temporary table {c.compile(table)} as {c.compile(expr)}" -def drop_table(db, name: DbPath): +def drop_table_oracle(name: DbPath): t = TablePath(name) - db.query(t.drop(if_exists=True)) + # Experience shows double drop is necessary + with suppress(QueryError): + yield t.drop() + yield t.drop() + yield commit +def drop_table(name: DbPath): + t = TablePath(name) + yield t.drop(if_exists=True) + yield commit + + +def append_to_table_oracle(name: DbPath, expr: Expr): + assert expr.schema, expr + t = TablePath(name, expr.schema) + with suppress(QueryError): + yield t.create() # uses expr.schema + yield commit + yield t.insert_expr(expr) + yield commit def append_to_table(name: DbPath, expr: Expr): + assert expr.schema, expr t = TablePath(name, expr.schema) yield t.create(if_not_exists=True) # uses expr.schema - yield "commit" + yield commit yield t.insert_expr(expr) - yield "commit" + yield commit def bool_to_int(x): @@ -143,8 +163,10 @@ def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult bg_funcs = [partial(self._test_duplicate_keys, table1, table2)] if self.validate_unique_key else [] if self.materialize_to_table: - drop_table(db, self.materialize_to_table) - db.query("COMMIT") + if isinstance(db, Oracle): + db.query(drop_table_oracle(self.materialize_to_table)) + else: + db.query(drop_table(self.materialize_to_table)) with self._run_in_background(*bg_funcs): @@ -306,8 +328,8 @@ def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): def exclusive_rows(expr): c = Compiler(db) name = c.new_unique_table_name("temp_table") - yield create_temp_table(c, name, expr.limit(self.write_limit)) - exclusive_rows = table(name, schema=expr.source_table.schema) + exclusive_rows = TablePath(name, schema=expr.source_table.schema) + yield create_temp_table(c, exclusive_rows, expr.limit(self.write_limit)) count = yield exclusive_rows.count() self.stats["exclusive_count"] = self.stats.get("exclusive_count", 0) + count[0][0] @@ -315,7 +337,7 @@ def exclusive_rows(expr): self.stats["exclusive_sample"] = self.stats.get("exclusive_sample", []) + sample_rows # Only drops if create table succeeded (meaning, the table didn't already exist) - yield f"drop table {c.quote(name)}" + yield exclusive_rows.drop() # Run as a sequence of thread-local queries (compiled into a ThreadLocalInterpreter) db.query(exclusive_rows(exclusive_rows_query), None) @@ -323,5 +345,6 @@ def exclusive_rows(expr): def _materialize_diff(self, db, diff_rows, segment_index=None): assert self.materialize_to_table - db.query(append_to_table(self.materialize_to_table, diff_rows.limit(self.write_limit))) + f = append_to_table_oracle if isinstance(db, Oracle) else append_to_table + db.query(f(self.materialize_to_table, diff_rows.limit(self.write_limit))) logger.info(f"Materialized diff to table '{'.'.join(self.materialize_to_table)}'.") diff --git a/data_diff/queries/__init__.py b/data_diff/queries/__init__.py index 64a6e60f..172e73e4 100644 --- a/data_diff/queries/__init__.py +++ b/data_diff/queries/__init__.py @@ -1,4 +1,4 @@ -from .compiler import Compiler, ThreadLocalInterpreter -from .api import this, join, outerjoin, table, SKIP, sum_, avg, min_, max_, cte +from .compiler import Compiler +from .api import this, join, outerjoin, table, SKIP, sum_, avg, min_, max_, cte, commit from .ast_classes import Expr, ExprNode, Select, Count, BinOp, Explain, In from .extras import Checksum, NormalizeAsString, ApplyFuncAndNormalizeAsString diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index a07a9084..c433f548 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -67,3 +67,5 @@ def max_(expr: Expr): def if_(cond: Expr, then: Expr, else_: Optional[Expr] = None): return CaseWhen([(cond, then)], else_=else_) + +commit = Commit() diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index b3552620..7bd9a520 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -186,7 +186,7 @@ class CaseWhen(ExprNode): def compile(self, c: Compiler) -> str: assert self.cases when_thens = " ".join(f"WHEN {c.compile(when)} THEN {c.compile(then)}" for when, then in self.cases) - else_ = (" " + c.compile(self.else_)) if self.else_ else "" + else_ = (" ELSE " + c.compile(self.else_)) if self.else_ is not None else "" return f"CASE {when_thens}{else_} END" @property @@ -600,23 +600,13 @@ class Statement(Compilable): type = None -def to_sql_type(t): - if isinstance(t, str): - return t - return { - int: "int", - str: "varchar(1024)", - bool: "boolean", - }[t] - - @dataclass class CreateTable(Statement): path: TablePath if_not_exists: bool = False def compile(self, c: Compiler) -> str: - schema = ", ".join(f"{k} {to_sql_type(v)}" for k, v in self.path.schema.items()) + schema = ", ".join(f"{k} {c.database.type_repr(v)}" for k, v in self.path.schema.items()) ne = "IF NOT EXISTS " if self.if_not_exists else "" return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})" @@ -639,3 +629,10 @@ class InsertToTable(Statement): def compile(self, c: Compiler) -> str: return f"INSERT INTO {c.compile(self.path)} {c.compile(self.expr)}" + + +@dataclass +class Commit(Statement): + + def compile(self, c: Compiler) -> str: + return "COMMIT" if not c.database.is_autocommit else SKIP diff --git a/data_diff/queries/base.py b/data_diff/queries/base.py index 50a57e2f..b5d02bb6 100644 --- a/data_diff/queries/base.py +++ b/data_diff/queries/base.py @@ -3,7 +3,11 @@ from data_diff.databases.database_types import DbPath, DbKey, Schema -SKIP = object() +class _SKIP: + def __repr__(self): + return 'SKIP' + +SKIP = _SKIP() class CompileError(Exception): diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index 2c48cb86..e6e3e236 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -1,12 +1,12 @@ import random from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Dict, Generator, Sequence, List, Union +from typing import Any, Dict, Sequence, List, Union from runtype import dataclass from data_diff.utils import ArithString -from data_diff.databases.database_types import AbstractDialect +from data_diff.databases.database_types import AbstractDialect, DbPath @dataclass @@ -32,7 +32,7 @@ def compile(self, elem) -> str: return f"WITH {subq}\n{res}" return res - def _compile(self, elem) -> Union[str, "ThreadLocalInterpreter"]: + def _compile(self, elem) -> str: if elem is None: return "NULL" elif isinstance(elem, Compilable): @@ -47,17 +47,15 @@ def _compile(self, elem) -> Union[str, "ThreadLocalInterpreter"]: return f"b'{elem.decode()}'" elif isinstance(elem, ArithString): return f"'{elem}'" - elif isinstance(elem, Generator): - return ThreadLocalInterpreter(self, elem) assert False, elem def new_unique_name(self, prefix="tmp"): self._counter[0] += 1 return f"{prefix}{self._counter[0]}" - def new_unique_table_name(self, prefix="tmp"): + def new_unique_table_name(self, prefix="tmp") -> DbPath: self._counter[0] += 1 - return f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}" + return self.database.parse_table_name(f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}") def add_table_context(self, *tables: Sequence): return self.replace(_table_context=self._table_context + list(tables)) @@ -68,22 +66,3 @@ class Compilable(ABC): def compile(self, c: Compiler) -> str: ... - -class ThreadLocalInterpreter: - """An interpeter used to execute a sequence of queries within the same thread. - - Useful for cursor-sensitive operations, such as creating a temporary table. - """ - - def __init__(self, compiler: Compiler, gen: Generator): - self.gen = gen - self.compiler = compiler - - def interpret(self): - q = next(self.gen) - while True: - try: - res = yield self.compiler.compile(q) - q = self.gen.send(res) - except StopIteration: - break diff --git a/data_diff/utils.py b/data_diff/utils.py index 8224d270..b572db1b 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -247,6 +247,9 @@ def keys(self) -> Iterable[str]: def items(self) -> Iterable[Tuple[str, V]]: return ((k, v[1]) for k, v in self._dict.items()) + def __len__(self): + return len(self._dict) + class CaseSensitiveDict(dict, CaseAwareMapping): def get_key(self, key): diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 7668ec61..639587de 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -515,6 +515,10 @@ def setUp(self): ] for i in range(0, 10000, 1000): a = ArithAlphanumeric(numberToAlphanum(i), max_len=10) + if not a and isinstance(self.connection, db.Oracle): + # Skip empty string, because Oracle treats it as NULL .. + continue + queries.append(f"INSERT INTO {self.table_src} VALUES ('{a}', '{i}')") queries += [ @@ -563,6 +567,10 @@ def setUp(self): ] for i in range(0, 10000, 1000): a = ArithAlphanumeric(numberToAlphanum(i * i)) + if not a and isinstance(self.connection, db.Oracle): + # Skip empty string, because Oracle treats it as NULL .. + continue + queries.append(f"INSERT INTO {self.table_src} VALUES ('{a}', '{i}')") queries += [ diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index d9726c85..139f62f0 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -1,6 +1,8 @@ +from typing import List from parameterized import parameterized_class from data_diff.databases.connect import connect +from data_diff.queries.ast_classes import TablePath from data_diff.table_segment import TableSegment, split_space from data_diff import databases as db from data_diff.joindiff_tables import JoinDiffer @@ -8,6 +10,7 @@ from .test_diff_tables import TestPerDatabase, _get_float_type, _get_text_type, _commit, _insert_row, _insert_rows from .common import ( + random_table_suffix, str_to_checksum, CONN_STRINGS, N_THREADS, @@ -80,11 +83,25 @@ def test_diff_small_tables(self): _insert_rows(self.connection, self.table_dst, cols, [[1, 1, 1, 9, time_str]]) _commit(self.connection) diff = list(self.differ.diff_tables(self.table, self.table2)) - expected = [("-", ("2", time + ".000000"))] + expected_row = ("2", time + ".000000") + expected = [("-", expected_row)] self.assertEqual(expected, diff) self.assertEqual(2, self.differ.stats["table1_count"]) self.assertEqual(1, self.differ.stats["table2_count"]) + # Test materialize + materialize_path = self.connection.parse_table_name(f'test_mat_{random_table_suffix()}') + mdiffer = self.differ.replace(materialize_to_table=materialize_path) + diff = list(mdiffer.diff_tables(self.table, self.table2)) + self.assertEqual(expected, diff) + + t = TablePath(materialize_path) + rows = self.connection.query( t.select(), List[tuple] ) + self.connection.query( t.drop() ) + # is_xa, is_xb, is_diff1, is_diff2, row1, row2 + assert rows == [(1, 0, 1, 1) + expected_row + (None, None)], rows + + def test_diff_table_above_bisection_threshold(self): time = "2022-01-01 00:00:00" time_str = f"timestamp '{time}'" From b18dbcb3b913500b1ba4c6ac18574ae329565542 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 5 Oct 2022 19:10:44 +0300 Subject: [PATCH 23/33] black --- data_diff/databases/base.py | 5 ----- data_diff/joindiff_tables.py | 2 ++ data_diff/queries/api.py | 1 + data_diff/queries/ast_classes.py | 3 ++- data_diff/queries/base.py | 3 ++- data_diff/queries/compiler.py | 1 - tests/test_joindiff.py | 7 +++---- 7 files changed, 10 insertions(+), 12 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 79a17578..376f5e78 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -98,9 +98,6 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal return callback(sql_code) - - - class Database(AbstractDatabase): """Base abstract class for databases. @@ -353,8 +350,6 @@ def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> lis return apply_query(callback, sql_code) - - class ThreadedDatabase(Database): """Access the database through singleton threads. diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index c42826c6..e0579845 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -73,6 +73,7 @@ def drop_table_oracle(name: DbPath): yield t.drop() yield commit + def drop_table(name: DbPath): t = TablePath(name) yield t.drop(if_exists=True) @@ -88,6 +89,7 @@ def append_to_table_oracle(name: DbPath, expr: Expr): yield t.insert_expr(expr) yield commit + def append_to_table(name: DbPath, expr: Expr): assert expr.schema, expr t = TablePath(name, expr.schema) diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index c433f548..60636346 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -68,4 +68,5 @@ def max_(expr: Expr): def if_(cond: Expr, then: Expr, else_: Optional[Expr] = None): return CaseWhen([(cond, then)], else_=else_) + commit = Commit() diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 7bd9a520..b5081fcb 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -31,11 +31,13 @@ def cast_to(self, to): Expr = Union[ExprNode, str, bool, int, datetime, ArithString, None] + def get_type(e: Expr) -> type: if isinstance(e, ExprNode): return e.type return type(e) + @dataclass class Alias(ExprNode): expr: Expr @@ -633,6 +635,5 @@ def compile(self, c: Compiler) -> str: @dataclass class Commit(Statement): - def compile(self, c: Compiler) -> str: return "COMMIT" if not c.database.is_autocommit else SKIP diff --git a/data_diff/queries/base.py b/data_diff/queries/base.py index b5d02bb6..7b0d96cb 100644 --- a/data_diff/queries/base.py +++ b/data_diff/queries/base.py @@ -5,7 +5,8 @@ class _SKIP: def __repr__(self): - return 'SKIP' + return "SKIP" + SKIP = _SKIP() diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index e6e3e236..02bb48bc 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -65,4 +65,3 @@ class Compilable(ABC): @abstractmethod def compile(self, c: Compiler) -> str: ... - diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 139f62f0..d88c338e 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -90,18 +90,17 @@ def test_diff_small_tables(self): self.assertEqual(1, self.differ.stats["table2_count"]) # Test materialize - materialize_path = self.connection.parse_table_name(f'test_mat_{random_table_suffix()}') + materialize_path = self.connection.parse_table_name(f"test_mat_{random_table_suffix()}") mdiffer = self.differ.replace(materialize_to_table=materialize_path) diff = list(mdiffer.diff_tables(self.table, self.table2)) self.assertEqual(expected, diff) t = TablePath(materialize_path) - rows = self.connection.query( t.select(), List[tuple] ) - self.connection.query( t.drop() ) + rows = self.connection.query(t.select(), List[tuple]) + self.connection.query(t.drop()) # is_xa, is_xb, is_diff1, is_diff2, row1, row2 assert rows == [(1, 0, 1, 1) + expected_row + (None, None)], rows - def test_diff_table_above_bisection_threshold(self): time = "2022-01-01 00:00:00" time_str = f"timestamp '{time}'" From 3a09a779c32e341654719863ea5f6b83a8eb61ad Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 6 Oct 2022 11:15:48 +0300 Subject: [PATCH 24/33] joindiff: docs, refactor --- data_diff/joindiff_tables.py | 27 +++++++++++++++++++-------- data_diff/queries/api.py | 7 +++++-- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index e0579845..2356912b 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -66,7 +66,7 @@ def create_temp_table(c: Compiler, table: TablePath, expr: Expr): def drop_table_oracle(name: DbPath): - t = TablePath(name) + t = table(name) # Experience shows double drop is necessary with suppress(QueryError): yield t.drop() @@ -75,14 +75,15 @@ def drop_table_oracle(name: DbPath): def drop_table(name: DbPath): - t = TablePath(name) + t = table(name) yield t.drop(if_exists=True) yield commit -def append_to_table_oracle(name: DbPath, expr: Expr): +def append_to_table_oracle(path: DbPath, expr: Expr): + """See append_to_table""" assert expr.schema, expr - t = TablePath(name, expr.schema) + t = table(path, schema=expr.schema) with suppress(QueryError): yield t.create() # uses expr.schema yield commit @@ -90,9 +91,11 @@ def append_to_table_oracle(name: DbPath, expr: Expr): yield commit -def append_to_table(name: DbPath, expr: Expr): +def append_to_table(path: DbPath, expr: Expr): + """Append to table + """ assert expr.schema, expr - t = TablePath(name, expr.schema) + t = table(path, schema=expr.schema) yield t.create(if_not_exists=True) # uses expr.schema yield commit yield t.insert_expr(expr) @@ -143,17 +146,25 @@ class JoinDiffer(TableDiffer): The algorithm uses an OUTER JOIN (or equivalent) with extra checks and statistics. The two tables must reside in the same database, and their primary keys must be unique and not null. + All parameters are optional. + Parameters: threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``. There may be many pools, so number of actual threads can be a lot higher. + validate_unique_key (bool): Enable/disable validating that the key columns are unique. + Single query, and can't be threaded, so it's very slow on non-cloud dbs. + Future versions will detect UNIQUE constraints in the schema. + sample_exclusive_rows (bool): Enable/disable sampling of exclusive rows. Creates a temporary table. + materialize_to_table (DbPath, optional): Path of new table to write diff results to. Disabled if not provided. + write_limit (int): Maximum number of rows to write when materializing, per thread. """ - stats: dict = {} validate_unique_key: bool = True sample_exclusive_rows: bool = True materialize_to_table: DbPath = None write_limit: int = WRITE_LIMIT + stats: dict = {} def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: db = table1.database @@ -330,7 +341,7 @@ def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): def exclusive_rows(expr): c = Compiler(db) name = c.new_unique_table_name("temp_table") - exclusive_rows = TablePath(name, schema=expr.source_table.schema) + exclusive_rows = table(name, schema=expr.source_table.schema) yield create_temp_table(c, exclusive_rows, expr.limit(self.write_limit)) count = yield exclusive_rows.count() diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index 60636346..f000ec67 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -30,8 +30,11 @@ def cte(expr: Expr, *, name: Optional[str] = None, params: Sequence[str] = None) return Cte(expr, name, params) -def table(*path: str, schema: Schema = None) -> ITable: - assert all(isinstance(i, str) for i in path), path +def table(*path: str, schema: Schema = None) -> TablePath: + if len(path) == 1 and isinstance(path[0], tuple): + path ,= path + if not all(isinstance(i, str) for i in path): + raise TypeError(f"All elements of table path must be of type 'str'. Got: {path}") return TablePath(path, schema) From 90cbfb6ba4bc012761f8ac25d624de57af96d49a Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 6 Oct 2022 20:15:04 +0300 Subject: [PATCH 25/33] Queries fix --- data_diff/queries/ast_classes.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index b5081fcb..173ce933 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -248,12 +248,17 @@ class BinOp(ExprNode, LazyOps): op: str args: Sequence[Expr] - def __post_init__(self): - assert len(self.args) == 2, self.args - def compile(self, c: Compiler) -> str: - a, b = self.args - return f"({c.compile(a)} {self.op} {c.compile(b)})" + expr = f" {self.op} ".join(c.compile(a) for a in self.args) + return f"({expr})" + + @property + def type(self): + types = {get_type(i) for i in self.args} + if len(types) > 1: + raise TypeError(f"Expected all args to have the same type, got {types}") + t ,= types + return t class BinBoolOp(BinOp): From ad48f5dcccea554a9648fa459ee7814b4fbfae60 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 6 Oct 2022 18:39:41 +0300 Subject: [PATCH 26/33] Composite key - initial (WIP); Refactor: TableSegment.key_column -> key_columns Added test --- data_diff/__init__.py | 25 ++++++---- data_diff/__main__.py | 16 +++---- data_diff/diff_tables.py | 11 ++++- data_diff/joindiff_tables.py | 8 ++-- data_diff/queries/ast_classes.py | 2 +- data_diff/table_segment.py | 27 ++++++----- tests/test_database_types.py | 8 ++-- tests/test_diff_tables.py | 79 +++++++++++++++++--------------- tests/test_joindiff.py | 69 ++++++++++++++++++++++------ tests/test_postgresql.py | 6 +-- 10 files changed, 158 insertions(+), 93 deletions(-) diff --git a/data_diff/__init__.py b/data_diff/__init__.py index 3e8451ba..ae6f021d 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -1,4 +1,4 @@ -from typing import Tuple, Iterator, Optional, Union +from typing import Sequence, Tuple, Iterator, Optional, Union from .tracking import disable_tracking from .databases.connect import connect @@ -12,7 +12,7 @@ def connect_to_table( db_info: Union[str, dict], table_name: Union[DbPath, str], - key_column: str = "id", + key_columns: str = ("id",), thread_count: Optional[int] = 1, **kwargs, ) -> TableSegment: @@ -21,19 +21,21 @@ def connect_to_table( Parameters: db_info: Either a URI string, or a dict of connection options. table_name: Name of the table as a string, or a tuple that signifies the path. - key_column: Name of the key column + key_columns: Names of the key columns thread_count: Number of threads for this connection (only if using a threadpooled db implementation) See Also: :meth:`connect` """ + if isinstance(key_columns, str): + key_columns = (key_columns,) db = connect(db_info, thread_count=thread_count) if isinstance(table_name, str): table_name = db.parse_table_name(table_name) - return TableSegment(db, table_name, key_column, **kwargs) + return TableSegment(db, table_name, key_columns, **kwargs) def diff_tables( @@ -41,7 +43,7 @@ def diff_tables( table2: TableSegment, *, # Name of the key column, which uniquely identifies each row (usually id) - key_column: str = None, + key_columns: Sequence[str] = None, # Name of updated column, which signals that rows changed (usually updated_at or last_update) update_column: str = None, # Extra columns to compare @@ -67,12 +69,12 @@ def diff_tables( """Finds the diff between table1 and table2. Parameters: - key_column (str): Name of the key column, which uniquely identifies each row (usually id) + key_columns (Tuple[str, ...]): Name of the key column, which uniquely identifies each row (usually id) update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update). Used by `min_update` and `max_update`. extra_columns (Tuple[str, ...], optional): Extra columns to compare - min_key (:data:`DbKey`, optional): Lowest key_column value, used to restrict the segment - max_key (:data:`DbKey`, optional): Highest key_column value, used to restrict the segment + min_key (:data:`DbKey`, optional): Lowest key value, used to restrict the segment + max_key (:data:`DbKey`, optional): Highest key value, used to restrict the segment min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment max_update (:data:`DbTime`, optional): Highest update_column value, used to restrict the segment algorithm (:class:`Algorithm`): Which diffing algorithm to use (`HASHDIFF` or `JOINDIFF`) @@ -84,7 +86,7 @@ def diff_tables( Note: The following parameters are used to override the corresponding attributes of the given :class:`TableSegment` instances: - `key_column`, `update_column`, `extra_columns`, `min_key`, `max_key`. If different values are needed per table, it's + `key_columns`, `update_column`, `extra_columns`, `min_key`, `max_key`. If different values are needed per table, it's possible to omit them here, and instead set them directly when creating each :class:`TableSegment`. Example: @@ -98,11 +100,14 @@ def diff_tables( :class:`JoinDiffer` """ + if isinstance(key_columns, str): + key_columns = (key_columns,) + tables = [table1, table2] override_attrs = { k: v for k, v in dict( - key_column=key_column, + key_columns=key_columns, update_column=update_column, extra_columns=extra_columns, min_key=min_key, diff --git a/data_diff/__main__.py b/data_diff/__main__.py index c6c8fefe..0b7fce7c 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -80,7 +80,7 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - @click.argument("table1", required=False) @click.argument("database2", required=False) @click.argument("table2", required=False) -@click.option("-k", "--key-column", default=None, help="Name of primary key column. Default='id'.", metavar="NAME") +@click.option("-k", "--key-columns", default=[], multiple=True, help="Names of primary key columns. Default='id'.", metavar="NAME") @click.option("-t", "--update-column", default=None, help="Name of updated_at/last_updated column", metavar="NAME") @click.option( "-c", @@ -187,7 +187,7 @@ def _main( table1, database2, table2, - key_column, + key_columns, update_column, columns, limit, @@ -233,7 +233,7 @@ def _main( logging.error("Cannot specify a limit when using the -s/--stats switch") return - key_column = key_column or "id" + key_columns = key_columns or ("id",) bisection_factor = DEFAULT_BISECTION_FACTOR if bisection_factor is None else int(bisection_factor) bisection_threshold = DEFAULT_BISECTION_THRESHOLD if bisection_threshold is None else int(bisection_threshold) @@ -328,23 +328,23 @@ def _main( expanded_columns |= match - columns = tuple(expanded_columns - {key_column, update_column}) + columns = tuple(expanded_columns - {*key_columns, update_column}) if db1 is db2: diff_schemas( schema1, schema2, ( - key_column, + *key_columns, update_column, + *columns, ) - + columns, ) - logging.info(f"Diffing using columns: key={key_column} update={update_column} extra={columns}") + logging.info(f"Diffing using columns: key={key_columns} update={update_column} extra={columns}") segments = [ - TableSegment(db, table_path, key_column, update_column, columns, **options)._with_raw_schema(raw_schema) + TableSegment(db, table_path, key_columns, update_column, columns, **options)._with_raw_schema(raw_schema) for db, table_path, raw_schema in safezip(dbs, table_paths, schemas) ] diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index a90d8ef8..6801e0d2 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -153,8 +153,15 @@ def _diff_segments( ... def _bisect_and_diff_tables(self, table1, table2): - key_type = table1._schema[table1.key_column] - key_type2 = table2._schema[table2.key_column] + if len(table1.key_columns) > 1: + raise NotImplementedError("Composite key not supported yet!") + if len(table2.key_columns) > 1: + raise NotImplementedError("Composite key not supported yet!") + key1 ,= table1.key_columns + key2 ,= table2.key_columns + + key_type = table1._schema[key1] + key_type2 = table2._schema[key2] if not isinstance(key_type, IKey): raise NotImplementedError(f"Cannot use column of type {key_type} as a key") if not isinstance(key_type2, IKey): diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 2356912b..3a8175e3 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -241,7 +241,7 @@ def _test_duplicate_keys(self, table1, table2): # Test duplicate keys for ts in [table1, table2]: t = ts._make_select() - key_columns = [ts.key_column] # XXX + key_columns = ts.key_columns q = t.select(total=Count(), total_distinct=Count(Concat(this[key_columns]), distinct=True)) total, total_distinct = ts.database.query(q, tuple) @@ -254,7 +254,7 @@ def _test_null_keys(self, table1, table2): # Test null keys for ts in [table1, table2]: t = ts._make_select() - key_columns = [ts.key_column] # XXX + key_columns = ts.key_columns q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns)) nulls = ts.database.query(q, list) @@ -294,8 +294,8 @@ def _create_outer_join(self, table1, table2): if db is not table2.database: raise ValueError("Joindiff only applies to tables within the same database") - keys1 = [table1.key_column] # XXX - keys2 = [table2.key_column] # XXX + keys1 = table1.key_columns + keys2 = table2.key_columns if len(keys1) != len(keys2): raise ValueError("The provided key columns are of a different count") diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 173ce933..92a17543 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -562,7 +562,7 @@ def __getattr__(self, name): return _ResolveColumn(name) def __getitem__(self, name): - if isinstance(name, list): + if isinstance(name, (list, tuple)): return [_ResolveColumn(n) for n in name] return _ResolveColumn(name) diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 3a4ddbe4..aa1a1498 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -1,5 +1,5 @@ import time -from typing import List, Tuple +from typing import List, Sequence, Tuple import logging from runtype import dataclass @@ -22,12 +22,12 @@ class TableSegment: Parameters: database (Database): Database instance. See :meth:`connect` table_path (:data:`DbPath`): Path to table in form of a tuple. e.g. `('my_dataset', 'table_name')` - key_column (str): Name of the key column, which uniquely identifies each row (usually id) - update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update). + key_columns (Tuple[str]): Name of the key column, which uniquely identifies each row (usually id) + update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update) Used by `min_update` and `max_update`. extra_columns (Tuple[str, ...], optional): Extra columns to compare - min_key (:data:`DbKey`, optional): Lowest key_column value, used to restrict the segment - max_key (:data:`DbKey`, optional): Highest key_column value, used to restrict the segment + min_key (:data:`DbKey`, optional): Lowest key value, used to restrict the segment + max_key (:data:`DbKey`, optional): Highest key value, used to restrict the segment min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment max_update (:data:`DbTime`, optional): Highest update_column value, used to restrict the segment where (str, optional): An additional 'where' expression to restrict the search space. @@ -41,7 +41,7 @@ class TableSegment: table_path: DbPath # Columns - key_column: str + key_columns: Tuple[str, ...] update_column: str = None extra_columns: Tuple[str, ...] = () @@ -80,9 +80,13 @@ def with_schema(self) -> "TableSegment": def _make_key_range(self): if self.min_key is not None: - yield self.min_key <= this[self.key_column] + assert len(self.key_columns) == 1 + k ,= self.key_columns + yield self.min_key <= this[k] if self.max_key is not None: - yield this[self.key_column] < self.max_key + assert len(self.key_columns) == 1 + k ,= self.key_columns + yield this[k] < self.max_key def _make_update_range(self): if self.min_update is not None: @@ -144,7 +148,7 @@ def _relevant_columns(self) -> List[str]: if self.update_column and self.update_column not in extras: extras = [self.update_column] + extras - return [self.key_column] + extras + return list(self.key_columns) + extras @property def _relevant_columns_repr(self) -> List[Expr]: @@ -174,9 +178,10 @@ def query_key_range(self) -> Tuple[int, int]: """Query database for minimum and maximum key. This is used for setting the initial bounds.""" # Normalizes the result (needed for UUIDs) after the min/max computation # TODO better error if there is no schema + k ,= self.key_columns select = self._make_select().select( - ApplyFuncAndNormalizeAsString(this[self.key_column], min_), - ApplyFuncAndNormalizeAsString(this[self.key_column], max_), + ApplyFuncAndNormalizeAsString(this[k], min_), + ApplyFuncAndNormalizeAsString(this[k], max_), ) min_key, max_key = self.database.query(select, tuple) diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 4ac8d5f4..c9e9042c 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -647,11 +647,11 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego insertion_target_duration = time.monotonic() - start if type_category == "uuid": - self.table = TableSegment(self.src_conn, src_table_path, "col", None, ("id",), case_sensitive=False) - self.table2 = TableSegment(self.dst_conn, dst_table_path, "col", None, ("id",), case_sensitive=False) + self.table = TableSegment(self.src_conn, src_table_path, ("col",), None, ("id",), case_sensitive=False) + self.table2 = TableSegment(self.dst_conn, dst_table_path, ("col",), None, ("id",), case_sensitive=False) else: - self.table = TableSegment(self.src_conn, src_table_path, "id", None, ("col",), case_sensitive=False) - self.table2 = TableSegment(self.dst_conn, dst_table_path, "id", None, ("col",), case_sensitive=False) + self.table = TableSegment(self.src_conn, src_table_path, ("id",), None, ("col",), case_sensitive=False) + self.table2 = TableSegment(self.dst_conn, dst_table_path, ("id",), None, ("col",), case_sensitive=False) start = time.monotonic() self.assertEqual(N_SAMPLES, self.table.count()) diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 639587de..9dfc7de7 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -44,6 +44,11 @@ def _class_per_db_dec(filter_name=None): ] return parameterized_class(("name", "db_name"), names) +def _table_segment(database, table_path, key_columns, *args, **kw): + if isinstance(key_columns, str): + key_columns = (key_columns,) + return TableSegment(database, table_path, key_columns, *args, **kw) + def test_per_database(cls): return _class_per_db_dec()(cls) @@ -102,7 +107,7 @@ class TestPerDatabase(unittest.TestCase): preql = None def setUp(self): - assert self.db_name + assert self.db_name, self.db_name init_instances() self.connection = DATABASE_INSTANCES[self.db_name] @@ -170,17 +175,17 @@ def setUp(self): self.preql.commit() def test_init(self): - a = TableSegment( + a = _table_segment( self.connection, self.table_src_path, "id", "datetime", max_update=self.now.datetime, case_sensitive=False ) self.assertRaises( - ValueError, TableSegment, self.connection, self.table_src_path, "id", max_update=self.now.datetime + ValueError, _table_segment, self.connection, self.table_src_path, "id", max_update=self.now.datetime ) def test_basic(self): differ = HashDiffer(bisection_factor=10, bisection_threshold=100) - a = TableSegment(self.connection, self.table_src_path, "id", "datetime", case_sensitive=False) - b = TableSegment(self.connection, self.table_dst_path, "id", "datetime", case_sensitive=False) + a = _table_segment(self.connection, self.table_src_path, "id", "datetime", case_sensitive=False) + b = _table_segment(self.connection, self.table_dst_path, "id", "datetime", case_sensitive=False) assert a.count() == 6 assert b.count() == 5 @@ -190,23 +195,23 @@ def test_basic(self): def test_offset(self): differ = HashDiffer(bisection_factor=2, bisection_threshold=10) sec1 = self.now.shift(seconds=-1).datetime - a = TableSegment(self.connection, self.table_src_path, "id", "datetime", max_update=sec1, case_sensitive=False) - b = TableSegment(self.connection, self.table_dst_path, "id", "datetime", max_update=sec1, case_sensitive=False) + a = _table_segment(self.connection, self.table_src_path, "id", "datetime", max_update=sec1, case_sensitive=False) + b = _table_segment(self.connection, self.table_dst_path, "id", "datetime", max_update=sec1, case_sensitive=False) assert a.count() == 4 assert b.count() == 3 assert not list(differ.diff_tables(a, a)) self.assertEqual(len(list(differ.diff_tables(a, b))), 1) - a = TableSegment(self.connection, self.table_src_path, "id", "datetime", min_update=sec1, case_sensitive=False) - b = TableSegment(self.connection, self.table_dst_path, "id", "datetime", min_update=sec1, case_sensitive=False) + a = _table_segment(self.connection, self.table_src_path, "id", "datetime", min_update=sec1, case_sensitive=False) + b = _table_segment(self.connection, self.table_dst_path, "id", "datetime", min_update=sec1, case_sensitive=False) assert a.count() == 2 assert b.count() == 2 assert not list(differ.diff_tables(a, b)) day1 = self.now.shift(days=-1).datetime - a = TableSegment( + a = _table_segment( self.connection, self.table_src_path, "id", @@ -215,7 +220,7 @@ def test_offset(self): max_update=sec1, case_sensitive=False, ) - b = TableSegment( + b = _table_segment( self.connection, self.table_dst_path, "id", @@ -249,8 +254,8 @@ def setUp(self): ) _commit(self.connection) - self.table = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) - self.table2 = TableSegment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) + self.table = _table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) + self.table2 = _table_segment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) self.differ = HashDiffer(bisection_factor=3, bisection_threshold=4) @@ -443,8 +448,8 @@ def test_diff_column_names(self): ], ) - table1 = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) - table2 = TableSegment(self.connection, self.table_dst_path, "id2", "timestamp2", case_sensitive=False) + table1 = _table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) + table2 = _table_segment(self.connection, self.table_dst_path, "id2", "timestamp2", case_sensitive=False) differ = HashDiffer() diff = list(differ.diff_tables(table1, table2)) @@ -478,8 +483,8 @@ def setUp(self): _commit(self.connection) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_string_keys(self): differ = HashDiffer() @@ -535,8 +540,8 @@ def setUp(self): _commit(self.connection) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_alphanum_keys(self): @@ -549,8 +554,8 @@ def test_alphanum_keys(self): ) _commit(self.connection) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) self.assertRaises(NotImplementedError, list, differ.diff_tables(self.a, self.b)) @@ -587,8 +592,8 @@ def setUp(self): _commit(self.connection) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_varying_alphanum_keys(self): # Test the class itself @@ -609,8 +614,8 @@ def test_varying_alphanum_keys(self): ) _commit(self.connection) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) self.assertRaises(NotImplementedError, list, differ.diff_tables(self.a, self.b)) @@ -619,8 +624,8 @@ def test_varying_alphanum_keys(self): class TestTableSegment(TestPerDatabase): def setUp(self) -> None: super().setUp() - self.table = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) - self.table2 = TableSegment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) + self.table = _table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) + self.table2 = _table_segment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) def test_table_segment(self): early = datetime.datetime(2021, 1, 1, 0, 0) @@ -641,11 +646,11 @@ def test_case_awareness(self): _insert_rows(self.connection, self.table_src, cols, [[1, 9, time_str], [2, 2, time_str]]) _commit(self.connection) - res = tuple(self.table.replace(key_column="Id", case_sensitive=False).with_schema().query_key_range()) + res = tuple(self.table.replace(key_columns=("Id",), case_sensitive=False).with_schema().query_key_range()) assert res == ("1", "2") self.assertRaises( - KeyError, self.table.replace(key_column="Id", case_sensitive=True).with_schema().query_key_range + KeyError, self.table.replace(key_columns=("Id",), case_sensitive=True).with_schema().query_key_range ) @@ -674,8 +679,8 @@ def setUp(self): _commit(self.connection) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_uuid_column_with_nulls(self): differ = HashDiffer() @@ -704,8 +709,8 @@ def setUp(self): _commit(self.connection) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_uuid_columns_with_nulls(self): """ @@ -762,10 +767,10 @@ def setUp(self): _commit(self.connection) - self.a = TableSegment( + self.a = _table_segment( self.connection, self.table_src_path, "id", extra_columns=("c1", "c2"), case_sensitive=False ) - self.b = TableSegment( + self.b = _table_segment( self.connection, self.table_dst_path, "id", extra_columns=("c1", "c2"), case_sensitive=False ) @@ -819,8 +824,8 @@ def setUp(self): _commit(self.connection) - self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_right_table_empty(self): differ = HashDiffer() diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index d88c338e..62f7cd8e 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -1,3 +1,4 @@ +from functools import wraps from typing import List from parameterized import parameterized_class @@ -28,9 +29,7 @@ def init_instances(): DATABASE_INSTANCES = {k.__name__: connect(v, N_THREADS) for k, v in CONN_STRINGS.items()} -TEST_DATABASES = { - x.__name__ - for x in ( +TEST_DATABASES = ( db.PostgreSQL, db.Snowflake, db.MySQL, @@ -40,17 +39,63 @@ def init_instances(): db.Trino, db.Oracle, db.Redshift, - ) -} - -_class_per_db_dec = parameterized_class( - ("name", "db_name"), [(name, name) for name in DATABASE_URIS if name in TEST_DATABASES] ) -def test_per_database(cls): +def test_per_database(cls, dbs=TEST_DATABASES): + dbs = {db.__name__ for db in dbs} + _class_per_db_dec = parameterized_class( + ("name", "db_name"), [(name, name) for name in DATABASE_URIS if name in dbs] + ) return _class_per_db_dec(cls) +def test_per_database2(*dbs): + @wraps(test_per_database) + def dec(cls): + return test_per_database(cls, dbs) + return dec + + +@test_per_database2(db.Snowflake, db.BigQuery) +class TestCompositeKey(TestPerDatabase): + def setUp(self): + super().setUp() + + float_type = _get_float_type(self.connection) + + self.connection.query( + f"create table {self.table_src}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", + ) + self.connection.query( + f"create table {self.table_dst}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", + ) + _commit(self.connection) + + self.differ = JoinDiffer() + + def test_composite_key(self): + time = "2022-01-01 00:00:00" + time_str = f"timestamp '{time}'" + + cols = "id userid movieid rating timestamp".split() + _insert_rows(self.connection, self.table_src, cols, [[1, 1, 1, 9, time_str], [2, 2, 2, 9, time_str]]) + _insert_rows(self.connection, self.table_dst, cols, [[1, 1, 1, 9, time_str], [2, 3, 2, 9, time_str]]) + _commit(self.connection) + + # Sanity + table1 = TableSegment(self.connection, self.table_src_path, ("id",), "timestamp", ('userid',), case_sensitive=False) + table2 = TableSegment(self.connection, self.table_dst_path, ("id",), "timestamp", ('userid',), case_sensitive=False) + diff = list(self.differ.diff_tables(table1, table2)) + assert len(diff) == 2 + assert self.differ.stats['exclusive_count'] == 0 + + # Test pks diffed, by checking exclusive_count + table1 = TableSegment(self.connection, self.table_src_path, ("id", "userid"), "timestamp", case_sensitive=False) + table2 = TableSegment(self.connection, self.table_dst_path, ("id", "userid"), "timestamp", case_sensitive=False) + diff = list(self.differ.diff_tables(table1, table2)) + assert len(diff) == 2 + assert self.differ.stats['exclusive_count'] == 2 + @test_per_database class TestJoindiff(TestPerDatabase): @@ -61,16 +106,14 @@ def setUp(self): self.connection.query( f"create table {self.table_src}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", - None, ) self.connection.query( f"create table {self.table_dst}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", - None, ) _commit(self.connection) - self.table = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) - self.table2 = TableSegment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) + self.table = TableSegment(self.connection, self.table_src_path, ("id",), "timestamp", case_sensitive=False) + self.table2 = TableSegment(self.connection, self.table_dst_path, ("id",), "timestamp", case_sensitive=False) self.differ = JoinDiffer() diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 3a4f4239..21e64b3e 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -37,8 +37,8 @@ def test_uuid(self): for query in queries: self.connection.query(query, None) - a = TableSegment(self.connection, (self.table_src,), "id", "comment") - b = TableSegment(self.connection, (self.table_dst,), "id", "comment") + a = TableSegment(self.connection, (self.table_src,), ("id",), "comment") + b = TableSegment(self.connection, (self.table_dst,), ("id",), "comment") differ = HashDiffer() diff = list(differ.diff_tables(a, b)) @@ -56,7 +56,7 @@ def test_uuid(self): mysql_conn.query(f"INSERT INTO {self.table_dst}(id, comment) VALUES ('{uuid}', '{comment}')", None) mysql_conn.query(f"COMMIT", None) - c = TableSegment(mysql_conn, (self.table_dst,), "id", "comment") + c = TableSegment(mysql_conn, (self.table_dst,), ("id",), "comment") diff = list(differ.diff_tables(a, c)) assert not diff, diff diff = list(differ.diff_tables(c, a)) From 377b4a77714550a06d144af6f7abd4dffbd11cc5 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 7 Oct 2022 11:18:14 +0300 Subject: [PATCH 27/33] Fixed interactive mode and explain --- data_diff/databases/base.py | 14 +++++++++----- data_diff/databases/database_types.py | 5 +++++ data_diff/databases/mysql.py | 3 +++ data_diff/databases/snowflake.py | 3 +++ data_diff/queries/ast_classes.py | 7 +++++++ tests/test_query.py | 3 +++ tests/test_sql.py | 2 +- 7 files changed, 31 insertions(+), 6 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 376f5e78..8b3a465d 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -27,7 +27,7 @@ DbPath, ) -from data_diff.queries import Expr, Compiler, table, Select, SKIP +from data_diff.queries import Expr, Compiler, table, Select, SKIP, Explain logger = logging.getLogger("database") @@ -114,7 +114,7 @@ class Database(AbstractDatabase): def name(self): return type(self).__name__ - def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): + def query(self, sql_ast: Union[Expr, Generator], res_type: type = list): "Query the given SQL code/AST, and attempt to convert the result to type 'res_type'" compiler = Compiler(self) @@ -128,8 +128,9 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code) if getattr(self, "_interactive", False) and isinstance(sql_ast, Select): explained_sql = compiler.compile(Explain(sql_ast)) - logger.info("EXPLAIN for SQL SELECT") - logger.info(self._query(explained_sql)) + explain = self._query(explained_sql) + for row, in explain: + logger.debug(f'EXPLAIN: {row}') answer = input("Continue? [y/n] ") if not answer.lower() in ["y", "yes"]: sys.exit(1) @@ -337,7 +338,7 @@ def _query_cursor(self, c, sql_code: str): assert isinstance(sql_code, str), sql_code try: c.execute(sql_code) - if sql_code.lower().startswith("select"): + if sql_code.lower().startswith(("select", "explain", "show")): return c.fetchall() except Exception as e: # logger.exception(e) @@ -349,6 +350,9 @@ def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> lis callback = partial(self._query_cursor, c) return apply_query(callback, sql_code) + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN {query}" + class ThreadedDatabase(Database): """Access the database through singleton threads. diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index 7fe436ae..27249f0d 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -172,6 +172,11 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None "Provide SQL fragment for limit and offset inside a select" ... + @abstractmethod + def explain_as_text(self, query: str) -> str: + "Provide SQL for explaining a query, returned in as table(varchar)" + ... + class AbstractDatabase(AbstractDialect): @abstractmethod diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index b34afb36..b666e0c5 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -84,3 +84,6 @@ def type_repr(self, t) -> str: }[t] except KeyError: return super().type_repr(t) + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN FORMAT=TREE {query}" diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index bbd0958c..714fb5f0 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -90,3 +90,6 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: def is_autocommit(self) -> bool: return True + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN USING TEXT {query}" diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 92a17543..ac57bbe9 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -600,6 +600,13 @@ def compile(self, c: Compiler) -> str: return c.database.random() +@dataclass +class Explain(ExprNode): + select: Select + + def compile(self, c: Compiler) -> str: + return c.database.explain_as_text(c.compile(self.select)) + # DDL diff --git a/tests/test_query.py b/tests/test_query.py index 5091843e..fe0de696 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -32,6 +32,9 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None x = offset and f"offset {offset}", limit and f"limit {limit}" return " ".join(filter(None, x)) + def explain_as_text(self, query: str) -> str: + return f"explain {query}" + class TestQuery(unittest.TestCase): def setUp(self): diff --git a/tests/test_sql.py b/tests/test_sql.py index fe17940b..0e1e8d13 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -91,7 +91,7 @@ def test_count_with_column(self): ) def test_explain(self): - expected_sql = "EXPLAIN SELECT count(id) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" + expected_sql = "EXPLAIN FORMAT=TREE SELECT count(id) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" self.assertEqual( expected_sql, self.compiler.compile( From 4c16bac24105729b676123ccda0ae2e8aeec5f35 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 7 Oct 2022 14:55:29 +0300 Subject: [PATCH 28/33] Update README --- README.md | 34 +++++++++++++++++++++++++++------- data_diff/__main__.py | 4 ++++ 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index df0d27c5..a36fe630 100644 --- a/README.md +++ b/README.md @@ -17,10 +17,20 @@ rows across two different databases. * 🔍 Outputs [diff of rows](#example-command-and-output) in detail * 🚨 Simple CLI/API to create monitoring and alerts * 🔁 Bridges column types of different formats and levels of precision (e.g. Double ⇆ Float ⇆ Decimal) -* 🔥 Verify 25M+ rows in <10s, and 1B+ rows in ~5min. +* 🔥 Fast! Verify 25M+ rows in <10s, and 1B+ rows in ~5min. * ♾️ Works for tables with 10s of billions of rows -**data-diff** splits the table into smaller segments, then checksums each +data-diff can diff tables within the same database, or across different databases. + +**Same-DB Diff**: Uses an outer-join to diff the rows as efficiently and accurately as possible. + +Supports materializing the diff results to a database table. + +Can also collect various extra statistics about the tables. + +**Cross-DB Diff**: Employs a divide and conquer algorithm based on hashing, optimized for few changes. + +data-diff splits the table into smaller segments, then checksums each segment in both databases. When the checksums for a segment aren't equal, it will further divide that segment into yet smaller segments, checksumming those until it gets to the differing row(s). See [Technical Explanation][tech-explain] for more @@ -69,8 +79,8 @@ better than MySQL. may span a half-dozen systems, without verifying each intermediate datastore it's extremely difficult to track down where a row got lost. * **Detecting hard deletes for an `updated_at`-based pipeline**. If you're - copying data to your warehouse based on an `updated_at`-style column, then - you'll miss hard-deletes that **data-diff** can find for you. + copying data to your warehouse based on an `updated_at`-style column, data-diff + can find any hard-deletes that you might have missed. * **Make your replication self-healing.** You can use **data-diff** to self-heal by using the diff output to write/update rows in the target database. @@ -217,7 +227,7 @@ may be case-sensitive. This is the case for the Snowflake schema and table names Options: - `--help` - Show help message and exit. - - `-k` or `--key-column` - Name of the primary key column + - `-k` or `--key-columns` - Name of the primary key column. If none provided, default is 'id'. - `-t` or `--update-column` - Name of updated_at/last_updated column - `-c` or `--columns` - Names of extra columns to compare. Can be used more than once in the same command. Accepts a name or a pattern like in SQL. @@ -232,12 +242,22 @@ Options: Example: `--min-age=5min` ignores rows from the last 5 minutes. Valid units: `d, days, h, hours, min, minutes, mon, months, s, seconds, w, weeks, y, years` - `--max-age` - Considers only rows younger than specified. See `--min-age`. - - `--bisection-factor` - Segments per iteration. When set to 2, it performs binary search. - - `--bisection-threshold` - Minimal bisection threshold. i.e. maximum size of pages to diff locally. - `-j` or `--threads` - Number of worker threads to use per database. Default=1. - `-w`, `--where` - An additional 'where' expression to restrict the search space. - `--conf`, `--run` - Specify the run and configuration from a TOML file. (see below) - `--no-tracking` - data-diff sends home anonymous usage data. Use this to disable it. + - `-a`, `--algorithm` `[auto|joindiff|hashdiff]` - Force algorithm choice + +Same-DB diff only: + - `-m`, `--materialize` - Materialize the diff results into a new table in the database. + Use `%t` in the name to place a timestamp. + Example: `-m test_mat_%t` + - `--assume-unique-key` - Skip validating the uniqueness of the key column during joindiff, which is costly in non-cloud dbs. + +Cross-DB diff only: + - `--bisection-threshold` - Minimal size of segment to be split. Smaller segments will be downloaded and compared locally. + - `--bisection-factor` - Segments per iteration. When set to 2, it performs binary search. + ### How to use with a configuration file diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 0b7fce7c..1380e3b4 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -375,6 +375,10 @@ def _main( print(f"Diff-Total: {len(diff)} changed rows out of {max_table_count}") print(f"Diff-Percent: {percent:.14f}%") print(f"Diff-Split: +{plus} -{minus}") + if differ.stats: + print("Extra-Info:") + for k, v in differ.stats.items(): + print(f' {k} = {v}') else: for op, values in diff_iter: color = COLOR_SCHEME[op] From 472f422b4746ee731ef118ab4fd17717d98b6ab9 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 7 Oct 2022 15:01:41 +0300 Subject: [PATCH 29/33] Added --sample-exclusive-rows switch --- README.md | 1 + data_diff/__main__.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/README.md b/README.md index a36fe630..c9e8d0f4 100644 --- a/README.md +++ b/README.md @@ -253,6 +253,7 @@ Same-DB diff only: Use `%t` in the name to place a timestamp. Example: `-m test_mat_%t` - `--assume-unique-key` - Skip validating the uniqueness of the key column during joindiff, which is costly in non-cloud dbs. + - `--sample-exclusive-rows` - Sample several rows that only appear in one of the tables, but not the other. Use with `-s`. Cross-DB diff only: - `--bisection-threshold` - Minimal size of segment to be split. Smaller segments will be downloaded and compared locally. diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 1380e3b4..2201ee58 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -139,6 +139,11 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - is_flag=True, help="Skip validating the uniqueness of the key column during joindiff, which is costly in non-cloud dbs.", ) +@click.option( + "--sample-exclusive-rows", + is_flag=True, + help="Sample several rows that only appear in one of the tables, but not the other.", +) @click.option( "-j", "--threads", @@ -206,6 +211,7 @@ def _main( json_output, where, assume_unique_key, + sample_exclusive_rows, materialize, threads1=None, threads2=None, @@ -294,6 +300,7 @@ def _main( threaded=threaded, max_threadpool_size=threads and threads * 2, validate_unique_key=not assume_unique_key, + sample_exclusive_rows=sample_exclusive_rows, materialize_to_table=materialize and db1.parse_table_name(eval_name_template(materialize)), ) else: From abaabe84f963080e408606157e0501b7a417aeef Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 7 Oct 2022 16:17:35 +0300 Subject: [PATCH 30/33] README: Updated supported database list --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c9e8d0f4..b54c3648 100644 --- a/README.md +++ b/README.md @@ -138,9 +138,9 @@ $ data-diff \ | PostgreSQL >=10 | `postgresql://:@:5432/` | 💚 | | MySQL | `mysql://:@:5432/` | 💚 | | Snowflake | `"snowflake://[:]@//?warehouse=&role=[&authenticator=externalbrowser]"` | 💚 | +| BigQuery | `bigquery:///` | 💚 | +| Redshift | `redshift://:@:5439/` | 💚 | | Oracle | `oracle://:@/database` | 💛 | -| BigQuery | `bigquery:///` | 💛 | -| Redshift | `redshift://:@:5439/` | 💛 | | Presto | `presto://:@:8080/` | 💛 | | Databricks | `databricks://:@//` | 💛 | | Trino | `trino://:@:8080/` | 💛 | @@ -151,6 +151,8 @@ $ data-diff \ | Pinot | | 📝 | | Druid | | 📝 | | Kafka | | 📝 | +| DuckDB | | 📝 | +| SQLite | | 📝 | * 💚: Implemented and thoroughly tested. * 💛: Implemented, but not thoroughly tested yet. From 245aeb6c4b2cca10c5ab9350e77fcaf6ab066f2b Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 7 Oct 2022 16:27:51 +0300 Subject: [PATCH 31/33] Updated docs; Ran black --- README.md | 1 + data_diff/__main__.py | 10 ++++++---- data_diff/config.py | 6 +++--- data_diff/databases/base.py | 4 ++-- data_diff/diff_tables.py | 6 +++--- data_diff/joindiff_tables.py | 3 +-- data_diff/queries/api.py | 2 +- data_diff/queries/ast_classes.py | 3 ++- data_diff/table_segment.py | 6 +++--- tests/test_diff_tables.py | 17 +++++++++++++---- tests/test_joindiff.py | 32 +++++++++++++++++++------------- 11 files changed, 54 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index b54c3648..83f93ae9 100644 --- a/README.md +++ b/README.md @@ -252,6 +252,7 @@ Options: Same-DB diff only: - `-m`, `--materialize` - Materialize the diff results into a new table in the database. + If a table exists by that name, it will be replaced. Use `%t` in the name to place a timestamp. Example: `-m test_mat_%t` - `--assume-unique-key` - Skip validating the uniqueness of the key column during joindiff, which is costly in non-cloud dbs. diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 2201ee58..5ca5e15b 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -80,7 +80,9 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - @click.argument("table1", required=False) @click.argument("database2", required=False) @click.argument("table2", required=False) -@click.option("-k", "--key-columns", default=[], multiple=True, help="Names of primary key columns. Default='id'.", metavar="NAME") +@click.option( + "-k", "--key-columns", default=[], multiple=True, help="Names of primary key columns. Default='id'.", metavar="NAME" +) @click.option("-t", "--update-column", default=None, help="Name of updated_at/last_updated column", metavar="NAME") @click.option( "-c", @@ -110,7 +112,7 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - "--materialize", default=None, metavar="TABLE_NAME", - help="Materialize the diff results into a new table in the database. (joindiff only)", + help="(joindiff only) Materialize the diff results into a new table in the database. If a table exists by that name, it will be replaced.", ) @click.option( "--min-age", @@ -345,7 +347,7 @@ def _main( *key_columns, update_column, *columns, - ) + ), ) logging.info(f"Diffing using columns: key={key_columns} update={update_column} extra={columns}") @@ -385,7 +387,7 @@ def _main( if differ.stats: print("Extra-Info:") for k, v in differ.stats.items(): - print(f' {k} = {v}') + print(f" {k} = {v}") else: for op, values in diff_iter: color = COLOR_SCHEME[op] diff --git a/data_diff/config.py b/data_diff/config.py index ad7c972d..941e2643 100644 --- a/data_diff/config.py +++ b/data_diff/config.py @@ -26,13 +26,13 @@ def _apply_config(config: Dict[str, Any], run_name: str, kw: Dict[str, Any]): else: run_name = "default" - if 'database1' in kw: - for attr in ('table1', 'database2', 'table2'): + if "database1" in kw: + for attr in ("table1", "database2", "table2"): if kw[attr] is None: raise ValueError(f"Specified database1 but not {attr}. Must specify all 4 arguments, or niether.") for index in "12": - run_args[index] = {attr: kw.pop(f"{attr}{index}") for attr in ('database', 'table')} + run_args[index] = {attr: kw.pop(f"{attr}{index}") for attr in ("database", "table")} # Process databases + tables for index in "12": diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 8b3a465d..c96ec6ae 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -129,8 +129,8 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list): if getattr(self, "_interactive", False) and isinstance(sql_ast, Select): explained_sql = compiler.compile(Explain(sql_ast)) explain = self._query(explained_sql) - for row, in explain: - logger.debug(f'EXPLAIN: {row}') + for (row,) in explain: + logger.debug(f"EXPLAIN: {row}") answer = input("Continue? [y/n] ") if not answer.lower() in ["y", "yes"]: sys.exit(1) diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 6801e0d2..24627c45 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -157,8 +157,8 @@ def _bisect_and_diff_tables(self, table1, table2): raise NotImplementedError("Composite key not supported yet!") if len(table2.key_columns) > 1: raise NotImplementedError("Composite key not supported yet!") - key1 ,= table1.key_columns - key2 ,= table2.key_columns + (key1,) = table1.key_columns + (key2,) = table2.key_columns key_type = table1._schema[key1] key_type2 = table2._schema[key2] @@ -214,7 +214,7 @@ def _bisect_and_diff_segments( assert table1.is_bounded and table2.is_bounded # Choose evenly spaced checkpoints (according to min_key and max_key) - biggest_table = max(table1, table2, key=methodcaller('approximate_size')) + biggest_table = max(table1, table2, key=methodcaller("approximate_size")) checkpoints = biggest_table.choose_checkpoints(self.bisection_factor - 1) # Create new instances of TableSegment between each checkpoint diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 3a8175e3..7617495f 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -92,8 +92,7 @@ def append_to_table_oracle(path: DbPath, expr: Expr): def append_to_table(path: DbPath, expr: Expr): - """Append to table - """ + """Append to table""" assert expr.schema, expr t = table(path, schema=expr.schema) yield t.create(if_not_exists=True) # uses expr.schema diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index f000ec67..d9c0945f 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -32,7 +32,7 @@ def cte(expr: Expr, *, name: Optional[str] = None, params: Sequence[str] = None) def table(*path: str, schema: Schema = None) -> TablePath: if len(path) == 1 and isinstance(path[0], tuple): - path ,= path + (path,) = path if not all(isinstance(i, str) for i in path): raise TypeError(f"All elements of table path must be of type 'str'. Got: {path}") return TablePath(path, schema) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index ac57bbe9..a73a69db 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -257,7 +257,7 @@ def type(self): types = {get_type(i) for i in self.args} if len(types) > 1: raise TypeError(f"Expected all args to have the same type, got {types}") - t ,= types + (t,) = types return t @@ -607,6 +607,7 @@ class Explain(ExprNode): def compile(self, c: Compiler) -> str: return c.database.explain_as_text(c.compile(self.select)) + # DDL diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index aa1a1498..c3219dc1 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -81,11 +81,11 @@ def with_schema(self) -> "TableSegment": def _make_key_range(self): if self.min_key is not None: assert len(self.key_columns) == 1 - k ,= self.key_columns + (k,) = self.key_columns yield self.min_key <= this[k] if self.max_key is not None: assert len(self.key_columns) == 1 - k ,= self.key_columns + (k,) = self.key_columns yield this[k] < self.max_key def _make_update_range(self): @@ -178,7 +178,7 @@ def query_key_range(self) -> Tuple[int, int]: """Query database for minimum and maximum key. This is used for setting the initial bounds.""" # Normalizes the result (needed for UUIDs) after the min/max computation # TODO better error if there is no schema - k ,= self.key_columns + (k,) = self.key_columns select = self._make_select().select( ApplyFuncAndNormalizeAsString(this[k], min_), ApplyFuncAndNormalizeAsString(this[k], max_), diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 9dfc7de7..57f9415b 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -44,6 +44,7 @@ def _class_per_db_dec(filter_name=None): ] return parameterized_class(("name", "db_name"), names) + def _table_segment(database, table_path, key_columns, *args, **kw): if isinstance(key_columns, str): key_columns = (key_columns,) @@ -195,16 +196,24 @@ def test_basic(self): def test_offset(self): differ = HashDiffer(bisection_factor=2, bisection_threshold=10) sec1 = self.now.shift(seconds=-1).datetime - a = _table_segment(self.connection, self.table_src_path, "id", "datetime", max_update=sec1, case_sensitive=False) - b = _table_segment(self.connection, self.table_dst_path, "id", "datetime", max_update=sec1, case_sensitive=False) + a = _table_segment( + self.connection, self.table_src_path, "id", "datetime", max_update=sec1, case_sensitive=False + ) + b = _table_segment( + self.connection, self.table_dst_path, "id", "datetime", max_update=sec1, case_sensitive=False + ) assert a.count() == 4 assert b.count() == 3 assert not list(differ.diff_tables(a, a)) self.assertEqual(len(list(differ.diff_tables(a, b))), 1) - a = _table_segment(self.connection, self.table_src_path, "id", "datetime", min_update=sec1, case_sensitive=False) - b = _table_segment(self.connection, self.table_dst_path, "id", "datetime", min_update=sec1, case_sensitive=False) + a = _table_segment( + self.connection, self.table_src_path, "id", "datetime", min_update=sec1, case_sensitive=False + ) + b = _table_segment( + self.connection, self.table_dst_path, "id", "datetime", min_update=sec1, case_sensitive=False + ) assert a.count() == 2 assert b.count() == 2 assert not list(differ.diff_tables(a, b)) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 62f7cd8e..5203c5a2 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -30,15 +30,15 @@ def init_instances(): TEST_DATABASES = ( - db.PostgreSQL, - db.Snowflake, - db.MySQL, - db.BigQuery, - db.Presto, - db.Vertica, - db.Trino, - db.Oracle, - db.Redshift, + db.PostgreSQL, + db.Snowflake, + db.MySQL, + db.BigQuery, + db.Presto, + db.Vertica, + db.Trino, + db.Oracle, + db.Redshift, ) @@ -49,10 +49,12 @@ def test_per_database(cls, dbs=TEST_DATABASES): ) return _class_per_db_dec(cls) + def test_per_database2(*dbs): @wraps(test_per_database) def dec(cls): return test_per_database(cls, dbs) + return dec @@ -83,18 +85,22 @@ def test_composite_key(self): _commit(self.connection) # Sanity - table1 = TableSegment(self.connection, self.table_src_path, ("id",), "timestamp", ('userid',), case_sensitive=False) - table2 = TableSegment(self.connection, self.table_dst_path, ("id",), "timestamp", ('userid',), case_sensitive=False) + table1 = TableSegment( + self.connection, self.table_src_path, ("id",), "timestamp", ("userid",), case_sensitive=False + ) + table2 = TableSegment( + self.connection, self.table_dst_path, ("id",), "timestamp", ("userid",), case_sensitive=False + ) diff = list(self.differ.diff_tables(table1, table2)) assert len(diff) == 2 - assert self.differ.stats['exclusive_count'] == 0 + assert self.differ.stats["exclusive_count"] == 0 # Test pks diffed, by checking exclusive_count table1 = TableSegment(self.connection, self.table_src_path, ("id", "userid"), "timestamp", case_sensitive=False) table2 = TableSegment(self.connection, self.table_dst_path, ("id", "userid"), "timestamp", case_sensitive=False) diff = list(self.differ.diff_tables(table1, table2)) assert len(diff) == 2 - assert self.differ.stats['exclusive_count'] == 2 + assert self.differ.stats["exclusive_count"] == 2 @test_per_database From e8965fd00b1da6d469a8f3fe83fe8ffb0e23610b Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 7 Oct 2022 16:52:27 +0300 Subject: [PATCH 32/33] Joindiff: Fix stats collections --- data_diff/databases/database_types.py | 92 ++++++++++++++------------- data_diff/databases/presto.py | 2 +- data_diff/joindiff_tables.py | 4 +- data_diff/table_segment.py | 1 - tests/test_diff_tables.py | 6 -- tests/test_joindiff.py | 2 + tests/test_query.py | 5 +- 7 files changed, 58 insertions(+), 54 deletions(-) diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index 27249f0d..dc7a806c 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -141,6 +141,8 @@ class UnknownColType(ColType): class AbstractDialect(ABC): + """Dialect-dependent query expressions""" + name: str @abstractmethod @@ -177,56 +179,18 @@ def explain_as_text(self, query: str) -> str: "Provide SQL for explaining a query, returned in as table(varchar)" ... - -class AbstractDatabase(AbstractDialect): @abstractmethod - def timestamp_value(self, t: DbTime) -> str: + def timestamp_value(self, t: datetime) -> str: "Provide SQL for the given timestamp value" ... - @abstractmethod - def md5_to_int(self, s: str) -> str: - "Provide SQL for computing md5 and returning an int" - ... - - @abstractmethod - def _query(self, sql_code: str) -> list: - "Send query to database and return result" - ... - - @abstractmethod - def select_table_schema(self, path: DbPath) -> str: - "Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)" - ... - @abstractmethod - def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: - """Query the table for its schema for table in 'path', and return {column: tuple} - where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?) - """ - ... +class AbstractDatadiffDialect(ABC): + """Dialect-dependent query expressions, that are specific to data-diff""" @abstractmethod - def _process_table_schema( - self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None - ): - """Process the result of query_table_schema(). - - Done in a separate step, to minimize the amount of processed columns. - Needed because processing each column may: - * throw errors and warnings - * query the database to sample values - - """ - - @abstractmethod - def parse_table_name(self, name: str) -> DbPath: - "Parse the given table name into a DbPath" - ... - - @abstractmethod - def close(self): - "Close connection(s) to the database instance. Querying will stop functioning." + def md5_to_int(self, s: str) -> str: + "Provide SQL for computing md5 and returning an int" ... @abstractmethod @@ -294,6 +258,48 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: return self.normalize_uuid(value, coltype) return self.to_string(value) + +class AbstractDatabase(AbstractDialect, AbstractDatadiffDialect): + @abstractmethod + def _query(self, sql_code: str) -> list: + "Send query to database and return result" + ... + + @abstractmethod + def select_table_schema(self, path: DbPath) -> str: + "Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)" + ... + + @abstractmethod + def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: + """Query the table for its schema for table in 'path', and return {column: tuple} + where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?) + """ + ... + + @abstractmethod + def _process_table_schema( + self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None + ): + """Process the result of query_table_schema(). + + Done in a separate step, to minimize the amount of processed columns. + Needed because processing each column may: + * throw errors and warnings + * query the database to sample values + + """ + + @abstractmethod + def parse_table_name(self, name: str) -> DbPath: + "Parse the given table name into a DbPath" + ... + + @abstractmethod + def close(self): + "Close connection(s) to the database instance. Querying will stop functioning." + ... + @abstractmethod def _normalize_table_path(self, path: DbPath) -> DbPath: ... diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index 2fb041fc..d7204775 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -83,7 +83,7 @@ def close(self): self._conn.close() def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - # TODO + # TODO rounds if coltype.rounds: s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" else: diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 7617495f..a1f23b23 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -10,7 +10,7 @@ from runtype import dataclass -from data_diff.databases.database_types import DbPath, Schema +from data_diff.databases.database_types import DbPath, NumericType, Schema from data_diff.databases.base import QueryError @@ -273,7 +273,7 @@ def _collect_stats(self, i, table): f"max_{c}": max_(this[c]), } for c in table._relevant_columns - if c == "id" # TODO just if the right type + if isinstance(table._schema[c], NumericType) ) col_exprs["count"] = Count() diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index c3219dc1..170955cd 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -177,7 +177,6 @@ def count_and_checksum(self) -> Tuple[int, int]: def query_key_range(self) -> Tuple[int, int]: """Query database for minimum and maximum key. This is used for setting the initial bounds.""" # Normalizes the result (needed for UUIDs) after the min/max computation - # TODO better error if there is no schema (k,) = self.key_columns select = self._make_select().select( ApplyFuncAndNormalizeAsString(this[k], min_), diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 57f9415b..84d70434 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -485,8 +485,6 @@ def setUp(self): self.new_uuid = uuid.uuid1(32132131) queries.append(f"INSERT INTO {self.table_src} VALUES ('{self.new_uuid}', 'This one is different')") - # TODO test unexpected values? - for query in queries: self.connection.query(query, None) @@ -542,8 +540,6 @@ def setUp(self): self.new_alphanum = "aBcDeFgHiJ" queries.append(f"INSERT INTO {self.table_src} VALUES ('{self.new_alphanum}', 'This one is different')") - # TODO test unexpected values? - for query in queries: self.connection.query(query, None) @@ -594,8 +590,6 @@ def setUp(self): self.new_alphanum = "aBcDeFgHiJ" queries.append(f"INSERT INTO {self.table_src} VALUES ('{self.new_alphanum}', 'This one is different')") - # TODO test unexpected values? - for query in queries: self.connection.query(query, None) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 5203c5a2..b1babe35 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -137,6 +137,8 @@ def test_diff_small_tables(self): self.assertEqual(expected, diff) self.assertEqual(2, self.differ.stats["table1_count"]) self.assertEqual(1, self.differ.stats["table2_count"]) + self.assertEqual(3, self.differ.stats["table1_sum_id"]) + self.assertEqual(1, self.differ.stats["table2_sum_id"]) # Test materialize materialize_path = self.connection.parse_table_name(f"test_mat_{random_table_suffix()}") diff --git a/tests/test_query.py b/tests/test_query.py index fe0de696..3ab26e43 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,4 +1,4 @@ -from cmath import exp +from datetime import datetime from typing import List, Optional import unittest from data_diff.databases.database_types import AbstractDialect, CaseInsensitiveDict, CaseSensitiveDict @@ -35,6 +35,9 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None def explain_as_text(self, query: str) -> str: return f"explain {query}" + def timestamp_value(self, t: datetime) -> str: + return f"timestamp '{t}'" + class TestQuery(unittest.TestCase): def setUp(self): From 47b9faa3202584d30bf3f2765c1ab57ae1e1636d Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Sat, 8 Oct 2022 10:37:33 +0300 Subject: [PATCH 33/33] Cleanup and minor fixes (pylint pass) --- data_diff/__init__.py | 17 ++++--- data_diff/databases/base.py | 14 ++++-- data_diff/databases/bigquery.py | 3 +- data_diff/databases/database_types.py | 53 ++++++++++---------- data_diff/databases/databricks.py | 14 +++++- data_diff/databases/mysql.py | 12 ++++- data_diff/databases/oracle.py | 16 +++++- data_diff/databases/postgresql.py | 12 ++++- data_diff/databases/presto.py | 23 +++++++-- data_diff/databases/redshift.py | 3 +- data_diff/databases/snowflake.py | 4 +- data_diff/databases/trino.py | 2 +- data_diff/diff_tables.py | 2 +- data_diff/hashdiff_tables.py | 21 ++++---- data_diff/joindiff_tables.py | 70 +++++++++++---------------- data_diff/queries/ast_classes.py | 20 +++----- data_diff/queries/compiler.py | 2 +- data_diff/table_segment.py | 29 +++++------ data_diff/thread_utils.py | 2 +- data_diff/utils.py | 11 ++++- tests/test_database_types.py | 11 ++--- tests/test_query.py | 2 +- 22 files changed, 203 insertions(+), 140 deletions(-) diff --git a/data_diff/__init__.py b/data_diff/__init__.py index ae6f021d..20c6b57d 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -70,24 +70,27 @@ def diff_tables( Parameters: key_columns (Tuple[str, ...]): Name of the key column, which uniquely identifies each row (usually id) - update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update). - Used by `min_update` and `max_update`. + update_column (str, optional): Name of updated column, which signals that rows changed. + Usually updated_at or last_update. Used by `min_update` and `max_update`. extra_columns (Tuple[str, ...], optional): Extra columns to compare min_key (:data:`DbKey`, optional): Lowest key value, used to restrict the segment max_key (:data:`DbKey`, optional): Highest key value, used to restrict the segment min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment max_update (:data:`DbTime`, optional): Highest update_column value, used to restrict the segment algorithm (:class:`Algorithm`): Which diffing algorithm to use (`HASHDIFF` or `JOINDIFF`) - bisection_factor (int): Into how many segments to bisect per iteration. (when algorithm is `HASHDIFF`) - bisection_threshold (Number): When should we stop bisecting and compare locally (when algorithm is `HASHDIFF`; in row count). + bisection_factor (int): Into how many segments to bisect per iteration. (Used when algorithm is `HASHDIFF`) + bisection_threshold (Number): Minimal row count of segment to bisect, otherwise download + and compare locally. (Used when algorithm is `HASHDIFF`). threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. - max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``. + max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. + Only relevant when `threaded` is ``True``. There may be many pools, so number of actual threads can be a lot higher. Note: The following parameters are used to override the corresponding attributes of the given :class:`TableSegment` instances: - `key_columns`, `update_column`, `extra_columns`, `min_key`, `max_key`. If different values are needed per table, it's - possible to omit them here, and instead set them directly when creating each :class:`TableSegment`. + `key_columns`, `update_column`, `extra_columns`, `min_key`, `max_key`. + If different values are needed per table, it's possible to omit them here, and instead set + them directly when creating each :class:`TableSegment`. Example: >>> table1 = connect_to_table('postgresql:///', 'Rating', 'id') diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index c96ec6ae..897d4c3a 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -8,6 +8,7 @@ from abc import abstractmethod from data_diff.utils import is_uuid, safezip +from data_diff.queries import Expr, Compiler, table, Select, SKIP, Explain from .database_types import ( AbstractDatabase, ColType, @@ -27,8 +28,6 @@ DbPath, ) -from data_diff.queries import Expr, Compiler, table, Select, SKIP, Explain - logger = logging.getLogger("database") @@ -110,6 +109,8 @@ class Database(AbstractDatabase): default_schema: str = None SUPPORTS_ALPHANUMS = True + _interactive = False + @property def name(self): return type(self).__name__ @@ -126,11 +127,14 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list): return SKIP logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code) - if getattr(self, "_interactive", False) and isinstance(sql_ast, Select): + if self._interactive and isinstance(sql_ast, Select): explained_sql = compiler.compile(Explain(sql_ast)) explain = self._query(explained_sql) - for (row,) in explain: - logger.debug(f"EXPLAIN: {row}") + for row in explain: + # Most returned a 1-tuple. Presto returns a string + if isinstance(row, tuple): + row ,= row + logger.debug("EXPLAIN: %s", row) answer = input("Continue? [y/n] ") if not answer.lower() in ["y", "yes"]: sys.exit(1) diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 7044c084..603bfecc 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -1,4 +1,5 @@ -from .database_types import * +from typing import Union +from .database_types import Timestamp, Datetime, Integer, Decimal, Float, Text, DbPath, FractionalType, TemporalType from .base import Database, import_helper, parse_table_name, ConnectError, apply_query from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index dc7a806c..1de1d2fc 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -1,6 +1,6 @@ import logging import decimal -from abc import ABC, abstractmethod, abstractproperty +from abc import ABC, abstractmethod from typing import Sequence, Optional, Tuple, Union, Dict, List from datetime import datetime @@ -234,30 +234,6 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: """ ... - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - """Creates an SQL expression, that converts 'value' to a normalized representation. - - The returned expression must accept any SQL value, and return a string. - - The default implementation dispatches to a method according to `coltype`: - - :: - - TemporalType -> normalize_timestamp() - FractionalType -> normalize_number() - *else* -> to_string() - - (`Integer` falls in the *else* category) - - """ - if isinstance(coltype, TemporalType): - return self.normalize_timestamp(value, coltype) - elif isinstance(coltype, FractionalType): - return self.normalize_number(value, coltype) - elif isinstance(coltype, ColType_UUID): - return self.normalize_uuid(value, coltype) - return self.to_string(value) - class AbstractDatabase(AbstractDialect, AbstractDatadiffDialect): @abstractmethod @@ -304,10 +280,35 @@ def close(self): def _normalize_table_path(self, path: DbPath) -> DbPath: ... - @abstractproperty + @property + @abstractmethod def is_autocommit(self) -> bool: ... + def normalize_value_by_type(self, value: str, coltype: ColType) -> str: + """Creates an SQL expression, that converts 'value' to a normalized representation. + + The returned expression must accept any SQL value, and return a string. + + The default implementation dispatches to a method according to `coltype`: + + :: + + TemporalType -> normalize_timestamp() + FractionalType -> normalize_number() + *else* -> to_string() + + (`Integer` falls in the *else* category) + + """ + if isinstance(coltype, TemporalType): + return self.normalize_timestamp(value, coltype) + elif isinstance(coltype, FractionalType): + return self.normalize_number(value, coltype) + elif isinstance(coltype, ColType_UUID): + return self.normalize_uuid(value, coltype) + return self.to_string(value) + Schema = CaseAwareMapping diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 5d381b66..612c1c8d 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -1,6 +1,18 @@ +from typing import Dict, Sequence import logging -from .database_types import * +from .database_types import ( + Integer, + Float, + Decimal, + Timestamp, + Text, + TemporalType, + NumericType, + DbPath, + ColType, + UnknownColType, +) from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Database, import_helper, parse_table_name diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index b666e0c5..3f9eb98c 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -1,4 +1,14 @@ -from .database_types import * +from .database_types import ( + Datetime, + Timestamp, + Float, + Decimal, + Integer, + Text, + TemporalType, + FractionalType, + ColType_UUID, +) from .base import ThreadedDatabase, import_helper, ConnectError from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 59004412..e65fd65a 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -1,6 +1,20 @@ +from typing import Dict, List, Optional + from ..utils import match_regexps -from .database_types import * +from .database_types import ( + Decimal, + Float, + Text, + DbPath, + TemporalType, + ColType, + DbTime, + ColType_UUID, + Timestamp, + TimestampTZ, + FractionalType, +) from .base import ThreadedDatabase, import_helper, ConnectError, QueryError from .base import TIMESTAMP_PRECISION_POS diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index d65ac7de..72d26d07 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -1,4 +1,14 @@ -from .database_types import * +from .database_types import ( + Timestamp, + TimestampTZ, + Float, + Decimal, + Integer, + TemporalType, + Native_UUID, + Text, + FractionalType, +) from .base import ThreadedDatabase, import_helper, ConnectError from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, _CHECKSUM_BITSIZE, TIMESTAMP_PRECISION_POS diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index d7204775..811a9491 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -3,7 +3,19 @@ from data_diff.utils import match_regexps -from .database_types import * +from .database_types import ( + Timestamp, + TimestampTZ, + Integer, + Float, + Text, + FractionalType, + DbPath, + Decimal, + ColType, + ColType_UUID, + TemporalType, +) from .base import Database, import_helper, ThreadLocalInterpreter from .base import ( MD5_HEXDIGITS, @@ -17,7 +29,7 @@ def query_cursor(c, sql_code): if sql_code.lower().startswith("select"): return c.fetchall() # Required for the query to actually run 🤯 - if re.match(r"(insert|create|truncate|drop)", sql_code, re.IGNORECASE): + if re.match(r"(insert|create|truncate|drop|explain)", sql_code, re.IGNORECASE): return c.fetchone() @@ -98,7 +110,7 @@ def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) return ( - "SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision " + "SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision, NULL as numeric_scale " "FROM INFORMATION_SCHEMA.COLUMNS " f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) @@ -110,6 +122,7 @@ def _parse_type( type_repr: str, datetime_precision: int = None, numeric_precision: int = None, + numeric_scale: int = None, ) -> ColType: timestamp_regexps = { r"timestamp\((\d)\)": Timestamp, @@ -134,5 +147,9 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Trim doesn't work on CHAR type return f"TRIM(CAST({value} AS VARCHAR))" + @property def is_autocommit(self) -> bool: return False + + def explain_as_text(self, query: str) -> str: + return f"EXPLAIN (FORMAT TEXT) {query}" diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index f11b950c..291d180b 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -1,4 +1,5 @@ -from .database_types import * +from typing import List +from .database_types import Float, TemporalType, FractionalType, DbPath from .postgresql import PostgreSQL, MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 714fb5f0..635ba8f4 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,6 +1,7 @@ +from typing import Union import logging -from .database_types import * +from .database_types import Timestamp, TimestampTZ, Decimal, Float, Text, FractionalType, TemporalType, DbPath from .base import ConnectError, Database, import_helper, CHECKSUM_MASK, ThreadLocalInterpreter @@ -88,6 +89,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: def normalize_number(self, value: str, coltype: FractionalType) -> str: return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") + @property def is_autocommit(self) -> bool: return True diff --git a/data_diff/databases/trino.py b/data_diff/databases/trino.py index c3e3e581..73ef4a97 100644 --- a/data_diff/databases/trino.py +++ b/data_diff/databases/trino.py @@ -1,4 +1,4 @@ -from .database_types import * +from .database_types import TemporalType, ColType_UUID from .presto import Presto from .base import import_helper from .base import TIMESTAMP_PRECISION_POS diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 24627c45..1148041b 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -78,6 +78,7 @@ def _run_in_background(self, *funcs): class TableDiffer(ThreadBase, ABC): bisection_factor = 32 + stats: dict = {} def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: """Diff the given tables. @@ -177,7 +178,6 @@ def _bisect_and_diff_tables(self, table1, table2): table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)] logger.info( - # f"Diffing tables | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}. " f"Diffing segments at key-range: {table1.min_key}..{table2.max_key}. " f"size: table1 <= {table1.approximate_size()}, table2 <= {table2.approximate_size()}" ) diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index ec575bdc..38e6fee5 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -9,7 +9,7 @@ from .utils import safezip from .thread_utils import ThreadedYielder -from .databases.database_types import ColType_UUID, IKey, NumericType, PrecisionType, StringType +from .databases.database_types import ColType_UUID, NumericType, PrecisionType, StringType from .table_segment import TableSegment from .diff_tables import TableDiffer @@ -27,7 +27,7 @@ def diff_sets(a: set, b: set) -> Iterator: s2 = set(b) d = defaultdict(list) - # The first item is always the key (see TableDiffer._relevant_columns) + # The first item is always the key (see TableDiffer.relevant_columns) for i in s1 - s2: d[i[0]].append(("-", i)) for i in s2 - s1: @@ -50,7 +50,8 @@ class HashDiffer(TableDiffer): bisection_factor (int): Into how many segments to bisect per iteration. bisection_threshold (Number): When should we stop bisecting and compare locally (in row count). threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. - max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``. + max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. + Only relevant when `threaded` is ``True``. There may be many pools, so number of actual threads can be a lot higher. """ @@ -67,7 +68,7 @@ def __post_init__(self): raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)") def _validate_and_adjust_columns(self, table1, table2): - for c1, c2 in safezip(table1._relevant_columns, table2._relevant_columns): + for c1, c2 in safezip(table1.relevant_columns, table2.relevant_columns): if c1 not in table1._schema: raise ValueError(f"Column '{c1}' not found in schema for table {table1}") if c2 not in table2._schema: @@ -109,7 +110,7 @@ def _validate_and_adjust_columns(self, table1, table2): raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") for t in [table1, table2]: - for c in t._relevant_columns: + for c in t.relevant_columns: ctype = t._schema[c] if not ctype.supported: logger.warning( @@ -144,10 +145,12 @@ def _diff_segments( (count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2]) if count1 == 0 and count2 == 0: - # logger.warning( - # f"Uneven distribution of keys detected in segment {table1.min_key}..{table2.max_key}. (big gaps in the key column). " - # "For better performance, we recommend to increase the bisection-threshold." - # ) + logger.debug( + "Uneven distribution of keys detected in segment %s..%s (big gaps in the key column). " + "For better performance, we recommend to increase the bisection-threshold.", + table1.min_key, + table1.max_key, + ) assert checksum1 is None and checksum2 is None return diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index a1f23b23..58246def 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -6,22 +6,22 @@ from decimal import Decimal from functools import partial import logging -from typing import Dict, List, Optional +from typing import List from runtype import dataclass -from data_diff.databases.database_types import DbPath, NumericType, Schema +from data_diff.databases.database_types import DbPath, NumericType from data_diff.databases.base import QueryError from .utils import safezip from .databases.base import Database -from .databases import MySQL, BigQuery, Presto, Oracle, PostgreSQL, Snowflake +from .databases import MySQL, BigQuery, Presto, Oracle, Snowflake from .table_segment import TableSegment from .diff_tables import TableDiffer, DiffResult from .thread_utils import ThreadedYielder -from .queries import table, sum_, min_, max_, avg, SKIP, commit +from .queries import table, sum_, min_, max_, avg, commit from .queries.api import and_, if_, or_, outerjoin, leftjoin, rightjoin, this, ITable from .queries.ast_classes import Concat, Count, Expr, Random, TablePath from .queries.compiler import Compiler @@ -40,29 +40,20 @@ def merge_dicts(dicts): return res -@dataclass(frozen=False) -class Stats: - exclusive_count: int - exclusive_sample: List[tuple] - diff_ratio_by_column: Dict[str, float] - diff_ratio_total: float - metrics: Dict[str, float] +def sample(table_expr): + return table_expr.order_by(Random()).limit(10) -def sample(table): - return table.order_by(Random()).limit(10) - - -def create_temp_table(c: Compiler, table: TablePath, expr: Expr): +def create_temp_table(c: Compiler, path: TablePath, expr: Expr): db = c.database if isinstance(db, BigQuery): - return f"create table {c.compile(table)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}" + return f"create table {c.compile(path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}" elif isinstance(db, Presto): - return f"create table {c.compile(table)} as {c.compile(expr)}" + return f"create table {c.compile(path)} as {c.compile(expr)}" elif isinstance(db, Oracle): - return f"create global temporary table {c.compile(table)} as {c.compile(expr)}" + return f"create global temporary table {c.compile(path)} as {c.compile(expr)}" else: - return f"create temporary table {c.compile(table)} as {c.compile(expr)}" + return f"create temporary table {c.compile(path)} as {c.compile(expr)}" def drop_table_oracle(name: DbPath): @@ -149,7 +140,8 @@ class JoinDiffer(TableDiffer): Parameters: threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads. - max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``. + max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. + Only relevant when `threaded` is ``True``. There may be many pools, so number of actual threads can be a lot higher. validate_unique_key (bool): Enable/disable validating that the key columns are unique. Single query, and can't be threaded, so it's very slow on non-cloud dbs. @@ -227,8 +219,8 @@ def _diff_segments( if is_xa and is_xb: # Can't both be exclusive, meaning a pk is NULL # This can happen if the explicit null test didn't finish running yet - raise ValueError(f"NULL values in one or more primary keys") - is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols)) + raise ValueError("NULL values in one or more primary keys") + _is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols)) if not is_xb: yield "-", tuple(a_row) if not is_xa: @@ -239,7 +231,7 @@ def _test_duplicate_keys(self, table1, table2): # Test duplicate keys for ts in [table1, table2]: - t = ts._make_select() + t = ts.make_select() key_columns = ts.key_columns q = t.select(total=Count(), total_distinct=Count(Concat(this[key_columns]), distinct=True)) @@ -252,17 +244,17 @@ def _test_null_keys(self, table1, table2): # Test null keys for ts in [table1, table2]: - t = ts._make_select() + t = ts.make_select() key_columns = ts.key_columns q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns)) nulls = ts.database.query(q, list) if nulls: - raise ValueError(f"NULL values in one or more primary keys") + raise ValueError("NULL values in one or more primary keys") - def _collect_stats(self, i, table): + def _collect_stats(self, i, table_seg: TableSegment): logger.info(f"Collecting stats for table #{i}") - db = table.database + db = table_seg.database # Metrics col_exprs = merge_dicts( @@ -272,21 +264,17 @@ def _collect_stats(self, i, table): f"min_{c}": min_(this[c]), f"max_{c}": max_(this[c]), } - for c in table._relevant_columns - if isinstance(table._schema[c], NumericType) + for c in table_seg.relevant_columns + if isinstance(table_seg._schema[c], NumericType) ) col_exprs["count"] = Count() - res = db.query(table._make_select().select(**col_exprs), tuple) + res = db.query(table_seg.make_select().select(**col_exprs), tuple) res = dict(zip([f"table{i}_{n}" for n in col_exprs], map(json_friendly_value, res))) for k, v in res.items(): self.stats[k] = self.stats.get(k, 0) + (v or 0) - # self.stats.update(res) - - logger.debug(f"Done collecting stats for table #{i}") - # stats.diff_ratio_by_column = diff_stats - # stats.diff_ratio_total = diff_stats['total_diff'] + logger.debug("Done collecting stats for table #%s", i) def _create_outer_join(self, table1, table2): db = table1.database @@ -298,13 +286,13 @@ def _create_outer_join(self, table1, table2): if len(keys1) != len(keys2): raise ValueError("The provided key columns are of a different count") - cols1 = table1._relevant_columns - cols2 = table2._relevant_columns + cols1 = table1.relevant_columns + cols2 = table2.relevant_columns if len(cols1) != len(cols2): raise ValueError("The provided columns are of a different count") - a = table1._make_select() - b = table2._make_select() + a = table1.make_select() + b = table2.make_select() is_diff_cols = {f"is_diff_{c1}": bool_to_int(a[c1].is_distinct_from(b[c2])) for c1, c2 in safezip(cols1, cols2)} @@ -359,4 +347,4 @@ def _materialize_diff(self, db, diff_rows, segment_index=None): f = append_to_table_oracle if isinstance(db, Oracle) else append_to_table db.query(f(self.materialize_to_table, diff_rows.limit(self.write_limit))) - logger.info(f"Materialized diff to table '{'.'.join(self.materialize_to_table)}'.") + logger.info("Materialized diff to table '%s'.", ".".join(self.materialize_to_table)) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index a73a69db..226c246b 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -131,7 +131,7 @@ def count(self): return Select(self, [Count()]) def union(self, other: "ITable"): - return Union(self, other) + return SetUnion(self, other) @dataclass @@ -401,7 +401,7 @@ def having(self): @dataclass -class Union(ExprNode, ITable): +class SetUnion(ExprNode, ITable): table1: ITable table2: ITable @@ -422,12 +422,12 @@ def schema(self): def compile(self, parent_c: Compiler) -> str: c = parent_c.replace(in_select=False) - union_all = f"{c.compile(self.table1)} UNION {c.compile(self.table2)}" + union = f"{c.compile(self.table1)} UNION {c.compile(self.table2)}" if parent_c.in_select: - union_all = f"({union_all}) {c.new_unique_name()}" + union = f"({union}) {c.new_unique_name()}" elif parent_c.in_join: - union_all = f"({union_all})" - return union_all + union = f"({union})" + return union @dataclass @@ -567,14 +567,6 @@ def __getitem__(self, name): return _ResolveColumn(name) -@dataclass -class Explain(ExprNode): - sql: Select - - def compile(self, c: Compiler) -> str: - return f"EXPLAIN {c.compile(self.sql)}" - - @dataclass class In(ExprNode): expr: Expr diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index 02bb48bc..eda7d981 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -1,7 +1,7 @@ import random from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Dict, Sequence, List, Union +from typing import Any, Dict, Sequence, List from runtype import dataclass diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 170955cd..cddbe9f5 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -1,5 +1,5 @@ import time -from typing import List, Sequence, Tuple +from typing import List, Tuple import logging from runtype import dataclass @@ -12,7 +12,7 @@ logger = logging.getLogger("table_segment") -RECOMMENDED_CHECKSUM_DURATION = 10 +RECOMMENDED_CHECKSUM_DURATION = 20 @dataclass @@ -23,8 +23,8 @@ class TableSegment: database (Database): Database instance. See :meth:`connect` table_path (:data:`DbPath`): Path to table in form of a tuple. e.g. `('my_dataset', 'table_name')` key_columns (Tuple[str]): Name of the key column, which uniquely identifies each row (usually id) - update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update) - Used by `min_update` and `max_update`. + update_column (str, optional): Name of updated column, which signals that rows changed. + Usually updated_at or last_update. Used by `min_update` and `max_update`. extra_columns (Tuple[str, ...], optional): Extra columns to compare min_key (:data:`DbKey`, optional): Lowest key value, used to restrict the segment max_key (:data:`DbKey`, optional): Highest key value, used to restrict the segment @@ -68,7 +68,7 @@ def __post_init__(self): ) def _with_raw_schema(self, raw_schema: dict) -> "TableSegment": - schema = self.database._process_table_schema(self.table_path, raw_schema, self._relevant_columns, self.where) + schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self.where) return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive)) def with_schema(self) -> "TableSegment": @@ -98,12 +98,12 @@ def _make_update_range(self): def source_table(self): return table(*self.table_path, schema=self._schema) - def _make_select(self): + def make_select(self): return self.source_table.where(*self._make_key_range(), *self._make_update_range(), self.where or SKIP) def get_values(self) -> list: "Download all the relevant values of the segment from the database" - select = self._make_select().select(*self._relevant_columns_repr) + select = self.make_select().select(*self._relevant_columns_repr) return self.database.query(select, List[Tuple]) def choose_checkpoints(self, count: int) -> List[DbKey]: @@ -142,7 +142,7 @@ def new(self, **kwargs) -> "TableSegment": return self.replace(**kwargs) @property - def _relevant_columns(self) -> List[str]: + def relevant_columns(self) -> List[str]: extras = list(self.extra_columns) if self.update_column and self.update_column not in extras: @@ -152,22 +152,23 @@ def _relevant_columns(self) -> List[str]: @property def _relevant_columns_repr(self) -> List[Expr]: - return [NormalizeAsString(this[c]) for c in self._relevant_columns] + return [NormalizeAsString(this[c]) for c in self.relevant_columns] def count(self) -> Tuple[int, int]: """Count how many rows are in the segment, in one pass.""" - return self.database.query(self._make_select().select(Count()), int) + return self.database.query(self.make_select().select(Count()), int) def count_and_checksum(self) -> Tuple[int, int]: """Count and checksum the rows in the segment, in one pass.""" start = time.monotonic() - q = self._make_select().select(Count(), Checksum(self._relevant_columns_repr)) + q = self.make_select().select(Count(), Checksum(self._relevant_columns_repr)) count, checksum = self.database.query(q, tuple) duration = time.monotonic() - start if duration > RECOMMENDED_CHECKSUM_DURATION: logger.warning( - f"Checksum is taking longer than expected ({duration:.2f}s). " - "We recommend increasing --bisection-factor or decreasing --threads." + "Checksum is taking longer than expected (%.2f). " + "We recommend increasing --bisection-factor or decreasing --threads.", + duration, ) if count: @@ -178,7 +179,7 @@ def query_key_range(self) -> Tuple[int, int]: """Query database for minimum and maximum key. This is used for setting the initial bounds.""" # Normalizes the result (needed for UUIDs) after the min/max computation (k,) = self.key_columns - select = self._make_select().select( + select = self.make_select().select( ApplyFuncAndNormalizeAsString(this[k], min_), ApplyFuncAndNormalizeAsString(this[k], max_), ) diff --git a/data_diff/thread_utils.py b/data_diff/thread_utils.py index 1e0d26b8..1be94ad4 100644 --- a/data_diff/thread_utils.py +++ b/data_diff/thread_utils.py @@ -1,9 +1,9 @@ import itertools -from concurrent.futures.thread import _WorkItem from queue import PriorityQueue from collections import deque from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor +from concurrent.futures.thread import _WorkItem from time import sleep from typing import Callable, Iterator, Optional diff --git a/data_diff/utils.py b/data_diff/utils.py index b572db1b..2c8ccfba 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -192,7 +192,10 @@ def remove_password_from_url(url: str, replace_with: str = "***") -> str: def join_iter(joiner: Any, iterable: Iterable) -> Iterable: it = iter(iterable) - yield next(it) + try: + yield next(it) + except StopIteration: + return for i in it: yield joiner yield i @@ -221,6 +224,10 @@ def __contains__(self, key: str) -> bool: def __repr__(self): return repr(dict(self.items())) + @abstractmethod + def items(self) -> Iterable[Tuple[str, V]]: + ... + class CaseInsensitiveDict(CaseAwareMapping): def __init__(self, initial): @@ -302,7 +309,7 @@ def getLogger(name): def eval_name_template(name): - def get_timestamp(m): + def get_timestamp(_match): return datetime.now().isoformat("_", "seconds").replace(":", "_") return re.sub("%t", get_timestamp, name) diff --git a/tests/test_database_types.py b/tests/test_database_types.py index c9e9042c..b63eb2b2 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -418,13 +418,10 @@ def __iter__(self): type_pairs = [] for source_db, source_type_categories in DATABASE_TYPES.items(): for target_db, target_type_categories in DATABASE_TYPES.items(): - for ( - type_category, - source_types, - ) in source_type_categories.items(): # int, datetime, .. - for source_type in source_types: - for target_type in target_type_categories[type_category]: - if CONN_STRINGS.get(source_db, False) and CONN_STRINGS.get(target_db, False): + if CONN_STRINGS.get(source_db, False) and CONN_STRINGS.get(target_db, False): + for type_category, source_types in source_type_categories.items(): # int, datetime, .. + for source_type in source_types: + for target_type in target_type_categories[type_category]: type_pairs.append( ( source_db, diff --git a/tests/test_query.py b/tests/test_query.py index 3ab26e43..d02e9745 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -149,7 +149,7 @@ def test_funcs(self): q = c.compile(t.order_by(Random()).limit(10)) assert q == "SELECT * FROM a ORDER BY random() limit 10" - def test_union_all(self): + def test_union(self): c = Compiler(MockDialect()) a = table("a").select("x") b = table("b").select("y")