Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 733972a

Browse files
committed
Queries: Ran black
1 parent 5cd424d commit 733972a

File tree

4 files changed

+29
-21
lines changed

4 files changed

+29
-21
lines changed

data_diff/queries/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ def leftjoin(*tables: ITable):
1515
"Left-joins each table into a 'struct'"
1616
return Join(tables, "LEFT")
1717

18+
1819
def rightjoin(*tables: ITable):
1920
"Right-joins each table into a 'struct'"
2021
return Join(tables, "RIGHT")
2122

23+
2224
def outerjoin(*tables: ITable):
2325
"Outer-joins each table into a 'struct'"
2426
return Join(tables, "FULL OUTER")

data_diff/queries/ast_classes.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from .base import SKIP, CompileError, DbPath, Schema, args_as_tuple
1111

1212

13-
1413
class ExprNode(Compilable):
1514
type: Any = None
1615

@@ -129,7 +128,7 @@ def __getitem__(self, column):
129128
def count(self):
130129
return Select(self, [Count()])
131130

132-
def union(self, other: 'ITable'):
131+
def union(self, other: "ITable"):
133132
return Union(self, other)
134133

135134

@@ -172,11 +171,13 @@ def compile(self, c: Compiler) -> str:
172171
args = ", ".join(c.compile(e) for e in self.args)
173172
return f"{self.name}({args})"
174173

174+
175175
def _expr_type(e: Expr):
176176
if isinstance(e, ExprNode):
177177
return e.type
178178
return type(e)
179179

180+
180181
@dataclass
181182
class CaseWhen(ExprNode):
182183
cases: Sequence[Tuple[Expr, Expr]]
@@ -190,12 +191,12 @@ def compile(self, c: Compiler) -> str:
190191

191192
@property
192193
def type(self):
193-
when_types = {_expr_type(w) for _c,w in self.cases }
194+
when_types = {_expr_type(w) for _c, w in self.cases}
194195
if self.else_:
195196
when_types |= _expr_type(self.else_)
196197
if len(when_types) > 1:
197198
raise RuntimeError(f"Non-matching types in when: {when_types}")
198-
t ,= when_types
199+
(t,) = when_types
199200
return t
200201

201202

@@ -252,6 +253,7 @@ def compile(self, c: Compiler) -> str:
252253
a, b = self.args
253254
return f"({c.compile(a)} {self.op} {c.compile(b)})"
254255

256+
255257
class BinBoolOp(BinOp):
256258
type = bool
257259

@@ -334,8 +336,8 @@ def source_table(self):
334336

335337
@property
336338
def schema(self):
337-
assert self.columns # TODO Implement SELECT *
338-
s = self.source_tables[0].schema # XXX
339+
assert self.columns # TODO Implement SELECT *
340+
s = self.source_tables[0].schema # XXX
339341
return type(s)({c.name: c.type for c in self.columns})
340342

341343
def on(self, *exprs):
@@ -390,6 +392,7 @@ class GroupBy(ITable):
390392
def having(self):
391393
pass
392394

395+
393396
@dataclass
394397
class Union(ExprNode, ITable):
395398
table1: ITable
@@ -441,7 +444,7 @@ def source_table(self):
441444
return self
442445

443446
def compile(self, parent_c: Compiler) -> str:
444-
c = parent_c.replace(in_select=True) #.add_table_context(self.table)
447+
c = parent_c.replace(in_select=True) # .add_table_context(self.table)
445448

446449
columns = ", ".join(map(c.compile, self.columns)) if self.columns else "*"
447450
select = f"SELECT {columns}"
@@ -547,7 +550,6 @@ def name(self):
547550
return self.resolved.name
548551

549552

550-
551553
class This:
552554
def __getattr__(self, name):
553555
return _ResolveColumn(name)
@@ -593,9 +595,11 @@ def compile(self, c: Compiler) -> str:
593595

594596
# DDL
595597

