diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 72ab3797..5592bfae 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -53,6 +53,12 @@ def import_bigquery_service_account(): return service_account +def import_bigquery_service_account_impersonation(): + from google.auth import impersonated_credentials + + return impersonated_credentials + + @attrs.define(frozen=False) class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: @@ -248,6 +254,13 @@ def __init__(self, project, *, dataset, bigquery_credentials=None, **kw): keyfile, scopes=["https://www.googleapis.com/auth/cloud-platform"], ) + elif kw.get("impersonate_service_account"): + bigquery_service_account_impersonation = import_bigquery_service_account_impersonation() + credentials = bigquery_service_account_impersonation.Credentials( + source_credentials=credentials, + target_principal=kw["impersonate_service_account"], + target_scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) self._client = bigquery.Client(project=project, credentials=credentials, **kw) self.project = project diff --git a/data_diff/dbt_parser.py b/data_diff/dbt_parser.py index 592f8b3a..5ca183cd 100644 --- a/data_diff/dbt_parser.py +++ b/data_diff/dbt_parser.py @@ -378,6 +378,7 @@ def set_connection(self): "driver": conn_type, "project": credentials.get("project") or credentials.get("database"), "dataset": credentials.get("dataset") or credentials.get("schema"), + "impersonate_service_account": credentials.get("impersonate_service_account"), } self.threads = credentials.get("threads") diff --git a/tests/test_dbt_parser.py b/tests/test_dbt_parser.py index 4fbdbde1..cc967551 100644 --- a/tests/test_dbt_parser.py +++ b/tests/test_dbt_parser.py @@ -269,6 +269,28 @@ def test_set_connection_bigquery_oauth(self): self.assertEqual(mock_self.connection.get("project"), expected_credentials["project"]) self.assertEqual(mock_self.connection.get("dataset"), expected_credentials["dataset"]) + def test_set_connection_bigquery_oauth_sa_impersonation(self): + expected_driver = "bigquery" + expected_credentials = { + "method": "oauth", + "project": "a_project", + "dataset": "a_dataset", + "impersonate_service_account": "a_service_account@yourproject.iam.gserviceaccount.com", + } + mock_self = Mock() + mock_self.get_connection_creds.return_value = (expected_credentials, expected_driver) + + DbtParser.set_connection(mock_self) + + self.assertIsInstance(mock_self.connection, dict) + self.assertEqual(mock_self.connection.get("driver"), expected_driver) + self.assertEqual(mock_self.connection.get("project"), expected_credentials["project"]) + self.assertEqual(mock_self.connection.get("dataset"), expected_credentials["dataset"]) + self.assertEqual( + mock_self.connection.get("impersonate_service_account"), + expected_credentials["impersonate_service_account"], + ) + def test_set_connection_bigquery_svc_account(self): expected_driver = "bigquery" expected_credentials = {