diff --git a/aws_embedded_metrics/config/configuration.py b/aws_embedded_metrics/config/configuration.py index 8628e26..5a04ba9 100644 --- a/aws_embedded_metrics/config/configuration.py +++ b/aws_embedded_metrics/config/configuration.py @@ -11,6 +11,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + class Configuration: def __init__( @@ -24,6 +26,7 @@ def __init__( ec2_metadata_endpoint: str = None, namespace: str = None, disable_metric_extraction: bool = False, + environment: Optional[str] = None, ): self.debug_logging_enabled = debug_logging_enabled self.service_name = service_name @@ -34,3 +37,4 @@ def __init__( self.ec2_metadata_endpoint = ec2_metadata_endpoint self.namespace = namespace self.disable_metric_extraction = disable_metric_extraction + self.environment = environment diff --git a/aws_embedded_metrics/config/environment_configuration_provider.py b/aws_embedded_metrics/config/environment_configuration_provider.py index f9102ae..219105a 100644 --- a/aws_embedded_metrics/config/environment_configuration_provider.py +++ b/aws_embedded_metrics/config/environment_configuration_provider.py @@ -25,6 +25,7 @@ EC2_METADATA_ENDPOINT = "EC2_METADATA_ENDPOINT" NAMESPACE = "NAMESPACE" DISABLE_METRIC_EXTRACTION = "DISABLE_METRIC_EXTRACTION" +ENVIRONMENT_OVERRIDE = "ENVIRONMENT" class EnvironmentConfigurationProvider: @@ -43,6 +44,7 @@ def get_configuration(self) -> Configuration: self.__get_env_var(EC2_METADATA_ENDPOINT), self.__get_env_var(NAMESPACE), self.__get_bool_env_var(DISABLE_METRIC_EXTRACTION), + self.__get_env_var(ENVIRONMENT_OVERRIDE), ) @staticmethod diff --git a/aws_embedded_metrics/environment/environment_detector.py b/aws_embedded_metrics/environment/environment_detector.py index a616d0f..d051c02 100644 --- a/aws_embedded_metrics/environment/environment_detector.py +++ b/aws_embedded_metrics/environment/environment_detector.py @@ -12,6 +12,7 @@ # limitations under the License. import logging +from aws_embedded_metrics import config from aws_embedded_metrics.environment import Environment from aws_embedded_metrics.environment.default_environment import DefaultEnvironment from aws_embedded_metrics.environment.lambda_environment import LambdaEnvironment @@ -20,7 +21,11 @@ log = logging.getLogger(__name__) -environments = [LambdaEnvironment(), EC2Environment()] +lambda_environment = LambdaEnvironment() +ec2_environment = EC2Environment() +default_environment = DefaultEnvironment() +environments = [lambda_environment, ec2_environment] +Config = config.get_config() class EnvironmentCache: @@ -32,6 +37,19 @@ async def resolve_environment() -> Environment: log.debug("Environment resolved from cache.") return EnvironmentCache.environment + if Config.environment: + lower_configured_enviroment = Config.environment.lower() + if lower_configured_enviroment == "lambda": + EnvironmentCache.environment = lambda_environment + elif lower_configured_enviroment == "ec2": + EnvironmentCache.environment = ec2_environment + elif lower_configured_enviroment == "default": + EnvironmentCache.environment = default_environment + else: + log.info("Failed to understand environment override: %s", Config.environment) + if EnvironmentCache.environment is not None: + return EnvironmentCache.environment + for env_under_test in environments: is_environment = False try: @@ -49,5 +67,5 @@ async def resolve_environment() -> Environment: return env_under_test log.info("No environment was detected. Using default.") - EnvironmentCache.environment = DefaultEnvironment() + EnvironmentCache.environment = default_environment return EnvironmentCache.environment diff --git a/tests/config/test_config.py b/tests/config/test_config.py index d4c559d..7c7908d 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -24,6 +24,7 @@ def test_can_get_config_from_environment(monkeypatch): ec2_metadata_endpoint = fake.word() namespace = fake.word() disable_metric_extraction = True + environment_override = fake.word() monkeypatch.setenv("AWS_EMF_ENABLE_DEBUG_LOGGING", str(debug_enabled)) monkeypatch.setenv("AWS_EMF_SERVICE_NAME", service_name) @@ -34,6 +35,7 @@ def test_can_get_config_from_environment(monkeypatch): monkeypatch.setenv("AWS_EMF_EC2_METADATA_ENDPOINT", ec2_metadata_endpoint) monkeypatch.setenv("AWS_EMF_NAMESPACE", namespace) monkeypatch.setenv("AWS_EMF_DISABLE_METRIC_EXTRACTION", str(disable_metric_extraction)) + monkeypatch.setenv("AWS_EMF_ENVIRONMENT", environment_override) # act result = get_config() @@ -48,6 +50,7 @@ def test_can_get_config_from_environment(monkeypatch): assert result.ec2_metadata_endpoint == ec2_metadata_endpoint assert result.namespace == namespace assert result.disable_metric_extraction == disable_metric_extraction + assert result.environment == environment_override def test_can_override_config(monkeypatch): @@ -61,6 +64,7 @@ def test_can_override_config(monkeypatch): monkeypatch.setenv("AWS_EMF_EC2_METADATA_ENDPOINT", fake.word()) monkeypatch.setenv("AWS_EMF_NAMESPACE", fake.word()) monkeypatch.setenv("AWS_EMF_DISABLE_METRIC_EXTRACTION", str(True)) + monkeypatch.setenv("AWS_EMF_ENVIRONMENT", fake.word()) config = get_config() @@ -73,6 +77,7 @@ def test_can_override_config(monkeypatch): ec2_metadata_endpoint = fake.word() namespace = fake.word() disable_metric_extraction = False + environment = fake.word() # act config.debug_logging_enabled = debug_enabled @@ -84,6 +89,7 @@ def test_can_override_config(monkeypatch): config.ec2_metadata_endpoint = ec2_metadata_endpoint config.namespace = namespace config.disable_metric_extraction = disable_metric_extraction + config.environment = environment # assert assert config.debug_logging_enabled == debug_enabled @@ -95,3 +101,4 @@ def test_can_override_config(monkeypatch): assert config.ec2_metadata_endpoint == ec2_metadata_endpoint assert config.namespace == namespace assert config.disable_metric_extraction == disable_metric_extraction + assert config.environment == environment diff --git a/tests/environment/test_environment_detector.py b/tests/environment/test_environment_detector.py index ea058c8..b6b26d7 100644 --- a/tests/environment/test_environment_detector.py +++ b/tests/environment/test_environment_detector.py @@ -3,7 +3,7 @@ import pytest from importlib import reload -from aws_embedded_metrics.config import get_config +from aws_embedded_metrics import config from aws_embedded_metrics.environment.lambda_environment import LambdaEnvironment from aws_embedded_metrics.environment.default_environment import DefaultEnvironment @@ -11,7 +11,6 @@ from aws_embedded_metrics.environment import environment_detector fake = Faker() -Config = get_config() @pytest.fixture @@ -60,3 +59,31 @@ async def test_resolve_environment_returns_default_envionment(before): # assert assert isinstance(result, DefaultEnvironment) + + +@pytest.mark.asyncio +async def test_resolve_environment_returns_override_ec2(before, monkeypatch): + # arrange + monkeypatch.setenv("AWS_EMF_ENVIRONMENT", "ec2") + reload(config) + reload(environment_detector) + + # act + result = await environment_detector.resolve_environment() + + # assert + assert isinstance(result, ec2_environment.EC2Environment) + + +@pytest.mark.asyncio +async def test_resolve_environment_returns_override_lambda(before, monkeypatch): + # arrange + monkeypatch.setenv("AWS_EMF_ENVIRONMENT", "lambda") + reload(config) + reload(environment_detector) + + # act + result = await environment_detector.resolve_environment() + + # assert + assert isinstance(result, LambdaEnvironment)