598+
596599
class Statement(Compilable):
597600
type = None
598601

602+
599603
def to_sql_type(t):
600604
if isinstance(t, str):
601605
return t
@@ -612,18 +616,20 @@ class CreateTable(Statement):
612616
if_not_exists: bool = False
613617

614618
def compile(self, c: Compiler) -> str:
615-
schema = ', '.join(f'{k} {to_sql_type(v)}' for k, v in self.path.schema.items())
616-
ne = 'IF NOT EXISTS ' if self.if_not_exists else ''
617-
return f'CREATE TABLE {ne}{c.compile(self.path)}({schema})'
619+
schema = ", ".join(f"{k} {to_sql_type(v)}" for k, v in self.path.schema.items())
620+
ne = "IF NOT EXISTS " if self.if_not_exists else ""
621+
return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})"
622+
618623

619624
@dataclass
620625
class DropTable(Statement):
621626
path: TablePath
622627
if_exists: bool = False
623628

624629
def compile(self, c: Compiler) -> str:
625-
ie = 'IF EXISTS ' if self.if_exists else ''
626-
return f'DROP TABLE {ie}{c.compile(self.path)}'
630+
ie = "IF EXISTS " if self.if_exists else ""
631+
return f"DROP TABLE {ie}{c.compile(self.path)}"
632+
627633

628634
@dataclass
629635
class InsertToTable(Statement):
@@ -632,4 +638,4 @@ class InsertToTable(Statement):
632638
expr: Expr
633639

634640
def compile(self, c: Compiler) -> str:
635-
return f'INSERT INTO {c.compile(self.path)} {c.compile(self.expr)}'
641+
return f"INSERT INTO {c.compile(self.path)} {c.compile(self.expr)}"

data_diff/queries/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class Compiler:
1414
database: AbstractDialect
1515
in_select: bool = False # Compilation runtime flag
16-
in_join: bool = False # Compilation runtime flag
16+
in_join: bool = False # Compilation runtime flag
1717

1818
_table_context: List = [] # List[ITable]
1919
_subqueries: Dict[str, Any] = {} # XXX not thread-safe
@@ -32,7 +32,7 @@ def compile(self, elem) -> str:
3232
return f"WITH {subq}\n{res}"
3333
return res
3434

35-
def _compile(self, elem) -> Union[str, 'ThreadLocalInterpreter']:
35+
def _compile(self, elem) -> Union[str, "ThreadLocalInterpreter"]:
3636
if elem is None:
3737
return "NULL"
3838
elif isinstance(elem, Compilable):

tests/test_query.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,12 @@ def test_schema(self):
9393
self.assertRaises(KeyError, q.__getitem__, "comment")
9494

9595
# test join
96-
s = CaseInsensitiveDict({'x': int, 'y': int})
96+
s = CaseInsensitiveDict({"x": int, "y": int})
9797
a = table("a", schema=s)
9898
b = table("b", schema=s)
9999
keys = ["x", "y"]
100-
j = outerjoin(a, b).on(a[k] == b[k] for k in keys).select(a['x'], b['y'], xsum=a['x'] + b['x'])
101-
j['x'], j['y'], j['xsum']
100+
j = outerjoin(a, b).on(a[k] == b[k] for k in keys).select(a["x"], b["y"], xsum=a["x"] + b["x"])
101+
j["x"], j["y"], j["xsum"]
102102
self.assertRaises(KeyError, j.__getitem__, "ysum")
103103

104104
def test_commutable_select(self):
@@ -145,8 +145,8 @@ def test_funcs(self):
145145

146146
def test_union_all(self):
147147
c = Compiler(MockDialect())
148-
a = table("a").select('x')
149-
b = table("b").select('y')
148+
a = table("a").select("x")
149+
b = table("b").select("y")
150150

151151
q = c.compile(a.union(b))
152152
assert q == "SELECT x FROM a UNION SELECT y FROM b"

0 commit comments

Comments
 (0)