Skip to content

Commit e22428f

Browse files
committed
Defer trace context extraction to ddtrace.
1 parent aa5a1c9 commit e22428f

File tree

2 files changed

+199
-325
lines changed

2 files changed

+199
-325
lines changed

datadog_lambda/tracing.py

Lines changed: 70 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ddtrace import tracer, patch, Span
3333
from ddtrace import __version__ as ddtrace_version
3434
from ddtrace.propagation.http import HTTPPropagator
35+
from ddtrace.context import Context
3536
from datadog_lambda import __version__ as datadog_lambda_version
3637
from datadog_lambda.trigger import (
3738
_EventSource,
@@ -53,7 +54,7 @@
5354

5455
logger = logging.getLogger(__name__)
5556

56-
dd_trace_context = {}
57+
dd_trace_context = None
5758
dd_tracing_enabled = os.environ.get("DD_TRACE_ENABLED", "false").lower() == "true"
5859
if dd_tracing_enabled:
5960
# Enable the telemetry client if the user has opted in
@@ -102,11 +103,11 @@ def _get_xray_trace_context():
102103
)
103104
if xray_trace_entity is None:
104105
return None
105-
trace_context = {
106-
"trace-id": _convert_xray_trace_id(xray_trace_entity.get("trace_id")),
107-
"parent-id": _convert_xray_entity_id(xray_trace_entity.get("parent_id")),
108-
"sampling-priority": _convert_xray_sampling(xray_trace_entity.get("sampled")),
109-
}
106+
trace_context = Context(
107+
trace_id=_convert_xray_trace_id(xray_trace_entity.get("trace_id")),
108+
span_id=_convert_xray_entity_id(xray_trace_entity.get("parent_id")),
109+
sampling_priority=_convert_xray_sampling(xray_trace_entity.get("sampled")),
110+
)
110111
logger.debug(
111112
"Converted trace context %s from X-Ray segment %s",
112113
trace_context,
@@ -124,26 +125,17 @@ def _get_dd_trace_py_context():
124125
if not span:
125126
return None
126127

127-
parent_id = span.context.span_id
128-
trace_id = span.context.trace_id
129-
sampling_priority = span.context.sampling_priority
130128
logger.debug(
131129
"found dd trace context: %s", (span.context.trace_id, span.context.span_id)
132130
)
133-
return {
134-
"parent-id": str(parent_id),
135-
"trace-id": str(trace_id),
136-
"sampling-priority": str(sampling_priority),
137-
"source": TraceContextSource.DDTRACE,
138-
}
131+
return span.context
139132

140133

141-
def _context_obj_to_headers(obj):
142-
return {
143-
TraceHeader.TRACE_ID: str(obj.get("trace-id")),
144-
TraceHeader.PARENT_ID: str(obj.get("parent-id")),
145-
TraceHeader.SAMPLING_PRIORITY: str(obj.get("sampling-priority")),
146-
}
134+
def _is_context_complete(context):
135+
return context and \
136+
context.trace_id and \
137+
context.span_id and \
138+
context.sampling_priority is not None
147139

148140

149141
def create_dd_dummy_metadata_subsegment(
@@ -164,28 +156,14 @@ def extract_context_from_lambda_context(lambda_context):
164156
165157
dd_trace libraries inject this trace context on synchronous invocations
166158
"""
159+
dd_data = None
167160
client_context = lambda_context.client_context
168-
trace_id = None
169-
parent_id = None
170-
sampling_priority = None
171161
if client_context and client_context.custom:
162+
dd_data = client_context.custom
172163
if "_datadog" in client_context.custom:
173164
# Legacy trace propagation dict
174-
dd_data = client_context.custom.get("_datadog", {})
175-
trace_id = dd_data.get(TraceHeader.TRACE_ID)
176-
parent_id = dd_data.get(TraceHeader.PARENT_ID)
177-
sampling_priority = dd_data.get(TraceHeader.SAMPLING_PRIORITY)
178-
elif (
179-
TraceHeader.TRACE_ID in client_context.custom
180-
and TraceHeader.PARENT_ID in client_context.custom
181-
and TraceHeader.SAMPLING_PRIORITY in client_context.custom
182-
):
183-
# New trace propagation keys
184-
trace_id = client_context.custom.get(TraceHeader.TRACE_ID)
185-
parent_id = client_context.custom.get(TraceHeader.PARENT_ID)
186-
sampling_priority = client_context.custom.get(TraceHeader.SAMPLING_PRIORITY)
187-
188-
return trace_id, parent_id, sampling_priority
165+
dd_data = client_context.custom.get("_datadog")
166+
return propagator.extract(dd_data)
189167

190168

191169
def extract_context_from_http_event_or_context(
@@ -205,33 +183,17 @@ def extract_context_from_http_event_or_context(
205183
EventTypes.API_GATEWAY, subtype=EventSubtypes.HTTP_API
206184
)
207185
injected_authorizer_data = get_injected_authorizer_data(event, is_http_api)
208-
if injected_authorizer_data:
209-
try:
210-
# fail fast on any KeyError here
211-
trace_id = injected_authorizer_data[TraceHeader.TRACE_ID]
212-
parent_id = injected_authorizer_data[TraceHeader.PARENT_ID]
213-
sampling_priority = injected_authorizer_data.get(
214-
TraceHeader.SAMPLING_PRIORITY
215-
)
216-
return trace_id, parent_id, sampling_priority
217-
except Exception as e:
218-
logger.debug(
219-
"extract_context_from_authorizer_event returned with error. \
220-
Continue without injecting the authorizer span %s",
221-
e,
222-
)
223-
224-
headers = event.get("headers", {}) or {}
225-
lowercase_headers = {k.lower(): v for k, v in headers.items()}
186+
context = propagator.extract(injected_authorizer_data)
187+
if _is_context_complete(context):
188+
return context
226189

227-
trace_id = lowercase_headers.get(TraceHeader.TRACE_ID)
228-
parent_id = lowercase_headers.get(TraceHeader.PARENT_ID)
229-
sampling_priority = lowercase_headers.get(TraceHeader.SAMPLING_PRIORITY)
190+
headers = event.get("headers")
191+
context = propagator.extract(headers)
230192

231-
if not trace_id or not parent_id or not sampling_priority:
193+
if not _is_context_complete(context):
232194
return extract_context_from_lambda_context(lambda_context)
233195

234-
return trace_id, parent_id, sampling_priority
196+
return context
235197

236198

237199
def create_sns_event(message):
@@ -262,12 +224,9 @@ def extract_context_from_sqs_or_sns_event_or_context(event, lambda_context):
262224

263225
# EventBridge => SQS
264226
try:
265-
(
266-
trace_id,
267-
parent_id,
268-
sampling_priority,
269-
) = _extract_context_from_eventbridge_sqs_event(event)
270-
return trace_id, parent_id, sampling_priority
227+
context = _extract_context_from_eventbridge_sqs_event(event)
228+
if _is_context_complete(context):
229+
return context
271230
except Exception:
272231
logger.debug("Failed extracting context as EventBridge to SQS.")
273232

@@ -311,11 +270,7 @@ def extract_context_from_sqs_or_sns_event_or_context(event, lambda_context):
311270
"context from String or Binary SQS/SNS message attributes"
312271
)
313272
dd_data = json.loads(dd_json_data)
314-
trace_id = dd_data.get(TraceHeader.TRACE_ID)
315-
parent_id = dd_data.get(TraceHeader.PARENT_ID)
316-
sampling_priority = dd_data.get(TraceHeader.SAMPLING_PRIORITY)
317-
318-
return trace_id, parent_id, sampling_priority
273+
return propagator.extract(dd_data)
319274
except Exception as e:
320275
logger.debug("The trace extractor returned with error %s", e)
321276
return extract_context_from_lambda_context(lambda_context)
@@ -329,20 +284,12 @@ def _extract_context_from_eventbridge_sqs_event(event):
329284
This is only possible if first record in `Records` contains a
330285
`body` field which contains the EventBridge `detail` as a JSON string.
331286
"""
332-
try:
333-
first_record = event.get("Records")[0]
334-
if "body" in first_record:
335-
body_str = first_record.get("body", {})
336-
body = json.loads(body_str)
337-
338-
detail = body.get("detail")
339-
dd_context = detail.get("_datadog")
340-
trace_id = dd_context.get(TraceHeader.TRACE_ID)
341-
parent_id = dd_context.get(TraceHeader.PARENT_ID)
342-
sampling_priority = dd_context.get(TraceHeader.SAMPLING_PRIORITY)
343-
return trace_id, parent_id, sampling_priority
344-
except Exception:
345-
raise
287+
first_record = event.get("Records")[0]
288+
body_str = first_record.get("body")
289+
body = json.loads(body_str)
290+
detail = body.get("detail")
291+
dd_context = detail.get("_datadog")
292+
return propagator.extract(dd_context)
346293

347294

348295
def extract_context_from_eventbridge_event(event, lambda_context):
@@ -355,10 +302,7 @@ def extract_context_from_eventbridge_event(event, lambda_context):
355302
dd_context = detail.get("_datadog")
356303
if not dd_context:
357304
return extract_context_from_lambda_context(lambda_context)
358-
trace_id = dd_context.get(TraceHeader.TRACE_ID)
359-
parent_id = dd_context.get(TraceHeader.PARENT_ID)
360-
sampling_priority = dd_context.get(TraceHeader.SAMPLING_PRIORITY)
361-
return trace_id, parent_id, sampling_priority
305+
return propagator.extract(dd_context)
362306
except Exception as e:
363307
logger.debug("The trace extractor returned with error %s", e)
364308
return extract_context_from_lambda_context(lambda_context)
@@ -381,10 +325,7 @@ def extract_context_from_kinesis_event(event, lambda_context):
381325
if not dd_ctx:
382326
return extract_context_from_lambda_context(lambda_context)
383327

384-
trace_id = dd_ctx.get(TraceHeader.TRACE_ID)
385-
parent_id = dd_ctx.get(TraceHeader.PARENT_ID)
386-
sampling_priority = dd_ctx.get(TraceHeader.SAMPLING_PRIORITY)
387-
return trace_id, parent_id, sampling_priority
328+
return propagator.extract(dd_ctx)
388329
except Exception as e:
389330
logger.debug("The trace extractor returned with error %s", e)
390331
return extract_context_from_lambda_context(lambda_context)
@@ -417,7 +358,7 @@ def extract_context_from_step_functions(event, lambda_context):
417358
execution_id + "#" + state_name + "#" + state_entered_time
418359
)
419360
sampling_priority = SamplingPriority.AUTO_KEEP
420-
return trace_id, parent_id, sampling_priority
361+
return Context(trace_id=trace_id, span_id=parent_id, sampling_priority=sampling_priority)
421362
except Exception as e:
422363
logger.debug("The Step Functions trace extractor returned with error %s", e)
423364
return extract_context_from_lambda_context(lambda_context)
@@ -433,12 +374,10 @@ def extract_context_custom_extractor(extractor, event, lambda_context):
433374
parent_id,
434375
sampling_priority,
435376
) = extractor(event, lambda_context)
436-
return trace_id, parent_id, sampling_priority
377+
return Context(trace_id=trace_id, span_id=parent_id, sampling_priority=sampling_priority)
437378
except Exception as e:
438379
logger.debug("The trace extractor returned with error %s", e)
439380

