diff --git a/.gitignore b/.gitignore index b4e4ffd6..7bb319c2 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ pip-selfcheck.json htmlcov venv +.idea diff --git a/aws_xray_sdk/core/recorder.py b/aws_xray_sdk/core/recorder.py index 3169e3a2..63077b77 100644 --- a/aws_xray_sdk/core/recorder.py +++ b/aws_xray_sdk/core/recorder.py @@ -232,7 +232,7 @@ def begin_segment(self, name=None, traceid=None, elif sampling: decision = sampling elif self.sampling: - decision = self._sampler.should_trace() + decision = self._sampler.should_trace({'service': seg_name}) if not decision: segment = DummySegment(seg_name) diff --git a/tests/test_recorder.py b/tests/test_recorder.py index c060d3b7..ee60e5a9 100644 --- a/tests/test_recorder.py +++ b/tests/test_recorder.py @@ -1,7 +1,11 @@ import platform +import time import pytest +from aws_xray_sdk.core.sampling.sampling_rule import SamplingRule +from aws_xray_sdk.core.sampling.rule_cache import RuleCache +from aws_xray_sdk.core.sampling.sampler import DefaultSampler from aws_xray_sdk.version import VERSION from .util import get_new_stubbed_recorder @@ -38,7 +42,6 @@ def test_default_runtime_context(): def test_subsegment_parenting(): - segment = xray_recorder.begin_segment('name') subsegment = xray_recorder.begin_subsegment('name') xray_recorder.end_subsegment('name') @@ -97,7 +100,6 @@ def test_put_annotation_metadata(): def test_pass_through_with_missing_context(): - xray_recorder = get_new_stubbed_recorder() xray_recorder.configure(sampling=False, context_missing='LOG_ERROR') assert not xray_recorder.is_sampled() @@ -175,7 +177,6 @@ def test_in_segment_exception(): assert segment.fault is True assert len(segment.cause['exceptions']) == 1 - with pytest.raises(Exception): with xray_recorder.in_segment('name') as segment: with xray_recorder.in_subsegment('name') as subsegment: @@ -259,7 +260,6 @@ def test_disabled_get_context_entity(): assert type(entity) is DummySegment - def test_max_stack_trace_zero(): xray_recorder.configure(max_trace_back=1) with pytest.raises(Exception): @@ -279,3 +279,41 @@ def test_max_stack_trace_zero(): assert len(segment_with_stack.cause['exceptions'][0].stack) == 1 assert len(segment_no_stack.cause['exceptions'][0].stack) == 0 + + +# CustomSampler to mimic the DefaultSampler, +# but without the rule and target polling logic. +class CustomSampler(DefaultSampler): + def start(self): + pass + + def should_trace(self, sampling_req=None): + rule_cache = RuleCache() + rule_cache.last_updated = int(time.time()) + sampling_rule_a = SamplingRule(name='rule_a', + priority=2, + rate=0.5, + reservoir_size=1, + service='app_a') + sampling_rule_b = SamplingRule(name='rule_b', + priority=2, + rate=0.5, + reservoir_size=1, + service='app_b') + rule_cache.load_rules([sampling_rule_a, sampling_rule_b]) + now = int(time.time()) + if sampling_req and not sampling_req.get('service_type', None): + sampling_req['service_type'] = self._origin + elif sampling_req is None: + sampling_req = {'service_type': self._origin} + matched_rule = rule_cache.get_matched_rule(sampling_req, now) + if matched_rule: + return self._process_matched_rule(matched_rule, now) + else: + return self._local_sampler.should_trace(sampling_req) + + +def test_begin_segment_matches_sampling_rule_on_name(): + xray_recorder.configure(sampling=True, sampler=CustomSampler()) + segment = xray_recorder.begin_segment("app_b") + assert segment.aws.get('xray').get('sampling_rule_name') == 'rule_b'