Skip to content

cold start tracing (python) #299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions datadog_lambda/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import os
import logging
from datadog_lambda.cold_start import initialize_cold_start_tracing

initialize_cold_start_tracing()

# The minor version corresponds to the Lambda layer version.
# E.g.,, version 0.5.0 gets packaged into layer version 5.
try:
Expand All @@ -7,8 +13,5 @@

__version__ = importlib_metadata.version(__name__)

import os
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.getLevelName(os.environ.get("DD_LOG_LEVEL", "INFO").upper()))
200 changes: 200 additions & 0 deletions datadog_lambda/cold_start.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
import time
import os
from typing import List, Hashable
import logging

logger = logging.getLogger(__name__)

_cold_start = True
_lambda_container_initialized = False

Expand All @@ -21,3 +28,196 @@ def is_cold_start():
def get_cold_start_tag():
"""Returns the cold start tag to be used in metrics"""
return "cold_start:{}".format(str(is_cold_start()).lower())


class ImportNode(object):
def __init__(self, module_name, full_file_path, start_time_ns, end_time_ns=None):
self.module_name = module_name
self.full_file_path = full_file_path
self.start_time_ns = start_time_ns
self.end_time_ns = end_time_ns
self.children = []


root_nodes: List[ImportNode] = []
import_stack: List[ImportNode] = []
already_wrapped_loaders = set()


def reset_node_stacks():
global root_nodes
root_nodes = []
global import_stack
import_stack = []


def push_node(module_name, file_path):
node = ImportNode(module_name, file_path, time.time_ns())
global import_stack
if import_stack:
import_stack[-1].children.append(node)
import_stack.append(node)


def pop_node(module_name):
global import_stack
if not import_stack:
return
node = import_stack.pop()
if node.module_name != module_name:
return
end_time_ns = time.time_ns()
node.end_time_ns = end_time_ns
if not import_stack: # import_stack empty, a root node has been found
global root_nodes
root_nodes.append(node)


def wrap_exec_module(original_exec_module):
def wrapped_method(module):
should_pop = False
try:
spec = module.__spec__
push_node(spec.name, spec.origin)
should_pop = True
except Exception:
pass
try:
return original_exec_module(module)
finally:
if should_pop:
pop_node(spec.name)

return wrapped_method


def wrap_find_spec(original_find_spec):
def wrapped_find_spec(*args, **kwargs):
spec = original_find_spec(*args, **kwargs)
if spec is None:
return None
loader = getattr(spec, "loader", None)
if (
loader is not None
and isinstance(loader, Hashable)
and loader not in already_wrapped_loaders
):
if hasattr(loader, "exec_module"):
try:
loader.exec_module = wrap_exec_module(loader.exec_module)
already_wrapped_loaders.add(loader)
except Exception as e:
logger.debug("Failed to wrap the loader. %s", e)
return spec

return wrapped_find_spec


def initialize_cold_start_tracing():
if (
is_cold_start()
and os.environ.get("DD_TRACE_ENABLED", "true").lower() == "true"
and os.environ.get("DD_COLD_START_TRACING", "true").lower() == "true"
):
from sys import version_info, meta_path

if version_info >= (3, 7): # current implementation only support version > 3.7
for importer in meta_path:
try:
importer.find_spec = wrap_find_spec(importer.find_spec)
except Exception:
pass


class ColdStartTracer(object):
def __init__(
self,
tracer,
function_name,
cold_start_span_finish_time_ns,
trace_ctx,
min_duration_ms: int,
ignored_libs: List[str] = [],
):
self._tracer = tracer
self.function_name = function_name
self.cold_start_span_finish_time_ns = cold_start_span_finish_time_ns
self.min_duration_ms = min_duration_ms
self.trace_ctx = trace_ctx
self.ignored_libs = ignored_libs
self.need_to_reactivate_context = True

def trace(self, root_nodes: List[ImportNode] = root_nodes):
if not root_nodes:
return
cold_start_span_start_time_ns = root_nodes[0].start_time_ns
cold_start_span = self.create_cold_start_span(cold_start_span_start_time_ns)
while root_nodes:
root_node = root_nodes.pop()
self.trace_tree(root_node, cold_start_span)
self.finish_span(cold_start_span, self.cold_start_span_finish_time_ns)

def trace_tree(self, import_node: ImportNode, parent_span):
if (
import_node.end_time_ns - import_node.start_time_ns
< self.min_duration_ms * 1e6
or import_node.module_name in self.ignored_libs
):
return

span = self.start_span(
"aws.lambda.import", import_node.module_name, import_node.start_time_ns
)
tags = {
"resource_names": import_node.module_name,
"resource.name": import_node.module_name,
"filename": import_node.full_file_path,
"operation_name": self.get_operation_name(import_node.full_file_path),
}
span.set_tags(tags)
if parent_span:
span.parent_id = parent_span.span_id
for child_node in import_node.children:
self.trace_tree(child_node, span)
self.finish_span(span, import_node.end_time_ns)

