Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 47b9faa

Browse files
committed
Cleanup and minor fixes (pylint pass)
1 parent e8965fd commit 47b9faa

22 files changed

+203
-140
lines changed

data_diff/__init__.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,24 +70,27 @@ def diff_tables(
7070
7171
Parameters:
7272
key_columns (Tuple[str, ...]): Name of the key column, which uniquely identifies each row (usually id)
73-
update_column (str, optional): Name of updated column, which signals that rows changed (usually updated_at or last_update).
74-
Used by `min_update` and `max_update`.
73+
update_column (str, optional): Name of updated column, which signals that rows changed.
74+
Usually updated_at or last_update. Used by `min_update` and `max_update`.
7575
extra_columns (Tuple[str, ...], optional): Extra columns to compare
7676
min_key (:data:`DbKey`, optional): Lowest key value, used to restrict the segment
7777
max_key (:data:`DbKey`, optional): Highest key value, used to restrict the segment
7878
min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment
7979
max_update (:data:`DbTime`, optional): Highest update_column value, used to restrict the segment
8080
algorithm (:class:`Algorithm`): Which diffing algorithm to use (`HASHDIFF` or `JOINDIFF`)
81-
bisection_factor (int): Into how many segments to bisect per iteration. (when algorithm is `HASHDIFF`)
82-
bisection_threshold (Number): When should we stop bisecting and compare locally (when algorithm is `HASHDIFF`; in row count).
81+
bisection_factor (int): Into how many segments to bisect per iteration. (Used when algorithm is `HASHDIFF`)
82+
bisection_threshold (Number): Minimal row count of segment to bisect, otherwise download
83+
and compare locally. (Used when algorithm is `HASHDIFF`).
8384
threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads.
84-
max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``.
85+
max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto.
86+
Only relevant when `threaded` is ``True``.
8587
There may be many pools, so number of actual threads can be a lot higher.
8688
8789
Note:
8890
The following parameters are used to override the corresponding attributes of the given :class:`TableSegment` instances:
89-
`key_columns`, `update_column`, `extra_columns`, `min_key`, `max_key`. If different values are needed per table, it's
90-
possible to omit them here, and instead set them directly when creating each :class:`TableSegment`.
91+
`key_columns`, `update_column`, `extra_columns`, `min_key`, `max_key`.
92+
If different values are needed per table, it's possible to omit them here, and instead set
93+
them directly when creating each :class:`TableSegment`.
9194
9295
Example:
9396
>>> table1 = connect_to_table('postgresql:///', 'Rating', 'id')

data_diff/databases/base.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from abc import abstractmethod
99

1010
from data_diff.utils import is_uuid, safezip
11+
from data_diff.queries import Expr, Compiler, table, Select, SKIP, Explain
1112
from .database_types import (
1213
AbstractDatabase,
1314
ColType,
@@ -27,8 +28,6 @@
2728
DbPath,
2829
)
2930

30-
from data_diff.queries import Expr, Compiler, table, Select, SKIP, Explain
31-
3231
logger = logging.getLogger("database")
3332

3433

@@ -110,6 +109,8 @@ class Database(AbstractDatabase):
110109
default_schema: str = None
111110
SUPPORTS_ALPHANUMS = True
112111

112+
_interactive = False
113+
113114
@property
114115
def name(self):
115116
return type(self).__name__
@@ -126,11 +127,14 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list):
126127
return SKIP
127128

128129
logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code)
129-
if getattr(self, "_interactive", False) and isinstance(sql_ast, Select):
130+
if self._interactive and isinstance(sql_ast, Select):
130131
explained_sql = compiler.compile(Explain(sql_ast))
131132
explain = self._query(explained_sql)
132-
for (row,) in explain:
133-
logger.debug(f"EXPLAIN: {row}")
133+
for row in explain:
134+
# Most returned a 1-tuple. Presto returns a string
135+
if isinstance(row, tuple):
136+
row ,= row
137+
logger.debug("EXPLAIN: %s", row)
134138
answer = input("Continue? [y/n] ")
135139
if not answer.lower() in ["y", "yes"]:
136140
sys.exit(1)

