|
17 | 17 | import attrs
|
18 | 18 | from typing_extensions import Self
|
19 | 19 |
|
20 |
| -from data_diff.abcs.compiler import AbstractCompiler |
| 20 | +from data_diff.abcs.compiler import AbstractCompiler, Compilable |
21 | 21 | from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, Checksum, NormalizeAsString
|
22 | 22 | from data_diff.utils import ArithString, is_uuid, join_iter, safezip
|
23 | 23 | from data_diff.queries.api import Expr, table, Select, SKIP, Explain, Code, this
|
|
55 | 55 | )
|
56 | 56 | from data_diff.abcs.database_types import (
|
57 | 57 | Array,
|
| 58 | + ColType_UUID, |
| 59 | + FractionalType, |
58 | 60 | Struct,
|
59 | 61 | ColType,
|
60 | 62 | Integer,
|
|
73 | 75 | Boolean,
|
74 | 76 | JSON,
|
75 | 77 | )
|
76 |
| -from data_diff.abcs.mixins import Compilable |
77 |
| -from data_diff.abcs.mixins import ( |
78 |
| - AbstractMixin_NormalizeValue, |
79 |
| - AbstractMixin_OptimizerHints, |
80 |
| -) |
81 | 78 |
|
82 | 79 | logger = logging.getLogger("database")
|
83 | 80 | cv_params = contextvars.ContextVar("params")
|
@@ -198,12 +195,6 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal
|
198 | 195 | return callback(sql_code)
|
199 | 196 |
|
200 | 197 |
|
201 |
| -@attrs.define(frozen=False) |
202 |
| -class Mixin_OptimizerHints(AbstractMixin_OptimizerHints): |
203 |
| - def optimizer_hints(self, hints: str) -> str: |
204 |
| - return f"/*+ {hints} */ " |
205 |
| - |
206 |
| - |
207 | 198 | @attrs.define(frozen=False)
|
208 | 199 | class BaseDialect(abc.ABC):
|
209 | 200 | SUPPORTS_PRIMARY_KEY: ClassVar[bool] = False
|
@@ -771,6 +762,98 @@ def to_string(self, s: str) -> str:
|
771 | 762 | def set_timezone_to_utc(self) -> str:
|
772 | 763 | "Provide SQL for setting the session timezone to UTC"
|
773 | 764 |
|
| 765 | + @abstractmethod |
| 766 | + def md5_as_int(self, s: str) -> str: |
| 767 | + "Provide SQL for computing md5 and returning an int" |
| 768 | + |
| 769 | + @abstractmethod |
| 770 | + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: |
| 771 | + """Creates an SQL expression, that converts 'value' to a normalized timestamp. |
| 772 | +
|
| 773 | + The returned expression must accept any SQL datetime/timestamp, and return a string. |
| 774 | +
|
| 775 | + Date format: ``YYYY-MM-DD HH:mm:SS.FFFFFF`` |
| 776 | +
|
| 777 | + Precision of dates should be rounded up/down according to coltype.rounds |
| 778 | + """ |
| 779 | + |
| 780 | + @abstractmethod |
| 781 | + def normalize_number(self, value: str, coltype: FractionalType) -> str: |
| 782 | + """Creates an SQL expression, that converts 'value' to a normalized number. |
| 783 | +
|
| 784 | + The returned expression must accept any SQL int/numeric/float, and return a string. |
| 785 | +
|
| 786 | + Floats/Decimals are expected in the format |
| 787 | + "I.P" |
| 788 | +
|
| 789 | + Where I is the integer part of the number (as many digits as necessary), |
| 790 | + and must be at least one digit (0). |
| 791 | + P is the fractional digits, the amount of which is specified with |
| 792 | + coltype.precision. Trailing zeroes may be necessary. |
| 793 | + If P is 0, the dot is omitted. |
| 794 | +
|
| 795 | + Note: We use 'precision' differently than most databases. For decimals, |
| 796 | + it's the same as ``numeric_scale``, and for floats, who use binary precision, |
| 797 | + it can be calculated as ``log10(2**numeric_precision)``. |
| 798 | + """ |
| 799 | + |
| 800 | + def normalize_boolean(self, value: str, _coltype: Boolean) -> str: |
| 801 | + """Creates an SQL expression, that converts 'value' to either '0' or '1'.""" |
| 802 | + return self.to_string(value) |
| 803 | + |
| 804 | + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: |
| 805 | + """Creates an SQL expression, that strips uuids of artifacts like whitespace.""" |
| 806 | + if isinstance(coltype, String_UUID): |
| 807 | + return f"TRIM({value})" |
| 808 | + return self.to_string(value) |
| 809 | + |
| 810 | + def normalize_json(self, value: str, _coltype: JSON) -> str: |
| 811 | + """Creates an SQL expression, that converts 'value' to its minified json string representation.""" |
| 812 | + return self.to_string(value) |
| 813 | + |
| 814 | + def normalize_array(self, value: str, _coltype: Array) -> str: |
| 815 | + """Creates an SQL expression, that serialized an array into a JSON string.""" |
| 816 | + return self.to_string(value) |
| 817 | + |
| 818 | + def normalize_struct(self, value: str, _coltype: Struct) -> str: |
| 819 | + """Creates an SQL expression, that serialized a typed struct into a JSON string.""" |
| 820 | + return self.to_string(value) |
| 821 | + |
| 822 | + def normalize_value_by_type(self, value: str, coltype: ColType) -> str: |
| 823 | + """Creates an SQL expression, that converts 'value' to a normalized representation. |
| 824 | +
|
| 825 | + The returned expression must accept any SQL value, and return a string. |
| 826 | +
|
| 827 | + The default implementation dispatches to a method according to `coltype`: |
| 828 | +
|
| 829 | + :: |
| 830 | +
|
| 831 | + TemporalType -> normalize_timestamp() |
| 832 | + FractionalType -> normalize_number() |
| 833 | + *else* -> to_string() |
| 834 | +
|
| 835 | + (`Integer` falls in the *else* category) |
| 836 | +
|
| 837 | + """ |
| 838 | + if isinstance(coltype, TemporalType): |
| 839 | + return self.normalize_timestamp(value, coltype) |
| 840 | + elif isinstance(coltype, FractionalType): |
| 841 | + return self.normalize_number(value, coltype) |
| 842 | + elif isinstance(coltype, ColType_UUID): |
| 843 | + return self.normalize_uuid(value, coltype) |
| 844 | + elif isinstance(coltype, Boolean): |
| 845 | + return self.normalize_boolean(value, coltype) |
| 846 | + elif isinstance(coltype, JSON): |
| 847 | + return self.normalize_json(value, coltype) |
| 848 | + elif isinstance(coltype, Array): |
| 849 | + return self.normalize_array(value, coltype) |
| 850 | + elif isinstance(coltype, Struct): |
| 851 | + return self.normalize_struct(value, coltype) |
| 852 | + return self.to_string(value) |
| 853 | + |
| 854 | + def optimizer_hints(self, hints: str) -> str: |
| 855 | + return f"/*+ {hints} */ " |
| 856 | + |
774 | 857 |
|
775 | 858 | T = TypeVar("T", bound=BaseDialect)
|
776 | 859 |
|
@@ -966,10 +1049,7 @@ def _refine_coltypes(
|
966 | 1049 | if not text_columns:
|
967 | 1050 | return
|
968 | 1051 |
|
969 |
| - if isinstance(self.dialect, AbstractMixin_NormalizeValue): |
970 |
| - fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns] |
971 |
| - else: |
972 |
| - fields = this[text_columns] |
| 1052 | + fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns] |
973 | 1053 |
|
974 | 1054 | samples_by_row = self.query(
|
975 | 1055 | table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list
|
|
0 commit comments