diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index f0a96f20..189dced4 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -335,6 +335,7 @@ def _constant_value(self, v): elif isinstance(v, str): return f"'{v}'" elif isinstance(v, datetime): + # TODO use self.timestamp_value return f"timestamp '{v}'" elif isinstance(v, UUID): return f"'{v}'" diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index 1de1d2fc..86df4489 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -147,12 +147,12 @@ class AbstractDialect(ABC): @abstractmethod def quote(self, s: str): - "Quote SQL name (implementation specific)" + "Quote SQL name" ... @abstractmethod def concat(self, l: List[str]) -> str: - "Provide SQL for concatenating a bunch of column into a string" + "Provide SQL for concatenating a bunch of columns into a string" ... @abstractmethod @@ -162,12 +162,13 @@ def is_distinct_from(self, a: str, b: str) -> str: @abstractmethod def to_string(self, s: str) -> str: + # TODO rewrite using cast_to(x, str) "Provide SQL for casting a column to string" ... @abstractmethod def random(self) -> str: - "Provide SQL for generating a random number" + "Provide SQL for generating a random number betweein 0..1" @abstractmethod def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): @@ -176,7 +177,7 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None @abstractmethod def explain_as_text(self, query: str) -> str: - "Provide SQL for explaining a query, returned in as table(varchar)" + "Provide SQL for explaining a query, returned as table(varchar)" ... @abstractmethod diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 80647ba3..bd849e59 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -153,4 +153,6 @@ def type_repr(self, t) -> str: return super().type_repr(t) def constant_values(self, rows) -> str: - return " UNION ALL ".join("SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows) + return " UNION ALL ".join( + "SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows + ) diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index 2f5d96be..797fafa5 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -47,14 +47,14 @@ def or_(*exprs: Expr): exprs = args_as_tuple(exprs) if len(exprs) == 1: return exprs[0] - return BinOp("OR", exprs) + return BinBoolOp("OR", exprs) def and_(*exprs: Expr): exprs = args_as_tuple(exprs) if len(exprs) == 1: return exprs[0] - return BinOp("AND", exprs) + return BinBoolOp("AND", exprs) def sum_(expr: Expr): diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index b5456b59..88d7ab11 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -32,7 +32,7 @@ def cast_to(self, to): Expr = Union[ExprNode, str, bool, int, datetime, ArithString, None] -def get_type(e: Expr) -> type: +def _expr_type(e: Expr) -> type: if isinstance(e, ExprNode): return e.type return type(e) @@ -48,7 +48,7 @@ def compile(self, c: Compiler) -> str: @property def type(self): - return get_type(self.expr) + return _expr_type(self.expr) def _drop_skips(exprs): @@ -156,6 +156,8 @@ class Count(ExprNode): expr: Expr = "*" distinct: bool = False + type = int + def compile(self, c: Compiler) -> str: expr = c.compile(self.expr) if self.distinct: @@ -174,12 +176,6 @@ def compile(self, c: Compiler) -> str: return f"{self.name}({args})" -def _expr_type(e: Expr): - if isinstance(e, ExprNode): - return e.type - return type(e) - - @dataclass class CaseWhen(ExprNode): cases: Sequence[Tuple[Expr, Expr]] @@ -226,6 +222,9 @@ def __le__(self, other): def __or__(self, other): return BinBoolOp("OR", [self, other]) + def __and__(self, other): + return BinBoolOp("AND", [self, other]) + def is_distinct_from(self, other): return IsDistinctFrom(self, other) @@ -254,7 +253,7 @@ def compile(self, c: Compiler) -> str: @property def type(self): - types = {get_type(i) for i in self.args} + types = {_expr_type(i) for i in self.args} if len(types) > 1: raise TypeError(f"Expected all args to have the same type, got {types}") (t,) = types @@ -298,6 +297,16 @@ class TablePath(ExprNode, ITable): path: DbPath schema: Optional[Schema] = field(default=None, repr=False) + @property + def source_table(self): + return self + + def compile(self, c: Compiler) -> str: + path = self.path # c.database._normalize_table_path(self.name) + return ".".join(map(c.quote, path)) + + # Statement shorthands + def create(self, source_table: ITable = None, *, if_not_exists=False): if source_table is None and not self.schema: raise ValueError("Either schema or source table needed to create table") @@ -323,14 +332,6 @@ def insert_expr(self, expr: Expr): expr = expr.select() return InsertToTable(self, expr) - @property - def source_table(self): - return self - - def compile(self, c: Compiler) -> str: - path = self.path # c.database._normalize_table_path(self.name) - return ".".join(map(c.quote, path)) - @dataclass class TableAlias(ExprNode, ITable): @@ -386,7 +387,7 @@ def compile(self, parent_c: Compiler) -> str: tables = [ t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in self.source_tables ] - c = parent_c.add_table_context(*tables).replace(in_join=True, in_select=False) + c = parent_c.add_table_context(*tables, in_join=True, in_select=False) op = " JOIN " if self.op is None else f" {self.op} JOIN " joined = op.join(c.compile(t) for t in tables) @@ -408,7 +409,7 @@ def compile(self, parent_c: Compiler) -> str: class GroupBy(ITable): def having(self): - pass + raise NotImplementedError() @dataclass @@ -546,26 +547,26 @@ class _ResolveColumn(ExprNode, LazyOps): resolve_name: str resolved: Expr = None - def resolve(self, expr): - assert self.resolved is None + def resolve(self, expr: Expr): + if self.resolved is not None: + raise RuntimeError("Already resolved!") self.resolved = expr - def compile(self, c: Compiler) -> str: + def _get_resolved(self) -> Expr: if self.resolved is None: raise RuntimeError(f"Column not resolved: {self.resolve_name}") - return self.resolved.compile(c) + return self.resolved + + def compile(self, c: Compiler) -> str: + return self._get_resolved().compile(c) @property def type(self): - if self.resolved is None: - raise RuntimeError(f"Column not resolved: {self.resolve_name}") - return self.resolved.type + return self._get_resolved().type @property def name(self): - if self.resolved is None: - raise RuntimeError(f"Column not resolved: {self.name}") - return self.resolved.name + return self._get_resolved().name class This: @@ -583,6 +584,8 @@ class In(ExprNode): expr: Expr list: Sequence[Expr] + type = bool + def compile(self, c: Compiler): elems = ", ".join(map(c.compile, self.list)) return f"({c.compile(self.expr)} IN ({elems}))" @@ -599,6 +602,8 @@ def compile(self, c: Compiler) -> str: @dataclass class Random(ExprNode): + type = float + def compile(self, c: Compiler) -> str: return c.database.random() @@ -618,6 +623,8 @@ def compile_for_insert(self, c: Compiler): class Explain(ExprNode): select: Select + type = str + def compile(self, c: Compiler) -> str: return c.database.explain_as_text(c.compile(self.select)) @@ -640,7 +647,7 @@ def compile(self, c: Compiler) -> str: if self.source_table: return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}" - schema = ", ".join(f"{c.database.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items()) + schema = ", ".join(f"{c.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items()) return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})" diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index eda7d981..31242131 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -21,9 +21,6 @@ class Compiler: _counter: List = [0] - def quote(self, s: str): - return self.database.quote(s) - def compile(self, elem) -> str: res = self._compile(elem) if self.root and self._subqueries: @@ -57,8 +54,11 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath: self._counter[0] += 1 return self.database.parse_table_name(f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}") - def add_table_context(self, *tables: Sequence): - return self.replace(_table_context=self._table_context + list(tables)) + def add_table_context(self, *tables: Sequence, **kw): + return self.replace(_table_context=self._table_context + list(tables), **kw) + + def quote(self, s: str): + return self.database.quote(s) class Compilable(ABC): diff --git a/data_diff/utils.py b/data_diff/utils.py index 2c8ccfba..a2b7e801 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -1,8 +1,8 @@ import logging import re import math -from typing import Iterable, Tuple, Union, Any, Sequence, Dict -from typing import TypeVar, Generic +from typing import Iterable, Iterator, MutableMapping, Union, Any, Sequence, Dict +from typing import TypeVar from abc import ABC, abstractmethod from urllib.parse import urlparse from uuid import UUID @@ -204,58 +204,39 @@ def join_iter(joiner: Any, iterable: Iterable) -> Iterable: V = TypeVar("V") -class CaseAwareMapping(ABC, Generic[V]): +class CaseAwareMapping(MutableMapping[str, V]): @abstractmethod def get_key(self, key: str) -> str: ... - @abstractmethod - def __getitem__(self, key: str) -> V: - ... - - @abstractmethod - def __setitem__(self, key: str, value: V): - ... - - @abstractmethod - def __contains__(self, key: str) -> bool: - ... - - def __repr__(self): - return repr(dict(self.items())) - - @abstractmethod - def items(self) -> Iterable[Tuple[str, V]]: - ... - class CaseInsensitiveDict(CaseAwareMapping): def __init__(self, initial): self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()} - def get_key(self, key: str) -> str: - return self._dict[key.lower()][0] - def __getitem__(self, key: str) -> V: return self._dict[key.lower()][1] + def __iter__(self) -> Iterator[V]: + return iter(self._dict) + + def __len__(self) -> int: + return len(self._dict) + def __setitem__(self, key: str, value): k = key.lower() if k in self._dict: key = self._dict[k][0] self._dict[k] = key, value - def __contains__(self, key): - return key.lower() in self._dict - - def keys(self) -> Iterable[str]: - return self._dict.keys() + def __delitem__(self, key: str): + del self._dict[key.lower()] - def items(self) -> Iterable[Tuple[str, V]]: - return ((k, v[1]) for k, v in self._dict.items()) + def get_key(self, key: str) -> str: + return self._dict[key.lower()][0] - def __len__(self): - return len(self._dict) + def __repr__(self) -> str: + return repr(dict(self.items())) class CaseSensitiveDict(dict, CaseAwareMapping):