diff --git a/datadog_lambda/tracing.py b/datadog_lambda/tracing.py index de6e76b4..05882749 100644 --- a/datadog_lambda/tracing.py +++ b/datadog_lambda/tracing.py @@ -20,7 +20,6 @@ from datadog_lambda.constants import ( SamplingPriority, - TraceHeader, TraceContextSource, XrayDaemon, Headers, @@ -32,6 +31,7 @@ from ddtrace import tracer, patch, Span from ddtrace import __version__ as ddtrace_version from ddtrace.propagation.http import HTTPPropagator +from ddtrace.context import Context from datadog_lambda import __version__ as datadog_lambda_version from datadog_lambda.trigger import ( _EventSource, @@ -53,7 +53,7 @@ logger = logging.getLogger(__name__) -dd_trace_context = {} +dd_trace_context = None dd_tracing_enabled = os.environ.get("DD_TRACE_ENABLED", "false").lower() == "true" if dd_tracing_enabled: # Enable the telemetry client if the user has opted in @@ -72,25 +72,21 @@ def _convert_xray_trace_id(xray_trace_id): """ Convert X-Ray trace id (hex)'s last 63 bits to a Datadog trace id (int). """ - return str(0x7FFFFFFFFFFFFFFF & int(xray_trace_id[-16:], 16)) + return 0x7FFFFFFFFFFFFFFF & int(xray_trace_id[-16:], 16) def _convert_xray_entity_id(xray_entity_id): """ Convert X-Ray (sub)segement id (hex) to a Datadog span id (int). """ - return str(int(xray_entity_id, 16)) + return int(xray_entity_id, 16) def _convert_xray_sampling(xray_sampled): """ Convert X-Ray sampled (True/False) to its Datadog counterpart. """ - return ( - str(SamplingPriority.USER_KEEP) - if xray_sampled - else str(SamplingPriority.USER_REJECT) - ) + return SamplingPriority.USER_KEEP if xray_sampled else SamplingPriority.USER_REJECT def _get_xray_trace_context(): @@ -102,11 +98,11 @@ def _get_xray_trace_context(): ) if xray_trace_entity is None: return None - trace_context = { - "trace-id": _convert_xray_trace_id(xray_trace_entity.get("trace_id")), - "parent-id": _convert_xray_entity_id(xray_trace_entity.get("parent_id")), - "sampling-priority": _convert_xray_sampling(xray_trace_entity.get("sampled")), - } + trace_context = Context( + trace_id=_convert_xray_trace_id(xray_trace_entity.get("trace_id")), + span_id=_convert_xray_entity_id(xray_trace_entity.get("parent_id")), + sampling_priority=_convert_xray_sampling(xray_trace_entity.get("sampled")), + ) logger.debug( "Converted trace context %s from X-Ray segment %s", trace_context, @@ -124,26 +120,19 @@ def _get_dd_trace_py_context(): if not span: return None - parent_id = span.context.span_id - trace_id = span.context.trace_id - sampling_priority = span.context.sampling_priority logger.debug( "found dd trace context: %s", (span.context.trace_id, span.context.span_id) ) - return { - "parent-id": str(parent_id), - "trace-id": str(trace_id), - "sampling-priority": str(sampling_priority), - "source": TraceContextSource.DDTRACE, - } + return span.context -def _context_obj_to_headers(obj): - return { - TraceHeader.TRACE_ID: str(obj.get("trace-id")), - TraceHeader.PARENT_ID: str(obj.get("parent-id")), - TraceHeader.SAMPLING_PRIORITY: str(obj.get("sampling-priority")), - } +def _is_context_complete(context): + return ( + context + and context.trace_id + and context.span_id + and context.sampling_priority is not None + ) def create_dd_dummy_metadata_subsegment( @@ -164,28 +153,14 @@ def extract_context_from_lambda_context(lambda_context): dd_trace libraries inject this trace context on synchronous invocations """ + dd_data = None client_context = lambda_context.client_context - trace_id = None - parent_id = None - sampling_priority = None if client_context and client_context.custom: + dd_data = client_context.custom if "_datadog" in client_context.custom: # Legacy trace propagation dict - dd_data = client_context.custom.get("_datadog", {}) - trace_id = dd_data.get(TraceHeader.TRACE_ID) - parent_id = dd_data.get(TraceHeader.PARENT_ID) - sampling_priority = dd_data.get(TraceHeader.SAMPLING_PRIORITY) - elif ( - TraceHeader.TRACE_ID in client_context.custom - and TraceHeader.PARENT_ID in client_context.custom - and TraceHeader.SAMPLING_PRIORITY in client_context.custom - ): - # New trace propagation keys - trace_id = client_context.custom.get(TraceHeader.TRACE_ID) - parent_id = client_context.custom.get(TraceHeader.PARENT_ID) - sampling_priority = client_context.custom.get(TraceHeader.SAMPLING_PRIORITY) - - return trace_id, parent_id, sampling_priority + dd_data = client_context.custom.get("_datadog") + return propagator.extract(dd_data) def extract_context_from_http_event_or_context( @@ -205,33 +180,17 @@ def extract_context_from_http_event_or_context( EventTypes.API_GATEWAY, subtype=EventSubtypes.HTTP_API ) injected_authorizer_data = get_injected_authorizer_data(event, is_http_api) - if injected_authorizer_data: - try: - # fail fast on any KeyError here - trace_id = injected_authorizer_data[TraceHeader.TRACE_ID] - parent_id = injected_authorizer_data[TraceHeader.PARENT_ID] - sampling_priority = injected_authorizer_data.get( - TraceHeader.SAMPLING_PRIORITY - ) - return trace_id, parent_id, sampling_priority - except Exception as e: - logger.debug( - "extract_context_from_authorizer_event returned with error. \ - Continue without injecting the authorizer span %s", - e, - ) - - headers = event.get("headers", {}) or {} - lowercase_headers = {k.lower(): v for k, v in headers.items()} + context = propagator.extract(injected_authorizer_data) + if _is_context_complete(context): + return context - trace_id = lowercase_headers.get(TraceHeader.TRACE_ID) - parent_id = lowercase_headers.get(TraceHeader.PARENT_ID) - sampling_priority = lowercase_headers.get(TraceHeader.SAMPLING_PRIORITY) + headers = event.get("headers") + context = propagator.extract(headers) - if not trace_id or not parent_id or not sampling_priority: + if not _is_context_complete(context): return extract_context_from_lambda_context(lambda_context) - return trace_id, parent_id, sampling_priority + return context def create_sns_event(message): @@ -262,12 +221,9 @@ def extract_context_from_sqs_or_sns_event_or_context(event, lambda_context): # EventBridge => SQS try: - ( - trace_id, - parent_id, - sampling_priority, - ) = _extract_context_from_eventbridge_sqs_event(event) - return trace_id, parent_id, sampling_priority + context = _extract_context_from_eventbridge_sqs_event(event) + if _is_context_complete(context): + return context except Exception: logger.debug("Failed extracting context as EventBridge to SQS.") @@ -311,11 +267,7 @@ def extract_context_from_sqs_or_sns_event_or_context(event, lambda_context): "context from String or Binary SQS/SNS message attributes" ) dd_data = json.loads(dd_json_data) - trace_id = dd_data.get(TraceHeader.TRACE_ID) - parent_id = dd_data.get(TraceHeader.PARENT_ID) - sampling_priority = dd_data.get(TraceHeader.SAMPLING_PRIORITY) - - return trace_id, parent_id, sampling_priority + return propagator.extract(dd_data) except Exception as e: logger.debug("The trace extractor returned with error %s", e) return extract_context_from_lambda_context(lambda_context) @@ -329,20 +281,12 @@ def _extract_context_from_eventbridge_sqs_event(event): This is only possible if first record in `Records` contains a `body` field which contains the EventBridge `detail` as a JSON string. """ - try: - first_record = event.get("Records")[0] - if "body" in first_record: - body_str = first_record.get("body", {}) - body = json.loads(body_str) - - detail = body.get("detail") - dd_context = detail.get("_datadog") - trace_id = dd_context.get(TraceHeader.TRACE_ID) - parent_id = dd_context.get(TraceHeader.PARENT_ID) - sampling_priority = dd_context.get(TraceHeader.SAMPLING_PRIORITY) - return trace_id, parent_id, sampling_priority - except Exception: - raise + first_record = event.get("Records")[0] + body_str = first_record.get("body") + body = json.loads(body_str) + detail = body.get("detail") + dd_context = detail.get("_datadog") + return propagator.extract(dd_context) def extract_context_from_eventbridge_event(event, lambda_context): @@ -355,10 +299,7 @@ def extract_context_from_eventbridge_event(event, lambda_context): dd_context = detail.get("_datadog") if not dd_context: return extract_context_from_lambda_context(lambda_context) - trace_id = dd_context.get(TraceHeader.TRACE_ID) - parent_id = dd_context.get(TraceHeader.PARENT_ID) - sampling_priority = dd_context.get(TraceHeader.SAMPLING_PRIORITY) - return trace_id, parent_id, sampling_priority + return propagator.extract(dd_context) except Exception as e: logger.debug("The trace extractor returned with error %s", e) return extract_context_from_lambda_context(lambda_context) @@ -381,25 +322,22 @@ def extract_context_from_kinesis_event(event, lambda_context): if not dd_ctx: return extract_context_from_lambda_context(lambda_context) - trace_id = dd_ctx.get(TraceHeader.TRACE_ID) - parent_id = dd_ctx.get(TraceHeader.PARENT_ID) - sampling_priority = dd_ctx.get(TraceHeader.SAMPLING_PRIORITY) - return trace_id, parent_id, sampling_priority + return propagator.extract(dd_ctx) except Exception as e: logger.debug("The trace extractor returned with error %s", e) return extract_context_from_lambda_context(lambda_context) -def _deterministic_md5_hash(s: str) -> str: +def _deterministic_md5_hash(s: str) -> int: """MD5 here is to generate trace_id, not for any encryption.""" hex_number = hashlib.md5(s.encode("ascii")).hexdigest() binary = bin(int(hex_number, 16)) binary_str = str(binary) binary_str_remove_0b = binary_str[2:].rjust(128, "0") most_significant_64_bits_without_leading_1 = "0" + binary_str_remove_0b[1:-64] - result = str(int(most_significant_64_bits_without_leading_1, 2)) - if result == "0" * 64: - return "1" + result = int(most_significant_64_bits_without_leading_1, 2) + if result == 0: + return 1 return result @@ -417,7 +355,9 @@ def extract_context_from_step_functions(event, lambda_context): execution_id + "#" + state_name + "#" + state_entered_time ) sampling_priority = SamplingPriority.AUTO_KEEP - return trace_id, parent_id, sampling_priority + return Context( + trace_id=trace_id, span_id=parent_id, sampling_priority=sampling_priority + ) except Exception as e: logger.debug("The Step Functions trace extractor returned with error %s", e) return extract_context_from_lambda_context(lambda_context) @@ -433,12 +373,14 @@ def extract_context_custom_extractor(extractor, event, lambda_context): parent_id, sampling_priority, ) = extractor(event, lambda_context) - return trace_id, parent_id, sampling_priority + return Context( + trace_id=int(trace_id), + span_id=int(parent_id), + sampling_priority=int(sampling_priority), + ) except Exception as e: logger.debug("The trace extractor returned with error %s", e) - return None, None, None - def is_authorizer_response(response) -> bool: try: @@ -504,56 +446,27 @@ def extract_dd_trace_context( event_source = parse_event_source(event) if extractor is not None: - ( - trace_id, - parent_id, - sampling_priority, - ) = extract_context_custom_extractor(extractor, event, lambda_context) + context = extract_context_custom_extractor(extractor, event, lambda_context) elif isinstance(event, (set, dict)) and "headers" in event: - ( - trace_id, - parent_id, - sampling_priority, - ) = extract_context_from_http_event_or_context( + context = extract_context_from_http_event_or_context( event, lambda_context, event_source, decode_authorizer_context ) elif event_source.equals(EventTypes.SNS) or event_source.equals(EventTypes.SQS): - ( - trace_id, - parent_id, - sampling_priority, - ) = extract_context_from_sqs_or_sns_event_or_context(event, lambda_context) + context = extract_context_from_sqs_or_sns_event_or_context( + event, lambda_context + ) elif event_source.equals(EventTypes.EVENTBRIDGE): - ( - trace_id, - parent_id, - sampling_priority, - ) = extract_context_from_eventbridge_event(event, lambda_context) + context = extract_context_from_eventbridge_event(event, lambda_context) elif event_source.equals(EventTypes.KINESIS): - ( - trace_id, - parent_id, - sampling_priority, - ) = extract_context_from_kinesis_event(event, lambda_context) + context = extract_context_from_kinesis_event(event, lambda_context) elif event_source.equals(EventTypes.STEPFUNCTIONS): - ( - trace_id, - parent_id, - sampling_priority, - ) = extract_context_from_step_functions(event, lambda_context) + context = extract_context_from_step_functions(event, lambda_context) else: - trace_id, parent_id, sampling_priority = extract_context_from_lambda_context( - lambda_context - ) + context = extract_context_from_lambda_context(lambda_context) - if trace_id and parent_id and sampling_priority: + if _is_context_complete(context): logger.debug("Extracted Datadog trace context from event or context") - metadata = { - "trace-id": trace_id, - "parent-id": parent_id, - "sampling-priority": sampling_priority, - } - dd_trace_context = metadata.copy() + dd_trace_context = context trace_context_source = TraceContextSource.EVENT else: # AWS Lambda runtime caches global variables between invocations, @@ -565,7 +478,7 @@ def extract_dd_trace_context( return dd_trace_context, trace_context_source, event_source -def get_dd_trace_context(): +def get_dd_trace_context_obj(): """ Return the Datadog trace context to be propagated on the outgoing requests. @@ -579,8 +492,8 @@ def get_dd_trace_context(): """ if dd_tracing_enabled: dd_trace_py_context = _get_dd_trace_py_context() - if dd_trace_py_context is not None: - return _context_obj_to_headers(dd_trace_py_context) + if _is_context_complete(dd_trace_py_context): + return dd_trace_py_context global dd_trace_context @@ -592,16 +505,32 @@ def get_dd_trace_context(): % e ) if not xray_context: - return {} + return None + + if not _is_context_complete(dd_trace_context): + return xray_context - if not dd_trace_context: - return _context_obj_to_headers(xray_context) + logger.debug("Set parent id from xray trace context: %s", xray_context.span_id) + return Context( + trace_id=dd_trace_context.trace_id, + span_id=xray_context.span_id, + sampling_priority=dd_trace_context.sampling_priority, + meta=dd_trace_context._meta.copy(), + metrics=dd_trace_context._metrics.copy(), + ) - context = dd_trace_context.copy() - context["parent-id"] = xray_context.get("parent-id") - logger.debug("Set parent id from xray trace context: %s", context.get("parent-id")) - return _context_obj_to_headers(context) +def get_dd_trace_context(): + """ + Return the Datadog trace context to be propagated on the outgoing requests, + as a dict of headers. + """ + headers = {} + context = get_dd_trace_context_obj() + if not _is_context_complete(context): + return headers + propagator.inject(context, headers) + return headers def set_correlation_ids(): @@ -619,14 +548,12 @@ def set_correlation_ids(): logger.debug("using ddtrace implementation for spans") return - context = get_dd_trace_context() - if not context: + context = get_dd_trace_context_obj() + if not _is_context_complete(context): return - span = tracer.trace("dummy.span") - span.trace_id = int(context[TraceHeader.TRACE_ID]) - span.span_id = int(context[TraceHeader.PARENT_ID]) - + tracer.context_provider.activate(context) + tracer.trace("dummy.span") logger.debug("correlation ids set") @@ -669,18 +596,20 @@ def is_lambda_context(): def set_dd_trace_py_root(trace_context_source, merge_xray_traces): if trace_context_source == TraceContextSource.EVENT or merge_xray_traces: - context = dict(dd_trace_context) + context = Context( + trace_id=dd_trace_context.trace_id, + span_id=dd_trace_context.span_id, + sampling_priority=dd_trace_context.sampling_priority, + ) if merge_xray_traces: xray_context = _get_xray_trace_context() - if xray_context is not None: - context["parent-id"] = xray_context.get("parent-id") + if xray_context.span_id: + context.span_id = xray_context.span_id - headers = _context_obj_to_headers(context) - span_context = propagator.extract(headers) - tracer.context_provider.activate(span_context) + tracer.context_provider.activate(context) logger.debug( "Set dd trace root context to: %s", - (span_context.trace_id, span_context.span_id), + (context.trace_id, context.span_id), ) diff --git a/tests/test_tracing.py b/tests/test_tracing.py index 24e6dcdd..745bf5d1 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -1,4 +1,5 @@ import unittest +import functools import json import os import copy @@ -29,7 +30,6 @@ _convert_xray_entity_id, _convert_xray_sampling, InferredSpanInfo, - extract_context_from_eventbridge_event, create_service_mapping, determine_service_name, service_mapping as global_service_mapping, @@ -96,6 +96,29 @@ def get_mock_context( return lambda_context +def with_trace_propagation_style(style): + style_list = list(style.split(",")) + + def _wrapper(fn): + @functools.wraps(fn) + def _wrap(*args, **kwargs): + from ddtrace.propagation.http import config + + orig_extract = config._propagation_style_extract + orig_inject = config._propagation_style_inject + config._propagation_style_extract = style_list + config._propagation_style_inject = style_list + try: + return fn(*args, **kwargs) + finally: + config._propagation_style_extract = orig_extract + config._propagation_style_inject = orig_inject + + return _wrap + + return _wrapper + + class TestExtractAndGetDDTraceContext(unittest.TestCase): def setUp(self): global dd_tracing_enabled @@ -114,17 +137,18 @@ def tearDown(self): dd_tracing_enabled = False del os.environ["_X_AMZN_TRACE_ID"] + @with_trace_propagation_style("datadog") def test_without_datadog_trace_headers(self): lambda_ctx = get_mock_context() ctx, source, event_source = extract_dd_trace_context({}, lambda_ctx) self.assertEqual(source, "xray") - self.assertDictEqual( + self.assertEqual( ctx, - { - "trace-id": fake_xray_header_value_root_decimal, - "parent-id": fake_xray_header_value_parent_decimal, - "sampling-priority": "2", - }, + Context( + trace_id=int(fake_xray_header_value_root_decimal), + span_id=int(fake_xray_header_value_parent_decimal), + sampling_priority=2, + ), ) self.assertDictEqual( get_dd_trace_context(), @@ -136,17 +160,18 @@ def test_without_datadog_trace_headers(self): {}, ) + @with_trace_propagation_style("datadog") def test_with_non_object_event(self): lambda_ctx = get_mock_context() ctx, source, event_source = extract_dd_trace_context(b"", lambda_ctx) self.assertEqual(source, "xray") - self.assertDictEqual( + self.assertEqual( ctx, - { - "trace-id": fake_xray_header_value_root_decimal, - "parent-id": fake_xray_header_value_parent_decimal, - "sampling-priority": "2", - }, + Context( + trace_id=int(fake_xray_header_value_root_decimal), + span_id=int(fake_xray_header_value_parent_decimal), + sampling_priority=2, + ), ) self.assertDictEqual( get_dd_trace_context(), @@ -158,6 +183,7 @@ def test_with_non_object_event(self): {}, ) + @with_trace_propagation_style("datadog") def test_with_incomplete_datadog_trace_headers(self): lambda_ctx = get_mock_context() ctx, source, event_source = extract_dd_trace_context( @@ -165,13 +191,13 @@ def test_with_incomplete_datadog_trace_headers(self): lambda_ctx, ) self.assertEqual(source, "xray") - self.assertDictEqual( + self.assertEqual( ctx, - { - "trace-id": fake_xray_header_value_root_decimal, - "parent-id": fake_xray_header_value_parent_decimal, - "sampling-priority": "2", - }, + Context( + trace_id=int(fake_xray_header_value_root_decimal), + span_id=int(fake_xray_header_value_parent_decimal), + sampling_priority=2, + ), ) self.assertDictEqual( get_dd_trace_context(), @@ -182,6 +208,7 @@ def test_with_incomplete_datadog_trace_headers(self): }, ) + @with_trace_propagation_style("datadog") def test_with_complete_datadog_trace_headers(self): lambda_ctx = get_mock_context() ctx, source, event_source = extract_dd_trace_context( @@ -195,10 +222,8 @@ def test_with_complete_datadog_trace_headers(self): lambda_ctx, ) self.assertEqual(source, "event") - self.assertDictEqual( - ctx, - {"trace-id": "123", "parent-id": "321", "sampling-priority": "1"}, - ) + expected_context = Context(trace_id=123, span_id=321, sampling_priority=1) + self.assertEqual(ctx, expected_context) self.assertDictEqual( get_dd_trace_context(), { @@ -211,9 +236,48 @@ def test_with_complete_datadog_trace_headers(self): self.mock_send_segment.assert_called() self.mock_send_segment.assert_called_with( XraySubsegment.TRACE_KEY, - {"trace-id": "123", "parent-id": "321", "sampling-priority": "1"}, + expected_context, + ) + + @with_trace_propagation_style("tracecontext") + def test_with_w3c_trace_headers(self): + lambda_ctx = get_mock_context() + ctx, source, event_source = extract_dd_trace_context( + { + "headers": { + "traceparent": "00-0000000000000000000000000000007b-0000000000000141-01", + "tracestate": "dd=s:2;t.dm:-0,rojo=00f067aa0ba902b7,congo=t61rcWkgMzE", + } + }, + lambda_ctx, + ) + self.assertEqual(source, "event") + expected_context = Context( + trace_id=123, + span_id=321, + sampling_priority=2, + meta={ + "traceparent": "00-0000000000000000000000000000007b-0000000000000141-01", + "tracestate": "dd=s:2;t.dm:-0,rojo=00f067aa0ba902b7,congo=t61rcWkgMzE", + "_dd.p.dm": "-0", + }, + ) + self.assertEqual(ctx, expected_context) + self.assertDictEqual( + get_dd_trace_context(), + { + "traceparent": "00-0000000000000000000000000000007b-94ae789b969f1cc5-01", + "tracestate": "dd=s:2;t.dm:-0,rojo=00f067aa0ba902b7,congo=t61rcWkgMzE", + }, + ) + create_dd_dummy_metadata_subsegment(ctx, XraySubsegment.TRACE_KEY) + self.mock_send_segment.assert_called() + self.mock_send_segment.assert_called_with( + XraySubsegment.TRACE_KEY, + expected_context, ) + @with_trace_propagation_style("datadog") def test_with_extractor_function(self): def extractor_foo(event, context): foo = event.get("foo", {}) @@ -237,13 +301,13 @@ def extractor_foo(event, context): extractor=extractor_foo, ) self.assertEqual(ctx_source, "event") - self.assertDictEqual( + self.assertEqual( ctx, - { - "trace-id": "123", - "parent-id": "321", - "sampling-priority": "1", - }, + Context( + trace_id=123, + span_id=321, + sampling_priority=1, + ), ) self.assertDictEqual( get_dd_trace_context(), @@ -254,6 +318,7 @@ def extractor_foo(event, context): }, ) + @with_trace_propagation_style("datadog") def test_graceful_fail_of_extractor_function(self): def extractor_raiser(event, context): raise Exception("kreator") @@ -271,13 +336,13 @@ def extractor_raiser(event, context): extractor=extractor_raiser, ) self.assertEqual(ctx_source, "xray") - self.assertDictEqual( + self.assertEqual( ctx, - { - "trace-id": fake_xray_header_value_root_decimal, - "parent-id": fake_xray_header_value_parent_decimal, - "sampling-priority": "2", - }, + Context( + trace_id=int(fake_xray_header_value_root_decimal), + span_id=int(fake_xray_header_value_parent_decimal), + sampling_priority=2, + ), ) self.assertDictEqual( get_dd_trace_context(), @@ -288,6 +353,7 @@ def extractor_raiser(event, context): }, ) + @with_trace_propagation_style("datadog") def test_with_sqs_distributed_datadog_trace_data(self): lambda_ctx = get_mock_context() sqs_event = { @@ -323,14 +389,12 @@ def test_with_sqs_distributed_datadog_trace_data(self): } ctx, source, event_source = extract_dd_trace_context(sqs_event, lambda_ctx) self.assertEqual(source, "event") - self.assertDictEqual( - ctx, - { - "trace-id": "123", - "parent-id": "321", - "sampling-priority": "1", - }, + expected_context = Context( + trace_id=123, + span_id=321, + sampling_priority=1, ) + self.assertEqual(ctx, expected_context) self.assertDictEqual( get_dd_trace_context(), { @@ -342,9 +406,69 @@ def test_with_sqs_distributed_datadog_trace_data(self): create_dd_dummy_metadata_subsegment(ctx, XraySubsegment.TRACE_KEY) self.mock_send_segment.assert_called_with( XraySubsegment.TRACE_KEY, - {"trace-id": "123", "parent-id": "321", "sampling-priority": "1"}, + expected_context, + ) + + @with_trace_propagation_style("tracecontext") + def test_with_sqs_distributed_w3c_trace_data(self): + lambda_ctx = get_mock_context() + sqs_event = { + "Records": [ + { + "messageId": "059f36b4-87a3-44ab-83d2-661975830a7d", + "receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a...", + "body": "Test message.", + "attributes": { + "ApproximateReceiveCount": "1", + "SentTimestamp": "1545082649183", + "SenderId": "AIDAIENQZJOLO23YVJ4VO", + "ApproximateFirstReceiveTimestamp": "1545082649185", + }, + "messageAttributes": { + "_datadog": { + "stringValue": json.dumps( + { + "traceparent": "00-0000000000000000000000000000007b-0000000000000141-01", + "tracestate": "dd=s:2;t.dm:-0,rojo=00f067aa0ba902b7,congo=t61rcWkgMzE", + } + ), + "dataType": "String", + } + }, + "md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3", + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", + "awsRegion": "us-east-2", + } + ] + } + ctx, source, event_source = extract_dd_trace_context(sqs_event, lambda_ctx) + self.assertEqual(source, "event") + expected_context = Context( + trace_id=123, + span_id=321, + sampling_priority=2, + meta={ + "traceparent": "00-0000000000000000000000000000007b-0000000000000141-01", + "tracestate": "dd=s:2;t.dm:-0,rojo=00f067aa0ba902b7,congo=t61rcWkgMzE", + "_dd.p.dm": "-0", + }, + ) + self.assertEqual(ctx, expected_context) + self.assertDictEqual( + get_dd_trace_context(), + { + "traceparent": "00-0000000000000000000000000000007b-94ae789b969f1cc5-01", + "tracestate": "dd=s:2;t.dm:-0,rojo=00f067aa0ba902b7,congo=t61rcWkgMzE", + }, + ) + create_dd_dummy_metadata_subsegment(ctx, XraySubsegment.TRACE_KEY) + self.mock_send_segment.assert_called_with( + XraySubsegment.TRACE_KEY, + expected_context, ) + @with_trace_propagation_style("datadog") def test_with_legacy_client_context_datadog_trace_data(self): lambda_ctx = get_mock_context( custom={ @@ -357,14 +481,12 @@ def test_with_legacy_client_context_datadog_trace_data(self): ) ctx, source, event_source = extract_dd_trace_context({}, lambda_ctx) self.assertEqual(source, "event") - self.assertDictEqual( - ctx, - { - "trace-id": "666", - "parent-id": "777", - "sampling-priority": "1", - }, + expected_context = Context( + trace_id=666, + span_id=777, + sampling_priority=1, ) + self.assertEqual(ctx, expected_context) self.assertDictEqual( get_dd_trace_context(), { @@ -377,9 +499,47 @@ def test_with_legacy_client_context_datadog_trace_data(self): self.mock_send_segment.assert_called() self.mock_send_segment.assert_called_with( XraySubsegment.TRACE_KEY, - {"trace-id": "666", "parent-id": "777", "sampling-priority": "1"}, + expected_context, + ) + + @with_trace_propagation_style("tracecontext") + def test_with_legacy_client_context_w3c_trace_data(self): + lambda_ctx = get_mock_context( + custom={ + "_datadog": { + "traceparent": "00-0000000000000000000000000000029a-0000000000000309-01", + "tracestate": "dd=s:1;t.dm:-0,rojo=00f067aa0ba902b7,congo=t61rcWkgMzE", + } + } + ) + ctx, source, event_source = extract_dd_trace_context({}, lambda_ctx) + self.assertEqual(source, "event") + expected_context = Context( + trace_id=666, + span_id=777, + sampling_priority=1, + meta={ + "traceparent": "00-0000000000000000000000000000029a-0000000000000309-01", + "tracestate": "dd=s:1;t.dm:-0,rojo=00f067aa0ba902b7,congo=t61rcWkgMzE", + "_dd.p.dm": "-0", + }, + ) + self.assertEqual(ctx, expected_context) + self.assertDictEqual( + get_dd_trace_context(), + { + "traceparent": "00-0000000000000000000000000000029a-94ae789b969f1cc5-01", + "tracestate": "dd=s:1;t.dm:-0,rojo=00f067aa0ba902b7,congo=t61rcWkgMzE", + }, + ) + create_dd_dummy_metadata_subsegment(ctx, XraySubsegment.TRACE_KEY) + self.mock_send_segment.assert_called() + self.mock_send_segment.assert_called_with( + XraySubsegment.TRACE_KEY, + expected_context, ) + @with_trace_propagation_style("datadog") def test_with_new_client_context_datadog_trace_data(self): lambda_ctx = get_mock_context( custom={ @@ -390,14 +550,12 @@ def test_with_new_client_context_datadog_trace_data(self): ) ctx, source, event_source = extract_dd_trace_context({}, lambda_ctx) self.assertEqual(source, "event") - self.assertDictEqual( - ctx, - { - "trace-id": "666", - "parent-id": "777", - "sampling-priority": "1", - }, + expected_context = Context( + trace_id=666, + span_id=777, + sampling_priority=1, ) + self.assertEqual(ctx, expected_context) self.assertDictEqual( get_dd_trace_context(), { @@ -410,9 +568,45 @@ def test_with_new_client_context_datadog_trace_data(self): self.mock_send_segment.assert_called() self.mock_send_segment.assert_called_with( XraySubsegment.TRACE_KEY, - {"trace-id": "666", "parent-id": "777", "sampling-priority": "1"}, + expected_context, ) + @with_trace_propagation_style("tracecontext") + def test_with_new_client_context_w3c_trace_data(self): + lambda_ctx = get_mock_context( + custom={ + "traceparent": "00-0000000000000000000000000000029a-0000000000000309-01", + "tracestate": "dd=s:1;t.dm:-0,rojo=00f067aa0ba902b7,congo=t61rcWkgMzE", + } + ) + ctx, source, event_source = extract_dd_trace_context({}, lambda_ctx) + self.assertEqual(source, "event") + expected_context = Context( + trace_id=666, + span_id=777, + sampling_priority=1, + meta={ + "traceparent": "00-0000000000000000000000000000029a-0000000000000309-01", + "tracestate": "dd=s:1;t.dm:-0,rojo=00f067aa0ba902b7,congo=t61rcWkgMzE", + "_dd.p.dm": "-0", + }, + ) + self.assertEqual(ctx, expected_context) + self.assertDictEqual( + get_dd_trace_context(), + { + "traceparent": "00-0000000000000000000000000000029a-94ae789b969f1cc5-01", + "tracestate": "dd=s:1;t.dm:-0,rojo=00f067aa0ba902b7,congo=t61rcWkgMzE", + }, + ) + create_dd_dummy_metadata_subsegment(ctx, XraySubsegment.TRACE_KEY) + self.mock_send_segment.assert_called() + self.mock_send_segment.assert_called_with( + XraySubsegment.TRACE_KEY, + expected_context, + ) + + @with_trace_propagation_style("datadog") def test_with_complete_datadog_trace_headers_with_mixed_casing(self): lambda_ctx = get_mock_context() extract_dd_trace_context( @@ -455,51 +649,85 @@ def test_with_complete_datadog_trace_headers_with_trigger_tags(self): ] ) + @with_trace_propagation_style("datadog") + def test_step_function_trace_data(self): + lambda_ctx = get_mock_context() + sqs_event = { + "Execution": { + "Id": "665c417c-1237-4742-aaca-8b3becbb9e75", + }, + "StateMachine": {}, + "State": { + "Name": "my-awesome-state", + "EnteredTime": "Mon Nov 13 12:43:33 PST 2023", + }, + } + ctx, source, event_source = extract_dd_trace_context(sqs_event, lambda_ctx) + self.assertEqual(source, "event") + expected_context = Context( + trace_id=1074655265866231755, + span_id=4776286484851030060, + sampling_priority=1, + ) + self.assertEqual(ctx, expected_context) + self.assertEqual( + get_dd_trace_context(), + { + TraceHeader.TRACE_ID: "1074655265866231755", + TraceHeader.PARENT_ID: fake_xray_header_value_parent_decimal, + TraceHeader.SAMPLING_PRIORITY: "1", + }, + ) + create_dd_dummy_metadata_subsegment(ctx, XraySubsegment.TRACE_KEY) + self.mock_send_segment.assert_called_with( + XraySubsegment.TRACE_KEY, + expected_context, + ) + class TestXRayContextConversion(unittest.TestCase): def test_convert_xray_trace_id(self): self.assertEqual( - _convert_xray_trace_id("00000000e1be46a994272793"), "7043144561403045779" + _convert_xray_trace_id("00000000e1be46a994272793"), 7043144561403045779 ) self.assertEqual( - _convert_xray_trace_id("bd862e3fe1be46a994272793"), "7043144561403045779" + _convert_xray_trace_id("bd862e3fe1be46a994272793"), 7043144561403045779 ) self.assertEqual( _convert_xray_trace_id("ffffffffffffffffffffffff"), - "9223372036854775807", # 0x7FFFFFFFFFFFFFFF + 9223372036854775807, # 0x7FFFFFFFFFFFFFFF ) def test_convert_xray_entity_id(self): self.assertEqual( - _convert_xray_entity_id("53995c3f42cd8ad8"), "6023947403358210776" + _convert_xray_entity_id("53995c3f42cd8ad8"), 6023947403358210776 ) self.assertEqual( - _convert_xray_entity_id("1000000000000000"), "1152921504606846976" + _convert_xray_entity_id("1000000000000000"), 1152921504606846976 ) self.assertEqual( - _convert_xray_entity_id("ffffffffffffffff"), "18446744073709551615" + _convert_xray_entity_id("ffffffffffffffff"), 18446744073709551615 ) def test_convert_xray_sampling(self): - self.assertEqual(_convert_xray_sampling(True), str(SamplingPriority.USER_KEEP)) + self.assertEqual(_convert_xray_sampling(True), SamplingPriority.USER_KEEP) - self.assertEqual( - _convert_xray_sampling(False), str(SamplingPriority.USER_REJECT) - ) + self.assertEqual(_convert_xray_sampling(False), SamplingPriority.USER_REJECT) class TestLogsInjection(unittest.TestCase): def setUp(self): - patcher = patch("datadog_lambda.tracing.get_dd_trace_context") + patcher = patch("datadog_lambda.tracing.get_dd_trace_context_obj") self.mock_get_dd_trace_context = patcher.start() - self.mock_get_dd_trace_context.return_value = { - TraceHeader.TRACE_ID: "123", - TraceHeader.PARENT_ID: "456", - } + self.mock_get_dd_trace_context.return_value = Context( + trace_id=int(fake_xray_header_value_root_decimal), + span_id=int(fake_xray_header_value_parent_decimal), + sampling_priority=1, + ) self.addCleanup(patcher.stop) patcher = patch("datadog_lambda.tracing.is_lambda_context") @@ -510,13 +738,13 @@ def setUp(self): def test_set_correlation_ids(self): set_correlation_ids() span = tracer.current_span() - self.assertEqual(span.trace_id, 123) - self.assertEqual(span.span_id, 456) + self.assertEqual(span.trace_id, int(fake_xray_header_value_root_decimal)) + self.assertEqual(span.parent_id, int(fake_xray_header_value_parent_decimal)) span.finish() def test_set_correlation_ids_handle_empty_trace_context(self): # neither x-ray or ddtrace is used. no tracing context at all. - self.mock_get_dd_trace_context.return_value = {} + self.mock_get_dd_trace_context.return_value = Context() # no exception thrown set_correlation_ids() span = tracer.current_span() @@ -1829,10 +2057,10 @@ def test_extract_context_from_eventbridge_event(self): with open(test_file, "r") as event: event = json.load(event) ctx = get_mock_context() - trace, parent, sampling = extract_context_from_eventbridge_event(event, ctx) - self.assertEqual(trace, "12345") - self.assertEqual(parent, "67890"), - self.assertEqual(sampling, "2") + context, source, event_type = extract_dd_trace_context(event, ctx) + self.assertEqual(context.trace_id, 12345) + self.assertEqual(context.span_id, 67890), + self.assertEqual(context.sampling_priority, 2) def test_extract_dd_trace_context_for_eventbridge(self): event_sample_source = "eventbridge-custom" @@ -1841,8 +2069,8 @@ def test_extract_dd_trace_context_for_eventbridge(self): event = json.load(event) ctx = get_mock_context() context, source, event_type = extract_dd_trace_context(event, ctx) - self.assertEqual(context["trace-id"], "12345") - self.assertEqual(context["parent-id"], "67890") + self.assertEqual(context.trace_id, 12345) + self.assertEqual(context.span_id, 67890) def test_extract_context_from_eventbridge_sqs_event(self): event_sample_source = "eventbridge-sqs" @@ -1852,9 +2080,9 @@ def test_extract_context_from_eventbridge_sqs_event(self): ctx = get_mock_context() context, source, event_type = extract_dd_trace_context(event, ctx) - self.assertEqual(context["trace-id"], "7379586022458917877") - self.assertEqual(context["parent-id"], "2644033662113726488") - self.assertEqual(context["sampling-priority"], "1") + self.assertEqual(context.trace_id, 7379586022458917877) + self.assertEqual(context.span_id, 2644033662113726488) + self.assertEqual(context.sampling_priority, 1) def test_extract_context_from_sqs_event_with_string_msg_attr(self): event_sample_source = "sqs-string-msg-attribute" @@ -1863,9 +2091,9 @@ def test_extract_context_from_sqs_event_with_string_msg_attr(self): event = json.load(event) ctx = get_mock_context() context, source, event_type = extract_dd_trace_context(event, ctx) - self.assertEqual(context["trace-id"], "2684756524522091840") - self.assertEqual(context["parent-id"], "7431398482019833808") - self.assertEqual(context["sampling-priority"], "1") + self.assertEqual(context.trace_id, 2684756524522091840) + self.assertEqual(context.span_id, 7431398482019833808) + self.assertEqual(context.sampling_priority, 1) def test_extract_context_from_sqs_batch_event(self): event_sample_source = "sqs-batch" @@ -1874,9 +2102,9 @@ def test_extract_context_from_sqs_batch_event(self): event = json.load(event) ctx = get_mock_context() context, source, event_source = extract_dd_trace_context(event, ctx) - self.assertEqual(context["trace-id"], "2684756524522091840") - self.assertEqual(context["parent-id"], "7431398482019833808") - self.assertEqual(context["sampling-priority"], "1") + self.assertEqual(context.trace_id, 2684756524522091840) + self.assertEqual(context.span_id, 7431398482019833808) + self.assertEqual(context.sampling_priority, 1) def test_extract_context_from_sns_event_with_string_msg_attr(self): event_sample_source = "sns-string-msg-attribute" @@ -1885,9 +2113,9 @@ def test_extract_context_from_sns_event_with_string_msg_attr(self): event = json.load(event) ctx = get_mock_context() context, source, event_source = extract_dd_trace_context(event, ctx) - self.assertEqual(context["trace-id"], "4948377316357291421") - self.assertEqual(context["parent-id"], "6746998015037429512") - self.assertEqual(context["sampling-priority"], "1") + self.assertEqual(context.trace_id, 4948377316357291421) + self.assertEqual(context.span_id, 6746998015037429512) + self.assertEqual(context.sampling_priority, 1) def test_extract_context_from_sns_event_with_b64_msg_attr(self): event_sample_source = "sns-b64-msg-attribute" @@ -1896,9 +2124,9 @@ def test_extract_context_from_sns_event_with_b64_msg_attr(self): event = json.load(event) ctx = get_mock_context() context, source, event_source = extract_dd_trace_context(event, ctx) - self.assertEqual(context["trace-id"], "4948377316357291421") - self.assertEqual(context["parent-id"], "6746998015037429512") - self.assertEqual(context["sampling-priority"], "1") + self.assertEqual(context.trace_id, 4948377316357291421) + self.assertEqual(context.span_id, 6746998015037429512) + self.assertEqual(context.sampling_priority, 1) def test_extract_context_from_sns_batch_event(self): event_sample_source = "sns-batch" @@ -1907,9 +2135,9 @@ def test_extract_context_from_sns_batch_event(self): event = json.load(event) ctx = get_mock_context() context, source, event_source = extract_dd_trace_context(event, ctx) - self.assertEqual(context["trace-id"], "4948377316357291421") - self.assertEqual(context["parent-id"], "6746998015037429512") - self.assertEqual(context["sampling-priority"], "1") + self.assertEqual(context.trace_id, 4948377316357291421) + self.assertEqual(context.span_id, 6746998015037429512) + self.assertEqual(context.sampling_priority, 1) def test_extract_context_from_kinesis_event(self): event_sample_source = "kinesis" @@ -1918,9 +2146,9 @@ def test_extract_context_from_kinesis_event(self): event = json.load(event) ctx = get_mock_context() context, source, event_source = extract_dd_trace_context(event, ctx) - self.assertEqual(context["trace-id"], "4948377316357291421") - self.assertEqual(context["parent-id"], "2876253380018681026") - self.assertEqual(context["sampling-priority"], "1") + self.assertEqual(context.trace_id, 4948377316357291421) + self.assertEqual(context.span_id, 2876253380018681026) + self.assertEqual(context.sampling_priority, 1) def test_extract_context_from_kinesis_batch_event(self): event_sample_source = "kinesis-batch" @@ -1929,9 +2157,9 @@ def test_extract_context_from_kinesis_batch_event(self): event = json.load(event) ctx = get_mock_context() context, source, event_source = extract_dd_trace_context(event, ctx) - self.assertEqual(context["trace-id"], "4948377316357291421") - self.assertEqual(context["parent-id"], "2876253380018681026") - self.assertEqual(context["sampling-priority"], "1") + self.assertEqual(context.trace_id, 4948377316357291421) + self.assertEqual(context.span_id, 2876253380018681026) + self.assertEqual(context.sampling_priority, 1) def test_create_inferred_span_from_api_gateway_event_no_apiid(self): event_sample_source = "api-gateway-no-apiid" @@ -1998,14 +2226,14 @@ def test_no_error_with_nonetype_headers(self): class TestStepFunctionsTraceContext(unittest.TestCase): def test_deterministic_m5_hash(self): result = _deterministic_md5_hash("some_testing_random_string") - self.assertEqual("2251275791555400689", result) + self.assertEqual(2251275791555400689, result) def test_deterministic_m5_hash__result_the_same_as_backend(self): result = _deterministic_md5_hash( "arn:aws:states:sa-east-1:601427271234:express:DatadogStateMachine:acaf1a67-336a-e854-1599-2a627eb2dd8a" ":c8baf081-31f1-464d-971f-70cb17d01111#step-one#2022-12-08T21:08:19.224Z" ) - self.assertEqual("8034507082463708833", result) + self.assertEqual(8034507082463708833, result) def test_deterministic_m5_hash__always_leading_with_zero(self): for i in range(100):