Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit b237fd8

Browse files
authored
Merge pull request #260 from datafold/oct20_queries
Various small fixes and refactors
2 parents 1fc52c2 + 601d2bb commit b237fd8

File tree

7 files changed

+68
-76
lines changed

7 files changed

+68
-76
lines changed

data_diff/databases/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ def _constant_value(self, v):
335335
elif isinstance(v, str):
336336
return f"'{v}'"
337337
elif isinstance(v, datetime):
338+
# TODO use self.timestamp_value
338339
return f"timestamp '{v}'"
339340
elif isinstance(v, UUID):
340341
return f"'{v}'"

data_diff/databases/database_types.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,12 @@ class AbstractDialect(ABC):
147147

148148
@abstractmethod
149149
def quote(self, s: str):
150-
"Quote SQL name (implementation specific)"
150+
"Quote SQL name"
151151
...
152152

153153
@abstractmethod
154154
def concat(self, l: List[str]) -> str:
155-
"Provide SQL for concatenating a bunch of column into a string"
155+
"Provide SQL for concatenating a bunch of columns into a string"
156156
...
157157

158158
@abstractmethod
@@ -162,12 +162,13 @@ def is_distinct_from(self, a: str, b: str) -> str:
162162

163163
@abstractmethod
164164
def to_string(self, s: str) -> str:
165+
# TODO rewrite using cast_to(x, str)
165166
"Provide SQL for casting a column to string"
166167
...
167168

168169
@abstractmethod
169170
def random(self) -> str:
170-
"Provide SQL for generating a random number"
171+
"Provide SQL for generating a random number betweein 0..1"
171172

172173
@abstractmethod
173174
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
@@ -176,7 +177,7 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None
176177

177178
@abstractmethod
178179
def explain_as_text(self, query: str) -> str:
179-
"Provide SQL for explaining a query, returned in as table(varchar)"
180+
"Provide SQL for explaining a query, returned as table(varchar)"
180181
...
181182

182183
@abstractmethod

data_diff/databases/oracle.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,4 +153,6 @@ def type_repr(self, t) -> str:
153153
return super().type_repr(t)
154154

