diff --git a/datadog_lambda/cold_start.py b/datadog_lambda/cold_start.py index e53341f4..c8862bf1 100644 --- a/datadog_lambda/cold_start.py +++ b/datadog_lambda/cold_start.py @@ -14,12 +14,10 @@ def set_cold_start(): def is_cold_start(): - """Returns the value of the global cold_start - """ + """Returns the value of the global cold_start""" return _cold_start def get_cold_start_tag(): - """Returns the cold start tag to be used in metrics - """ + """Returns the cold start tag to be used in metrics""" return "cold_start:{}".format(str(is_cold_start()).lower()) diff --git a/datadog_lambda/metric.py b/datadog_lambda/metric.py index 5eb1c2ac..62100464 100644 --- a/datadog_lambda/metric.py +++ b/datadog_lambda/metric.py @@ -9,13 +9,14 @@ import base64 import logging +from botocore.exceptions import ClientError import boto3 from datadog import api, initialize, statsd from datadog.threadstats import ThreadStats from datadog_lambda.extension import should_use_extension from datadog_lambda.tags import get_enhanced_metrics_tags, tag_dd_lambda_layer - +KMS_ENCRYPTION_CONTEXT_KEY = "LambdaFunctionName" ENHANCED_METRICS_NAMESPACE_PREFIX = "aws.lambda.enhanced" logger = logging.getLogger(__name__) @@ -213,12 +214,52 @@ def submit_errors_metric(lambda_context): submit_enhanced_metric("errors", lambda_context) -# Set API Key and Host in the module, so they only set once per container +def decrypt_kms_api_key(kms_client, ciphertext): + """ + Decodes and deciphers the base64-encoded ciphertext given as a parameter using KMS. + For this to work properly, the Lambda function must have the appropriate IAM permissions. + + Args: + kms_client: The KMS client to use for decryption + ciphertext (string): The base64-encoded ciphertext to decrypt + """ + decoded_bytes = base64.b64decode(ciphertext) + + """ + The Lambda console UI changed the way it encrypts environment variables. The current behavior + as of May 2021 is to encrypt environment variables using the function name as an encryption + context. Previously, the behavior was to encrypt environment variables without an encryption + context. We need to try both, as supplying the incorrect encryption context will cause + decryption to fail. + """ + # Try with encryption context + function_name = os.environ.get("AWS_LAMBDA_FUNCTION_NAME") + try: + plaintext = kms_client.decrypt( + CiphertextBlob=decoded_bytes, + EncryptionContext={ + KMS_ENCRYPTION_CONTEXT_KEY: function_name, + }, + )["Plaintext"].decode("utf-8") + except ClientError: + logger.debug( + "Failed to decrypt ciphertext with encryption context, retrying without" + ) + # Try without encryption context + plaintext = kms_client.decrypt(CiphertextBlob=decoded_bytes)[ + "Plaintext" + ].decode("utf-8") + + return plaintext + + +# Set API Key if not api._api_key: DD_API_KEY_SECRET_ARN = os.environ.get("DD_API_KEY_SECRET_ARN", "") DD_API_KEY_SSM_NAME = os.environ.get("DD_API_KEY_SSM_NAME", "") DD_KMS_API_KEY = os.environ.get("DD_KMS_API_KEY", "") DD_API_KEY = os.environ.get("DD_API_KEY", os.environ.get("DATADOG_API_KEY", "")) + if DD_API_KEY_SECRET_ARN: api._api_key = boto3.client("secretsmanager").get_secret_value( SecretId=DD_API_KEY_SECRET_ARN @@ -228,11 +269,11 @@ def submit_errors_metric(lambda_context): Name=DD_API_KEY_SSM_NAME, WithDecryption=True )["Parameter"]["Value"] elif DD_KMS_API_KEY: - api._api_key = boto3.client("kms").decrypt( - CiphertextBlob=base64.b64decode(DD_KMS_API_KEY) - )["Plaintext"] + kms_client = boto3.client("kms") + api._api_key = decrypt_kms_api_key(kms_client, DD_KMS_API_KEY) else: api._api_key = DD_API_KEY + logger.debug("Setting DATADOG_API_KEY of length %d", len(api._api_key)) # Set DATADOG_HOST, to send data to a non-default Datadog datacenter diff --git a/datadog_lambda/module_name.py b/datadog_lambda/module_name.py index f018c5d6..9e4a93e5 100644 --- a/datadog_lambda/module_name.py +++ b/datadog_lambda/module_name.py @@ -1,4 +1,3 @@ def modify_module_name(module_name): - """Returns a valid modified module to get imported - """ + """Returns a valid modified module to get imported""" return ".".join(module_name.split("/")) diff --git a/datadog_lambda/tags.py b/datadog_lambda/tags.py index f21c8f92..2ba6ae72 100644 --- a/datadog_lambda/tags.py +++ b/datadog_lambda/tags.py @@ -69,8 +69,7 @@ def parse_lambda_tags_from_arn(lambda_context): def get_runtime_tag(): - """Get the runtime tag from the current Python version - """ + """Get the runtime tag from the current Python version""" major_version, minor_version, _ = python_version_tuple() return "runtime:python{major}.{minor}".format( @@ -79,14 +78,12 @@ def get_runtime_tag(): def get_library_version_tag(): - """Get datadog lambda library tag - """ + """Get datadog lambda library tag""" return "datadog_lambda:v{}".format(__version__) def get_enhanced_metrics_tags(lambda_context): - """Get the list of tags to apply to enhanced metrics - """ + """Get the list of tags to apply to enhanced metrics""" return parse_lambda_tags_from_arn(lambda_context) + [ get_cold_start_tag(), "memorysize:{}".format(lambda_context.memory_limit_in_mb), @@ -96,7 +93,9 @@ def get_enhanced_metrics_tags(lambda_context): def check_if_number(alias): - """ Check if the alias is a version or number. Python 2 has no easy way to test this like Python 3 + """ + Check if the alias is a version or number. + Python 2 has no easy way to test this like Python 3 """ try: float(alias) diff --git a/datadog_lambda/tracing.py b/datadog_lambda/tracing.py index 4f528cc3..ae48a161 100644 --- a/datadog_lambda/tracing.py +++ b/datadog_lambda/tracing.py @@ -186,7 +186,11 @@ def extract_context_custom_extractor(extractor, event, lambda_context): Extract Datadog trace context using a custom trace extractor function """ try: - (trace_id, parent_id, sampling_priority,) = extractor(event, lambda_context) + ( + trace_id, + parent_id, + sampling_priority, + ) = extractor(event, lambda_context) return trace_id, parent_id, sampling_priority except Exception as e: logger.debug("The trace extractor returned with error %s", e) @@ -205,9 +209,11 @@ def extract_dd_trace_context(event, lambda_context, extractor=None): trace_context_source = None if extractor is not None: - (trace_id, parent_id, sampling_priority,) = extract_context_custom_extractor( - extractor, event, lambda_context - ) + ( + trace_id, + parent_id, + sampling_priority, + ) = extract_context_custom_extractor(extractor, event, lambda_context) elif "headers" in event: ( trace_id, diff --git a/scripts/check_format.sh b/scripts/check_format.sh index 7c38d48b..d616017e 100755 --- a/scripts/check_format.sh +++ b/scripts/check_format.sh @@ -6,7 +6,7 @@ if [ "$PYTHON_VERSION" = "2" ]; then echo "Skipping formatting, black not compatible with python 2" exit 0 fi -pip install -Iv black==19.10b0 +pip install -Iv black==21.5b2 python -m black --check datadog_lambda/ --diff python -m black --check tests --diff diff --git a/setup.py b/setup.py index 921aab4c..820f320f 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ # If building for Python 3, use the latest version of setuptools "setuptools>=54.2.0; python_version >= '3.0'", # If building for Python 2, use the latest version that supports Python 2 - "setuptools>=44.1.1; python_version < '3.0'" + "setuptools>=44.1.1; python_version < '3.0'", ], extras_require={ "dev": ["nose2==0.9.1", "flake8==3.7.9", "requests==2.22.0", "boto3==1.10.33"] diff --git a/tests/test_metric.py b/tests/test_metric.py index 848a9328..d22cf881 100644 --- a/tests/test_metric.py +++ b/tests/test_metric.py @@ -6,8 +6,14 @@ except ImportError: from mock import patch, call +from botocore.exceptions import ClientError as BotocoreClientError from datadog.api.exceptions import ClientError -from datadog_lambda.metric import lambda_metric, ThreadStatsWriter +from datadog_lambda.metric import ( + decrypt_kms_api_key, + lambda_metric, + ThreadStatsWriter, + KMS_ENCRYPTION_CONTEXT_KEY, +) from datadog_lambda.tags import _format_dd_lambda_layer_tag @@ -58,3 +64,56 @@ def test_retry_on_remote_disconnected(self): ) lambda_stats.flush() self.assertEqual(self.mock_threadstats_flush_distributions.call_count, 2) + + +MOCK_FUNCTION_NAME = "myFunction" + +# An API key encrypted with KMS and encoded as a base64 string +MOCK_ENCRYPTED_API_KEY_BASE64 = "MjIyMjIyMjIyMjIyMjIyMg==" + +# The encrypted API key after it has been decoded from base64 +MOCK_ENCRYPTED_API_KEY = "2222222222222222" + +# The true value of the API key after decryption by KMS +EXPECTED_DECRYPTED_API_KEY = "1111111111111111" + + +class TestDecryptKMSApiKey(unittest.TestCase): + def test_key_encrypted_with_encryption_context(self): + os.environ["AWS_LAMBDA_FUNCTION_NAME"] = MOCK_FUNCTION_NAME + + class MockKMSClient: + def decrypt(self, CiphertextBlob=None, EncryptionContext={}): + if ( + EncryptionContext.get(KMS_ENCRYPTION_CONTEXT_KEY) + != MOCK_FUNCTION_NAME + ): + raise BotocoreClientError({}, "Decrypt") + if CiphertextBlob == MOCK_ENCRYPTED_API_KEY.encode("utf-8"): + return { + "Plaintext": EXPECTED_DECRYPTED_API_KEY.encode("utf-8"), + } + + mock_kms_client = MockKMSClient() + decrypted_key = decrypt_kms_api_key( + mock_kms_client, MOCK_ENCRYPTED_API_KEY_BASE64 + ) + self.assertEqual(decrypted_key, EXPECTED_DECRYPTED_API_KEY) + + del os.environ["AWS_LAMBDA_FUNCTION_NAME"] + + def test_key_encrypted_without_encryption_context(self): + class MockKMSClient: + def decrypt(self, CiphertextBlob=None, EncryptionContext={}): + if EncryptionContext.get(KMS_ENCRYPTION_CONTEXT_KEY) != None: + raise BotocoreClientError({}, "Decrypt") + if CiphertextBlob == MOCK_ENCRYPTED_API_KEY.encode("utf-8"): + return { + "Plaintext": EXPECTED_DECRYPTED_API_KEY.encode("utf-8"), + } + + mock_kms_client = MockKMSClient() + decrypted_key = decrypt_kms_api_key( + mock_kms_client, MOCK_ENCRYPTED_API_KEY_BASE64 + ) + self.assertEqual(decrypted_key, EXPECTED_DECRYPTED_API_KEY) diff --git a/tests/test_tracing.py b/tests/test_tracing.py index b7365df4..431f339f 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -73,7 +73,8 @@ def test_without_datadog_trace_headers(self): ctx, source = extract_dd_trace_context({}, lambda_ctx) self.assertEqual(source, "xray") self.assertDictEqual( - ctx, {"trace-id": "4369", "parent-id": "65535", "sampling-priority": "2"}, + ctx, + {"trace-id": "4369", "parent-id": "65535", "sampling-priority": "2"}, ) self.assertDictEqual( get_dd_trace_context(), @@ -93,7 +94,8 @@ def test_with_incomplete_datadog_trace_headers(self): ) self.assertEqual(source, "xray") self.assertDictEqual( - ctx, {"trace-id": "4369", "parent-id": "65535", "sampling-priority": "2"}, + ctx, + {"trace-id": "4369", "parent-id": "65535", "sampling-priority": "2"}, ) self.assertDictEqual( get_dd_trace_context(), @@ -118,7 +120,8 @@ def test_with_complete_datadog_trace_headers(self): ) self.assertEqual(source, "event") self.assertDictEqual( - ctx, {"trace-id": "123", "parent-id": "321", "sampling-priority": "1"}, + ctx, + {"trace-id": "123", "parent-id": "321", "sampling-priority": "1"}, ) self.assertDictEqual( get_dd_trace_context(), @@ -160,7 +163,12 @@ def extractor_foo(event, context): ) self.assertEquals(ctx_source, "event") self.assertDictEqual( - ctx, {"trace-id": "123", "parent-id": "321", "sampling-priority": "1",}, + ctx, + { + "trace-id": "123", + "parent-id": "321", + "sampling-priority": "1", + }, ) self.assertDictEqual( get_dd_trace_context(), @@ -189,7 +197,12 @@ def extractor_raiser(event, context): ) self.assertEquals(ctx_source, "xray") self.assertDictEqual( - ctx, {"trace-id": "4369", "parent-id": "65535", "sampling-priority": "2",}, + ctx, + { + "trace-id": "4369", + "parent-id": "65535", + "sampling-priority": "2", + }, ) self.assertDictEqual( get_dd_trace_context(), @@ -235,7 +248,12 @@ def test_with_sqs_distributed_datadog_trace_data(self): ctx, source = extract_dd_trace_context(sqs_event, lambda_ctx) self.assertEqual(source, "event") self.assertDictEqual( - ctx, {"trace-id": "123", "parent-id": "321", "sampling-priority": "1",}, + ctx, + { + "trace-id": "123", + "parent-id": "321", + "sampling-priority": "1", + }, ) self.assertDictEqual( get_dd_trace_context(), @@ -267,7 +285,12 @@ def test_with_client_context_datadog_trace_data(self): ctx, source = extract_dd_trace_context({}, lambda_ctx) self.assertEqual(source, "event") self.assertDictEqual( - ctx, {"trace-id": "666", "parent-id": "777", "sampling-priority": "1",}, + ctx, + { + "trace-id": "666", + "parent-id": "777", + "sampling-priority": "1", + }, ) self.assertDictEqual( get_dd_trace_context(),