32
32
from ddtrace import tracer , patch , Span
33
33
from ddtrace import __version__ as ddtrace_version
34
34
from ddtrace .propagation .http import HTTPPropagator
35
+ from ddtrace .context import Context
35
36
from datadog_lambda import __version__ as datadog_lambda_version
36
37
from datadog_lambda .trigger import (
37
38
_EventSource ,
53
54
54
55
logger = logging .getLogger (__name__ )
55
56
56
- dd_trace_context = {}
57
+ dd_trace_context = None
57
58
dd_tracing_enabled = os .environ .get ("DD_TRACE_ENABLED" , "false" ).lower () == "true"
58
59
if dd_tracing_enabled :
59
60
# Enable the telemetry client if the user has opted in
@@ -102,11 +103,11 @@ def _get_xray_trace_context():
102
103
)
103
104
if xray_trace_entity is None :
104
105
return None
105
- trace_context = {
106
- "trace-id" : _convert_xray_trace_id (xray_trace_entity .get ("trace_id" )),
107
- "parent-id" : _convert_xray_entity_id (xray_trace_entity .get ("parent_id" )),
108
- "sampling-priority" : _convert_xray_sampling (xray_trace_entity .get ("sampled" )),
109
- }
106
+ trace_context = Context (
107
+ trace_id = _convert_xray_trace_id (xray_trace_entity .get ("trace_id" )),
108
+ span_id = _convert_xray_entity_id (xray_trace_entity .get ("parent_id" )),
109
+ sampling_priority = _convert_xray_sampling (xray_trace_entity .get ("sampled" )),
110
+ )
110
111
logger .debug (
111
112
"Converted trace context %s from X-Ray segment %s" ,
112
113
trace_context ,
@@ -124,26 +125,17 @@ def _get_dd_trace_py_context():
124
125
if not span :
125
126
return None
126
127
127
- parent_id = span .context .span_id
128
- trace_id = span .context .trace_id
129
- sampling_priority = span .context .sampling_priority
130
128
logger .debug (
131
129
"found dd trace context: %s" , (span .context .trace_id , span .context .span_id )
132
130
)
133
- return {
134
- "parent-id" : str (parent_id ),
135
- "trace-id" : str (trace_id ),
136
- "sampling-priority" : str (sampling_priority ),
137
- "source" : TraceContextSource .DDTRACE ,
138
- }
131
+ return span .context
139
132
140
133
141
- def _context_obj_to_headers (obj ):
142
- return {
143
- TraceHeader .TRACE_ID : str (obj .get ("trace-id" )),
144
- TraceHeader .PARENT_ID : str (obj .get ("parent-id" )),
145
- TraceHeader .SAMPLING_PRIORITY : str (obj .get ("sampling-priority" )),
146
- }
134
+ def _is_context_complete (context ):
135
+ return context and \
136
+ context .trace_id and \
137
+ context .span_id and \
138
+ context .sampling_priority is not None
147
139
148
140
149
141
def create_dd_dummy_metadata_subsegment (
@@ -164,28 +156,14 @@ def extract_context_from_lambda_context(lambda_context):
164
156
165
157
dd_trace libraries inject this trace context on synchronous invocations
166
158
"""
159
+ dd_data = None
167
160
client_context = lambda_context .client_context
168
- trace_id = None
169
- parent_id = None
170
- sampling_priority = None
171
161
if client_context and client_context .custom :
162
+ dd_data = client_context .custom
172
163
if "_datadog" in client_context .custom :
173
164
# Legacy trace propagation dict
174
- dd_data = client_context .custom .get ("_datadog" , {})
175
- trace_id = dd_data .get (TraceHeader .TRACE_ID )
176
- parent_id = dd_data .get (TraceHeader .PARENT_ID )
177
- sampling_priority = dd_data .get (TraceHeader .SAMPLING_PRIORITY )
178
- elif (
179
- TraceHeader .TRACE_ID in client_context .custom
180
- and TraceHeader .PARENT_ID in client_context .custom
181
- and TraceHeader .SAMPLING_PRIORITY in client_context .custom
182
- ):
183
- # New trace propagation keys
184
- trace_id = client_context .custom .get (TraceHeader .TRACE_ID )
185
- parent_id = client_context .custom .get (TraceHeader .PARENT_ID )
186
- sampling_priority = client_context .custom .get (TraceHeader .SAMPLING_PRIORITY )
187
-
188
- return trace_id , parent_id , sampling_priority
165
+ dd_data = client_context .custom .get ("_datadog" )
166
+ return propagator .extract (dd_data )
189
167
190
168
191
169
def extract_context_from_http_event_or_context (
@@ -205,33 +183,17 @@ def extract_context_from_http_event_or_context(
205
183
EventTypes .API_GATEWAY , subtype = EventSubtypes .HTTP_API
206
184
)
207
185
injected_authorizer_data = get_injected_authorizer_data (event , is_http_api )
208
- if injected_authorizer_data :
209
- try :
210
- # fail fast on any KeyError here
211
- trace_id = injected_authorizer_data [TraceHeader .TRACE_ID ]
212
- parent_id = injected_authorizer_data [TraceHeader .PARENT_ID ]
213
- sampling_priority = injected_authorizer_data .get (
214
- TraceHeader .SAMPLING_PRIORITY
215
- )
216
- return trace_id , parent_id , sampling_priority
217
- except Exception as e :
218
- logger .debug (
219
- "extract_context_from_authorizer_event returned with error. \
220
- Continue without injecting the authorizer span %s" ,
221
- e ,
222
- )
223
-
224
- headers = event .get ("headers" , {}) or {}
225
- lowercase_headers = {k .lower (): v for k , v in headers .items ()}
186
+ context = propagator .extract (injected_authorizer_data )
187
+ if _is_context_complete (context ):
188
+ return context
226
189
227
- trace_id = lowercase_headers .get (TraceHeader .TRACE_ID )
228
- parent_id = lowercase_headers .get (TraceHeader .PARENT_ID )
229
- sampling_priority = lowercase_headers .get (TraceHeader .SAMPLING_PRIORITY )
190
+ headers = event .get ("headers" )
191
+ context = propagator .extract (headers )
230
192
231
- if not trace_id or not parent_id or not sampling_priority :
193
+ if not _is_context_complete ( context ) :
232
194
return extract_context_from_lambda_context (lambda_context )
233
195
234
- return trace_id , parent_id , sampling_priority
196
+ return context
235
197
236
198
237
199
def create_sns_event (message ):
@@ -262,12 +224,9 @@ def extract_context_from_sqs_or_sns_event_or_context(event, lambda_context):
262
224
263
225
# EventBridge => SQS
264
226
try :
265
- (
266
- trace_id ,
267
- parent_id ,
268
- sampling_priority ,
269
- ) = _extract_context_from_eventbridge_sqs_event (event )
270
- return trace_id , parent_id , sampling_priority
227
+ context = _extract_context_from_eventbridge_sqs_event (event )
228
+ if _is_context_complete (context ):
229
+ return context
271
230
except Exception :
272
231
logger .debug ("Failed extracting context as EventBridge to SQS." )
273
232
@@ -311,11 +270,7 @@ def extract_context_from_sqs_or_sns_event_or_context(event, lambda_context):
311
270
"context from String or Binary SQS/SNS message attributes"
312
271
)
313
272
dd_data = json .loads (dd_json_data )
314
- trace_id = dd_data .get (TraceHeader .TRACE_ID )
315
- parent_id = dd_data .get (TraceHeader .PARENT_ID )
316
- sampling_priority = dd_data .get (TraceHeader .SAMPLING_PRIORITY )
317
-
318
- return trace_id , parent_id , sampling_priority
273
+ return propagator .extract (dd_data )
319
274
except Exception as e :
320
275
logger .debug ("The trace extractor returned with error %s" , e )
321
276
return extract_context_from_lambda_context (lambda_context )
@@ -329,20 +284,12 @@ def _extract_context_from_eventbridge_sqs_event(event):
329
284
This is only possible if first record in `Records` contains a
330
285
`body` field which contains the EventBridge `detail` as a JSON string.
331
286
"""
332
- try :
333
- first_record = event .get ("Records" )[0 ]
334
- if "body" in first_record :
335
- body_str = first_record .get ("body" , {})
336
- body = json .loads (body_str )
337
-
338
- detail = body .get ("detail" )
339
- dd_context = detail .get ("_datadog" )
340
- trace_id = dd_context .get (TraceHeader .TRACE_ID )
341
- parent_id = dd_context .get (TraceHeader .PARENT_ID )
342
- sampling_priority = dd_context .get (TraceHeader .SAMPLING_PRIORITY )
343
- return trace_id , parent_id , sampling_priority
344
- except Exception :
345
- raise
287
+ first_record = event .get ("Records" )[0 ]
288
+ body_str = first_record .get ("body" )
289
+ body = json .loads (body_str )
290
+ detail = body .get ("detail" )
291
+ dd_context = detail .get ("_datadog" )
292
+ return propagator .extract (dd_context )
346
293
347
294
348
295
def extract_context_from_eventbridge_event (event , lambda_context ):
@@ -355,10 +302,7 @@ def extract_context_from_eventbridge_event(event, lambda_context):
355
302
dd_context = detail .get ("_datadog" )
356
303
if not dd_context :
357
304
return extract_context_from_lambda_context (lambda_context )
358
- trace_id = dd_context .get (TraceHeader .TRACE_ID )
359
- parent_id = dd_context .get (TraceHeader .PARENT_ID )
360
- sampling_priority = dd_context .get (TraceHeader .SAMPLING_PRIORITY )
361
- return trace_id , parent_id , sampling_priority
305
+ return propagator .extract (dd_context )
362
306
except Exception as e :
363
307
logger .debug ("The trace extractor returned with error %s" , e )
364
308
return extract_context_from_lambda_context (lambda_context )
@@ -381,10 +325,7 @@ def extract_context_from_kinesis_event(event, lambda_context):
381
325
if not dd_ctx :
382
326
return extract_context_from_lambda_context (lambda_context )
383
327
384
- trace_id = dd_ctx .get (TraceHeader .TRACE_ID )
385
- parent_id = dd_ctx .get (TraceHeader .PARENT_ID )
386
- sampling_priority = dd_ctx .get (TraceHeader .SAMPLING_PRIORITY )
387
- return trace_id , parent_id , sampling_priority
328
+ return propagator .extract (dd_ctx )
388
329
except Exception as e :
389
330
logger .debug ("The trace extractor returned with error %s" , e )
390
331
return extract_context_from_lambda_context (lambda_context )
@@ -417,7 +358,7 @@ def extract_context_from_step_functions(event, lambda_context):
417
358
execution_id + "#" + state_name + "#" + state_entered_time
418
359
)
419
360
sampling_priority = SamplingPriority .AUTO_KEEP
420
- return trace_id , parent_id , sampling_priority
361
+ return Context ( trace_id = trace_id , span_id = parent_id , sampling_priority = sampling_priority )
421
362
except Exception as e :
422
363
logger .debug ("The Step Functions trace extractor returned with error %s" , e )
423
364
return extract_context_from_lambda_context (lambda_context )
@@ -433,12 +374,10 @@ def extract_context_custom_extractor(extractor, event, lambda_context):
433
374
parent_id ,
434
375
sampling_priority ,
435
376
) = extractor (event , lambda_context )
436
- return trace_id , parent_id , sampling_priority
377
+ return Context ( trace_id = trace_id , span_id = parent_id , sampling_priority = sampling_priority )
437
378
except Exception as e :
438
379
logger .debug ("The trace extractor returned with error %s" , e )
439
380
440
- return None , None , None
441
-
442
381
443
382
def is_authorizer_response (response ) -> bool :
444
383
try :
@@ -504,56 +443,25 @@ def extract_dd_trace_context(
504
443
event_source = parse_event_source (event )
505
444
506
445
if extractor is not None :
507
- (
508
- trace_id ,
509
- parent_id ,
510
- sampling_priority ,
511
- ) = extract_context_custom_extractor (extractor , event , lambda_context )
446
+ context = extract_context_custom_extractor (extractor , event , lambda_context )
512
447
elif isinstance (event , (set , dict )) and "headers" in event :
513
- (
514
- trace_id ,
515
- parent_id ,
516
- sampling_priority ,
517
- ) = extract_context_from_http_event_or_context (
448
+ context = extract_context_from_http_event_or_context (
518
449
event , lambda_context , event_source , decode_authorizer_context
519
450
)
520
451
elif event_source .equals (EventTypes .SNS ) or event_source .equals (EventTypes .SQS ):
521
- (
522
- trace_id ,
523
- parent_id ,
524
- sampling_priority ,
525
- ) = extract_context_from_sqs_or_sns_event_or_context (event , lambda_context )
452
+ context = extract_context_from_sqs_or_sns_event_or_context (event , lambda_context )
526
453
elif event_source .equals (EventTypes .EVENTBRIDGE ):
527
- (
528
- trace_id ,
529
- parent_id ,
530
- sampling_priority ,
531
- ) = extract_context_from_eventbridge_event (event , lambda_context )
454
+ context = extract_context_from_eventbridge_event (event , lambda_context )
532
455
elif event_source .equals (EventTypes .KINESIS ):
533
- (
534
- trace_id ,
535
- parent_id ,
536
- sampling_priority ,
537
- ) = extract_context_from_kinesis_event (event , lambda_context )
456
+ context = extract_context_from_kinesis_event (event , lambda_context )
538
457
elif event_source .equals (EventTypes .STEPFUNCTIONS ):
539
- (
540
- trace_id ,
541
- parent_id ,
542
- sampling_priority ,
543
- ) = extract_context_from_step_functions (event , lambda_context )
458
+ context = extract_context_from_step_functions (event , lambda_context )
544
459
else :
545
- trace_id , parent_id , sampling_priority = extract_context_from_lambda_context (
546
- lambda_context
547
- )
460
+ context = extract_context_from_lambda_context (lambda_context )
548
461
549
- if trace_id and parent_id and sampling_priority :
462
+ if _is_context_complete ( context ) :
550
463
logger .debug ("Extracted Datadog trace context from event or context" )
551
- metadata = {
552
- "trace-id" : trace_id ,
553
- "parent-id" : parent_id ,
554
- "sampling-priority" : sampling_priority ,
555
- }
556
- dd_trace_context = metadata .copy ()
464
+ dd_trace_context = context
557
465
trace_context_source = TraceContextSource .EVENT
558
466
else :
559
467
# AWS Lambda runtime caches global variables between invocations,
@@ -579,8 +487,8 @@ def get_dd_trace_context():
579
487
"""
580
488
if dd_tracing_enabled :
581
489
dd_trace_py_context = _get_dd_trace_py_context ()
582
- if dd_trace_py_context is not None :
583
- return _context_obj_to_headers ( dd_trace_py_context )
490
+ if _is_context_complete ( dd_trace_py_context ) :
491
+ return dd_trace_py_context
584
492
585
493
global dd_trace_context
586
494
@@ -592,16 +500,17 @@ def get_dd_trace_context():
592
500
% e
593
501
)
594
502
if not xray_context :
595
- return {}
596
-
597
- if not dd_trace_context :
598
- return _context_obj_to_headers (xray_context )
503
+ return None
599
504
600
- context = dd_trace_context .copy ()
601
- context ["parent-id" ] = xray_context .get ("parent-id" )
602
- logger .debug ("Set parent id from xray trace context: %s" , context .get ("parent-id" ))
505
+ if not _is_context_complete (dd_trace_context ):
506
+ return xray_context
603
507
604
- return _context_obj_to_headers (context )
508
+ logger .debug ("Set parent id from xray trace context: %s" , xray_context .span_id )
509
+ return Context (
510
+ trace_id = dd_trace_context .trace_id ,
511
+ span_id = xray_context .span_id ,
512
+ sampling_priority = dd_trace_context .sampling_priority ,
513
+ )
605
514
606
515
607
516
def set_correlation_ids ():
@@ -620,13 +529,12 @@ def set_correlation_ids():
620
529
return
621
530
622
531
context = get_dd_trace_context ()
623
- if not context :
532
+ if not _is_context_complete ( context ) :
624
533
return
625
534
626
- span = tracer .trace ("dummy.span" )
627
- span .trace_id = int (context [TraceHeader .TRACE_ID ])
628
- span .span_id = int (context [TraceHeader .PARENT_ID ])
629
535
536
+ tracer .context_provider .activate (context )
537
+ tracer .trace ("dummy.span" )
630
538
logger .debug ("correlation ids set" )
631
539
632
540
@@ -669,18 +577,20 @@ def is_lambda_context():
669
577
670
578
def set_dd_trace_py_root (trace_context_source , merge_xray_traces ):
671
579
if trace_context_source == TraceContextSource .EVENT or merge_xray_traces :
672
- context = dict (dd_trace_context )
580
+ context = Context (
581
+ trace_id = dd_trace_context .trace_id ,
582
+ span_id = dd_trace_context .span_id ,
583
+ sampling_priority = dd_trace_context .sampling_priority ,
584
+ )
673
585
if merge_xray_traces :
674
586
xray_context = _get_xray_trace_context ()
675
- if xray_context is not None :
676
- context [ "parent-id" ] = xray_context .get ( "parent-id" )
587
+ if xray_context . span_id :
588
+ context . span_id = xray_context .span_id
677
589
678
- headers = _context_obj_to_headers (context )
679
- span_context = propagator .extract (headers )
680
- tracer .context_provider .activate (span_context )
590
+ tracer .context_provider .activate (context )
681
591
logger .debug (
682
592
"Set dd trace root context to: %s" ,
683
- (span_context .trace_id , span_context .span_id ),
593
+ (context .trace_id , context .span_id ),
684
594
)
685
595
686
596
0 commit comments