155155
def constant_values(self, rows) -> str:
156-
return " UNION ALL ".join("SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows)
156+
return " UNION ALL ".join(
157+
"SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows
158+
)

data_diff/queries/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,14 @@ def or_(*exprs: Expr):
4747
exprs = args_as_tuple(exprs)
4848
if len(exprs) == 1:
4949
return exprs[0]
50-
return BinOp("OR", exprs)
50+
return BinBoolOp("OR", exprs)
5151

5252

5353
def and_(*exprs: Expr):
5454
exprs = args_as_tuple(exprs)
5555
if len(exprs) == 1:
5656
return exprs[0]
57-
return BinOp("AND", exprs)
57+
return BinBoolOp("AND", exprs)
5858

5959

6060
def sum_(expr: Expr):

data_diff/queries/ast_classes.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def cast_to(self, to):
3232
Expr = Union[ExprNode, str, bool, int, datetime, ArithString, None]
3333

3434

35-
def get_type(e: Expr) -> type:
35+
def _expr_type(e: Expr) -> type:
3636
if isinstance(e, ExprNode):
3737
return e.type
3838
return type(e)
@@ -48,7 +48,7 @@ def compile(self, c: Compiler) -> str:
4848

4949
@property
5050
def type(self):
51-
return get_type(self.expr)
51+
return _expr_type(self.expr)
5252

5353

5454
def _drop_skips(exprs):
@@ -156,6 +156,8 @@ class Count(ExprNode):
156156
expr: Expr = "*"
157157
distinct: bool = False
158158

159+
type = int
160+
159161
def compile(self, c: Compiler) -> str:
160162
expr = c.compile(self.expr)
161163
if self.distinct:
@@ -174,12 +176,6 @@ def compile(self, c: Compiler) -> str:
174176
return f"{self.name}({args})"
175177

176178

177-
def _expr_type(e: Expr):
178-
if isinstance(e, ExprNode):
179-
return e.type
180-
return type(e)
181-
182-
183179
@dataclass
184180
class CaseWhen(ExprNode):
185181
cases: Sequence[Tuple[Expr, Expr]]
@@ -226,6 +222,9 @@ def __le__(self, other):
226222
def __or__(self, other):
227223
return BinBoolOp("OR", [self, other])
228224

225+
def __and__(self, other):
226+
return BinBoolOp("AND", [self, other])
227+
229228
def is_distinct_from(self, other):
230229
return IsDistinctFrom(self, other)
231230

@@ -254,7 +253,7 @@ def compile(self, c: Compiler) -> str:
254253

255254
@property
256255
def type(self):
257-
types = {get_type(i) for i in self.args}
256+
types = {_expr_type(i) for i in self.args}
258257
if len(types) > 1:
259258
raise TypeError(f"Expected all args to have the same type, got {types}")
260259
(t,) = types
@@ -298,6 +297,16 @@ class TablePath(ExprNode, ITable):
298297
path: DbPath
299298
schema: Optional[Schema] = field(default=None, repr=False)
300299

300+
@property
301+
def source_table(self):
302+
return self
303+
304+
def compile(self, c: Compiler) -> str:
305+
path = self.path # c.database._normalize_table_path(self.name)
306+
return ".".join(map(c.quote, path))
307+
308+
# Statement shorthands
309+
301310
def create(self, source_table: ITable = None, *, if_not_exists=False):
302311
if source_table is None and not self.schema:
303312
raise ValueError("Either schema or source table needed to create table")
@@ -323,14 +332,6 @@ def insert_expr(self, expr: Expr):
323332
expr = expr.select()
324333
return InsertToTable(self, expr)
325334

326-
@property
327-
def source_table(self):
328-
return self
329-
330-
def compile(self, c: Compiler) -> str:
331-
path = self.path # c.database._normalize_table_path(self.name)
332-
return ".".join(map(c.quote, path))
333-
334335

335336
@dataclass
336337
class TableAlias(ExprNode, ITable):
@@ -386,7 +387,7 @@ def compile(self, parent_c: Compiler) -> str:
386387
tables = [
387388
t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in self.source_tables
388389
]
389-
c = parent_c.add_table_context(*tables).replace(in_join=True, in_select=False)
390+
c = parent_c.add_table_context(*tables, in_join=True, in_select=False)
390391
op = " JOIN " if self.op is None else f" {self.op} JOIN "
391392
joined = op.join(c.compile(t) for t in tables)
392393

@@ -408,7 +409,7 @@ def compile(self, parent_c: Compiler) -> str:
408409

409410
class GroupBy(ITable):
410411
def having(self):
411-
pass
412+
raise NotImplementedError()
412413

413414

414415
@dataclass
@@ -546,26 +547,26 @@ class _ResolveColumn(ExprNode, LazyOps):
546547
resolve_name: str
547548
resolved: Expr = None
548549

549-
def resolve(self, expr):
550-
assert self.resolved is None
550+
def resolve(self, expr: Expr):
551+
if self.resolved is not None:
552+
raise RuntimeError("Already resolved!")
551553
self.resolved = expr
552554

553-
def compile(self, c: Compiler) -> str:
555+
def _get_resolved(self) -> Expr:
554556
if self.resolved is None:
555557
raise RuntimeError(f"Column not resolved: {self.resolve_name}")
556-
return self.resolved.compile(c)
558+
return self.resolved
559+
560+
def compile(self, c: Compiler) -> str:
561+
return self._get_resolved().compile(c)
557562

558563
@property
559564
def type(self):
560-
if self.resolved is None:
561-
raise RuntimeError(f"Column not resolved: {self.resolve_name}")
562-
return self.resolved.type
565+
return self._get_resolved().type
563566

564567
@property
565568
def name(self):
566-
if self.resolved is None:
567-
raise RuntimeError(f"Column not resolved: {self.name}")
568-
return self.resolved.name
569+
return self._get_resolved().name
569570

570571

571572
class This:
@@ -583,6 +584,8 @@ class In(ExprNode):
583584
expr: Expr
584585
list: Sequence[Expr]
585586

587+
type = bool
588+
586589
def compile(self, c: Compiler):
587590
elems = ", ".join(map(c.compile, self.list))
588591
return f"({c.compile(self.expr)} IN ({elems}))"
@@ -599,6 +602,8 @@ def compile(self, c: Compiler) -> str:
599602

600603
@dataclass
601604
class Random(ExprNode):
605+
type = float
606+
602607
def compile(self, c: Compiler) -> str:
603608
return c.database.random()
604609

@@ -618,6 +623,8 @@ def compile_for_insert(self, c: Compiler):
618623
class Explain(ExprNode):
619624
select: Select
620625

626+
type = str
627+
621628
def compile(self, c: Compiler) -> str:
622629
return c.database.explain_as_text(c.compile(self.select))
623630

@@ -640,7 +647,7 @@ def compile(self, c: Compiler) -> str:
640647
if self.source_table:
641648
return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}"
642649

