Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 119ade5

Browse files
authored
Merge pull request #756 from pik94/support-top-operator
feat: support TOP operator
2 parents c8a598b + 6da7db7 commit 119ade5

File tree

4 files changed

+36
-18
lines changed

4 files changed

+36
-18
lines changed

data_diff/databases/base.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def render_select(self, parent_c: Compiler, elem: Select) -> str:
493493

494494
if elem.limit_expr is not None:
495495
has_order_by = bool(elem.order_by_exprs)
496-
select += " " + self.offset_limit(0, elem.limit_expr, has_order_by=has_order_by)
496+
select = self.limit_select(select_query=select, offset=0, limit=elem.limit_expr, has_order_by=has_order_by)
497497

498498
if parent_c.in_select:
499499
select = f"({select}) {c.new_unique_name()}"
@@ -605,14 +605,17 @@ def render_inserttotable(self, c: Compiler, elem: InsertToTable) -> str:
605605

606606
return f"INSERT INTO {self.compile(c, elem.path)}{columns} {expr}"
607607

608-
def offset_limit(
609-
self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None
608+
def limit_select(
609+
self,
610+
select_query: str,
611+
offset: Optional[int] = None,
612+
limit: Optional[int] = None,
613+
has_order_by: Optional[bool] = None,
610614
) -> str:
611-
"Provide SQL fragment for limit and offset inside a select"
612615
if offset:
613616
raise NotImplementedError("No support for OFFSET in query")
614617

615-
return f"LIMIT {limit}"
618+
return f"SELECT * FROM ({select_query}) AS LIMITED_SELECT LIMIT {limit}"
616619

617620
def concat(self, items: List[str]) -> str:
618621
"Provide SQL for concatenating a bunch of columns into a string"
@@ -1103,7 +1106,7 @@ def _query_cursor(self, c, sql_code: str) -> QueryResult:
11031106
return result
11041107
except Exception as _e:
11051108
# logger.exception(e)
1106-
# logger.error(f'Caused by SQL: {sql_code}')
1109+
# logger.error(f"Caused by SQL: {sql_code}")
11071110
raise
11081111

11091112
def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult:

data_diff/databases/mssql.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,12 @@ def is_distinct_from(self, a: str, b: str) -> str:
110110
# See: https://stackoverflow.com/a/18684859/857383
111111
return f"(({a}<>{b} OR {a} IS NULL OR {b} IS NULL) AND NOT({a} IS NULL AND {b} IS NULL))"
112112

113-
def offset_limit(
114-
self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None
113+
def limit_select(
114+
self,
115+
select_query: str,
116+
offset: Optional[int] = None,
117+
limit: Optional[int] = None,
118+
has_order_by: Optional[bool] = None,
115119
) -> str:
116120
if offset:
117121
raise NotImplementedError("No support for OFFSET in query")
@@ -121,7 +125,7 @@ def offset_limit(
121125
result += "ORDER BY 1"
122126

123127
result += f" OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY"
124-
return result
128+
return f"SELECT * FROM ({select_query}) AS LIMITED_SELECT {result}"
125129

126130
def constant_values(self, rows) -> str:
127131
values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows)

data_diff/databases/oracle.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,17 @@ def quote(self, s: str):
6464
def to_string(self, s: str):
6565
return f"cast({s} as varchar(1024))"
6666

67-
def offset_limit(
68-
self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None
67+
def limit_select(
68+
self,
69+
select_query: str,
70+
offset: Optional[int] = None,
71+
limit: Optional[int] = None,
72+
has_order_by: Optional[bool] = None,
6973
) -> str:
7074
if offset:
7175
raise NotImplementedError("No support for OFFSET in query")
7276

73-
return f"FETCH NEXT {limit} ROWS ONLY"
77+
return f"SELECT * FROM ({select_query}) FETCH NEXT {limit} ROWS ONLY"
7478

7579
def concat(self, items: List[str]) -> str:
7680
joined_exprs = " || ".join(items)

tests/test_query.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,16 @@ def current_database(self) -> str:
5050
def current_schema(self) -> str:
5151
return "current_schema()"
5252

53-
def offset_limit(
54-
self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None
53+
def limit_select(
54+
self,
55+
select_query: str,
56+
offset: Optional[int] = None,
57+
limit: Optional[int] = None,
58+
has_order_by: Optional[bool] = None,
5559
) -> str:
5660
x = offset and f"OFFSET {offset}", limit and f"LIMIT {limit}"
57-
return " ".join(filter(None, x))
61+
result = " ".join(filter(None, x))
62+
return f"SELECT * FROM ({select_query}) AS LIMITED_SELECT {result}"
5863

5964
def explain_as_text(self, query: str) -> str:
6065
return f"explain {query}"
@@ -192,7 +197,7 @@ def test_funcs(self):
192197
t = table("a")
193198

194199
q = c.compile(t.order_by(Random()).limit(10))
195-
self.assertEqual(q, "SELECT * FROM a ORDER BY random() LIMIT 10")
200+
self.assertEqual(q, "SELECT * FROM (SELECT * FROM a ORDER BY random()) AS LIMITED_SELECT LIMIT 10")
196201

197202
q = c.compile(t.select(coalesce(this.a, this.b)))
198203
self.assertEqual(q, "SELECT COALESCE(a, b) FROM a")
@@ -210,7 +215,7 @@ def test_select_distinct(self):
210215

211216
# selects stay apart
212217
q = c.compile(t.limit(10).select(this.b, distinct=True))
213-
self.assertEqual(q, "SELECT DISTINCT b FROM (SELECT * FROM a LIMIT 10) tmp1")
218+
self.assertEqual(q, "SELECT DISTINCT b FROM (SELECT * FROM (SELECT * FROM a) AS LIMITED_SELECT LIMIT 10) tmp1")
214219

215220
q = c.compile(t.select(this.b, distinct=True).select(distinct=False))
216221
self.assertEqual(q, "SELECT * FROM (SELECT DISTINCT b FROM a) tmp2")
@@ -226,7 +231,9 @@ def test_select_with_optimizer_hints(self):
226231
self.assertEqual(q, "SELECT /*+ PARALLEL(a 16) */ b FROM a WHERE (b > 10)")
227232

228233
q = c.compile(t.limit(10).select(this.b, optimizer_hints="PARALLEL(a 16)"))
229-
self.assertEqual(q, "SELECT /*+ PARALLEL(a 16) */ b FROM (SELECT * FROM a LIMIT 10) tmp1")
234+
self.assertEqual(
235+
q, "SELECT /*+ PARALLEL(a 16) */ b FROM (SELECT * FROM (SELECT * FROM a) AS LIMITED_SELECT LIMIT 10) tmp1"
236+
)
230237

231238
q = c.compile(t.select(this.a).group_by(this.b).agg(this.c).select(optimizer_hints="PARALLEL(a 16)"))
232239
self.assertEqual(

0 commit comments

Comments
 (0)