diff --git a/data_diff/abcs/mixins.py b/data_diff/abcs/mixins.py deleted file mode 100644 index 3bc566e5..00000000 --- a/data_diff/abcs/mixins.py +++ /dev/null @@ -1,133 +0,0 @@ -from abc import ABC, abstractmethod - -import attrs - -from data_diff.abcs.database_types import ( - Array, - TemporalType, - FractionalType, - ColType_UUID, - Boolean, - ColType, - String_UUID, - JSON, - Struct, -) -from data_diff.abcs.compiler import Compilable - - -@attrs.define(frozen=False) -class AbstractMixin(ABC): - "A mixin for a database dialect" - - -@attrs.define(frozen=False) -class AbstractMixin_NormalizeValue(AbstractMixin): - @abstractmethod - def to_comparable(self, value: str, coltype: ColType) -> str: - """Ensure that the expression is comparable in ``IS DISTINCT FROM``.""" - - @abstractmethod - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - """Creates an SQL expression, that converts 'value' to a normalized timestamp. - - The returned expression must accept any SQL datetime/timestamp, and return a string. - - Date format: ``YYYY-MM-DD HH:mm:SS.FFFFFF`` - - Precision of dates should be rounded up/down according to coltype.rounds - """ - - @abstractmethod - def normalize_number(self, value: str, coltype: FractionalType) -> str: - """Creates an SQL expression, that converts 'value' to a normalized number. - - The returned expression must accept any SQL int/numeric/float, and return a string. - - Floats/Decimals are expected in the format - "I.P" - - Where I is the integer part of the number (as many digits as necessary), - and must be at least one digit (0). - P is the fractional digits, the amount of which is specified with - coltype.precision. Trailing zeroes may be necessary. - If P is 0, the dot is omitted. - - Note: We use 'precision' differently than most databases. For decimals, - it's the same as ``numeric_scale``, and for floats, who use binary precision, - it can be calculated as ``log10(2**numeric_precision)``. - """ - - def normalize_boolean(self, value: str, _coltype: Boolean) -> str: - """Creates an SQL expression, that converts 'value' to either '0' or '1'.""" - return self.to_string(value) - - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - """Creates an SQL expression, that strips uuids of artifacts like whitespace.""" - if isinstance(coltype, String_UUID): - return f"TRIM({value})" - return self.to_string(value) - - def normalize_json(self, value: str, _coltype: JSON) -> str: - """Creates an SQL expression, that converts 'value' to its minified json string representation.""" - return self.to_string(value) - - def normalize_array(self, value: str, _coltype: Array) -> str: - """Creates an SQL expression, that serialized an array into a JSON string.""" - return self.to_string(value) - - def normalize_struct(self, value: str, _coltype: Struct) -> str: - """Creates an SQL expression, that serialized a typed struct into a JSON string.""" - return self.to_string(value) - - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - """Creates an SQL expression, that converts 'value' to a normalized representation. - - The returned expression must accept any SQL value, and return a string. - - The default implementation dispatches to a method according to `coltype`: - - :: - - TemporalType -> normalize_timestamp() - FractionalType -> normalize_number() - *else* -> to_string() - - (`Integer` falls in the *else* category) - - """ - if isinstance(coltype, TemporalType): - return self.normalize_timestamp(value, coltype) - elif isinstance(coltype, FractionalType): - return self.normalize_number(value, coltype) - elif isinstance(coltype, ColType_UUID): - return self.normalize_uuid(value, coltype) - elif isinstance(coltype, Boolean): - return self.normalize_boolean(value, coltype) - elif isinstance(coltype, JSON): - return self.normalize_json(value, coltype) - elif isinstance(coltype, Array): - return self.normalize_array(value, coltype) - elif isinstance(coltype, Struct): - return self.normalize_struct(value, coltype) - return self.to_string(value) - - -@attrs.define(frozen=False) -class AbstractMixin_MD5(AbstractMixin): - """Methods for calculating an MD6 hash as an integer.""" - - @abstractmethod - def md5_as_int(self, s: str) -> str: - "Provide SQL for computing md5 and returning an int" - - -@attrs.define(frozen=False) -class AbstractMixin_OptimizerHints(AbstractMixin): - @abstractmethod - def optimizer_hints(self, optimizer_hints: str) -> str: - """Creates a compatible optimizer_hints string - - Parameters: - optimizer_hints - string of optimizer hints - """ diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 0d4184ef..8caa6817 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -17,7 +17,7 @@ import attrs from typing_extensions import Self -from data_diff.abcs.compiler import AbstractCompiler +from data_diff.abcs.compiler import AbstractCompiler, Compilable from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, Checksum, NormalizeAsString from data_diff.utils import ArithString, is_uuid, join_iter, safezip from data_diff.queries.api import Expr, table, Select, SKIP, Explain, Code, this @@ -55,6 +55,8 @@ ) from data_diff.abcs.database_types import ( Array, + ColType_UUID, + FractionalType, Struct, ColType, Integer, @@ -73,11 +75,6 @@ Boolean, JSON, ) -from data_diff.abcs.mixins import Compilable -from data_diff.abcs.mixins import ( - AbstractMixin_NormalizeValue, - AbstractMixin_OptimizerHints, -) logger = logging.getLogger("database") cv_params = contextvars.ContextVar("params") @@ -198,12 +195,6 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal return callback(sql_code) -@attrs.define(frozen=False) -class Mixin_OptimizerHints(AbstractMixin_OptimizerHints): - def optimizer_hints(self, hints: str) -> str: - return f"/*+ {hints} */ " - - @attrs.define(frozen=False) class BaseDialect(abc.ABC): SUPPORTS_PRIMARY_KEY: ClassVar[bool] = False @@ -771,6 +762,98 @@ def to_string(self, s: str) -> str: def set_timezone_to_utc(self) -> str: "Provide SQL for setting the session timezone to UTC" + @abstractmethod + def md5_as_int(self, s: str) -> str: + "Provide SQL for computing md5 and returning an int" + + @abstractmethod + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + """Creates an SQL expression, that converts 'value' to a normalized timestamp. + + The returned expression must accept any SQL datetime/timestamp, and return a string. + + Date format: ``YYYY-MM-DD HH:mm:SS.FFFFFF`` + + Precision of dates should be rounded up/down according to coltype.rounds + """ + + @abstractmethod + def normalize_number(self, value: str, coltype: FractionalType) -> str: + """Creates an SQL expression, that converts 'value' to a normalized number. + + The returned expression must accept any SQL int/numeric/float, and return a string. + + Floats/Decimals are expected in the format + "I.P" + + Where I is the integer part of the number (as many digits as necessary), + and must be at least one digit (0). + P is the fractional digits, the amount of which is specified with + coltype.precision. Trailing zeroes may be necessary. + If P is 0, the dot is omitted. + + Note: We use 'precision' differently than most databases. For decimals, + it's the same as ``numeric_scale``, and for floats, who use binary precision, + it can be calculated as ``log10(2**numeric_precision)``. + """ + + def normalize_boolean(self, value: str, _coltype: Boolean) -> str: + """Creates an SQL expression, that converts 'value' to either '0' or '1'.""" + return self.to_string(value) + + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + """Creates an SQL expression, that strips uuids of artifacts like whitespace.""" + if isinstance(coltype, String_UUID): + return f"TRIM({value})" + return self.to_string(value) + + def normalize_json(self, value: str, _coltype: JSON) -> str: + """Creates an SQL expression, that converts 'value' to its minified json string representation.""" + return self.to_string(value) + + def normalize_array(self, value: str, _coltype: Array) -> str: + """Creates an SQL expression, that serialized an array into a JSON string.""" + return self.to_string(value) + + def normalize_struct(self, value: str, _coltype: Struct) -> str: + """Creates an SQL expression, that serialized a typed struct into a JSON string.""" + return self.to_string(value) + + def normalize_value_by_type(self, value: str, coltype: ColType) -> str: + """Creates an SQL expression, that converts 'value' to a normalized representation. + + The returned expression must accept any SQL value, and return a string. + + The default implementation dispatches to a method according to `coltype`: + + :: + + TemporalType -> normalize_timestamp() + FractionalType -> normalize_number() + *else* -> to_string() + + (`Integer` falls in the *else* category) + + """ + if isinstance(coltype, TemporalType): + return self.normalize_timestamp(value, coltype) + elif isinstance(coltype, FractionalType): + return self.normalize_number(value, coltype) + elif isinstance(coltype, ColType_UUID): + return self.normalize_uuid(value, coltype) + elif isinstance(coltype, Boolean): + return self.normalize_boolean(value, coltype) + elif isinstance(coltype, JSON): + return self.normalize_json(value, coltype) + elif isinstance(coltype, Array): + return self.normalize_array(value, coltype) + elif isinstance(coltype, Struct): + return self.normalize_struct(value, coltype) + return self.to_string(value) + + def optimizer_hints(self, hints: str) -> str: + return f"/*+ {hints} */ " + T = TypeVar("T", bound=BaseDialect) @@ -966,10 +1049,7 @@ def _refine_coltypes( if not text_columns: return - if isinstance(self.dialect, AbstractMixin_NormalizeValue): - fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns] - else: - fields = this[text_columns] + fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns] samples_by_row = self.query( table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 02e19323..cca2108c 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -20,12 +20,6 @@ Boolean, UnknownColType, ) -from data_diff.abcs.mixins import ( - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, -) -from data_diff.abcs.compiler import Compilable -from data_diff.queries.api import this, table, SKIP, code from data_diff.databases.base import ( BaseDialect, Database, @@ -61,7 +55,7 @@ def import_bigquery_service_account_impersonation(): @attrs.define(frozen=False) -class Dialect(BaseDialect, AbstractMixin_MD5, AbstractMixin_NormalizeValue): +class Dialect(BaseDialect): name = "BigQuery" ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation TYPE_CLASSES = { diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 9fbf2eb8..7a8816d8 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -24,7 +24,6 @@ Timestamp, Boolean, ) -from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue # https://clickhouse.com/docs/en/operations/server-configuration-parameters/settings/#default-database DEFAULT_DATABASE = "default" @@ -38,7 +37,7 @@ def import_clickhouse(): @attrs.define(frozen=False) -class Dialect(BaseDialect, AbstractMixin_MD5, AbstractMixin_NormalizeValue): +class Dialect(BaseDialect): name = "Clickhouse" ROUNDS_ON_PREC_LOSS = False TYPE_CLASSES = { diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index f5bbadc5..b6dec72a 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -17,7 +17,6 @@ UnknownColType, Boolean, ) -from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from data_diff.databases.base import ( MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, @@ -37,7 +36,7 @@ def import_databricks(): @attrs.define(frozen=False) -class Dialect(BaseDialect, AbstractMixin_MD5, AbstractMixin_NormalizeValue): +class Dialect(BaseDialect): name = "Databricks" ROUNDS_ON_PREC_LOSS = True TYPE_CLASSES = { diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index 43edcd3f..a105b71a 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -17,10 +17,6 @@ FractionalType, Boolean, ) -from data_diff.abcs.mixins import ( - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, -) from data_diff.databases.base import ( Database, BaseDialect, @@ -41,7 +37,7 @@ def import_duckdb(): @attrs.define(frozen=False) -class Dialect(BaseDialect, AbstractMixin_MD5, AbstractMixin_NormalizeValue): +class Dialect(BaseDialect): name = "DuckDB" ROUNDS_ON_PREC_LOSS = False SUPPORTS_PRIMARY_KEY = True diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index 1ada701e..7f039cc6 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -2,10 +2,8 @@ import attrs -from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from data_diff.databases.base import ( CHECKSUM_HEXDIGITS, - Mixin_OptimizerHints, CHECKSUM_OFFSET, QueryError, ThreadedDatabase, @@ -37,12 +35,7 @@ def import_mssql(): @attrs.define(frozen=False) -class Dialect( - BaseDialect, - Mixin_OptimizerHints, - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, -): +class Dialect(BaseDialect): name = "MsSQL" ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 6b11068b..f4993b87 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -15,12 +15,7 @@ Boolean, Date, ) -from data_diff.abcs.mixins import ( - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, -) from data_diff.databases.base import ( - Mixin_OptimizerHints, ThreadedDatabase, import_helper, ConnectError, @@ -42,12 +37,7 @@ def import_mysql(): @attrs.define(frozen=False) -class Dialect( - BaseDialect, - Mixin_OptimizerHints, - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, -): +class Dialect(BaseDialect): name = "MySQL" ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index bcba374d..a8b8b75b 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -16,10 +16,8 @@ TimestampTZ, FractionalType, ) -from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from data_diff.databases.base import ( BaseDialect, - Mixin_OptimizerHints, ThreadedDatabase, import_helper, ConnectError, @@ -43,9 +41,6 @@ def import_oracle(): @attrs.define(frozen=False) class Dialect( BaseDialect, - Mixin_OptimizerHints, - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, ): name = "Oracle" SUPPORTS_PRIMARY_KEY = True diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 6bc3d488..075d6aff 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -18,7 +18,6 @@ Boolean, Date, ) -from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from data_diff.databases.base import BaseDialect, ThreadedDatabase, import_helper, ConnectError from data_diff.databases.base import ( MD5_HEXDIGITS, @@ -40,7 +39,7 @@ def import_postgresql(): @attrs.define(frozen=False) -class PostgresqlDialect(BaseDialect, AbstractMixin_MD5, AbstractMixin_NormalizeValue): +class PostgresqlDialect(BaseDialect): name = "PostgreSQL" ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index b308ac77..f575719a 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -21,7 +21,6 @@ TemporalType, Boolean, ) -from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from data_diff.databases.base import ( BaseDialect, Database, @@ -52,7 +51,7 @@ def import_presto(): return prestodb -class Dialect(BaseDialect, AbstractMixin_MD5, AbstractMixin_NormalizeValue): +class Dialect(BaseDialect): name = "Presto" ROUNDS_ON_PREC_LOSS = True TYPE_CLASSES = { diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 746b52e0..857e7c89 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -15,12 +15,6 @@ Boolean, Date, ) -from data_diff.abcs.mixins import ( - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, -) -from data_diff.abcs.compiler import Compilable -from data_diff.queries.api import table, this, SKIP, code from data_diff.databases.base import ( BaseDialect, ConnectError, @@ -41,7 +35,7 @@ def import_snowflake(): return snowflake, serialization, default_backend -class Dialect(BaseDialect, AbstractMixin_MD5, AbstractMixin_NormalizeValue): +class Dialect(BaseDialect): name = "Snowflake" ROUNDS_ON_PREC_LOSS = False TYPE_CLASSES = { diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index e4f86043..51dc00fa 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -27,7 +27,6 @@ Boolean, ColType_UUID, ) -from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue @import_helper("vertica") @@ -37,7 +36,7 @@ def import_vertica(): return vertica_python -class Dialect(BaseDialect, AbstractMixin_MD5, AbstractMixin_NormalizeValue): +class Dialect(BaseDialect): name = "Vertica" ROUNDS_ON_PREC_LOSS = True diff --git a/tests/test_query.py b/tests/test_query.py index 0d253dd5..69900b4b 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,6 +1,8 @@ from datetime import datetime from typing import List, Optional import unittest + +from data_diff.abcs.database_types import FractionalType, TemporalType from data_diff.databases.base import Database, BaseDialect from data_diff.utils import CaseInsensitiveDict, CaseSensitiveDict @@ -66,6 +68,15 @@ def set_timezone_to_utc(self) -> str: def optimizer_hints(self, s: str): return f"/*+ {s} */ " + def md5_as_int(self, s: str) -> str: + raise NotImplementedError + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + raise NotImplementedError + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + raise NotImplementedError + parse_type = NotImplemented