diff --git a/datadog_lambda/__init__.py b/datadog_lambda/__init__.py index cbec8f4f..20b42443 100644 --- a/datadog_lambda/__init__.py +++ b/datadog_lambda/__init__.py @@ -1,3 +1,9 @@ +import os +import logging +from datadog_lambda.cold_start import initialize_cold_start_tracing + +initialize_cold_start_tracing() + # The minor version corresponds to the Lambda layer version. # E.g.,, version 0.5.0 gets packaged into layer version 5. try: @@ -7,8 +13,5 @@ __version__ = importlib_metadata.version(__name__) -import os -import logging - logger = logging.getLogger(__name__) logger.setLevel(logging.getLevelName(os.environ.get("DD_LOG_LEVEL", "INFO").upper())) diff --git a/datadog_lambda/cold_start.py b/datadog_lambda/cold_start.py index c8862bf1..fdb43b81 100644 --- a/datadog_lambda/cold_start.py +++ b/datadog_lambda/cold_start.py @@ -1,3 +1,10 @@ +import time +import os +from typing import List, Hashable +import logging + +logger = logging.getLogger(__name__) + _cold_start = True _lambda_container_initialized = False @@ -21,3 +28,196 @@ def is_cold_start(): def get_cold_start_tag(): """Returns the cold start tag to be used in metrics""" return "cold_start:{}".format(str(is_cold_start()).lower()) + + +class ImportNode(object): + def __init__(self, module_name, full_file_path, start_time_ns, end_time_ns=None): + self.module_name = module_name + self.full_file_path = full_file_path + self.start_time_ns = start_time_ns + self.end_time_ns = end_time_ns + self.children = [] + + +root_nodes: List[ImportNode] = [] +import_stack: List[ImportNode] = [] +already_wrapped_loaders = set() + + +def reset_node_stacks(): + global root_nodes + root_nodes = [] + global import_stack + import_stack = [] + + +def push_node(module_name, file_path): + node = ImportNode(module_name, file_path, time.time_ns()) + global import_stack + if import_stack: + import_stack[-1].children.append(node) + import_stack.append(node) + + +def pop_node(module_name): + global import_stack + if not import_stack: + return + node = import_stack.pop() + if node.module_name != module_name: + return + end_time_ns = time.time_ns() + node.end_time_ns = end_time_ns + if not import_stack: # import_stack empty, a root node has been found + global root_nodes + root_nodes.append(node) + + +def wrap_exec_module(original_exec_module): + def wrapped_method(module): + should_pop = False + try: + spec = module.__spec__ + push_node(spec.name, spec.origin) + should_pop = True + except Exception: + pass + try: + return original_exec_module(module) + finally: + if should_pop: + pop_node(spec.name) + + return wrapped_method + + +def wrap_find_spec(original_find_spec): + def wrapped_find_spec(*args, **kwargs): + spec = original_find_spec(*args, **kwargs) + if spec is None: + return None + loader = getattr(spec, "loader", None) + if ( + loader is not None + and isinstance(loader, Hashable) + and loader not in already_wrapped_loaders + ): + if hasattr(loader, "exec_module"): + try: + loader.exec_module = wrap_exec_module(loader.exec_module) + already_wrapped_loaders.add(loader) + except Exception as e: + logger.debug("Failed to wrap the loader. %s", e) + return spec + + return wrapped_find_spec + + +def initialize_cold_start_tracing(): + if ( + is_cold_start() + and os.environ.get("DD_TRACE_ENABLED", "true").lower() == "true" + and os.environ.get("DD_COLD_START_TRACING", "true").lower() == "true" + ): + from sys import version_info, meta_path + + if version_info >= (3, 7): # current implementation only support version > 3.7 + for importer in meta_path: + try: + importer.find_spec = wrap_find_spec(importer.find_spec) + except Exception: + pass + + +class ColdStartTracer(object): + def __init__( + self, + tracer, + function_name, + cold_start_span_finish_time_ns, + trace_ctx, + min_duration_ms: int, + ignored_libs: List[str] = [], + ): + self._tracer = tracer + self.function_name = function_name + self.cold_start_span_finish_time_ns = cold_start_span_finish_time_ns + self.min_duration_ms = min_duration_ms + self.trace_ctx = trace_ctx + self.ignored_libs = ignored_libs + self.need_to_reactivate_context = True + + def trace(self, root_nodes: List[ImportNode] = root_nodes): + if not root_nodes: + return + cold_start_span_start_time_ns = root_nodes[0].start_time_ns + cold_start_span = self.create_cold_start_span(cold_start_span_start_time_ns) + while root_nodes: + root_node = root_nodes.pop() + self.trace_tree(root_node, cold_start_span) + self.finish_span(cold_start_span, self.cold_start_span_finish_time_ns) + + def trace_tree(self, import_node: ImportNode, parent_span): + if ( + import_node.end_time_ns - import_node.start_time_ns + < self.min_duration_ms * 1e6 + or import_node.module_name in self.ignored_libs + ): + return + + span = self.start_span( + "aws.lambda.import", import_node.module_name, import_node.start_time_ns + ) + tags = { + "resource_names": import_node.module_name, + "resource.name": import_node.module_name, + "filename": import_node.full_file_path, + "operation_name": self.get_operation_name(import_node.full_file_path), + } + span.set_tags(tags) + if parent_span: + span.parent_id = parent_span.span_id + for child_node in import_node.children: + self.trace_tree(child_node, span) + self.finish_span(span, import_node.end_time_ns) + + def create_cold_start_span(self, start_time_ns): + span = self.start_span("aws.lambda.load", self.function_name, start_time_ns) + tags = { + "resource_names": self.function_name, + "resource.name": self.function_name, + "operation_name": "aws.lambda.load", + } + span.set_tags(tags) + return span + + def start_span(self, span_type, resource, start_time_ns): + if self.need_to_reactivate_context: + self._tracer.context_provider.activate( + self.trace_ctx + ) # reactivate required after each finish() call + self.need_to_reactivate_context = False + span_kwargs = { + "service": "aws.lambda", + "resource": resource, + "span_type": span_type, + } + span = self._tracer.trace(span_type, **span_kwargs) + span.start_ns = start_time_ns + return span + + def finish_span(self, span, finish_time_ns): + span.finish(finish_time_ns / 1e9) + self.need_to_reactivate_context = True + + def get_operation_name(self, filename: str): + if filename is None: + return "aws.lambda.import_core_module" + if not isinstance(filename, str): + return "aws.lambda.import" + if filename.startswith("/opt/"): + return "aws.lambda.import_layer" + elif filename.startswith("/var/lang/"): + return "aws.lambda.import_runtime" + else: + return "aws.lambda.import" diff --git a/datadog_lambda/wrapper.py b/datadog_lambda/wrapper.py index 51ffccec..fb849cec 100644 --- a/datadog_lambda/wrapper.py +++ b/datadog_lambda/wrapper.py @@ -2,7 +2,6 @@ # under the Apache License Version 2.0. # This product includes software developed at Datadog (https://www.datadoghq.com/). # Copyright 2019 Datadog, Inc. - import base64 import os import logging @@ -12,7 +11,7 @@ from time import time_ns from datadog_lambda.extension import should_use_extension, flush_extension -from datadog_lambda.cold_start import set_cold_start, is_cold_start +from datadog_lambda.cold_start import set_cold_start, is_cold_start, ColdStartTracer from datadog_lambda.constants import ( TraceContextSource, XraySubsegment, @@ -38,6 +37,7 @@ create_inferred_span, InferredSpanInfo, is_authorizer_response, + tracer, ) from datadog_lambda.trigger import ( extract_trigger_tags, @@ -131,6 +131,28 @@ def __init__(self, func): self.decode_authorizer_context = ( os.environ.get("DD_DECODE_AUTHORIZER_CONTEXT", "true").lower() == "true" ) + self.cold_start_tracing = ( + os.environ.get("DD_COLD_START_TRACING", "true").lower() == "true" + ) + self.min_cold_start_trace_duration = 3 + if "DD_MIN_COLD_START_DURATION" in os.environ: + try: + self.min_cold_start_trace_duration = int( + os.environ["DD_MIN_COLD_START_DURATION"] + ) + except Exception: + logger.debug("Malformatted env DD_MIN_COLD_START_DURATION") + self.cold_start_trace_skip_lib = [ + "ddtrace.internal.compat", + "ddtrace.filters", + ] + if "DD_COLD_START_TRACE_SKIP_LIB" in os.environ: + try: + self.cold_start_trace_skip_lib = os.environ[ + "DD_COLD_START_TRACE_SKIP_LIB" + ].split(",") + except Exception: + logger.debug("Malformatted for env DD_COLD_START_TRACE_SKIP_LIB") self.response = None if profiling_env_var: self.prof = profiler.Profiler(env=env_env_var, service=service_env_var) @@ -257,6 +279,11 @@ def _after(self, event, context): create_dd_dummy_metadata_subsegment( self.trigger_tags, XraySubsegment.LAMBDA_FUNCTION_TAGS_KEY ) + should_trace_cold_start = ( + dd_tracing_enabled and self.cold_start_tracing and is_cold_start() + ) + if should_trace_cold_start: + trace_ctx = tracer.current_trace_context() if self.span: if dd_capture_lambda_payload_enabled: @@ -276,6 +303,20 @@ def _after(self, event, context): else: self.inferred_span.finish() + if should_trace_cold_start: + try: + following_span = self.span or self.inferred_span + ColdStartTracer( + tracer, + self.function_name, + following_span.start_ns, + trace_ctx, + self.min_cold_start_trace_duration, + self.cold_start_trace_skip_lib, + ).trace() + except Exception as e: + logger.debug("Failed to create cold start spans. %s", e) + if not self.flush_to_log or should_use_extension: flush_stats() if should_use_extension: diff --git a/tests/integration/serverless.yml b/tests/integration/serverless.yml index 9bb8a79b..27112f54 100644 --- a/tests/integration/serverless.yml +++ b/tests/integration/serverless.yml @@ -11,6 +11,7 @@ provider: DD_TRACE_ENABLED: true DD_API_KEY: ${env:DD_API_KEY} DD_TRACE_MANAGED_SERVICES: true + DD_COLD_START_TRACING: false timeout: 15 deploymentBucket: name: integration-tests-serververless-deployment-bucket diff --git a/tests/test_cold_start.py b/tests/test_cold_start.py new file mode 100644 index 00000000..22e7dc9c --- /dev/null +++ b/tests/test_cold_start.py @@ -0,0 +1,200 @@ +import unittest +import datadog_lambda.cold_start as cold_start +from sys import modules, meta_path +import os +from unittest.mock import MagicMock + + +class TestColdStartTracingSetup(unittest.TestCase): + def test_initialize_cold_start_tracing(self): + cold_start.initialize_cold_start_tracing() # testing double wrapping + cold_start.initialize_cold_start_tracing() + cold_start.reset_node_stacks() + for module_name in ["ast", "dis", "inspect"]: + if module_name in modules: + del modules[module_name] + import inspect # import some package + + self.assertTrue(inspect.ismodule(inspect)) + self.assertEqual(len(cold_start.root_nodes), 1) + self.assertEqual(cold_start.root_nodes[0].module_name, "inspect") + + def test_bad_importer_find_spec_attribute_error(self): + mock_importer = object() # AttributeError when accessing find_spec + meta_path.append(mock_importer) + cold_start.initialize_cold_start_tracing() # safe to call + meta_path.pop() + + def test_not_wrapping_case(self): + os.environ["DD_COLD_START_TRACING"] = "false" + mock_importer = MagicMock() + mock_module_spec = MagicMock() + mock_module_spec.name = "test_name" + mock_loader = object() + mock_module_spec.loader = mock_loader + + def find_spec(*args, **kwargs): + return mock_module_spec + + mock_importer.find_spec = find_spec + meta_path.append(mock_importer) + cold_start.initialize_cold_start_tracing() + self.assertFalse(mock_loader in cold_start.already_wrapped_loaders) + meta_path.pop() + os.environ["DD_COLD_START_TRACING"] = "true" + + def test_exec_module_failure_case(self): + mock_importer = MagicMock() + mock_module_spec = MagicMock() + mock_module_spec.name = "test_name" + mock_loader = MagicMock() + + def bad_exec_module(*args, **kwargs): + raise Exception("Module failed to load") + + mock_loader.exec_module = bad_exec_module + mock_module_spec.loader = mock_loader + + def find_spec(*args, **kwargs): + return mock_module_spec + + mock_importer.find_spec = find_spec + meta_path.insert(0, mock_importer) + cold_start.initialize_cold_start_tracing() + cold_start.reset_node_stacks() + try: + import dummy_module + except Exception as e: + self.assertEqual(str(e), "Module failed to load") + self.assertEqual(len(cold_start.root_nodes), 1) + self.assertEqual(cold_start.root_nodes[0].module_name, mock_module_spec.name) + meta_path.pop(0) + + +class TestColdStartTracer(unittest.TestCase): + def setUp(self) -> None: + mock_tracer = MagicMock() + self.output_spans = [] + self.shared_mock_span = MagicMock() + self.shared_mock_span.current_spans = [] + self.finish_call_count = 0 + + def _finish(finish_time_s): + module_name = self.shared_mock_span.current_spans.pop() + self.output_spans.append(module_name) + self.finish_call_count += 1 + + self.shared_mock_span.finish = _finish + + def _trace(*args, **kwargs): + module_name = kwargs["resource"] + self.shared_mock_span.current_spans.append(module_name) + return self.shared_mock_span + + mock_tracer.trace = _trace + self.mock_activate = MagicMock() + mock_tracer.context_provider.activate = self.mock_activate + self.mock_trace_ctx = MagicMock() + self.first_node_start_time_ns = 1676217209680116000 + self.cold_start_tracer = cold_start.ColdStartTracer( + mock_tracer, + "unittest_cold_start", + self.first_node_start_time_ns + 2e9, + self.mock_trace_ctx, + 3, + ["ignored_module_a", "ignored_module_b"], + ) + self.test_time_unit = (self.cold_start_tracer.min_duration_ms + 1) * 1e6 + + def test_trace_empty_root_nodes(self): + self.cold_start_tracer.trace([]) + self.assertEqual(len(self.output_spans), 0) + + def test_trace_one_root_node_no_children(self): + node_0 = cold_start.ImportNode("node_0", None, self.first_node_start_time_ns) + node_0.end_time_ns = self.first_node_start_time_ns + 4e6 + self.cold_start_tracer.trace([node_0]) + self.mock_activate.assert_called_once_with(self.mock_trace_ctx) + self.assertEqual(self.output_spans, ["node_0", "unittest_cold_start"]) + + def test_trace_one_root_node_with_children(self): + node_0 = cold_start.ImportNode("node_0", None, self.first_node_start_time_ns) + node_0.end_time_ns = self.first_node_start_time_ns + self.test_time_unit * 2 + node_1 = cold_start.ImportNode("node_1", None, self.first_node_start_time_ns) + node_1.end_time_ns = self.first_node_start_time_ns + self.test_time_unit + node_2 = cold_start.ImportNode( + "node_2", None, self.first_node_start_time_ns + self.test_time_unit + ) + node_2.end_time_ns = self.first_node_start_time_ns + self.test_time_unit * 2 + node_3 = cold_start.ImportNode("node_3", None, self.first_node_start_time_ns) + node_3.end_time_ns = self.first_node_start_time_ns + self.test_time_unit + nodes = [node_0] + node_0.children = [node_1, node_2] + node_1.children = [node_3] + self.cold_start_tracer.trace(nodes) + self.mock_activate.assert_called_with(self.mock_trace_ctx) + self.assertEqual(self.finish_call_count, 5) + self.assertEqual(self.mock_activate.call_count, 2) + self.assertEqual( + self.output_spans, + ["node_3", "node_1", "node_2", "node_0", "unittest_cold_start"], + ) + + def test_trace_multiple_root_nodes(self): + node_0 = cold_start.ImportNode("node_0", None, self.first_node_start_time_ns) + node_0.end_time_ns = self.first_node_start_time_ns + self.test_time_unit * 2 + node_1 = cold_start.ImportNode( + "node_1", None, self.first_node_start_time_ns + self.test_time_unit * 2 + ) + node_1.end_time_ns = self.first_node_start_time_ns + self.test_time_unit * 3 + node_2 = cold_start.ImportNode("node_2", None, self.first_node_start_time_ns) + node_2.end_time_ns = self.first_node_start_time_ns + self.test_time_unit + node_3 = cold_start.ImportNode( + "node_3", None, self.first_node_start_time_ns + self.test_time_unit + ) + node_3.end_time_ns = self.first_node_start_time_ns + self.test_time_unit * 2 + node_4 = cold_start.ImportNode( + "node_4", None, self.first_node_start_time_ns + self.test_time_unit * 2 + ) + node_4.end_time_ns = self.first_node_start_time_ns + self.test_time_unit * 3 + nodes = [node_0, node_1] + node_0.children = [node_2, node_3] + node_1.children = [node_4] + self.cold_start_tracer.trace(nodes) + self.mock_activate.assert_called_with(self.mock_trace_ctx) + self.assertEqual(self.finish_call_count, 6) + self.assertEqual(self.mock_activate.call_count, 3) + self.assertEqual( + self.output_spans, + ["node_4", "node_1", "node_2", "node_3", "node_0", "unittest_cold_start"], + ) + + def test_trace_min_duration(self): + node_0 = cold_start.ImportNode("node_0", None, self.first_node_start_time_ns) + node_0.end_time_ns = ( + self.first_node_start_time_ns + + self.cold_start_tracer.min_duration_ms * 1e6 + - 1e5 + ) + self.cold_start_tracer.trace([node_0]) + self.mock_activate.assert_called_once_with(self.mock_trace_ctx) + self.assertEqual(self.output_spans, ["unittest_cold_start"]) + + def test_trace_ignore_libs(self): + node_0 = cold_start.ImportNode("node_0", None, self.first_node_start_time_ns) + node_0.end_time_ns = self.first_node_start_time_ns + self.test_time_unit + node_1 = cold_start.ImportNode( + "ignored_module_a", + None, + self.first_node_start_time_ns + self.test_time_unit, + ) + node_1.end_time_ns = self.first_node_start_time_ns + self.test_time_unit * 2 + node_2 = cold_start.ImportNode( + "ignored_module_b", None, self.first_node_start_time_ns + ) + node_2.end_time_ns = self.first_node_start_time_ns + self.test_time_unit + nodes = [node_0, node_1] + node_0.children = [node_2] + self.cold_start_tracer.trace(nodes) + self.mock_activate.assert_called_once_with(self.mock_trace_ctx) + self.assertEqual(self.output_spans, ["node_0", "unittest_cold_start"])