Skip to content

Commit ebeee64

Browse files
committed
patch requests when it's used
1 parent 7e860e0 commit ebeee64

File tree

4 files changed

+50
-20
lines changed

4 files changed

+50
-20
lines changed

datadog_lambda/patch.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010

1111
from wrapt import wrap_function_wrapper as wrap
12+
from wrapt.importer import when_imported
1213

1314
from datadog_lambda.tracing import get_dd_trace_context
1415

@@ -29,7 +30,7 @@ def patch_all():
2930
Datadog trace context.
3031
"""
3132
_patch_httplib()
32-
_patch_requests()
33+
_ensure_patch_requests()
3334

3435

3536
def _patch_httplib():
@@ -45,7 +46,20 @@ def _patch_httplib():
4546
logger.debug("Patched %s", httplib_module)
4647

4748

48-
def _patch_requests():
49+
def _ensure_patch_requests():
50+
"""
51+
`requests` is third-party, may not be installed or used,
52+
but ensure it gets patched if installed and used.
53+
"""
54+
if "requests" in sys.modules:
55+
# already imported, patch now
56+
_patch_requests(sys.modules["requests"])
57+
else:
58+
# patch when imported
59+
when_imported("requests")(_patch_requests)
60+
61+
62+
def _patch_requests(module):
4963
"""
5064
Patch the high-level HTTP client module `requests`
5165
if it's installed.
@@ -66,9 +80,9 @@ def _wrap_requests_request(func, instance, args, kwargs):
6680
into the outgoing requests.
6781
"""
6882
context = get_dd_trace_context()
69-
if "headers" in kwargs:
83+
if "headers" in kwargs and isinstance(kwargs["headers"], dict):
7084
kwargs["headers"].update(context)
71-
elif len(args) >= 5:
85+
elif len(args) >= 5 and isinstance(args[4], dict):
7286
args[4].update(context)
7387
else:
7488
kwargs["headers"] = context
@@ -86,9 +100,9 @@ def _wrap_httplib_request(func, instance, args, kwargs):
86100
the Datadog trace headers into the outgoing requests.
87101
"""
88102
context = get_dd_trace_context()
89-
if "headers" in kwargs:
103+
if "headers" in kwargs and isinstance(kwargs["headers"], dict):
90104
kwargs["headers"].update(context)
91-
elif len(args) >= 4:
105+
elif len(args) >= 4 and isinstance(args[3], dict):
92106
args[3].update(context)
93107
else:
94108
kwargs["headers"] = context

tests/integration/handle.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import json
2-
31
from datadog_lambda.metric import lambda_metric
42
from datadog_lambda.wrapper import datadog_lambda_wrapper
53

tests/integration/http_requests.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import requests
32

43
from datadog_lambda.metric import lambda_metric
@@ -12,7 +11,7 @@ def handle(event, context):
1211
"tests.integration.count", 21, tags=["test:integration", "role:hello"]
1312
)
1413

15-
us_response = requests.get("https://ip-ranges.datadoghq.com/")
16-
eu_response = requests.get("https://ip-ranges.datadoghq.eu/")
14+
requests.get("https://ip-ranges.datadoghq.com/")
15+
requests.get("https://ip-ranges.datadoghq.eu/")
1716

1817
return {"statusCode": 200, "body": {"message": "hello, dog!"}}

tests/test_patch.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,20 @@
77

88
from datadog_lambda.patch import (
99
_patch_httplib,
10-
_patch_requests,
10+
_ensure_patch_requests,
1111
)
1212
from datadog_lambda.constants import TraceHeader
1313

1414

1515
class TestPatchHTTPClients(unittest.TestCase):
1616

1717
def setUp(self):
18-
patcher = patch('datadog_lambda.patch.get_dd_trace_context')
18+
patcher = patch("datadog_lambda.patch.get_dd_trace_context")
1919
self.mock_get_dd_trace_context = patcher.start()
2020
self.mock_get_dd_trace_context.return_value = {
21-
TraceHeader.TRACE_ID: '123',
22-
TraceHeader.PARENT_ID: '321',
23-
TraceHeader.SAMPLING_PRIORITY: '2',
21+
TraceHeader.TRACE_ID: "123",
22+
TraceHeader.PARENT_ID: "321",
23+
TraceHeader.SAMPLING_PRIORITY: "2",
2424
}
2525
self.addCleanup(patcher.stop)
2626

@@ -34,10 +34,29 @@ def test_patch_httplib(self):
3434
self.mock_get_dd_trace_context.assert_called()
3535

3636
def test_patch_requests(self):
37-
_patch_requests()
37+
_ensure_patch_requests()
3838
import requests
3939
r = requests.get("https://www.datadoghq.com/")
4040
self.mock_get_dd_trace_context.assert_called()
41-
self.assertEqual(r.request.headers[TraceHeader.TRACE_ID], '123')
42-
self.assertEqual(r.request.headers[TraceHeader.PARENT_ID], '321')
43-
self.assertEqual(r.request.headers[TraceHeader.SAMPLING_PRIORITY], '2')
41+
self.assertEqual(r.request.headers[TraceHeader.TRACE_ID], "123")
42+
self.assertEqual(r.request.headers[TraceHeader.PARENT_ID], "321")
43+
self.assertEqual(r.request.headers[TraceHeader.SAMPLING_PRIORITY], "2")
44+
45+
def test_patch_requests_with_headers(self):
46+
_ensure_patch_requests()
47+
import requests
48+
r = requests.get("https://www.datadoghq.com/", headers={"key": "value"})
49+
self.mock_get_dd_trace_context.assert_called()
50+
self.assertEqual(r.request.headers["key"], "value")
51+
self.assertEqual(r.request.headers[TraceHeader.TRACE_ID], "123")
52+
self.assertEqual(r.request.headers[TraceHeader.PARENT_ID], "321")
53+
self.assertEqual(r.request.headers[TraceHeader.SAMPLING_PRIORITY], "2")
54+
55+
def test_patch_requests_with_headers_none(self):
56+
_ensure_patch_requests()
57+
import requests
58+
r = requests.get("https://www.datadoghq.com/", headers=None)
59+
self.mock_get_dd_trace_context.assert_called()
60+
self.assertEqual(r.request.headers[TraceHeader.TRACE_ID], "123")
61+
self.assertEqual(r.request.headers[TraceHeader.PARENT_ID], "321")
62+
self.assertEqual(r.request.headers[TraceHeader.SAMPLING_PRIORITY], "2")

0 commit comments

Comments
 (0)