440-
return None, None, None
441-
442381

443382
def is_authorizer_response(response) -> bool:
444383
try:
@@ -504,56 +443,25 @@ def extract_dd_trace_context(
504443
event_source = parse_event_source(event)
505444

506445
if extractor is not None:
507-
(
508-
trace_id,
509-
parent_id,
510-
sampling_priority,
511-
) = extract_context_custom_extractor(extractor, event, lambda_context)
446+
context = extract_context_custom_extractor(extractor, event, lambda_context)
512447
elif isinstance(event, (set, dict)) and "headers" in event:
513-
(
514-
trace_id,
515-
parent_id,
516-
sampling_priority,
517-
) = extract_context_from_http_event_or_context(
448+
context = extract_context_from_http_event_or_context(
518449
event, lambda_context, event_source, decode_authorizer_context
519450
)
520451
elif event_source.equals(EventTypes.SNS) or event_source.equals(EventTypes.SQS):
521-
(
522-
trace_id,
523-
parent_id,
524-
sampling_priority,
525-
) = extract_context_from_sqs_or_sns_event_or_context(event, lambda_context)
452+
context = extract_context_from_sqs_or_sns_event_or_context(event, lambda_context)
526453
elif event_source.equals(EventTypes.EVENTBRIDGE):
527-
(
528-
trace_id,
529-
parent_id,
530-
sampling_priority,
531-
) = extract_context_from_eventbridge_event(event, lambda_context)
454+
context = extract_context_from_eventbridge_event(event, lambda_context)
532455
elif event_source.equals(EventTypes.KINESIS):
533-
(
534-
trace_id,
535-
parent_id,
536-
sampling_priority,
537-
) = extract_context_from_kinesis_event(event, lambda_context)
456+
context = extract_context_from_kinesis_event(event, lambda_context)
538457
elif event_source.equals(EventTypes.STEPFUNCTIONS):
539-
(
540-
trace_id,
541-
parent_id,
542-
sampling_priority,
543-
) = extract_context_from_step_functions(event, lambda_context)
458+
context = extract_context_from_step_functions(event, lambda_context)
544459
else:
545-
trace_id, parent_id, sampling_priority = extract_context_from_lambda_context(
546-
lambda_context
547-
)
460+
context = extract_context_from_lambda_context(lambda_context)
548461

549-
if trace_id and parent_id and sampling_priority:
462+
if _is_context_complete(context):
550463
logger.debug("Extracted Datadog trace context from event or context")
551-
metadata = {
552-
"trace-id": trace_id,
553-
"parent-id": parent_id,
554-
"sampling-priority": sampling_priority,
555-
}
556-
dd_trace_context = metadata.copy()
464+
dd_trace_context = context
557465
trace_context_source = TraceContextSource.EVENT
558466
else:
559467
# AWS Lambda runtime caches global variables between invocations,
@@ -579,8 +487,8 @@ def get_dd_trace_context():
579487
"""
580488
if dd_tracing_enabled:
581489
dd_trace_py_context = _get_dd_trace_py_context()
582-
if dd_trace_py_context is not None:
583-
return _context_obj_to_headers(dd_trace_py_context)
490+
if _is_context_complete(dd_trace_py_context):
491+
return dd_trace_py_context
584492

585493
global dd_trace_context
586494

@@ -592,16 +500,17 @@ def get_dd_trace_context():
592500
% e
593501
)
594502
if not xray_context:
595-
return {}
596-
597-
if not dd_trace_context:
598-
return _context_obj_to_headers(xray_context)
503+
return None
599504

600-
context = dd_trace_context.copy()
601-
context["parent-id"] = xray_context.get("parent-id")
602-
logger.debug("Set parent id from xray trace context: %s", context.get("parent-id"))
505+
if not _is_context_complete(dd_trace_context):
506+
return xray_context
603507

604-
return _context_obj_to_headers(context)
508+
logger.debug("Set parent id from xray trace context: %s", xray_context.span_id)
509+
return Context(
510+
trace_id=dd_trace_context.trace_id,
511+
span_id=xray_context.span_id,
512+
sampling_priority=dd_trace_context.sampling_priority,
513+
)
605514

606515

607516
def set_correlation_ids():
@@ -620,13 +529,12 @@ def set_correlation_ids():
620529
return
621530

622531
context = get_dd_trace_context()
623-
if not context:
532+
if not _is_context_complete(context):
624533
return
625534

626-
span = tracer.trace("dummy.span")
627-
span.trace_id = int(context[TraceHeader.TRACE_ID])
628-
span.span_id = int(context[TraceHeader.PARENT_ID])
629535

536+
tracer.context_provider.activate(context)
537+
tracer.trace("dummy.span")
630538
logger.debug("correlation ids set")
631539

632540

@@ -669,18 +577,20 @@ def is_lambda_context():
669577

670578
def set_dd_trace_py_root(trace_context_source, merge_xray_traces):
671579
if trace_context_source == TraceContextSource.EVENT or merge_xray_traces:
672-
context = dict(dd_trace_context)
580+
context = Context(
581+
trace_id=dd_trace_context.trace_id,
582+
span_id=dd_trace_context.span_id,
583+
sampling_priority=dd_trace_context.sampling_priority,
584+
)
673585
if merge_xray_traces:
674586
xray_context = _get_xray_trace_context()
675-
if xray_context is not None:
676-
context["parent-id"] = xray_context.get("parent-id")
587+
if xray_context.span_id:
588+
context.span_id = xray_context.span_id
677589

678-
headers = _context_obj_to_headers(context)
679-
span_context = propagator.extract(headers)
680-
tracer.context_provider.activate(span_context)
590+
tracer.context_provider.activate(context)
681591
logger.debug(
682592
"Set dd trace root context to: %s",
683-
(span_context.trace_id, span_context.span_id),
593+
(context.trace_id, context.span_id),
684594
)
685595

686596

0 commit comments

Comments
 (0)