diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index bedacd80..4152a407 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,4 +1,5 @@ -from typing import Any, ClassVar, Union, List, Type +import base64 +from typing import Any, ClassVar, Union, List, Type, Optional import logging import attrs @@ -103,7 +104,7 @@ class Snowflake(Database): _conn: Any - def __init__(self, *, schema: str, **kw): + def __init__(self, *, schema: str, key: Optional[str] = None, key_content: Optional[str] = None, **kw): super().__init__() snowflake, serialization, default_backend = import_snowflake() logging.getLogger("snowflake.connector").setLevel(logging.WARNING) @@ -113,20 +114,29 @@ def __init__(self, *, schema: str, **kw): logging.getLogger("snowflake.connector.network").disabled = True assert '"' not in schema, "Schema name should not contain quotes!" + if key_content and key: + raise ConnectError("Only key value or key file path can be specified, not both") + + key_bytes = None + if key: + with open(key, "rb") as f: + key_bytes = f.read() + if key_content: + key_bytes = base64.b64decode(key_content) + # If a private key is used, read it from the specified path and pass it as "private_key" to the connector. - if "key" in kw: - with open(kw.get("key"), "rb") as key: - if "password" in kw: - raise ConnectError("Cannot use password and key at the same time") - if kw.get("private_key_passphrase"): - encoded_passphrase = kw.get("private_key_passphrase").encode() - else: - encoded_passphrase = None - p_key = serialization.load_pem_private_key( - key.read(), - password=encoded_passphrase, - backend=default_backend(), - ) + if key_bytes: + if "password" in kw: + raise ConnectError("Cannot use password and key at the same time") + if kw.get("private_key_passphrase"): + encoded_passphrase = kw.get("private_key_passphrase").encode() + else: + encoded_passphrase = None + p_key = serialization.load_pem_private_key( + key_bytes, + password=encoded_passphrase, + backend=default_backend(), + ) kw["private_key"] = p_key.private_bytes( encoding=serialization.Encoding.DER,