Skip to content

Commit 149c699

Browse files
authored
fix: Fix filter_dfs to be recursive, add support for skip_dependent_tables (#39)
We were missing support for filtering of nested tables before, as well as `skip_dependent_tables`. This is essentially a direct port of the Go version (#28) Closes #28
1 parent 381f9ef commit 149c699

File tree

2 files changed

+171
-14
lines changed

2 files changed

+171
-14
lines changed

cloudquery/sdk/schema/table.py

Lines changed: 88 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3-
from typing import List, Generator, Any
3+
import copy
44
import fnmatch
5+
from typing import List
6+
57
import pyarrow as pa
68

79
from cloudquery.sdk.schema import arrow
@@ -87,17 +89,90 @@ def tables_to_arrow_schemas(tables: List[Table]):
8789

8890

8991
def filter_dfs(
90-
tables: List[Table], include_tables: List[str], skip_tables: List[str]
92+
tables: List[Table],
93+
include_tables: List[str],
94+
skip_tables: List[str],
95+
skip_dependent_tables: bool = False,
9196
) -> List[Table]:
92-
filtered: List[Table] = []
97+
flattened_tables = flatten_tables(tables)
98+
for include_pattern in include_tables:
99+
matched = any(
100+
fnmatch.fnmatch(table.name, include_pattern) for table in flattened_tables
101+
)
102+
if not matched:
103+
raise ValueError(
104+
f"tables include a pattern {include_pattern} with no matches"
105+
)
106+
107+
for exclude_pattern in skip_tables:
108+
matched = any(
109+
fnmatch.fnmatch(table.name, exclude_pattern) for table in flattened_tables
110+
)
111+
if not matched:
112+
raise ValueError(
113+
f"skip_tables include a pattern {exclude_pattern} with no matches"
114+
)
115+
116+
def include_func(t):
117+
return any(
118+
fnmatch.fnmatch(t.name, include_pattern)
119+
for include_pattern in include_tables
120+
)
121+
122+
def exclude_func(t):
123+
return any(
124+
fnmatch.fnmatch(t.name, exclude_pattern) for exclude_pattern in skip_tables
125+
)
126+
127+
return filter_dfs_func(tables, include_func, exclude_func, skip_dependent_tables)
128+
129+
130+
def filter_dfs_func(tt: List[Table], include, exclude, skip_dependent_tables: bool):
131+
filtered_tables = []
132+
for t in tt:
133+
filtered_table = copy.deepcopy(t)
134+
filtered_table = _filter_dfs_impl(
135+
filtered_table, False, include, exclude, skip_dependent_tables
136+
)
137+
if filtered_table is not None:
138+
filtered_tables.append(filtered_table)
139+
return filtered_tables
140+
141+
142+
def _filter_dfs_impl(t, parent_matched, include, exclude, skip_dependent_tables):
143+
def filter_dfs_child(r, matched, include, exclude, skip_dependent_tables):
144+
filtered_child = _filter_dfs_impl(
145+
r, matched, include, exclude, skip_dependent_tables
146+
)
147+
if filtered_child is not None:
148+
return True, r
149+
return matched, None
150+
151+
if exclude(t):
152+
return None
153+
154+
matched = parent_matched and not skip_dependent_tables
155+
if include(t):
156+
matched = True
157+
158+
filtered_relations = []
159+
for r in t.relations:
160+
matched, filtered_child = filter_dfs_child(
161+
r, matched, include, exclude, skip_dependent_tables
162+
)
163+
if filtered_child is not None:
164+
filtered_relations.append(filtered_child)
165+
166+
t.relations = filtered_relations
167+
168+
if matched:
169+
return t
170+
return None
171+
172+
173+
def flatten_tables(tables: List[Table]) -> List[Table]:
174+
flattened: List[Table] = []
93175
for table in tables:
94-
matched = False
95-
for include_table in include_tables:
96-
if fnmatch.fnmatch(table.name, include_table):
97-
matched = True
98-
for skip_table in skip_tables:
99-
if fnmatch.fnmatch(table.name, skip_table):
100-
matched = False
101-
if matched:
102-
filtered.append(table)
103-
return filtered
176+
flattened.append(table)
177+
flattened.extend(flatten_tables(table.relations))
178+
return flattened

tests/schema/test_table.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,90 @@
11
import pyarrow as pa
2+
import pytest
23

3-
from cloudquery.sdk.schema import Table, Column
4+
from cloudquery.sdk.schema import Table, Column, filter_dfs
5+
from cloudquery.sdk.schema.table import flatten_tables
46

57

68
def test_table():
79
table = Table("test_table", [Column("test_column", pa.int32())])
810
table.to_arrow_schema()
11+
12+
13+
def test_filter_dfs_warns_no_matches():
14+
with pytest.raises(ValueError):
15+
tables = [Table("test1", []), Table("test2", [])]
16+
filter_dfs(tables, include_tables=["test3"], skip_tables=[])
17+
18+
with pytest.raises(ValueError):
19+
tables = [Table("test1", []), Table("test2", [])]
20+
filter_dfs(tables, include_tables=["*"], skip_tables=["test3"])
21+
22+
23+
def test_filter_dfs():
24+
table_grandchild = Table("test_grandchild", [Column("test_column", pa.int32())])
25+
table_child = Table(
26+
"test_child",
27+
[Column("test_column", pa.int32())],
28+
relations=[
29+
table_grandchild,
30+
],
31+
)
32+
table_top1 = Table(
33+
"test_top1",
34+
[Column("test_column", pa.int32())],
35+
relations=[
36+
table_child,
37+
],
38+
)
39+
table_top2 = Table("test_top2", [Column("test_column", pa.int32())])
40+
41+
tables = [table_top1, table_top2]
42+
43+
cases = [
44+
{
45+
"include_tables": ["*"],
46+
"skip_tables": [],
47+
"skip_dependent_tables": False,
48+
"expect_top": ["test_top1", "test_top2"],
49+
"expect_flattened": [
50+
"test_top1",
51+
"test_top2",
52+
"test_child",
53+
"test_grandchild",
54+
],
55+
},
56+
{
57+
"include_tables": ["*"],
58+
"skip_tables": ["test_top1"],
59+
"skip_dependent_tables": False,
60+
"expect_top": ["test_top2"],
61+
"expect_flattened": ["test_top2"],
62+
},
63+
{
64+
"include_tables": ["test_top1"],
65+
"skip_tables": ["test_top2"],
66+
"skip_dependent_tables": True,
67+
"expect_top": ["test_top1"],
68+
"expect_flattened": ["test_top1"],
69+
},
70+
{
71+
"include_tables": ["test_child"],
72+
"skip_tables": [],
73+
"skip_dependent_tables": True,
74+
"expect_top": ["test_top1"],
75+
"expect_flattened": ["test_top1", "test_child"],
76+
},
77+
]
78+
for case in cases:
79+
got = filter_dfs(
80+
tables=tables,
81+
include_tables=case["include_tables"],
82+
skip_tables=case["skip_tables"],
83+
skip_dependent_tables=case["skip_dependent_tables"],
84+
)
85+
assert sorted([t.name for t in got]) == sorted(case["expect_top"]), case
86+
87+
got_flattened = flatten_tables(got)
88+
want_flattened = sorted(case["expect_flattened"])
89+
got_flattened = sorted([t.name for t in got_flattened])
90+
assert got_flattened == want_flattened, case

0 commit comments

Comments
 (0)