diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 9dc03909..871c650d 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -199,10 +199,19 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal class BaseDialect(abc.ABC): SUPPORTS_PRIMARY_KEY: ClassVar[bool] = False SUPPORTS_INDEXES: ClassVar[bool] = False + PREVENT_OVERFLOW_WHEN_CONCAT: ClassVar[bool] = False TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {} PLACEHOLDER_TABLE = None # Used for Oracle + # Some database do not support long string so concatenation might lead to type overflow + + _prevent_overflow_when_concat: bool = False + + def enable_preventing_type_overflow(self) -> None: + logger.info("Preventing type overflow when concatenation is enabled") + self._prevent_overflow_when_concat = True + def parse_table_name(self, name: str) -> DbPath: "Parse the given table name into a DbPath" return parse_table_name(name) @@ -392,10 +401,19 @@ def render_checksum(self, c: Compiler, elem: Checksum) -> str: return f"sum({md5})" def render_concat(self, c: Compiler, elem: Concat) -> str: + if self._prevent_overflow_when_concat: + items = [ + f"{self.compile(c, Code(self.md5_as_hex(self.to_string(self.compile(c, expr)))))}" + for expr in elem.exprs + ] + # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL - items = [ - f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '')" for expr in elem.exprs - ] + else: + items = [ + f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '')" + for expr in elem.exprs + ] + assert items if len(items) == 1: return items[0] @@ -769,6 +787,10 @@ def set_timezone_to_utc(self) -> str: def md5_as_int(self, s: str) -> str: "Provide SQL for computing md5 and returning an int" + @abstractmethod + def md5_as_hex(self, s: str) -> str: + """Method to calculate MD5""" + @abstractmethod def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: """Creates an SQL expression, that converts 'value' to a normalized timestamp. @@ -885,6 +907,8 @@ class Database(abc.ABC): Instanciated using :meth:`~data_diff.connect` """ + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = BaseDialect + SUPPORTS_ALPHANUMS: ClassVar[bool] = True SUPPORTS_UNIQUE_CONSTAINT: ClassVar[bool] = False CONNECT_URI_KWPARAMS: ClassVar[List[str]] = [] @@ -892,6 +916,7 @@ class Database(abc.ABC): default_schema: Optional[str] = None _interactive: bool = False is_closed: bool = False + _dialect: BaseDialect = None @property def name(self): @@ -1120,10 +1145,13 @@ def close(self): return super().close() @property - @abstractmethod def dialect(self) -> BaseDialect: "The dialect of the database. Used internally by Database, and also available publicly." + if not self._dialect: + self._dialect = self.DIALECT_CLASS() + return self._dialect + @property @abstractmethod def CONNECT_URI_HELP(self) -> str: diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index e672b928..26d8aec3 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -1,5 +1,5 @@ import re -from typing import Any, List, Union +from typing import Any, ClassVar, List, Union, Type import attrs @@ -134,6 +134,9 @@ def parse_table_name(self, name: str) -> DbPath: def md5_as_int(self, s: str) -> str: return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS})) as int64) as numeric) - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" @@ -179,9 +182,9 @@ def normalize_struct(self, value: str, _coltype: Struct) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class BigQuery(Database): + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "bigquery:///" CONNECT_URI_PARAMS = ["dataset"] - dialect = Dialect() project: str dataset: str diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 7a8816d8..13082504 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Type +from typing import Any, ClassVar, Dict, Optional, Type import attrs @@ -105,6 +105,9 @@ def md5_as_int(self, s: str) -> str: f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx}))))) - {CHECKSUM_OFFSET}" ) + def md5_as_hex(self, s: str) -> str: + return f"hex(MD5({s}))" + def normalize_number(self, value: str, coltype: FractionalType) -> str: # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. # For example: @@ -164,7 +167,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Clickhouse(ThreadedDatabase): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "clickhouse://:@/" CONNECT_URI_PARAMS = ["database?"] diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 7394f2df..19a1f103 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -1,5 +1,5 @@ import math -from typing import Any, Dict, Sequence +from typing import Any, ClassVar, Dict, Sequence, Type import logging import attrs @@ -82,6 +82,9 @@ def parse_table_name(self, name: str) -> DbPath: def md5_as_int(self, s: str) -> str: return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0)) - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: """Databricks timestamp contains no more than 6 digits in precision""" @@ -104,7 +107,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Databricks(ThreadedDatabase): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "databricks://:@/" CONNECT_URI_PARAMS = ["catalog", "schema"] diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index a105b71a..6c65b16b 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Union +from typing import Any, ClassVar, Dict, Union, Type import attrs @@ -100,6 +100,9 @@ def current_timestamp(self) -> str: def md5_as_int(self, s: str) -> str: return f"('0x' || SUBSTRING(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS},{CHECKSUM_HEXDIGITS}))::BIGINT - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: # It's precision 6 by default. If precision is less than 6 -> we remove the trailing numbers. if coltype.rounds and coltype.precision > 0: @@ -116,7 +119,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class DuckDB(Database): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect SUPPORTS_UNIQUE_CONSTAINT = False # Temporary, until we implement it CONNECT_URI_HELP = "duckdb://@" CONNECT_URI_PARAMS = ["database", "dbpath"] diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index fd23bef1..8f5195ee 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, ClassVar, Dict, Optional, Type import attrs @@ -38,7 +38,7 @@ def import_mssql(): class Dialect(BaseDialect): name = "MsSQL" ROUNDS_ON_PREC_LOSS = True - SUPPORTS_PRIMARY_KEY = True + SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True SUPPORTS_INDEXES = True TYPE_CLASSES = { # Timestamps @@ -151,10 +151,13 @@ def normalize_number(self, value: str, coltype: NumericType) -> str: def md5_as_int(self, s: str) -> str: return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1)) - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"HashBytes('MD5', {s})" + @attrs.define(frozen=False, init=False, kw_only=True) class MsSQL(ThreadedDatabase): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "mssql://:@//" CONNECT_URI_PARAMS = ["database", "schema"] diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index f4993b87..647388f2 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, ClassVar, Dict, Type import attrs @@ -40,7 +40,7 @@ def import_mysql(): class Dialect(BaseDialect): name = "MySQL" ROUNDS_ON_PREC_LOSS = True - SUPPORTS_PRIMARY_KEY = True + SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True SUPPORTS_INDEXES = True TYPE_CLASSES = { # Dates @@ -101,6 +101,9 @@ def set_timezone_to_utc(self) -> str: def md5_as_int(self, s: str) -> str: return f"conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))") @@ -117,7 +120,7 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class MySQL(ThreadedDatabase): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect SUPPORTS_ALPHANUMS = False SUPPORTS_UNIQUE_CONSTAINT = True CONNECT_URI_HELP = "mysql://:@/" diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 32bd30ef..ab84f0b6 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, ClassVar, Dict, List, Optional, Type import attrs @@ -43,7 +43,7 @@ class Dialect( BaseDialect, ): name = "Oracle" - SUPPORTS_PRIMARY_KEY = True + SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True SUPPORTS_INDEXES = True TYPE_CLASSES: Dict[str, type] = { "NUMBER": Decimal, @@ -137,6 +137,9 @@ def md5_as_int(self, s: str) -> str: # TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ? return f"to_number(substr(standard_hash({s}, 'MD5'), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 'xxxxxxxxxxxxxxx') - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"standard_hash({s}, 'MD5')" + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Cast is necessary for correct MD5 (trimming not enough) return f"CAST(TRIM({value}) AS VARCHAR(36))" @@ -161,7 +164,7 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Oracle(ThreadedDatabase): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "oracle://:@/" CONNECT_URI_PARAMS = ["database?"] diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 075d6aff..4b9e945f 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -42,7 +42,7 @@ def import_postgresql(): class PostgresqlDialect(BaseDialect): name = "PostgreSQL" ROUNDS_ON_PREC_LOSS = True - SUPPORTS_PRIMARY_KEY = True + SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True SUPPORTS_INDEXES = True TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = { @@ -98,6 +98,9 @@ def type_repr(self, t) -> str: def md5_as_int(self, s: str) -> str: return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" @@ -119,7 +122,7 @@ def normalize_json(self, value: str, _coltype: JSON) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class PostgreSQL(ThreadedDatabase): - dialect = PostgresqlDialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = PostgresqlDialect SUPPORTS_UNIQUE_CONSTAINT = True CONNECT_URI_HELP = "postgresql://:@/" CONNECT_URI_PARAMS = ["database?"] diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index f575719a..ba1c7360 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -1,6 +1,6 @@ from functools import partial import re -from typing import Any +from typing import Any, ClassVar, Type import attrs @@ -128,6 +128,9 @@ def current_timestamp(self) -> str: def md5_as_int(self, s: str) -> str: return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0)) - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"to_hex(md5(to_utf8({s})))" + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Trim doesn't work on CHAR type return f"TRIM(CAST({value} AS VARCHAR))" @@ -150,7 +153,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Presto(Database): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "presto://@//" CONNECT_URI_PARAMS = ["catalog", "schema"] diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index 968f57bb..7a621f57 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -12,6 +12,7 @@ TimestampTZ, ) from data_diff.databases.postgresql import ( + BaseDialect, PostgreSQL, MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, @@ -47,6 +48,9 @@ def type_repr(self, t) -> str: def md5_as_int(self, s: str) -> str: return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38) - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: timestamp = f"{value}::timestamp(6)" @@ -76,7 +80,7 @@ def normalize_json(self, value: str, _coltype: JSON) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Redshift(PostgreSQL): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "redshift://:@/" CONNECT_URI_PARAMS = ["database?"] diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 857e7c89..bedacd80 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,4 +1,4 @@ -from typing import Any, Union, List +from typing import Any, ClassVar, Union, List, Type import logging import attrs @@ -76,6 +76,9 @@ def type_repr(self, t) -> str: def md5_as_int(self, s: str) -> str: return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK}) - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"md5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, convert_timezone('UTC', {value})::timestamp(9))/1000000000, {coltype.precision}))" @@ -93,7 +96,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Snowflake(Database): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "snowflake://:@//?warehouse=" CONNECT_URI_PARAMS = ["database", "schema"] CONNECT_URI_KWPARAMS = ["warehouse"] diff --git a/data_diff/databases/trino.py b/data_diff/databases/trino.py index f0c95ee4..b76ba74b 100644 --- a/data_diff/databases/trino.py +++ b/data_diff/databases/trino.py @@ -1,11 +1,11 @@ -from typing import Any +from typing import Any, ClassVar, Type import attrs from data_diff.abcs.database_types import TemporalType, ColType_UUID from data_diff.databases import presto from data_diff.databases.base import import_helper -from data_diff.databases.base import TIMESTAMP_PRECISION_POS +from data_diff.databases.base import TIMESTAMP_PRECISION_POS, BaseDialect @import_helper("trino") @@ -34,7 +34,7 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Trino(presto.Presto): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "trino://@//" CONNECT_URI_PARAMS = ["catalog", "schema"] diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index 51dc00fa..23f63acc 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, ClassVar, Dict, List, Type import attrs @@ -36,6 +36,7 @@ def import_vertica(): return vertica_python +@attrs.define(frozen=False) class Dialect(BaseDialect): name = "Vertica" ROUNDS_ON_PREC_LOSS = True @@ -109,6 +110,9 @@ def current_timestamp(self) -> str: def md5_as_int(self, s: str) -> str: return f"CAST(HEX_TO_INTEGER(SUBSTRING(MD5({s}), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS})) AS NUMERIC(38, 0)) - {CHECKSUM_OFFSET}" + def md5_as_hex(self, s: str) -> str: + return f"MD5({s})" + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: return f"TO_CHAR({value}::TIMESTAMP({coltype.precision}), 'YYYY-MM-DD HH24:MI:SS.US')" @@ -131,7 +135,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: @attrs.define(frozen=False, init=False, kw_only=True) class Vertica(ThreadedDatabase): - dialect = Dialect() + DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect CONNECT_URI_HELP = "vertica://:@/" CONNECT_URI_PARAMS = ["database?"] diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 44daba34..66802426 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -208,6 +208,10 @@ def _diff_tables_wrapper(self, table1: TableSegment, table2: TableSegment, info_ event_json = create_start_event_json(options) run_as_daemon(send_event_json, event_json) + if table1.database.dialect.PREVENT_OVERFLOW_WHEN_CONCAT or table2.database.dialect.PREVENT_OVERFLOW_WHEN_CONCAT: + table1.database.dialect.enable_preventing_type_overflow() + table2.database.dialect.enable_preventing_type_overflow() + start = time.monotonic() error = None try: diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index ed8a31b6..0f664c45 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -270,7 +270,9 @@ def test_null_pks(self): self.assertRaises(ValueError, list, x) -@test_each_database_in_list(d for d in TEST_DATABASES if d.dialect.SUPPORTS_PRIMARY_KEY and d.SUPPORTS_UNIQUE_CONSTAINT) +@test_each_database_in_list( + d for d in TEST_DATABASES if d.DIALECT_CLASS.SUPPORTS_PRIMARY_KEY and d.SUPPORTS_UNIQUE_CONSTAINT +) class TestUniqueConstraint(DiffTestCase): def setUp(self): super().setUp() diff --git a/tests/test_query.py b/tests/test_query.py index 0625a75d..2585c02e 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -76,6 +76,9 @@ def optimizer_hints(self, s: str): def md5_as_int(self, s: str) -> str: raise NotImplementedError + def md5_as_hex(self, s: str) -> str: + raise NotImplementedError + def normalize_number(self, value: str, coltype: FractionalType) -> str: raise NotImplementedError