def create_cold_start_span(self, start_time_ns):
span = self.start_span("aws.lambda.load", self.function_name, start_time_ns)
tags = {
"resource_names": self.function_name,
"resource.name": self.function_name,
"operation_name": "aws.lambda.load",
}
span.set_tags(tags)
return span

def start_span(self, span_type, resource, start_time_ns):
if self.need_to_reactivate_context:
self._tracer.context_provider.activate(
self.trace_ctx
) # reactivate required after each finish() call
self.need_to_reactivate_context = False
span_kwargs = {
"service": "aws.lambda",
"resource": resource,
"span_type": span_type,
}
span = self._tracer.trace(span_type, **span_kwargs)
span.start_ns = start_time_ns
return span

def finish_span(self, span, finish_time_ns):
span.finish(finish_time_ns / 1e9)
self.need_to_reactivate_context = True

def get_operation_name(self, filename: str):
if filename is None:
return "aws.lambda.import_core_module"
if not isinstance(filename, str):
return "aws.lambda.import"
if filename.startswith("/opt/"):
return "aws.lambda.import_layer"
elif filename.startswith("/var/lang/"):
return "aws.lambda.import_runtime"
else:
return "aws.lambda.import"
45 changes: 43 additions & 2 deletions datadog_lambda/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# under the Apache License Version 2.0.
# This product includes software developed at Datadog (https://www.datadoghq.com/).
# Copyright 2019 Datadog, Inc.

import base64
import os
import logging
Expand All @@ -12,7 +11,7 @@
from time import time_ns

from datadog_lambda.extension import should_use_extension, flush_extension
from datadog_lambda.cold_start import set_cold_start, is_cold_start
from datadog_lambda.cold_start import set_cold_start, is_cold_start, ColdStartTracer
from datadog_lambda.constants import (
TraceContextSource,
XraySubsegment,
Expand All @@ -38,6 +37,7 @@
create_inferred_span,
InferredSpanInfo,
is_authorizer_response,
tracer,
)
from datadog_lambda.trigger import (
extract_trigger_tags,
Expand Down Expand Up @@ -131,6 +131,28 @@ def __init__(self, func):
self.decode_authorizer_context = (
os.environ.get("DD_DECODE_AUTHORIZER_CONTEXT", "true").lower() == "true"
)
self.cold_start_tracing = (
os.environ.get("DD_COLD_START_TRACING", "true").lower() == "true"
)
self.min_cold_start_trace_duration = 3
if "DD_MIN_COLD_START_DURATION" in os.environ:
try:
self.min_cold_start_trace_duration = int(
os.environ["DD_MIN_COLD_START_DURATION"]
)
except Exception:
logger.debug("Malformatted env DD_MIN_COLD_START_DURATION")
self.cold_start_trace_skip_lib = [
"ddtrace.internal.compat",
"ddtrace.filters",
]
if "DD_COLD_START_TRACE_SKIP_LIB" in os.environ:
try:
self.cold_start_trace_skip_lib = os.environ[
"DD_COLD_START_TRACE_SKIP_LIB"
].split(",")
except Exception:
logger.debug("Malformatted for env DD_COLD_START_TRACE_SKIP_LIB")
self.response = None
if profiling_env_var:
self.prof = profiler.Profiler(env=env_env_var, service=service_env_var)
Expand Down Expand Up @@ -257,6 +279,11 @@ def _after(self, event, context):
create_dd_dummy_metadata_subsegment(
self.trigger_tags, XraySubsegment.LAMBDA_FUNCTION_TAGS_KEY
)
should_trace_cold_start = (
dd_tracing_enabled and self.cold_start_tracing and is_cold_start()
)
if should_trace_cold_start:
trace_ctx = tracer.current_trace_context()

if self.span:
if dd_capture_lambda_payload_enabled:
Expand All @@ -276,6 +303,20 @@ def _after(self, event, context):
else:
self.inferred_span.finish()

if should_trace_cold_start:
try:
following_span = self.span or self.inferred_span
ColdStartTracer(
tracer,
self.function_name,
following_span.start_ns,
trace_ctx,
self.min_cold_start_trace_duration,
self.cold_start_trace_skip_lib,
).trace()
except Exception as e:
logger.debug("Failed to create cold start spans. %s", e)

if not self.flush_to_log or should_use_extension:
flush_stats()
if should_use_extension:
Expand Down
1 change: 1 addition & 0 deletions tests/integration/serverless.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ provider:
DD_TRACE_ENABLED: true
DD_API_KEY: ${env:DD_API_KEY}
DD_TRACE_MANAGED_SERVICES: true
DD_COLD_START_TRACING: false
timeout: 15
deploymentBucket:
name: integration-tests-serververless-deployment-bucket
Expand Down
Loading