Skip to content

Commit 2447967

Browse files
committed
[ext.httpx] Call inject_trace_header with correct subsegment
1 parent 3c04255 commit 2447967

File tree

3 files changed

+90
-68
lines changed

3 files changed

+90
-68
lines changed

aws_xray_sdk/ext/httpx/patch.py

Lines changed: 31 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from aws_xray_sdk.core import xray_recorder
44
from aws_xray_sdk.core.models import http
5-
from aws_xray_sdk.ext.util import UNKNOWN_HOSTNAME, inject_trace_header
5+
from aws_xray_sdk.ext.util import inject_trace_header, get_hostname
66

77

88
def patch():
@@ -32,54 +32,40 @@ def __init__(self, transport: httpx.BaseTransport):
3232
self._wrapped_transport = transport
3333

3434
def handle_request(self, request: httpx.Request) -> httpx.Response:
35-
def httpx_processor(return_value, exception, subsegment, stack, **kwargs):
36-
subsegment.put_http_meta(http.METHOD, request.method)
37-
subsegment.put_http_meta(
38-
http.URL,
39-
str(request.url.copy_with(password=None, query=None, fragment=None)),
40-
)
41-
42-
if return_value is not None:
43-
subsegment.put_http_meta(http.STATUS, return_value.status_code)
44-
elif exception:
45-
subsegment.add_exception(exception, stack)
46-
47-
inject_trace_header(request.headers, xray_recorder.current_subsegment())
48-
return xray_recorder.record_subsegment(
49-
wrapped=self._wrapped_transport.handle_request,
50-
instance=self._wrapped_transport,
51-
args=(request,),
52-
kwargs={},
53-
name=request.url.host or UNKNOWN_HOSTNAME,
54-
namespace="remote",
55-
meta_processor=httpx_processor,
56-
)
35+
with xray_recorder.in_subsegment(
36+
get_hostname(str(request.url)), namespace="remote"
37+
) as subsegment:
38+
if subsegment is not None:
39+
subsegment.put_http_meta(http.METHOD, request.method)
40+
subsegment.put_http_meta(
41+
http.URL,
42+
str(request.url.copy_with(password=None, query=None, fragment=None)),
43+
)
44+
inject_trace_header(request.headers, subsegment)
45+
46+
response = self._wrapped_transport.handle_request(request)
47+
if subsegment is not None:
48+
subsegment.put_http_meta(http.STATUS, response.status_code)
49+
return response
5750

5851

5952
class AsyncInstrumentedTransport(httpx.AsyncBaseTransport):
6053
def __init__(self, transport: httpx.AsyncBaseTransport):
6154
self._wrapped_transport = transport
6255

6356
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
64-
def httpx_processor(return_value, exception, subsegment, stack, **kwargs):
65-
subsegment.put_http_meta(http.METHOD, request.method)
66-
subsegment.put_http_meta(
67-
http.URL,
68-
str(request.url.copy_with(password=None, query=None, fragment=None)),
69-
)
70-
71-
if return_value is not None:
72-
subsegment.put_http_meta(http.STATUS, return_value.status_code)
73-
elif exception:
74-
subsegment.add_exception(exception, stack)
75-
76-
inject_trace_header(request.headers, xray_recorder.current_subsegment())
77-
return await xray_recorder.record_subsegment_async(
78-
wrapped=self._wrapped_transport.handle_async_request,
79-
instance=self._wrapped_transport,
80-
args=(request,),
81-
kwargs={},
82-
name=request.url.host or UNKNOWN_HOSTNAME,
83-
namespace="remote",
84-
meta_processor=httpx_processor,
85-
)
57+
async with xray_recorder.in_subsegment_async(
58+
get_hostname(str(request.url)), namespace="remote"
59+
) as subsegment:
60+
if subsegment is not None:
61+
subsegment.put_http_meta(http.METHOD, request.method)
62+
subsegment.put_http_meta(
63+
http.URL,
64+
str(request.url.copy_with(password=None, query=None, fragment=None)),
65+
)
66+
inject_trace_header(request.headers, subsegment)
67+
68+
response = await self._wrapped_transport.handle_async_request(request)
69+
if subsegment is not None:
70+
subsegment.put_http_meta(http.STATUS, response.status_code)
71+
return response

tests/ext/httpx/test_httpx.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,14 @@ def test_ok(use_client):
3333
url = "http://{}/status/{}?foo=bar".format(BASE_URL, status_code)
3434
if use_client:
3535
with httpx.Client() as client:
36-
client.get(url)
36+
response = client.get(url)
3737
else:
38-
httpx.get(url)
38+
response = httpx.get(url)
39+
assert "x-amzn-trace-id" in response._request.headers
40+
3941
subsegment = xray_recorder.current_segment().subsegments[0]
4042
assert get_hostname(url) == BASE_URL
43+
assert subsegment.namespace == "remote"
4144
assert subsegment.name == get_hostname(url)
4245

