diff --git a/datadog_lambda/trigger.py b/datadog_lambda/trigger.py index 8090e36e..52978d4b 100644 --- a/datadog_lambda/trigger.py +++ b/datadog_lambda/trigger.py @@ -114,10 +114,14 @@ def parse_event_source(event: dict) -> _EventSource: event_source = None - request_context = event.get("requestContext") + # Get requestContext safely and ensure it's a dictionary + request_context = event.get("requestContext", {}) + if not isinstance(request_context, dict): + request_context = {} + if request_context and request_context.get("stage"): if "domainName" in request_context and detect_lambda_function_url_domain( - request_context.get("domainName") + request_context.get("domainName", "") ): return _EventSource(EventTypes.LAMBDA_FUNCTION_URL) event_source = _EventSource(EventTypes.API_GATEWAY) @@ -171,6 +175,8 @@ def parse_event_source(event: dict) -> _EventSource: def detect_lambda_function_url_domain(domain: str) -> bool: # e.g. "etsn5fibjr.lambda-url.eu-south-1.amazonaws.com" + if not isinstance(domain, str): + return False domain_parts = domain.split(".") if len(domain_parts) < 2: return False @@ -283,17 +289,28 @@ def extract_http_tags(event): Extracts HTTP facet tags from the triggering event """ http_tags = {} - request_context = event.get("requestContext") + + # Safely get request_context and ensure it's a dictionary + request_context = event.get("requestContext", {}) + if not isinstance(request_context, dict): + request_context = {} + path = event.get("path") method = event.get("httpMethod") + if request_context and request_context.get("stage"): - if request_context.get("domainName"): - http_tags["http.url"] = request_context.get("domainName") + domain_name = request_context.get("domainName") + if domain_name: + http_tags["http.url"] = domain_name path = request_context.get("path") method = request_context.get("httpMethod") + # Version 2.0 HTTP API Gateway - apigateway_v2_http = request_context.get("http") + apigateway_v2_http = request_context.get("http", {}) + if not isinstance(apigateway_v2_http, dict): + apigateway_v2_http = {} + if event.get("version") == "2.0" and apigateway_v2_http: path = apigateway_v2_http.get("path") method = apigateway_v2_http.get("method") @@ -303,15 +320,23 @@ def extract_http_tags(event): if method: http_tags["http.method"] = method - headers = event.get("headers") + # Safely get headers + headers = event.get("headers", {}) + if not isinstance(headers, dict): + headers = {} + if headers and headers.get("Referer"): http_tags["http.referer"] = headers.get("Referer") # Try to get `routeKey` from API GW v2; otherwise try to get `resource` from API GW v1 route = event.get("routeKey") or event.get("resource") - if route: - # "GET /my/endpoint" = > "/my/endpoint" - http_tags["http.route"] = route.split(" ")[-1] + if route and isinstance(route, str): + try: + # "GET /my/endpoint" = > "/my/endpoint" + http_tags["http.route"] = route.split(" ")[-1] + except Exception: + # If splitting fails, use the route as is + http_tags["http.route"] = route return http_tags diff --git a/tests/test_trigger.py b/tests/test_trigger.py index 9cb088f1..b4da7ff0 100644 --- a/tests/test_trigger.py +++ b/tests/test_trigger.py @@ -256,6 +256,30 @@ def test_event_source_unsupported(self): self.assertEqual(event_source.to_string(), "unknown") self.assertEqual(event_source_arn, None) + def test_event_source_with_non_dict_request_context(self): + # Test with requestContext as a string instead of a dict + event = {"requestContext": "not_a_dict"} + event_source = parse_event_source(event) + # Should still return a valid event source (unknown in this case) + self.assertEqual(event_source.to_string(), "unknown") + + def test_event_source_with_invalid_domain_name(self): + # Test with domainName that isn't a string + event = {"requestContext": {"stage": "prod", "domainName": 12345}} + event_source = parse_event_source(event) + # Should detect as API Gateway since stage is present + self.assertEqual(event_source.to_string(), "api-gateway") + + def test_detect_lambda_function_url_domain_with_invalid_input(self): + from datadog_lambda.trigger import detect_lambda_function_url_domain + + # Test with non-string input + self.assertFalse(detect_lambda_function_url_domain(None)) + self.assertFalse(detect_lambda_function_url_domain(12345)) + self.assertFalse(detect_lambda_function_url_domain({"not": "a-string"})) + # Test with string that would normally cause an exception when split + self.assertFalse(detect_lambda_function_url_domain("")) + class GetTriggerTags(unittest.TestCase): def test_extract_trigger_tags_api_gateway(self): @@ -530,6 +554,47 @@ def test_extract_trigger_tags_list_type_event(self): tags = extract_trigger_tags(event, ctx) self.assertEqual(tags, {}) + def test_extract_http_tags_with_invalid_request_context(self): + from datadog_lambda.trigger import extract_http_tags + + # Test with requestContext as a string instead of a dict + event = {"requestContext": "not_a_dict", "path": "/test", "httpMethod": "GET"} + http_tags = extract_http_tags(event) + # Should still extract valid tags from the event + self.assertEqual( + http_tags, {"http.url_details.path": "/test", "http.method": "GET"} + ) + + def test_extract_http_tags_with_invalid_apigateway_http(self): + from datadog_lambda.trigger import extract_http_tags + + # Test with http in requestContext that's not a dict + event = { + "requestContext": {"stage": "prod", "http": "not_a_dict"}, + "version": "2.0", + } + http_tags = extract_http_tags(event) + # Should not raise an exception + self.assertEqual(http_tags, {}) + + def test_extract_http_tags_with_invalid_headers(self): + from datadog_lambda.trigger import extract_http_tags + + # Test with headers that's not a dict + event = {"headers": "not_a_dict"} + http_tags = extract_http_tags(event) + # Should not raise an exception + self.assertEqual(http_tags, {}) + + def test_extract_http_tags_with_invalid_route(self): + from datadog_lambda.trigger import extract_http_tags + + # Test with routeKey that would cause a split error + event = {"routeKey": 12345} # Not a string + http_tags = extract_http_tags(event) + # Should not raise an exception + self.assertEqual(http_tags, {}) + class ExtractHTTPStatusCodeTag(unittest.TestCase): def test_extract_http_status_code_tag_from_response_dict(self):