Skip to content

Commit 315c73d

Browse files
authored
Support decrypting API keys encrypted with an encryption context (#145)
1 parent d467b16 commit 315c73d

File tree

9 files changed

+157
-32
lines changed

9 files changed

+157
-32
lines changed

datadog_lambda/cold_start.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,10 @@ def set_cold_start():
1414

1515

1616
def is_cold_start():
17-
"""Returns the value of the global cold_start
18-
"""
17+
"""Returns the value of the global cold_start"""
1918
return _cold_start
2019

2120

2221
def get_cold_start_tag():
23-
"""Returns the cold start tag to be used in metrics
24-
"""
22+
"""Returns the cold start tag to be used in metrics"""
2523
return "cold_start:{}".format(str(is_cold_start()).lower())

datadog_lambda/metric.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
import base64
1010
import logging
1111

12+
from botocore.exceptions import ClientError
1213
import boto3
1314
from datadog import api, initialize, statsd
1415
from datadog.threadstats import ThreadStats
1516
from datadog_lambda.extension import should_use_extension
1617
from datadog_lambda.tags import get_enhanced_metrics_tags, tag_dd_lambda_layer
1718

18-
19+
KMS_ENCRYPTION_CONTEXT_KEY = "LambdaFunctionName"
1920
ENHANCED_METRICS_NAMESPACE_PREFIX = "aws.lambda.enhanced"
2021

2122
logger = logging.getLogger(__name__)
@@ -213,12 +214,52 @@ def submit_errors_metric(lambda_context):
213214
submit_enhanced_metric("errors", lambda_context)
214215

215216

216-
# Set API Key and Host in the module, so they only set once per container
217+
def decrypt_kms_api_key(kms_client, ciphertext):
218+
"""
219+
Decodes and deciphers the base64-encoded ciphertext given as a parameter using KMS.
220+
For this to work properly, the Lambda function must have the appropriate IAM permissions.
221+
222+
Args:
223+
kms_client: The KMS client to use for decryption
224+
ciphertext (string): The base64-encoded ciphertext to decrypt
225+
"""
226+
decoded_bytes = base64.b64decode(ciphertext)
227+
228+
"""
229+
The Lambda console UI changed the way it encrypts environment variables. The current behavior
230+
as of May 2021 is to encrypt environment variables using the function name as an encryption
231+
context. Previously, the behavior was to encrypt environment variables without an encryption
232+
context. We need to try both, as supplying the incorrect encryption context will cause
233+
decryption to fail.
234+
"""
235+
# Try with encryption context
236+
function_name = os.environ.get("AWS_LAMBDA_FUNCTION_NAME")
237+
try:
238+
plaintext = kms_client.decrypt(
239+
CiphertextBlob=decoded_bytes,
240+
EncryptionContext={
241+
KMS_ENCRYPTION_CONTEXT_KEY: function_name,
242+
},
243+
)["Plaintext"].decode("utf-8")
244+
except ClientError:
245+
logger.debug(
246+
"Failed to decrypt ciphertext with encryption context, retrying without"
247+
)
248+
# Try without encryption context
249+
plaintext = kms_client.decrypt(CiphertextBlob=decoded_bytes)[
250+
"Plaintext"
251+
].decode("utf-8")
252+
253+
return plaintext
254+
255+
256+
# Set API Key
217257
if not api._api_key:
218258
DD_API_KEY_SECRET_ARN = os.environ.get("DD_API_KEY_SECRET_ARN", "")
219259
DD_API_KEY_SSM_NAME = os.environ.get("DD_API_KEY_SSM_NAME", "")
220260
DD_KMS_API_KEY = os.environ.get("DD_KMS_API_KEY", "")
221261
DD_API_KEY = os.environ.get("DD_API_KEY", os.environ.get("DATADOG_API_KEY", ""))
262+
222263
if DD_API_KEY_SECRET_ARN:
223264
api._api_key = boto3.client("secretsmanager").get_secret_value(
224265
SecretId=DD_API_KEY_SECRET_ARN
@@ -228,11 +269,11 @@ def submit_errors_metric(lambda_context):
228269
Name=DD_API_KEY_SSM_NAME, WithDecryption=True
229270
)["Parameter"]["Value"]
230271
elif DD_KMS_API_KEY:
231-
api._api_key = boto3.client("kms").decrypt(
232-
CiphertextBlob=base64.b64decode(DD_KMS_API_KEY)
233-
)["Plaintext"]
272+
kms_client = boto3.client("kms")
273+
api._api_key = decrypt_kms_api_key(kms_client, DD_KMS_API_KEY)
234274
else:
235275
api._api_key = DD_API_KEY
276+
236277
logger.debug("Setting DATADOG_API_KEY of length %d", len(api._api_key))
237278

238279
# Set DATADOG_HOST, to send data to a non-default Datadog datacenter

datadog_lambda/module_name.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
def modify_module_name(module_name):
2-
"""Returns a valid modified module to get imported
3-
"""
2+
"""Returns a valid modified module to get imported"""
43
return ".".join(module_name.split("/"))

datadog_lambda/tags.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ def parse_lambda_tags_from_arn(lambda_context):
6969

7070

7171
def get_runtime_tag():
72-
"""Get the runtime tag from the current Python version
73-
"""
72+
"""Get the runtime tag from the current Python version"""
7473
major_version, minor_version, _ = python_version_tuple()
7574

7675
return "runtime:python{major}.{minor}".format(
@@ -79,14 +78,12 @@ def get_runtime_tag():
7978

8079

8180
def get_library_version_tag():
82-
"""Get datadog lambda library tag
83-
"""
81+
"""Get datadog lambda library tag"""
8482
return "datadog_lambda:v{}".format(__version__)
8583

8684

8785
def get_enhanced_metrics_tags(lambda_context):
88-
"""Get the list of tags to apply to enhanced metrics
89-
"""
86+
"""Get the list of tags to apply to enhanced metrics"""
9087
return parse_lambda_tags_from_arn(lambda_context) + [
9188
get_cold_start_tag(),
9289
"memorysize:{}".format(lambda_context.memory_limit_in_mb),
@@ -96,7 +93,9 @@ def get_enhanced_metrics_tags(lambda_context):
9693

9794

9895
def check_if_number(alias):
99-
""" Check if the alias is a version or number. Python 2 has no easy way to test this like Python 3
96+
"""
97+
Check if the alias is a version or number.
98+
Python 2 has no easy way to test this like Python 3
10099
"""
101100
try:
102101
float(alias)

datadog_lambda/tracing.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,11 @@ def extract_context_custom_extractor(extractor, event, lambda_context):
186186
Extract Datadog trace context using a custom trace extractor function
187187
"""
188188
try:
189-
(trace_id, parent_id, sampling_priority,) = extractor(event, lambda_context)
189+
(
190+
trace_id,
191+
parent_id,
192+
sampling_priority,
193+
) = extractor(event, lambda_context)
190194
return trace_id, parent_id, sampling_priority
191195
except Exception as e:
192196
logger.debug("The trace extractor returned with error %s", e)
@@ -205,9 +209,11 @@ def extract_dd_trace_context(event, lambda_context, extractor=None):
205209
trace_context_source = None
206210

207211
if extractor is not None:
208-
(trace_id, parent_id, sampling_priority,) = extract_context_custom_extractor(
209-
extractor, event, lambda_context
210-
)
212+
(
213+
trace_id,
214+
parent_id,
215+
sampling_priority,
216+
) = extract_context_custom_extractor(extractor, event, lambda_context)
211217
elif "headers" in event:
212218
(
213219
trace_id,

scripts/check_format.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ if [ "$PYTHON_VERSION" = "2" ]; then
66
echo "Skipping formatting, black not compatible with python 2"
77
exit 0
88
fi
9-
pip install -Iv black==19.10b0
9+
pip install -Iv black==21.5b2
1010

1111
python -m black --check datadog_lambda/ --diff
1212
python -m black --check tests --diff

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
# If building for Python 3, use the latest version of setuptools
3636
"setuptools>=54.2.0; python_version >= '3.0'",
3737
# If building for Python 2, use the latest version that supports Python 2
38-
"setuptools>=44.1.1; python_version < '3.0'"
38+
"setuptools>=44.1.1; python_version < '3.0'",
3939
],
4040
extras_require={
4141
"dev": ["nose2==0.9.1", "flake8==3.7.9", "requests==2.22.0", "boto3==1.10.33"]

tests/test_metric.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,14 @@
66
except ImportError:
77
from mock import patch, call
88

9+
from botocore.exceptions import ClientError as BotocoreClientError
910
from datadog.api.exceptions import ClientError
10-
from datadog_lambda.metric import lambda_metric, ThreadStatsWriter
11+
from datadog_lambda.metric import (
12+
decrypt_kms_api_key,
13+
lambda_metric,
14+
ThreadStatsWriter,
15+
KMS_ENCRYPTION_CONTEXT_KEY,
16+
)
1117
from datadog_lambda.tags import _format_dd_lambda_layer_tag
1218

1319

@@ -58,3 +64,56 @@ def test_retry_on_remote_disconnected(self):
5864
)
5965
lambda_stats.flush()
6066
self.assertEqual(self.mock_threadstats_flush_distributions.call_count, 2)
67+
68+
69+
MOCK_FUNCTION_NAME = "myFunction"
70+
71+
# An API key encrypted with KMS and encoded as a base64 string
72+
MOCK_ENCRYPTED_API_KEY_BASE64 = "MjIyMjIyMjIyMjIyMjIyMg=="
73+
74+
# The encrypted API key after it has been decoded from base64
75+
MOCK_ENCRYPTED_API_KEY = "2222222222222222"
76+
77+
# The true value of the API key after decryption by KMS
78+
EXPECTED_DECRYPTED_API_KEY = "1111111111111111"
79+
80+
81+
class TestDecryptKMSApiKey(unittest.TestCase):
82+
def test_key_encrypted_with_encryption_context(self):
83+
os.environ["AWS_LAMBDA_FUNCTION_NAME"] = MOCK_FUNCTION_NAME
84+
85+
class MockKMSClient:
86+
def decrypt(self, CiphertextBlob=None, EncryptionContext={}):
87+
if (
88+
EncryptionContext.get(KMS_ENCRYPTION_CONTEXT_KEY)
89+
!= MOCK_FUNCTION_NAME
90+
):
91+
raise BotocoreClientError({}, "Decrypt")
92+
if CiphertextBlob == MOCK_ENCRYPTED_API_KEY.encode("utf-8"):
93+
return {
94+
"Plaintext": EXPECTED_DECRYPTED_API_KEY.encode("utf-8"),
95+
}
96+
97+
mock_kms_client = MockKMSClient()
98+
decrypted_key = decrypt_kms_api_key(
99+
mock_kms_client, MOCK_ENCRYPTED_API_KEY_BASE64
100+
)
101+
self.assertEqual(decrypted_key, EXPECTED_DECRYPTED_API_KEY)
102+
103+
del os.environ["AWS_LAMBDA_FUNCTION_NAME"]
104+
105+
def test_key_encrypted_without_encryption_context(self):
106+
class MockKMSClient:
107+
def decrypt(self, CiphertextBlob=None, EncryptionContext={}):
108+
if EncryptionContext.get(KMS_ENCRYPTION_CONTEXT_KEY) != None:
109+
raise BotocoreClientError({}, "Decrypt")
110+
if CiphertextBlob == MOCK_ENCRYPTED_API_KEY.encode("utf-8"):
111+
return {
112+
"Plaintext": EXPECTED_DECRYPTED_API_KEY.encode("utf-8"),
113+
}
114+
115+
mock_kms_client = MockKMSClient()
116+
decrypted_key = decrypt_kms_api_key(
117+
mock_kms_client, MOCK_ENCRYPTED_API_KEY_BASE64
118+
)
119+
self.assertEqual(decrypted_key, EXPECTED_DECRYPTED_API_KEY)

tests/test_tracing.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def test_without_datadog_trace_headers(self):
7373
ctx, source = extract_dd_trace_context({}, lambda_ctx)
7474
self.assertEqual(source, "xray")
7575
self.assertDictEqual(
76-
ctx, {"trace-id": "4369", "parent-id": "65535", "sampling-priority": "2"},
76+
ctx,
77+
{"trace-id": "4369", "parent-id": "65535", "sampling-priority": "2"},
7778
)
7879
self.assertDictEqual(
7980
get_dd_trace_context(),
@@ -93,7 +94,8 @@ def test_with_incomplete_datadog_trace_headers(self):
9394
)
9495
self.assertEqual(source, "xray")
9596
self.assertDictEqual(
96-
ctx, {"trace-id": "4369", "parent-id": "65535", "sampling-priority": "2"},
97+
ctx,
98+
{"trace-id": "4369", "parent-id": "65535", "sampling-priority": "2"},
9799
)
98100
self.assertDictEqual(
99101
get_dd_trace_context(),
@@ -118,7 +120,8 @@ def test_with_complete_datadog_trace_headers(self):
118120
)
119121
self.assertEqual(source, "event")
120122
self.assertDictEqual(
121-
ctx, {"trace-id": "123", "parent-id": "321", "sampling-priority": "1"},
123+
ctx,
124+
{"trace-id": "123", "parent-id": "321", "sampling-priority": "1"},
122125
)
123126
self.assertDictEqual(
124127
get_dd_trace_context(),
@@ -160,7 +163,12 @@ def extractor_foo(event, context):
160163
)
161164
self.assertEquals(ctx_source, "event")
162165
self.assertDictEqual(
163-
ctx, {"trace-id": "123", "parent-id": "321", "sampling-priority": "1",},
166+
ctx,
167+
{
168+
"trace-id": "123",
169+
"parent-id": "321",
170+
"sampling-priority": "1",
171+
},
164172
)
165173
self.assertDictEqual(
166174
get_dd_trace_context(),
@@ -189,7 +197,12 @@ def extractor_raiser(event, context):
189197
)
190198
self.assertEquals(ctx_source, "xray")
191199
self.assertDictEqual(
192-
ctx, {"trace-id": "4369", "parent-id": "65535", "sampling-priority": "2",},
200+
ctx,
201+
{
202+
"trace-id": "4369",
203+
"parent-id": "65535",
204+
"sampling-priority": "2",
205+
},
193206
)
194207
self.assertDictEqual(
195208
get_dd_trace_context(),
@@ -235,7 +248,12 @@ def test_with_sqs_distributed_datadog_trace_data(self):
235248
ctx, source = extract_dd_trace_context(sqs_event, lambda_ctx)
236249
self.assertEqual(source, "event")
237250
self.assertDictEqual(
238-
ctx, {"trace-id": "123", "parent-id": "321", "sampling-priority": "1",},
251+
ctx,
252+
{
253+
"trace-id": "123",
254+
"parent-id": "321",
255+
"sampling-priority": "1",
256+
},
239257
)
240258
self.assertDictEqual(
241259
get_dd_trace_context(),
@@ -267,7 +285,12 @@ def test_with_client_context_datadog_trace_data(self):
267285
ctx, source = extract_dd_trace_context({}, lambda_ctx)
268286
self.assertEqual(source, "event")
269287
self.assertDictEqual(
270-
ctx, {"trace-id": "666", "parent-id": "777", "sampling-priority": "1",},
288+
ctx,
289+
{
290+
"trace-id": "666",
291+
"parent-id": "777",
292+
"sampling-priority": "1",
293+
},
271294
)
272295
self.assertDictEqual(
273296
get_dd_trace_context(),

0 commit comments

Comments
 (0)