From d9d54127c24a8f5b2faf8d795c891f9f7bf68808 Mon Sep 17 00:00:00 2001 From: Herman Schaaf Date: Thu, 10 Aug 2023 16:13:30 +0100 Subject: [PATCH 1/2] Fix filter_dfs to be recursive, add support for skip_dependent_tables --- cloudquery/sdk/schema/table.py | 98 ++++++++++++++++++++++++++-------- tests/schema/test_table.py | 71 +++++++++++++++++++++++- 2 files changed, 147 insertions(+), 22 deletions(-) diff --git a/cloudquery/sdk/schema/table.py b/cloudquery/sdk/schema/table.py index d0b56ef..19c67bf 100644 --- a/cloudquery/sdk/schema/table.py +++ b/cloudquery/sdk/schema/table.py @@ -1,7 +1,9 @@ from __future__ import annotations -from typing import List, Generator, Any +import copy import fnmatch +from typing import List + import pyarrow as pa from cloudquery.sdk.schema import arrow @@ -14,14 +16,14 @@ class Client: class Table: def __init__( - self, - name: str, - columns: List[Column], - title: str = "", - description: str = "", - parent: Table = None, - relations: List[Table] = None, - is_incremental: bool = False, + self, + name: str, + columns: List[Column], + title: str = "", + description: str = "", + parent: Table = None, + relations: List[Table] = None, + is_incremental: bool = False, ) -> None: self.name = name self.columns = columns @@ -87,17 +89,71 @@ def tables_to_arrow_schemas(tables: List[Table]): def filter_dfs( - tables: List[Table], include_tables: List[str], skip_tables: List[str] + tables: List[Table], + include_tables: List[str], + skip_tables: List[str], + skip_dependent_tables: bool = False ) -> List[Table]: - filtered: List[Table] = [] + flattened_tables = flatten_tables(tables) + for include_pattern in include_tables: + matched = any(fnmatch.fnmatch(table.name, include_pattern) for table in flattened_tables) + if not matched: + raise ValueError(f"tables include a pattern {include_pattern} with no matches") + + for exclude_pattern in skip_tables: + matched = any(fnmatch.fnmatch(table.name, exclude_pattern) for table in flattened_tables) + if not matched: + raise ValueError(f"skip_tables include a pattern {exclude_pattern} with no matches") + + def include_func(t): + return any(fnmatch.fnmatch(t.name, include_pattern) for include_pattern in include_tables) + + def exclude_func(t): + return any(fnmatch.fnmatch(t.name, exclude_pattern) for exclude_pattern in skip_tables) + + return filter_dfs_func(tables, include_func, exclude_func, skip_dependent_tables) + + +def filter_dfs_func(tt: List[Table], include, exclude, skip_dependent_tables: bool): + filtered_tables = [] + for t in tt: + filtered_table = copy.deepcopy(t) + filtered_table = _filter_dfs_impl(filtered_table, False, include, exclude, skip_dependent_tables) + if filtered_table is not None: + filtered_tables.append(filtered_table) + return filtered_tables + + +def _filter_dfs_impl(t, parent_matched, include, exclude, skip_dependent_tables): + def filter_dfs_child(r, matched, include, exclude, skip_dependent_tables): + filtered_child = _filter_dfs_impl(r, matched, include, exclude, skip_dependent_tables) + if filtered_child is not None: + return True, r + return matched, None + + if exclude(t): + return None + + matched = parent_matched and not skip_dependent_tables + if include(t): + matched = True + + filtered_relations = [] + for r in t.relations: + matched, filtered_child = filter_dfs_child(r, matched, include, exclude, skip_dependent_tables) + if filtered_child is not None: + filtered_relations.append(filtered_child) + + t.relations = filtered_relations + + if matched: + return t + return None + + +def flatten_tables(tables: List[Table]) -> List[Table]: + flattened: List[Table] = [] for table in tables: - matched = False - for include_table in include_tables: - if fnmatch.fnmatch(table.name, include_table): - matched = True - for skip_table in skip_tables: - if fnmatch.fnmatch(table.name, skip_table): - matched = False - if matched: - filtered.append(table) - return filtered + flattened.append(table) + flattened.extend(flatten_tables(table.relations)) + return flattened diff --git a/tests/schema/test_table.py b/tests/schema/test_table.py index 9958ad7..85e3b97 100644 --- a/tests/schema/test_table.py +++ b/tests/schema/test_table.py @@ -1,8 +1,77 @@ import pyarrow as pa +import pytest -from cloudquery.sdk.schema import Table, Column +from cloudquery.sdk.schema import Table, Column, filter_dfs +from cloudquery.sdk.schema.table import flatten_tables def test_table(): table = Table("test_table", [Column("test_column", pa.int32())]) table.to_arrow_schema() + + +def test_filter_dfs_warns_no_matches(): + with pytest.raises(ValueError): + tables = [Table("test1", []), Table("test2", [])] + filter_dfs(tables, include_tables=["test3"], skip_tables=[]) + + with pytest.raises(ValueError): + tables = [Table("test1", []), Table("test2", [])] + filter_dfs(tables, include_tables=["*"], skip_tables=["test3"]) + + +def test_filter_dfs(): + table_grandchild = Table("test_grandchild", [Column("test_column", pa.int32())]) + table_child = Table("test_child", [Column("test_column", pa.int32())], relations=[ + table_grandchild, + ]) + table_top1 = Table("test_top1", [Column("test_column", pa.int32())], relations=[ + table_child, + ]) + table_top2 = Table("test_top2", [Column("test_column", pa.int32())]) + + tables = [table_top1, table_top2] + + cases = [ + { + "include_tables": ["*"], + "skip_tables": [], + "skip_dependent_tables": False, + "expect_top": ["test_top1", "test_top2"], + "expect_flattened": ["test_top1", "test_top2", "test_child", "test_grandchild"], + }, + { + "include_tables": ["*"], + "skip_tables": ["test_top1"], + "skip_dependent_tables": False, + "expect_top": ["test_top2"], + "expect_flattened": ["test_top2"], + }, + { + "include_tables": ["test_top1"], + "skip_tables": ["test_top2"], + "skip_dependent_tables": True, + "expect_top": ["test_top1"], + "expect_flattened": ["test_top1"], + }, + { + "include_tables": ["test_child"], + "skip_tables": [], + "skip_dependent_tables": True, + "expect_top": ["test_top1"], + "expect_flattened": ["test_top1", "test_child"], + }, + ] + for case in cases: + got = filter_dfs( + tables=tables, + include_tables=case["include_tables"], + skip_tables=case["skip_tables"], + skip_dependent_tables=case["skip_dependent_tables"], + ) + assert sorted([t.name for t in got]) == sorted(case["expect_top"]), case + + got_flattened = flatten_tables(got) + want_flattened = sorted(case["expect_flattened"]) + got_flattened = sorted([t.name for t in got_flattened]) + assert got_flattened == want_flattened, case From 976b406b8317b0267217ff161753eb5bd1dd6a81 Mon Sep 17 00:00:00 2001 From: Herman Schaaf Date: Thu, 10 Aug 2023 16:15:38 +0100 Subject: [PATCH 2/2] fmt --- cloudquery/sdk/schema/table.py | 61 ++++++++++++++++++++++------------ tests/schema/test_table.py | 27 +++++++++++---- 2 files changed, 60 insertions(+), 28 deletions(-) diff --git a/cloudquery/sdk/schema/table.py b/cloudquery/sdk/schema/table.py index 19c67bf..63c6228 100644 --- a/cloudquery/sdk/schema/table.py +++ b/cloudquery/sdk/schema/table.py @@ -16,14 +16,14 @@ class Client: class Table: def __init__( - self, - name: str, - columns: List[Column], - title: str = "", - description: str = "", - parent: Table = None, - relations: List[Table] = None, - is_incremental: bool = False, + self, + name: str, + columns: List[Column], + title: str = "", + description: str = "", + parent: Table = None, + relations: List[Table] = None, + is_incremental: bool = False, ) -> None: self.name = name self.columns = columns @@ -89,27 +89,40 @@ def tables_to_arrow_schemas(tables: List[Table]): def filter_dfs( - tables: List[Table], - include_tables: List[str], - skip_tables: List[str], - skip_dependent_tables: bool = False + tables: List[Table], + include_tables: List[str], + skip_tables: List[str], + skip_dependent_tables: bool = False, ) -> List[Table]: flattened_tables = flatten_tables(tables) for include_pattern in include_tables: - matched = any(fnmatch.fnmatch(table.name, include_pattern) for table in flattened_tables) + matched = any( + fnmatch.fnmatch(table.name, include_pattern) for table in flattened_tables + ) if not matched: - raise ValueError(f"tables include a pattern {include_pattern} with no matches") + raise ValueError( + f"tables include a pattern {include_pattern} with no matches" + ) for exclude_pattern in skip_tables: - matched = any(fnmatch.fnmatch(table.name, exclude_pattern) for table in flattened_tables) + matched = any( + fnmatch.fnmatch(table.name, exclude_pattern) for table in flattened_tables + ) if not matched: - raise ValueError(f"skip_tables include a pattern {exclude_pattern} with no matches") + raise ValueError( + f"skip_tables include a pattern {exclude_pattern} with no matches" + ) def include_func(t): - return any(fnmatch.fnmatch(t.name, include_pattern) for include_pattern in include_tables) + return any( + fnmatch.fnmatch(t.name, include_pattern) + for include_pattern in include_tables + ) def exclude_func(t): - return any(fnmatch.fnmatch(t.name, exclude_pattern) for exclude_pattern in skip_tables) + return any( + fnmatch.fnmatch(t.name, exclude_pattern) for exclude_pattern in skip_tables + ) return filter_dfs_func(tables, include_func, exclude_func, skip_dependent_tables) @@ -118,7 +131,9 @@ def filter_dfs_func(tt: List[Table], include, exclude, skip_dependent_tables: bo filtered_tables = [] for t in tt: filtered_table = copy.deepcopy(t) - filtered_table = _filter_dfs_impl(filtered_table, False, include, exclude, skip_dependent_tables) + filtered_table = _filter_dfs_impl( + filtered_table, False, include, exclude, skip_dependent_tables + ) if filtered_table is not None: filtered_tables.append(filtered_table) return filtered_tables @@ -126,7 +141,9 @@ def filter_dfs_func(tt: List[Table], include, exclude, skip_dependent_tables: bo def _filter_dfs_impl(t, parent_matched, include, exclude, skip_dependent_tables): def filter_dfs_child(r, matched, include, exclude, skip_dependent_tables): - filtered_child = _filter_dfs_impl(r, matched, include, exclude, skip_dependent_tables) + filtered_child = _filter_dfs_impl( + r, matched, include, exclude, skip_dependent_tables + ) if filtered_child is not None: return True, r return matched, None @@ -140,7 +157,9 @@ def filter_dfs_child(r, matched, include, exclude, skip_dependent_tables): filtered_relations = [] for r in t.relations: - matched, filtered_child = filter_dfs_child(r, matched, include, exclude, skip_dependent_tables) + matched, filtered_child = filter_dfs_child( + r, matched, include, exclude, skip_dependent_tables + ) if filtered_child is not None: filtered_relations.append(filtered_child) diff --git a/tests/schema/test_table.py b/tests/schema/test_table.py index 85e3b97..e7bb073 100644 --- a/tests/schema/test_table.py +++ b/tests/schema/test_table.py @@ -22,12 +22,20 @@ def test_filter_dfs_warns_no_matches(): def test_filter_dfs(): table_grandchild = Table("test_grandchild", [Column("test_column", pa.int32())]) - table_child = Table("test_child", [Column("test_column", pa.int32())], relations=[ - table_grandchild, - ]) - table_top1 = Table("test_top1", [Column("test_column", pa.int32())], relations=[ - table_child, - ]) + table_child = Table( + "test_child", + [Column("test_column", pa.int32())], + relations=[ + table_grandchild, + ], + ) + table_top1 = Table( + "test_top1", + [Column("test_column", pa.int32())], + relations=[ + table_child, + ], + ) table_top2 = Table("test_top2", [Column("test_column", pa.int32())]) tables = [table_top1, table_top2] @@ -38,7 +46,12 @@ def test_filter_dfs(): "skip_tables": [], "skip_dependent_tables": False, "expect_top": ["test_top1", "test_top2"], - "expect_flattened": ["test_top1", "test_top2", "test_child", "test_grandchild"], + "expect_flattened": [ + "test_top1", + "test_top2", + "test_child", + "test_grandchild", + ], }, { "include_tables": ["*"],