data_diff/databases/bigquery.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from .database_types import *
1+
from typing import Union
2+
from .database_types import Timestamp, Datetime, Integer, Decimal, Float, Text, DbPath, FractionalType, TemporalType
23
from .base import Database, import_helper, parse_table_name, ConnectError, apply_query
34
from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter
45

data_diff/databases/database_types.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import decimal
3-
from abc import ABC, abstractmethod, abstractproperty
3+
from abc import ABC, abstractmethod
44
from typing import Sequence, Optional, Tuple, Union, Dict, List
55
from datetime import datetime
66

@@ -234,30 +234,6 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
234234
"""
235235
...
236236

237-
def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
238-
"""Creates an SQL expression, that converts 'value' to a normalized representation.
239-
240-
The returned expression must accept any SQL value, and return a string.
241-
242-
The default implementation dispatches to a method according to `coltype`:
243-
244-
::
245-
246-
TemporalType -> normalize_timestamp()
247-
FractionalType -> normalize_number()
248-
*else* -> to_string()
249-
250-
(`Integer` falls in the *else* category)
251-
252-
"""
253-
if isinstance(coltype, TemporalType):
254-
return self.normalize_timestamp(value, coltype)
255-
elif isinstance(coltype, FractionalType):
256-
return self.normalize_number(value, coltype)
257-
elif isinstance(coltype, ColType_UUID):
258-
return self.normalize_uuid(value, coltype)
259-
return self.to_string(value)
260-
261237

262238
class AbstractDatabase(AbstractDialect, AbstractDatadiffDialect):
263239
@abstractmethod
@@ -304,10 +280,35 @@ def close(self):
304280
def _normalize_table_path(self, path: DbPath) -> DbPath:
305281
...
306282

307-
@abstractproperty
283+
@property
284+
@abstractmethod
308285
def is_autocommit(self) -> bool:
309286
...
310287

288+
def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
289+
"""Creates an SQL expression, that converts 'value' to a normalized representation.
290+
291+
The returned expression must accept any SQL value, and return a string.
292+
293+
The default implementation dispatches to a method according to `coltype`:
294+
295+
::
296+
297+
TemporalType -> normalize_timestamp()
298+
FractionalType -> normalize_number()
299+
*else* -> to_string()
300+
301+
(`Integer` falls in the *else* category)
302+
303+
"""
304+
if isinstance(coltype, TemporalType):
305+
return self.normalize_timestamp(value, coltype)
306+
elif isinstance(coltype, FractionalType):
307+
return self.normalize_number(value, coltype)
308+
elif isinstance(coltype, ColType_UUID):
309+
return self.normalize_uuid(value, coltype)
310+
return self.to_string(value)
311+
311312

312313
Schema = CaseAwareMapping
313314

data_diff/databases/databricks.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
1+
from typing import Dict, Sequence
12
import logging
23

3-
from .database_types import *
4+
from .database_types import (
5+
Integer,
6+
Float,
7+
Decimal,
8+
Timestamp,
9+
Text,
10+
TemporalType,
11+
NumericType,
12+
DbPath,
13+
ColType,
14+
UnknownColType,
15+
)
416
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Database, import_helper, parse_table_name
517

618

data_diff/databases/mysql.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1-
from .database_types import *
1+
from .database_types import (
2+
Datetime,
3+
Timestamp,
4+
Float,
5+
Decimal,
6+
Integer,
7+
Text,
8+
TemporalType,
9+
FractionalType,
10+
ColType_UUID,
11+
)
212
from .base import ThreadedDatabase, import_helper, ConnectError
313
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS
414

data_diff/databases/oracle.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
1+
from typing import Dict, List, Optional
2+
13
from ..utils import match_regexps
24

3-
from .database_types import *
5+
from .database_types import (
6+
Decimal,
7+
Float,
8+
Text,
9+
DbPath,
10+
TemporalType,
11+
ColType,
12+
DbTime,
13+
ColType_UUID,
14+
Timestamp,
15+
TimestampTZ,
16+
FractionalType,
17+
)
418
from .base import ThreadedDatabase, import_helper, ConnectError, QueryError
519
from .base import TIMESTAMP_PRECISION_POS
620

data_diff/databases/postgresql.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1-
from .database_types import *
1+
from .database_types import (
2+
Timestamp,
3+
TimestampTZ,
4+
Float,
5+
Decimal,
6+
Integer,
7+
TemporalType,
8+
Native_UUID,
9+
Text,
10+
FractionalType,
11+
)
212
from .base import ThreadedDatabase, import_helper, ConnectError
313
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, _CHECKSUM_BITSIZE, TIMESTAMP_PRECISION_POS
414

data_diff/databases/presto.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,19 @@
33

44
from data_diff.utils import match_regexps
55

6-
from .database_types import *
6+
from .database_types import (
7+
Timestamp,
8+
TimestampTZ,
9+
Integer,
10+
Float,
11+
Text,
12+
FractionalType,
13+
DbPath,
14+
Decimal,
15+
ColType,
16+
ColType_UUID,
17+
TemporalType,
18+
)
719
from .base import Database, import_helper, ThreadLocalInterpreter
820
from .base import (
921
MD5_HEXDIGITS,
@@ -17,7 +29,7 @@ def query_cursor(c, sql_code):
1729
if sql_code.lower().startswith("select"):
1830
return c.fetchall()
1931
# Required for the query to actually run 🤯
20-
if re.match(r"(insert|create|truncate|drop)", sql_code, re.IGNORECASE):
32+
if re.match(r"(insert|create|truncate|drop|explain)", sql_code, re.IGNORECASE):
2133
return c.fetchone()
2234

2335

@@ -98,7 +110,7 @@ def select_table_schema(self, path: DbPath) -> str:
98110
schema, table = self._normalize_table_path(path)
99111

100112
return (
101-
"SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision "
113+
"SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision, NULL as numeric_scale "
102114
"FROM INFORMATION_SCHEMA.COLUMNS "
103115
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
104116
)
@@ -110,6 +122,7 @@ def _parse_type(
110122
type_repr: str,
111123
datetime_precision: int = None,
112124
numeric_precision: int = None,
125+
numeric_scale: int = None,
113126
) -> ColType:
114127
timestamp_regexps = {
115128
r"timestamp\((\d)\)": Timestamp,
@@ -134,5 +147,9 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
134147
# Trim doesn't work on CHAR type
135148
return f"TRIM(CAST({value} AS VARCHAR))"
136149

150+
@property
137151
def is_autocommit(self) -> bool:
138152
return False
153+
154+
def explain_as_text(self, query: str) -> str:
155+
return f"EXPLAIN (FORMAT TEXT) {query}"

data_diff/databases/redshift.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from .database_types import *
1+
from typing import List
2+
from .database_types import Float, TemporalType, FractionalType, DbPath
23
from .postgresql import PostgreSQL, MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS
34

45

data_diff/databases/snowflake.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from typing import Union
12
import logging
23

3-
from .database_types import *
4+
from .database_types import Timestamp, TimestampTZ, Decimal, Float, Text, FractionalType, TemporalType, DbPath
45
from .base import ConnectError, Database, import_helper, CHECKSUM_MASK, ThreadLocalInterpreter
56

67

@@ -88,6 +89,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
8889
def normalize_number(self, value: str, coltype: FractionalType) -> str:
8990
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")
9091

92+
@property
9193
def is_autocommit(self) -> bool:
9294
return True
9395

data_diff/databases/trino.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .database_types import *
1+
from .database_types import TemporalType, ColType_UUID
22
from .presto import Presto
33
from .base import import_helper
44
from .base import TIMESTAMP_PRECISION_POS

data_diff/diff_tables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def _run_in_background(self, *funcs):
7878

7979
class TableDiffer(ThreadBase, ABC):
8080
bisection_factor = 32
81+
stats: dict = {}
8182

8283
def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
8384
"""Diff the given tables.
@@ -177,7 +178,6 @@ def _bisect_and_diff_tables(self, table1, table2):
177178
table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)]
178179

179180
logger.info(
180-
# f"Diffing tables | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}. "
181181
f"Diffing segments at key-range: {table1.min_key}..{table2.max_key}. "
182182
f"size: table1 <= {table1.approximate_size()}, table2 <= {table2.approximate_size()}"
183183
)

0 commit comments

Comments
 (0)