4346
http_meta = subsegment.http
@@ -52,10 +55,13 @@ def test_error(use_client):
5255
url = "http://{}/status/{}".format(BASE_URL, status_code)
5356
if use_client:
5457
with httpx.Client() as client:
55-
client.post(url)
58+
response = client.post(url)
5659
else:
57-
httpx.post(url)
60+
response = httpx.post(url)
61+
assert "x-amzn-trace-id" in response._request.headers
62+
5863
subsegment = xray_recorder.current_segment().subsegments[0]
64+
assert subsegment.namespace == "remote"
5965
assert subsegment.name == get_hostname(url)
6066
assert subsegment.error
6167

@@ -71,10 +77,13 @@ def test_throttle(use_client):
7177
url = "http://{}/status/{}".format(BASE_URL, status_code)
7278
if use_client:
7379
with httpx.Client() as client:
74-
client.head(url)
80+
response = client.head(url)
7581
else:
76-
httpx.head(url)
82+
response = httpx.head(url)
83+
assert "x-amzn-trace-id" in response._request.headers
84+
7785
subsegment = xray_recorder.current_segment().subsegments[0]
86+
assert subsegment.namespace == "remote"
7887
assert subsegment.name == get_hostname(url)
7988
assert subsegment.error
8089
assert subsegment.throttle
@@ -91,10 +100,13 @@ def test_fault(use_client):
91100
url = "http://{}/status/{}".format(BASE_URL, status_code)
92101
if use_client:
93102
with httpx.Client() as client:
94-
client.put(url)
103+
response = client.put(url)
95104
else:
96-
httpx.put(url)
105+
response = httpx.put(url)
106+
assert "x-amzn-trace-id" in response._request.headers
107+
97108
subsegment = xray_recorder.current_segment().subsegments[0]
109+
assert subsegment.namespace == "remote"
98110
assert subsegment.name == get_hostname(url)
99111
assert subsegment.fault
100112

@@ -114,6 +126,7 @@ def test_nonexistent_domain(use_client):
114126
httpx.get("http://doesnt.exist")
115127

116128
subsegment = xray_recorder.current_segment().subsegments[0]
129+
assert subsegment.namespace == "remote"
117130
assert subsegment.fault
118131

119132
exception = subsegment.cause["exceptions"][0]
@@ -131,6 +144,7 @@ def test_invalid_url(use_client):
131144
httpx.get(url)
132145

133146
subsegment = xray_recorder.current_segment().subsegments[0]
147+
assert subsegment.namespace == "remote"
134148
assert subsegment.name == get_hostname(url)
135149
assert subsegment.fault
136150

@@ -152,6 +166,7 @@ def test_name_uses_hostname(use_client):
152166
url1 = "http://{}/fakepath/stuff/koo/lai/ahh".format(BASE_URL)
153167
client.get(url1)
154168
subsegment = xray_recorder.current_segment().subsegments[-1]
169+
assert subsegment.namespace == "remote"
155170
assert subsegment.name == BASE_URL
156171
http_meta1 = subsegment.http
157172
assert http_meta1["request"]["url"] == strip_url(url1)
@@ -160,6 +175,7 @@ def test_name_uses_hostname(use_client):
160175
url2 = "http://{}/".format(BASE_URL)
161176
client.get(url2, params={"some": "payload", "not": "toBeIncluded"})
162177
subsegment = xray_recorder.current_segment().subsegments[-1]
178+
assert subsegment.namespace == "remote"
163179
assert subsegment.name == BASE_URL
164180
http_meta2 = subsegment.http
165181
assert http_meta2["request"]["url"] == strip_url(url2)
@@ -171,6 +187,7 @@ def test_name_uses_hostname(use_client):
171187
except httpx.ConnectError:
172188
pass
173189
subsegment = xray_recorder.current_segment().subsegments[-1]
190+
assert subsegment.namespace == "remote"
174191
assert subsegment.name == "subdomain." + BASE_URL
175192
http_meta3 = subsegment.http
176193
assert http_meta3["request"]["url"] == strip_url(url3)
@@ -186,10 +203,13 @@ def test_strip_http_url(use_client):
186203
url = "http://{}/get?foo=bar".format(BASE_URL)
187204
if use_client:
188205
with httpx.Client() as client:
189-
client.get(url)
206+
response = client.get(url)
190207
else:
191-
httpx.get(url)
208+
response = httpx.get(url)
209+
assert "x-amzn-trace-id" in response._request.headers
210+
192211
subsegment = xray_recorder.current_segment().subsegments[0]
212+
assert subsegment.namespace == "remote"
193213
assert subsegment.name == get_hostname(url)
194214

195215
http_meta = subsegment.http

tests/ext/httpx/test_httpx_async.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,12 @@ async def test_ok_async():
3232
status_code = 200
3333
url = "http://{}/status/{}?foo=bar".format(BASE_URL, status_code)
3434
async with httpx.AsyncClient() as client:
35-
await client.get(url)
35+
response = await client.get(url)
36+
assert "x-amzn-trace-id" in response._request.headers
37+
3638
subsegment = xray_recorder.current_segment().subsegments[0]
3739
assert get_hostname(url) == BASE_URL
40+
assert subsegment.namespace == "remote"
3841
assert subsegment.name == get_hostname(url)
3942

