Skip to content

Support decrypting API keys encrypted with an encryption context #145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jun 10, 2021
6 changes: 2 additions & 4 deletions datadog_lambda/cold_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ def set_cold_start():


def is_cold_start():
"""Returns the value of the global cold_start
"""
"""Returns the value of the global cold_start"""
return _cold_start


def get_cold_start_tag():
"""Returns the cold start tag to be used in metrics
"""
"""Returns the cold start tag to be used in metrics"""
return "cold_start:{}".format(str(is_cold_start()).lower())
51 changes: 46 additions & 5 deletions datadog_lambda/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
import base64
import logging

from botocore.exceptions import ClientError
import boto3
from datadog import api, initialize, statsd
from datadog.threadstats import ThreadStats
from datadog_lambda.extension import should_use_extension
from datadog_lambda.tags import get_enhanced_metrics_tags, tag_dd_lambda_layer


KMS_ENCRYPTION_CONTEXT_KEY = "LambdaFunctionName"
ENHANCED_METRICS_NAMESPACE_PREFIX = "aws.lambda.enhanced"

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -213,12 +214,52 @@ def submit_errors_metric(lambda_context):
submit_enhanced_metric("errors", lambda_context)


# Set API Key and Host in the module, so they only set once per container
def decrypt_kms_api_key(kms_client, ciphertext):
"""
Decodes and deciphers the base64-encoded ciphertext given as a parameter using KMS.
For this to work properly, the Lambda function must have the appropriate IAM permissions.

Args:
kms_client: The KMS client to use for decryption
ciphertext (string): The base64-encoded ciphertext to decrypt
"""
decoded_bytes = base64.b64decode(ciphertext)

"""
The Lambda console UI changed the way it encrypts environment variables. The current behavior
as of May 2021 is to encrypt environment variables using the function name as an encryption
context. Previously, the behavior was to encrypt environment variables without an encryption
context. We need to try both, as supplying the incorrect encryption context will cause
decryption to fail.
"""
# Try with encryption context
function_name = os.environ.get("AWS_LAMBDA_FUNCTION_NAME")
try:
plaintext = kms_client.decrypt(
CiphertextBlob=decoded_bytes,
EncryptionContext={
KMS_ENCRYPTION_CONTEXT_KEY: function_name,
},
)["Plaintext"].decode("utf-8")
except ClientError:
logger.debug(
"Failed to decrypt ciphertext with encryption context, retrying without"
)
# Try without encryption context
plaintext = kms_client.decrypt(CiphertextBlob=decoded_bytes)[
"Plaintext"
].decode("utf-8")

return plaintext


# Set API Key
if not api._api_key:
DD_API_KEY_SECRET_ARN = os.environ.get("DD_API_KEY_SECRET_ARN", "")
DD_API_KEY_SSM_NAME = os.environ.get("DD_API_KEY_SSM_NAME", "")
DD_KMS_API_KEY = os.environ.get("DD_KMS_API_KEY", "")
DD_API_KEY = os.environ.get("DD_API_KEY", os.environ.get("DATADOG_API_KEY", ""))

if DD_API_KEY_SECRET_ARN:
api._api_key = boto3.client("secretsmanager").get_secret_value(
SecretId=DD_API_KEY_SECRET_ARN
Expand All @@ -228,11 +269,11 @@ def submit_errors_metric(lambda_context):
Name=DD_API_KEY_SSM_NAME, WithDecryption=True
)["Parameter"]["Value"]
elif DD_KMS_API_KEY:
api._api_key = boto3.client("kms").decrypt(
CiphertextBlob=base64.b64decode(DD_KMS_API_KEY)
)["Plaintext"]
kms_client = boto3.client("kms")
api._api_key = decrypt_kms_api_key(kms_client, DD_KMS_API_KEY)
else:
api._api_key = DD_API_KEY

logger.debug("Setting DATADOG_API_KEY of length %d", len(api._api_key))

# Set DATADOG_HOST, to send data to a non-default Datadog datacenter
Expand Down
3 changes: 1 addition & 2 deletions datadog_lambda/module_name.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
def modify_module_name(module_name):
"""Returns a valid modified module to get imported
"""
"""Returns a valid modified module to get imported"""
return ".".join(module_name.split("/"))
13 changes: 6 additions & 7 deletions datadog_lambda/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def parse_lambda_tags_from_arn(lambda_context):


def get_runtime_tag():
"""Get the runtime tag from the current Python version
"""
"""Get the runtime tag from the current Python version"""
major_version, minor_version, _ = python_version_tuple()

