diff --git a/sentry_sdk/tracing_utils.py b/sentry_sdk/tracing_utils.py index 6aa4e4882a..ba56695740 100644 --- a/sentry_sdk/tracing_utils.py +++ b/sentry_sdk/tracing_utils.py @@ -5,7 +5,7 @@ import sys from collections.abc import Mapping from datetime import timedelta -from decimal import ROUND_DOWN, Decimal +from decimal import ROUND_DOWN, Context, Decimal from functools import wraps from random import Random from urllib.parse import quote, unquote @@ -871,7 +871,11 @@ def _generate_sample_rand( sample_rand = rng.uniform(lower, upper) # Round down to exactly six decimal-digit precision. - return Decimal(sample_rand).quantize(Decimal("0.000001"), rounding=ROUND_DOWN) + # Setting the context is needed to avoid an InvalidOperation exception + # in case the user has changed the default precision. + return Decimal(sample_rand).quantize( + Decimal("0.000001"), rounding=ROUND_DOWN, context=Context(prec=6) + ) def _sample_rand_range(parent_sampled, sample_rate): diff --git a/tests/tracing/test_sample_rand.py b/tests/tracing/test_sample_rand.py index b8f5c042ed..ef277a3dec 100644 --- a/tests/tracing/test_sample_rand.py +++ b/tests/tracing/test_sample_rand.py @@ -1,3 +1,4 @@ +import decimal from unittest import mock import pytest @@ -53,3 +54,28 @@ def test_transaction_uses_incoming_sample_rand( # Transaction event captured if sample_rand < sample_rate, indicating that # sample_rand is used to make the sampling decision. assert len(events) == int(sample_rand < sample_rate) + + +def test_decimal_context(sentry_init, capture_events): + """ + Ensure that having a decimal context with a precision below 6 + does not cause an InvalidOperation exception. + """ + sentry_init(traces_sample_rate=1.0) + events = capture_events() + + old_prec = decimal.getcontext().prec + decimal.getcontext().prec = 2 + + try: + with mock.patch( + "sentry_sdk.tracing_utils.Random.uniform", return_value=0.123456789 + ): + with sentry_sdk.start_transaction() as transaction: + assert ( + transaction.get_baggage().sentry_items["sample_rand"] == "0.123456" + ) + finally: + decimal.getcontext().prec = old_prec + + assert len(events) == 1