Skip to content

fix: safely getting all the values for trigger tags #593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions datadog_lambda/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,14 @@ def parse_event_source(event: dict) -> _EventSource:

event_source = None

request_context = event.get("requestContext")
# Get requestContext safely and ensure it's a dictionary
request_context = event.get("requestContext", {})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to do request_context = event.get("requestContext") or {} becuase it will avoid an allocation in the case that the key is found in the dict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's some expert level advice on allocation, thanks!

if not isinstance(request_context, dict):
request_context = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would probably be better in this case to set request_context = None. This way we avoid the extra allocation and we don't even attempt any of the request_context.gets that happen after.


if request_context and request_context.get("stage"):
if "domainName" in request_context and detect_lambda_function_url_domain(
request_context.get("domainName")
request_context.get("domainName", "")
):
return _EventSource(EventTypes.LAMBDA_FUNCTION_URL)
event_source = _EventSource(EventTypes.API_GATEWAY)
Expand Down Expand Up @@ -171,6 +175,8 @@ def parse_event_source(event: dict) -> _EventSource:

def detect_lambda_function_url_domain(domain: str) -> bool:
# e.g. "etsn5fibjr.lambda-url.eu-south-1.amazonaws.com"
if not isinstance(domain, str):
return False
domain_parts = domain.split(".")
if len(domain_parts) < 2:
return False
Expand Down Expand Up @@ -283,17 +289,28 @@ def extract_http_tags(event):
Extracts HTTP facet tags from the triggering event
"""
http_tags = {}
request_context = event.get("requestContext")

# Safely get request_context and ensure it's a dictionary
request_context = event.get("requestContext", {})
if not isinstance(request_context, dict):
request_context = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. I think we can do this more efficiently.


path = event.get("path")
method = event.get("httpMethod")

if request_context and request_context.get("stage"):
if request_context.get("domainName"):
http_tags["http.url"] = request_context.get("domainName")
domain_name = request_context.get("domainName")
if domain_name:
http_tags["http.url"] = domain_name

path = request_context.get("path")
method = request_context.get("httpMethod")

# Version 2.0 HTTP API Gateway
apigateway_v2_http = request_context.get("http")
apigateway_v2_http = request_context.get("http", {})
if not isinstance(apigateway_v2_http, dict):
apigateway_v2_http = {}

if event.get("version") == "2.0" and apigateway_v2_http:
path = apigateway_v2_http.get("path")
method = apigateway_v2_http.get("method")
Expand All @@ -303,15 +320,23 @@ def extract_http_tags(event):
if method:
http_tags["http.method"] = method

headers = event.get("headers")
# Safely get headers
headers = event.get("headers", {})
if not isinstance(headers, dict):
headers = {}

if headers and headers.get("Referer"):
http_tags["http.referer"] = headers.get("Referer")

# Try to get `routeKey` from API GW v2; otherwise try to get `resource` from API GW v1
route = event.get("routeKey") or event.get("resource")
if route:
# "GET /my/endpoint" = > "/my/endpoint"
http_tags["http.route"] = route.split(" ")[-1]
if route and isinstance(route, str):
try:
# "GET /my/endpoint" = > "/my/endpoint"
http_tags["http.route"] = route.split(" ")[-1]
except Exception:
# If splitting fails, use the route as is
http_tags["http.route"] = route

return http_tags

Expand Down
65 changes: 65 additions & 0 deletions tests/test_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,30 @@ def test_event_source_unsupported(self):
self.assertEqual(event_source.to_string(), "unknown")
self.assertEqual(event_source_arn, None)

def test_event_source_with_non_dict_request_context(self):
# Test with requestContext as a string instead of a dict
event = {"requestContext": "not_a_dict"}
event_source = parse_event_source(event)
# Should still return a valid event source (unknown in this case)
self.assertEqual(event_source.to_string(), "unknown")

def test_event_source_with_invalid_domain_name(self):
# Test with domainName that isn't a string
event = {"requestContext": {"stage": "prod", "domainName": 12345}}
event_source = parse_event_source(event)
# Should detect as API Gateway since stage is present
self.assertEqual(event_source.to_string(), "api-gateway")

def test_detect_lambda_function_url_domain_with_invalid_input(self):
from datadog_lambda.trigger import detect_lambda_function_url_domain

# Test with non-string input
self.assertFalse(detect_lambda_function_url_domain(None))
self.assertFalse(detect_lambda_function_url_domain(12345))
self.assertFalse(detect_lambda_function_url_domain({"not": "a-string"}))
# Test with string that would normally cause an exception when split
self.assertFalse(detect_lambda_function_url_domain(""))


class GetTriggerTags(unittest.TestCase):
def test_extract_trigger_tags_api_gateway(self):
Expand Down Expand Up @@ -530,6 +554,47 @@ def test_extract_trigger_tags_list_type_event(self):
tags = extract_trigger_tags(event, ctx)
self.assertEqual(tags, {})

def test_extract_http_tags_with_invalid_request_context(self):
from datadog_lambda.trigger import extract_http_tags

# Test with requestContext as a string instead of a dict
event = {"requestContext": "not_a_dict", "path": "/test", "httpMethod": "GET"}
http_tags = extract_http_tags(event)
# Should still extract valid tags from the event
self.assertEqual(
http_tags, {"http.url_details.path": "/test", "http.method": "GET"}
)

def test_extract_http_tags_with_invalid_apigateway_http(self):
from datadog_lambda.trigger import extract_http_tags

# Test with http in requestContext that's not a dict
event = {
"requestContext": {"stage": "prod", "http": "not_a_dict"},
"version": "2.0",
}
http_tags = extract_http_tags(event)
# Should not raise an exception
self.assertEqual(http_tags, {})

def test_extract_http_tags_with_invalid_headers(self):
from datadog_lambda.trigger import extract_http_tags

# Test with headers that's not a dict
event = {"headers": "not_a_dict"}
http_tags = extract_http_tags(event)
# Should not raise an exception
self.assertEqual(http_tags, {})

def test_extract_http_tags_with_invalid_route(self):
from datadog_lambda.trigger import extract_http_tags

# Test with routeKey that would cause a split error
event = {"routeKey": 12345} # Not a string
http_tags = extract_http_tags(event)
# Should not raise an exception
self.assertEqual(http_tags, {})


class ExtractHTTPStatusCodeTag(unittest.TestCase):
def test_extract_http_status_code_tag_from_response_dict(self):
Expand Down
Loading