643-
schema = ", ".join(f"{c.database.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items())
650+
schema = ", ".join(f"{c.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items())
644651
return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})"
645652

646653

data_diff/queries/compiler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@ class Compiler:
2121

2222
_counter: List = [0]
2323

24-
def quote(self, s: str):
25-
return self.database.quote(s)
26-
2724
def compile(self, elem) -> str:
2825
res = self._compile(elem)
2926
if self.root and self._subqueries:
@@ -57,8 +54,11 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath:
5754
self._counter[0] += 1
5855
return self.database.parse_table_name(f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}")
5956

60-
def add_table_context(self, *tables: Sequence):
61-
return self.replace(_table_context=self._table_context + list(tables))
57+
def add_table_context(self, *tables: Sequence, **kw):
58+
return self.replace(_table_context=self._table_context + list(tables), **kw)
59+
60+
def quote(self, s: str):
61+
return self.database.quote(s)
6262

6363

6464
class Compilable(ABC):

data_diff/utils.py

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import logging
22
import re
33
import math
4-
from typing import Iterable, Tuple, Union, Any, Sequence, Dict
5-
from typing import TypeVar, Generic
4+
from typing import Iterable, Iterator, MutableMapping, Union, Any, Sequence, Dict
5+
from typing import TypeVar
66
from abc import ABC, abstractmethod
77
from urllib.parse import urlparse
88
from uuid import UUID
@@ -204,58 +204,39 @@ def join_iter(joiner: Any, iterable: Iterable) -> Iterable:
204204
V = TypeVar("V")
205205

206206

207-
class CaseAwareMapping(ABC, Generic[V]):
207+
class CaseAwareMapping(MutableMapping[str, V]):
208208
@abstractmethod
209209
def get_key(self, key: str) -> str:
210210
...
211211

212-
@abstractmethod
213-
def __getitem__(self, key: str) -> V:
214-
...
215-
216-
@abstractmethod
217-
def __setitem__(self, key: str, value: V):
218-
...
219-
220-
@abstractmethod
221-
def __contains__(self, key: str) -> bool:
222-
...
223-
224-
def __repr__(self):
225-
return repr(dict(self.items()))
226-
227-
@abstractmethod
228-
def items(self) -> Iterable[Tuple[str, V]]:
229-
...
230-
231212

232213
class CaseInsensitiveDict(CaseAwareMapping):
233214
def __init__(self, initial):
234215
self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()}
235216

236-
def get_key(self, key: str) -> str:
237-
return self._dict[key.lower()][0]
238-
239217
def __getitem__(self, key: str) -> V:
240218
return self._dict[key.lower()][1]
241219

220+
def __iter__(self) -> Iterator[V]:
221+
return iter(self._dict)
222+
223+
def __len__(self) -> int:
224+
return len(self._dict)
225+
242226
def __setitem__(self, key: str, value):
243227
k = key.lower()
244228
if k in self._dict:
245229
key = self._dict[k][0]
246230
self._dict[k] = key, value
247231

248-
def __contains__(self, key):
249-
return key.lower() in self._dict
250-
251-
def keys(self) -> Iterable[str]:
252-
return self._dict.keys()
232+
def __delitem__(self, key: str):
233+
del self._dict[key.lower()]
253234

254-
def items(self) -> Iterable[Tuple[str, V]]:
255-
return ((k, v[1]) for k, v in self._dict.items())
235+
def get_key(self, key: str) -> str:
236+
return self._dict[key.lower()][0]
256237

257-
def __len__(self):
258-
return len(self._dict)
238+
def __repr__(self) -> str:
239+
return repr(dict(self.items()))
259240

260241

261242
class CaseSensitiveDict(dict, CaseAwareMapping):

0 commit comments

Comments
 (0)