Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Bigquery dbt impersonation #715

Merged
merged 9 commits into from
Oct 13, 2023
Merged
13 changes: 13 additions & 0 deletions data_diff/databases/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ def import_bigquery_service_account():
return service_account


def import_bigquery_service_account_impersonation():
from google.auth import impersonated_credentials

return impersonated_credentials


@attrs.define(frozen=False)
class Mixin_MD5(AbstractMixin_MD5):
def md5_as_int(self, s: str) -> str:
Expand Down Expand Up @@ -248,6 +254,13 @@ def __init__(self, project, *, dataset, bigquery_credentials=None, **kw):
keyfile,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
elif kw.get("impersonate_service_account"):
bigquery_service_account_impersonation = import_bigquery_service_account_impersonation()
credentials = bigquery_service_account_impersonation.Credentials(
source_credentials=credentials,
target_principal=kw["impersonate_service_account"],
target_scopes=["https://www.googleapis.com/auth/cloud-platform"],
)

self._client = bigquery.Client(project=project, credentials=credentials, **kw)
self.project = project
Expand Down
1 change: 1 addition & 0 deletions data_diff/dbt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def set_connection(self):
"driver": conn_type,
"project": credentials.get("project") or credentials.get("database"),
"dataset": credentials.get("dataset") or credentials.get("schema"),
"impersonate_service_account": credentials.get("impersonate_service_account"),
}

self.threads = credentials.get("threads")
Expand Down
22 changes: 22 additions & 0 deletions tests/test_dbt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,28 @@ def test_set_connection_bigquery_oauth(self):
self.assertEqual(mock_self.connection.get("project"), expected_credentials["project"])
self.assertEqual(mock_self.connection.get("dataset"), expected_credentials["dataset"])

def test_set_connection_bigquery_oauth_sa_impersonation(self):
expected_driver = "bigquery"
expected_credentials = {
"method": "oauth",
"project": "a_project",
"dataset": "a_dataset",
"impersonate_service_account": "a_service_account@yourproject.iam.gserviceaccount.com",
}
mock_self = Mock()
mock_self.get_connection_creds.return_value = (expected_credentials, expected_driver)

DbtParser.set_connection(mock_self)

self.assertIsInstance(mock_self.connection, dict)
self.assertEqual(mock_self.connection.get("driver"), expected_driver)
self.assertEqual(mock_self.connection.get("project"), expected_credentials["project"])
self.assertEqual(mock_self.connection.get("dataset"), expected_credentials["dataset"])
self.assertEqual(
mock_self.connection.get("impersonate_service_account"),
expected_credentials["impersonate_service_account"],
)

def test_set_connection_bigquery_svc_account(self):
expected_driver = "bigquery"
expected_credentials = {
Expand Down