diff --git a/sentry_sdk/integrations/celery.py b/sentry_sdk/integrations/celery.py index c2dc4e1e74..ba7aabefa6 100644 --- a/sentry_sdk/integrations/celery.py +++ b/sentry_sdk/integrations/celery.py @@ -16,12 +16,14 @@ capture_internal_exceptions, event_from_exception, logger, + match_regex_list, ) if TYPE_CHECKING: from typing import Any from typing import Callable from typing import Dict + from typing import List from typing import Optional from typing import Tuple from typing import TypeVar @@ -59,10 +61,16 @@ class CeleryIntegration(Integration): identifier = "celery" - def __init__(self, propagate_traces=True, monitor_beat_tasks=False): - # type: (bool, bool) -> None + def __init__( + self, + propagate_traces=True, + monitor_beat_tasks=False, + exclude_beat_tasks=None, + ): + # type: (bool, bool, Optional[List[str]]) -> None self.propagate_traces = propagate_traces self.monitor_beat_tasks = monitor_beat_tasks + self.exclude_beat_tasks = exclude_beat_tasks if monitor_beat_tasks: _patch_beat_apply_entry() @@ -420,9 +428,18 @@ def sentry_apply_entry(*args, **kwargs): app = scheduler.app celery_schedule = schedule_entry.schedule - monitor_config = _get_monitor_config(celery_schedule, app) monitor_name = schedule_entry.name + hub = Hub.current + integration = hub.get_integration(CeleryIntegration) + if integration is None: + return original_apply_entry(*args, **kwargs) + + if match_regex_list(monitor_name, integration.exclude_beat_tasks): + return original_apply_entry(*args, **kwargs) + + monitor_config = _get_monitor_config(celery_schedule, app) + headers = schedule_entry.options.pop("headers", {}) headers.update( { diff --git a/sentry_sdk/tracing_utils.py b/sentry_sdk/tracing_utils.py index d1cd906d2c..d49aad4c8a 100644 --- a/sentry_sdk/tracing_utils.py +++ b/sentry_sdk/tracing_utils.py @@ -7,6 +7,7 @@ from sentry_sdk.utils import ( capture_internal_exceptions, Dsn, + match_regex_list, to_string, ) from sentry_sdk._compat import PY2, iteritems @@ -334,15 +335,7 @@ def should_propagate_trace(hub, url): client = hub.client # type: Any trace_propagation_targets = client.options["trace_propagation_targets"] - if trace_propagation_targets is None: - return False - - for target in trace_propagation_targets: - matched = re.search(target, url) - if matched: - return True - - return False + return match_regex_list(url, trace_propagation_targets, substring_matching=True) # Circular imports diff --git a/sentry_sdk/utils.py b/sentry_sdk/utils.py index 4e557578e4..fa4346ecdb 100644 --- a/sentry_sdk/utils.py +++ b/sentry_sdk/utils.py @@ -1304,6 +1304,22 @@ def is_valid_sample_rate(rate, source): return True +def match_regex_list(item, regex_list=None, substring_matching=False): + # type: (str, Optional[List[str]], bool) -> bool + if regex_list is None: + return False + + for item_matcher in regex_list: + if not substring_matching and item_matcher[-1] != "$": + item_matcher += "$" + + matched = re.search(item_matcher, item) + if matched: + return True + + return False + + if PY37: def nanosecond_time(): diff --git a/tests/integrations/celery/test_celery_beat_crons.py b/tests/integrations/celery/test_celery_beat_crons.py index 431e32642d..a74214a9ee 100644 --- a/tests/integrations/celery/test_celery_beat_crons.py +++ b/tests/integrations/celery/test_celery_beat_crons.py @@ -8,6 +8,7 @@ _get_headers, _get_humanized_interval, _get_monitor_config, + _patch_beat_apply_entry, crons_task_success, crons_task_failure, crons_task_retry, @@ -243,3 +244,56 @@ def test_get_monitor_config_default_timezone(): monitor_config = _get_monitor_config(celery_schedule, app) assert monitor_config["timezone"] == "UTC" + + +@pytest.mark.parametrize( + "task_name,exclude_beat_tasks,task_in_excluded_beat_tasks", + [ + ["some_task_name", ["xxx", "some_task.*"], True], + ["some_task_name", ["xxx", "some_other_task.*"], False], + ], +) +def test_exclude_beat_tasks_option( + task_name, exclude_beat_tasks, task_in_excluded_beat_tasks +): + """ + Test excluding Celery Beat tasks from automatic instrumentation. + """ + fake_apply_entry = mock.MagicMock() + + fake_scheduler = mock.MagicMock() + fake_scheduler.apply_entry = fake_apply_entry + + fake_integration = mock.MagicMock() + fake_integration.exclude_beat_tasks = exclude_beat_tasks + + fake_schedule_entry = mock.MagicMock() + fake_schedule_entry.name = task_name + + fake_get_monitor_config = mock.MagicMock() + + with mock.patch( + "sentry_sdk.integrations.celery.Scheduler", fake_scheduler + ) as Scheduler: # noqa: N806 + with mock.patch( + "sentry_sdk.integrations.celery.Hub.current.get_integration", + return_value=fake_integration, + ): + with mock.patch( + "sentry_sdk.integrations.celery._get_monitor_config", + fake_get_monitor_config, + ) as _get_monitor_config: + # Mimic CeleryIntegration patching of Scheduler.apply_entry() + _patch_beat_apply_entry() + # Mimic Celery Beat calling a task from the Beat schedule + Scheduler.apply_entry(fake_scheduler, fake_schedule_entry) + + if task_in_excluded_beat_tasks: + # Only the original Scheduler.apply_entry() is called, _get_monitor_config is NOT called. + fake_apply_entry.assert_called_once() + _get_monitor_config.assert_not_called() + + else: + # The original Scheduler.apply_entry() is called, AND _get_monitor_config is called. + fake_apply_entry.assert_called_once() + _get_monitor_config.assert_called_once() diff --git a/tests/test_utils.py b/tests/test_utils.py index aa88d26c44..ed8c49b56a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,6 +5,7 @@ from sentry_sdk.utils import ( is_valid_sample_rate, logger, + match_regex_list, parse_url, sanitize_url, serialize_frame, @@ -241,3 +242,24 @@ def test_include_source_context_when_serializing_frame(include_source_context): assert include_source_context ^ ("pre_context" in result) ^ True assert include_source_context ^ ("context_line" in result) ^ True assert include_source_context ^ ("post_context" in result) ^ True + + +@pytest.mark.parametrize( + "item,regex_list,expected_result", + [ + ["", [], False], + [None, [], False], + ["", None, False], + [None, None, False], + ["some-string", [], False], + ["some-string", None, False], + ["some-string", ["some-string"], True], + ["some-string", ["some"], False], + ["some-string", ["some$"], False], # same as above + ["some-string", ["some.*"], True], + ["some-string", ["Some"], False], # we do case sensitive matching + ["some-string", [".*string$"], True], + ], +) +def test_match_regex_list(item, regex_list, expected_result): + assert match_regex_list(item, regex_list) == expected_result