4043
http_meta = subsegment.http
@@ -48,8 +51,11 @@ async def test_error_async():
4851
status_code = 400
4952
url = "http://{}/status/{}".format(BASE_URL, status_code)
5053
async with httpx.AsyncClient() as client:
51-
await client.post(url)
54+
response = await client.post(url)
55+
assert "x-amzn-trace-id" in response._request.headers
56+
5257
subsegment = xray_recorder.current_segment().subsegments[0]
58+
assert subsegment.namespace == "remote"
5359
assert subsegment.name == get_hostname(url)
5460
assert subsegment.error
5561

@@ -64,8 +70,11 @@ async def test_throttle_async():
6470
status_code = 429
6571
url = "http://{}/status/{}".format(BASE_URL, status_code)
6672
async with httpx.AsyncClient() as client:
67-
await client.head(url)
73+
response = await client.head(url)
74+
assert "x-amzn-trace-id" in response._request.headers
75+
6876
subsegment = xray_recorder.current_segment().subsegments[0]
77+
assert subsegment.namespace == "remote"
6978
assert subsegment.name == get_hostname(url)
7079
assert subsegment.error
7180
assert subsegment.throttle
@@ -81,8 +90,11 @@ async def test_fault_async():
8190
status_code = 500
8291
url = "http://{}/status/{}".format(BASE_URL, status_code)
8392
async with httpx.AsyncClient() as client:
84-
await client.put(url)
93+
response = await client.put(url)
94+
assert "x-amzn-trace-id" in response._request.headers
95+
8596
subsegment = xray_recorder.current_segment().subsegments[0]
97+
assert subsegment.namespace == "remote"
8698
assert subsegment.name == get_hostname(url)
8799
assert subsegment.fault
88100

@@ -94,13 +106,12 @@ async def test_fault_async():
94106

95107
@pytest.mark.asyncio
96108
async def test_nonexistent_domain_async():
97-
try:
109+
with pytest.raises(httpx.ConnectError):
98110
async with httpx.AsyncClient() as client:
99111
await client.get("http://doesnt.exist")
100-
except Exception:
101-
# prevent uncatch exception from breaking test run
102-
pass
112+
103113
subsegment = xray_recorder.current_segment().subsegments[0]
114+
assert subsegment.namespace == "remote"
104115
assert subsegment.fault
105116

106117
exception = subsegment.cause["exceptions"][0]
@@ -110,13 +121,12 @@ async def test_nonexistent_domain_async():
110121
@pytest.mark.asyncio
111122
async def test_invalid_url_async():
112123
url = "KLSDFJKLSDFJKLSDJF"
113-
try:
124+
with pytest.raises(httpx.UnsupportedProtocol):
114125
async with httpx.AsyncClient() as client:
115126
await client.get(url)
116-
except Exception:
117-
# prevent uncatch exception from breaking test run
118-
pass
127+
119128
subsegment = xray_recorder.current_segment().subsegments[0]
129+
assert subsegment.namespace == "remote"
120130
assert subsegment.name == get_hostname(url)
121131
assert subsegment.fault
122132

@@ -133,6 +143,7 @@ async def test_name_uses_hostname_async():
133143
url1 = "http://{}/fakepath/stuff/koo/lai/ahh".format(BASE_URL)
134144
await client.get(url1)
135145
subsegment = xray_recorder.current_segment().subsegments[-1]
146+
assert subsegment.namespace == "remote"
136147
assert subsegment.name == BASE_URL
137148
http_meta1 = subsegment.http
138149
assert http_meta1["request"]["url"] == strip_url(url1)
@@ -141,6 +152,7 @@ async def test_name_uses_hostname_async():
141152
url2 = "http://{}/".format(BASE_URL)
142153
await client.get(url2, params={"some": "payload", "not": "toBeIncluded"})
143154
subsegment = xray_recorder.current_segment().subsegments[-1]
155+
assert subsegment.namespace == "remote"
144156
assert subsegment.name == BASE_URL
145157
http_meta2 = subsegment.http
146158
assert http_meta2["request"]["url"] == strip_url(url2)
@@ -153,6 +165,7 @@ async def test_name_uses_hostname_async():
153165
# This is an invalid url so we dont want to break the test
154166
pass
155167
subsegment = xray_recorder.current_segment().subsegments[-1]
168+
assert subsegment.namespace == "remote"
156169
assert subsegment.name == "subdomain." + BASE_URL
157170
http_meta3 = subsegment.http
158171
assert http_meta3["request"]["url"] == strip_url(url3)
@@ -164,8 +177,11 @@ async def test_strip_http_url_async():
164177
status_code = 200
165178
url = "http://{}/get?foo=bar".format(BASE_URL)
166179
async with httpx.AsyncClient() as client:
167-
await client.get(url)
180+
response = await client.get(url)
181+
assert "x-amzn-trace-id" in response._request.headers
182+
168183
subsegment = xray_recorder.current_segment().subsegments[0]
184+
assert subsegment.namespace == "remote"
169185
assert subsegment.name == get_hostname(url)
170186

171187
http_meta = subsegment.http

0 commit comments

Comments
 (0)