return "runtime:python{major}.{minor}".format(
Expand All @@ -79,14 +78,12 @@ def get_runtime_tag():


def get_library_version_tag():
"""Get datadog lambda library tag
"""
"""Get datadog lambda library tag"""
return "datadog_lambda:v{}".format(__version__)


def get_enhanced_metrics_tags(lambda_context):
"""Get the list of tags to apply to enhanced metrics
"""
"""Get the list of tags to apply to enhanced metrics"""
return parse_lambda_tags_from_arn(lambda_context) + [
get_cold_start_tag(),
"memorysize:{}".format(lambda_context.memory_limit_in_mb),
Expand All @@ -96,7 +93,9 @@ def get_enhanced_metrics_tags(lambda_context):


def check_if_number(alias):
""" Check if the alias is a version or number. Python 2 has no easy way to test this like Python 3
"""
Check if the alias is a version or number.
Python 2 has no easy way to test this like Python 3
"""
try:
float(alias)
Expand Down
14 changes: 10 additions & 4 deletions datadog_lambda/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,11 @@ def extract_context_custom_extractor(extractor, event, lambda_context):
Extract Datadog trace context using a custom trace extractor function
"""
try:
(trace_id, parent_id, sampling_priority,) = extractor(event, lambda_context)
(
trace_id,
parent_id,
sampling_priority,
) = extractor(event, lambda_context)
return trace_id, parent_id, sampling_priority
except Exception as e:
logger.debug("The trace extractor returned with error %s", e)
Expand All @@ -205,9 +209,11 @@ def extract_dd_trace_context(event, lambda_context, extractor=None):
trace_context_source = None

if extractor is not None:
(trace_id, parent_id, sampling_priority,) = extract_context_custom_extractor(
extractor, event, lambda_context
)
(
trace_id,
parent_id,
sampling_priority,
) = extract_context_custom_extractor(extractor, event, lambda_context)
elif "headers" in event:
(
trace_id,
Expand Down
2 changes: 1 addition & 1 deletion scripts/check_format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ if [ "$PYTHON_VERSION" = "2" ]; then
echo "Skipping formatting, black not compatible with python 2"
exit 0
fi
pip install -Iv black==19.10b0
pip install -Iv black==21.5b2

python -m black --check datadog_lambda/ --diff
python -m black --check tests --diff
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# If building for Python 3, use the latest version of setuptools
"setuptools>=54.2.0; python_version >= '3.0'",
# If building for Python 2, use the latest version that supports Python 2
"setuptools>=44.1.1; python_version < '3.0'"
"setuptools>=44.1.1; python_version < '3.0'",
],
extras_require={
"dev": ["nose2==0.9.1", "flake8==3.7.9", "requests==2.22.0", "boto3==1.10.33"]
Expand Down
61 changes: 60 additions & 1 deletion tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
except ImportError:
from mock import patch, call

from botocore.exceptions import ClientError as BotocoreClientError
from datadog.api.exceptions import ClientError
from datadog_lambda.metric import lambda_metric, ThreadStatsWriter
from datadog_lambda.metric import (
decrypt_kms_api_key,
lambda_metric,
ThreadStatsWriter,
KMS_ENCRYPTION_CONTEXT_KEY,
)
from datadog_lambda.tags import _format_dd_lambda_layer_tag


Expand Down Expand Up @@ -58,3 +64,56 @@ def test_retry_on_remote_disconnected(self):
)
lambda_stats.flush()
self.assertEqual(self.mock_threadstats_flush_distributions.call_count, 2)


MOCK_FUNCTION_NAME = "myFunction"

# An API key encrypted with KMS and encoded as a base64 string
MOCK_ENCRYPTED_API_KEY_BASE64 = "MjIyMjIyMjIyMjIyMjIyMg=="

# The encrypted API key after it has been decoded from base64
MOCK_ENCRYPTED_API_KEY = "2222222222222222"

# The true value of the API key after decryption by KMS
EXPECTED_DECRYPTED_API_KEY = "1111111111111111"


class TestDecryptKMSApiKey(unittest.TestCase):
def test_key_encrypted_with_encryption_context(self):
os.environ["AWS_LAMBDA_FUNCTION_NAME"] = MOCK_FUNCTION_NAME

class MockKMSClient:
def decrypt(self, CiphertextBlob=None, EncryptionContext={}):
if (
EncryptionContext.get(KMS_ENCRYPTION_CONTEXT_KEY)
!= MOCK_FUNCTION_NAME
):
raise BotocoreClientError({}, "Decrypt")
if CiphertextBlob == MOCK_ENCRYPTED_API_KEY.encode("utf-8"):
return {
"Plaintext": EXPECTED_DECRYPTED_API_KEY.encode("utf-8"),
}

mock_kms_client = MockKMSClient()
decrypted_key = decrypt_kms_api_key(
mock_kms_client, MOCK_ENCRYPTED_API_KEY_BASE64
)
self.assertEqual(decrypted_key, EXPECTED_DECRYPTED_API_KEY)

del os.environ["AWS_LAMBDA_FUNCTION_NAME"]

def test_key_encrypted_without_encryption_context(self):
class MockKMSClient:
def decrypt(self, CiphertextBlob=None, EncryptionContext={}):
if EncryptionContext.get(KMS_ENCRYPTION_CONTEXT_KEY) != None:
raise BotocoreClientError({}, "Decrypt")
if CiphertextBlob == MOCK_ENCRYPTED_API_KEY.encode("utf-8"):
return {
"Plaintext": EXPECTED_DECRYPTED_API_KEY.encode("utf-8"),
}

mock_kms_client = MockKMSClient()
decrypted_key = decrypt_kms_api_key(
mock_kms_client, MOCK_ENCRYPTED_API_KEY_BASE64
)
self.assertEqual(decrypted_key, EXPECTED_DECRYPTED_API_KEY)
37 changes: 30 additions & 7 deletions tests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def test_without_datadog_trace_headers(self):
ctx, source = extract_dd_trace_context({}, lambda_ctx)
self.assertEqual(source, "xray")
self.assertDictEqual(
ctx, {"trace-id": "4369", "parent-id": "65535", "sampling-priority": "2"},
ctx,
{"trace-id": "4369", "parent-id": "65535", "sampling-priority": "2"},
)
self.assertDictEqual(
get_dd_trace_context(),
Expand All @@ -93,7 +94,8 @@ def test_with_incomplete_datadog_trace_headers(self):
)
self.assertEqual(source, "xray")
self.assertDictEqual(
ctx, {"trace-id": "4369", "parent-id": "65535", "sampling-priority": "2"},
ctx,
{"trace-id": "4369", "parent-id": "65535", "sampling-priority": "2"},
)
self.assertDictEqual(
get_dd_trace_context(),
Expand All @@ -118,7 +120,8 @@ def test_with_complete_datadog_trace_headers(self):
)
self.assertEqual(source, "event")
self.assertDictEqual(
ctx, {"trace-id": "123", "parent-id": "321", "sampling-priority": "1"},
ctx,
{"trace-id": "123", "parent-id": "321", "sampling-priority": "1"},
)
self.assertDictEqual(
get_dd_trace_context(),
Expand Down Expand Up @@ -160,7 +163,12 @@ def extractor_foo(event, context):
)
self.assertEquals(ctx_source, "event")
self.assertDictEqual(
ctx, {"trace-id": "123", "parent-id": "321", "sampling-priority": "1",},
ctx,
{
"trace-id": "123",
"parent-id": "321",
"sampling-priority": "1",
},
)
self.assertDictEqual(
get_dd_trace_context(),
Expand Down Expand Up @@ -189,7 +197,12 @@ def extractor_raiser(event, context):
)
self.assertEquals(ctx_source, "xray")
self.assertDictEqual(
ctx, {"trace-id": "4369", "parent-id": "65535", "sampling-priority": "2",},
ctx,
{
"trace-id": "4369",
"parent-id": "65535",
"sampling-priority": "2",
},
)
self.assertDictEqual(
get_dd_trace_context(),
Expand Down Expand Up @@ -235,7 +248,12 @@ def test_with_sqs_distributed_datadog_trace_data(self):
ctx, source = extract_dd_trace_context(sqs_event, lambda_ctx)
self.assertEqual(source, "event")
self.assertDictEqual(
ctx, {"trace-id": "123", "parent-id": "321", "sampling-priority": "1",},
ctx,
{
"trace-id": "123",
"parent-id": "321",
"sampling-priority": "1",
},
)
self.assertDictEqual(
get_dd_trace_context(),
Expand Down Expand Up @@ -267,7 +285,12 @@ def test_with_client_context_datadog_trace_data(self):
ctx, source = extract_dd_trace_context({}, lambda_ctx)
self.assertEqual(source, "event")
self.assertDictEqual(
ctx, {"trace-id": "666", "parent-id": "777", "sampling-priority": "1",},
ctx,
{
"trace-id": "666",
"parent-id": "777",
"sampling-priority": "1",
},
)
self.assertDictEqual(
get_dd_trace_context(),
Expand Down