From d611ff6e826ea49d316ee915d52548880565a68d Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Tue, 2 Jan 2024 16:27:34 -0700 Subject: [PATCH 1/6] Allow nested env var source to override nested init source. --- pydantic_settings/sources.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 28178f03..fc7bba93 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, List, Mapping, Sequence, Tuple, Union, cast from dotenv import dotenv_values -from pydantic import AliasChoices, AliasPath, BaseModel, Json +from pydantic import AliasChoices, AliasPath, BaseModel, Json, TypeAdapter from pydantic._internal._typing_extra import origin_is_union from pydantic._internal._utils import deep_update, lenient_issubclass from pydantic.fields import FieldInfo @@ -121,7 +121,7 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, return None, '', False def __call__(self) -> dict[str, Any]: - return self.init_kwargs + return TypeAdapter(dict[str, Any]).dump_python(self.init_kwargs) def __repr__(self) -> str: return f'InitSettingsSource(init_kwargs={self.init_kwargs!r})' From 5519fd6f96d0d162652a9b23648973af238613ab Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Thu, 4 Jan 2024 21:12:07 -0700 Subject: [PATCH 2/6] Add test. --- pydantic_settings/sources.py | 4 ++-- tests/test_settings.py | 37 ++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index fc7bba93..7af1f487 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -7,7 +7,7 @@ from collections import deque from dataclasses import is_dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, List, Mapping, Sequence, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Sequence, Tuple, Union, cast from dotenv import dotenv_values from pydantic import AliasChoices, AliasPath, BaseModel, Json, TypeAdapter @@ -121,7 +121,7 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, return None, '', False def __call__(self) -> dict[str, Any]: - return TypeAdapter(dict[str, Any]).dump_python(self.init_kwargs) + return TypeAdapter(Dict[str, Any]).dump_python(self.init_kwargs) def __repr__(self) -> str: return f'InitSettingsSource(init_kwargs={self.init_kwargs!r})' diff --git a/tests/test_settings.py b/tests/test_settings.py index 9a7fe697..36dcc7e4 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -619,6 +619,43 @@ def settings_customise_sources( assert s.bar == 'env setting' +def test_env_deep_override(env): + class DeepSubModel(BaseModel): + v4: str + + class SubModel(BaseModel): + v1: str + v2: bytes + v3: int + deep: DeepSubModel + + class Settings(BaseSettings, env_nested_delimiter='__'): + v0: str + sub_model: SubModel + + @classmethod + def settings_customise_sources( + cls, settings_cls, init_settings, env_settings, dotenv_settings, file_secret_settings + ): + return env_settings, dotenv_settings, init_settings, file_secret_settings + + env.set('SUB_MODEL__DEEP__V4', 'override-v4') + + s_final = {'v0': '0', 'sub_model': {'v1': 'init-v1', 'v2': b'init-v2', 'v3': 3, 'deep': {'v4': 'override-v4'}}} + + s = Settings(v0='0', sub_model={'v1': 'init-v1', 'v2': b'init-v2', 'v3': 3, 'deep': {'v4': 'init-v4'}}) + assert s.model_dump() == s_final + + s = Settings(v0='0', sub_model=SubModel(v1='init-v1', v2=b'init-v2', v3=3, deep=DeepSubModel(v4='init-v4'))) + assert s.model_dump() == s_final + + s = Settings(v0='0', sub_model=SubModel(v1='init-v1', v2=b'init-v2', v3=3, deep={'v4': 'init-v4'})) + assert s.model_dump() == s_final + + s = Settings(v0='0', sub_model={'v1': 'init-v1', 'v2': b'init-v2', 'v3': 3, 'deep': DeepSubModel(v4='init-v4')}) + assert s.model_dump() == s_final + + def test_config_file_settings_nornir(env): """ See https://github.com/pydantic/pydantic/pull/341#issuecomment-450378771 From 9ee8acb70843a699f450c04e8a925aa00b9be791 Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Mon, 19 Feb 2024 21:48:14 -0700 Subject: [PATCH 3/6] Update nested deep override to handle objects. --- pydantic_settings/main.py | 30 +++++++++++++++++--- pydantic_settings/sources.py | 4 +-- tests/test_settings.py | 54 ++++++++++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 6 deletions(-) diff --git a/pydantic_settings/main.py b/pydantic_settings/main.py index 6deca656..87821e8f 100644 --- a/pydantic_settings/main.py +++ b/pydantic_settings/main.py @@ -1,11 +1,11 @@ from __future__ import annotations as _annotations from pathlib import Path -from typing import Any, ClassVar +from typing import Any, ClassVar, TypeVar -from pydantic import ConfigDict +from pydantic import ConfigDict, TypeAdapter from pydantic._internal._config import config_keys -from pydantic._internal._utils import deep_update +from pydantic._internal._utils import is_model_class from pydantic.main import BaseModel from .sources import ( @@ -19,6 +19,8 @@ SecretsSettingsSource, ) +KeyType = TypeVar('KeyType') + class SettingsConfigDict(ConfigDict, total=False): case_sensitive: bool @@ -184,7 +186,7 @@ def _settings_build_values( file_secret_settings=file_secret_settings, ) if sources: - return deep_update(*reversed([source() for source in sources])) + return BaseSettings._deep_update(*reversed([source() for source in sources])) else: # no one should mean to do this, but I think returning an empty dict is marginally preferable # to an informative error and much better than a confusing error @@ -209,3 +211,23 @@ def _settings_build_values( secrets_dir=None, protected_namespaces=('model_', 'settings_'), ) + + @staticmethod + def _deep_update(mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]) -> dict[KeyType, Any]: + """Adapts logic from `pydantic._internal._utils.deep_update` to handle nested partial overrides of BaseModel derived types.""" + updated_mapping = mapping.copy() + for updating_mapping in updating_mappings: + for key, new_val in updating_mapping.items(): + if key in updated_mapping: + old_val = updated_mapping[key] + old_val_type = type(old_val) + if is_model_class(old_val_type) and isinstance(new_val, dict): + old_val = old_val.model_dump() + updated_mapping[key] = ( + TypeAdapter(old_val_type).validate_python(BaseSettings._deep_update(old_val, new_val)) + if isinstance(old_val, dict) and isinstance(new_val, dict) + else new_val + ) + else: + updated_mapping[key] = new_val + return updated_mapping diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 44bcc9f7..82ec7271 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -8,7 +8,7 @@ from collections import deque from dataclasses import is_dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Sequence, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, List, Mapping, Sequence, Tuple, Union, cast from dotenv import dotenv_values from pydantic import AliasChoices, AliasPath, BaseModel, Json @@ -165,7 +165,7 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, return None, '', False def __call__(self) -> dict[str, Any]: - return TypeAdapter(Dict[str, Any]).dump_python(self.init_kwargs) + return self.init_kwargs def __repr__(self) -> str: return f'InitSettingsSource(init_kwargs={self.init_kwargs!r})' diff --git a/tests/test_settings.py b/tests/test_settings.py index dc3bacc9..a8af3a8c 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -6,6 +6,7 @@ from datetime import datetime, timezone from pathlib import Path from typing import Any, Callable, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar, Union +from abc import ABC, abstractmethod import pytest from annotated_types import MinLen @@ -684,6 +685,59 @@ def settings_customise_sources( assert s.model_dump() == s_final +def test_env_deep_override_copy_by_reference(env): + class BaseAuth(ABC, BaseModel): + @property + @abstractmethod + def token(self) -> str: + """returns authentication token for XYZ""" + pass + + class CustomAuth(BaseAuth): + url: HttpUrl + username: str + password: SecretStr + + _token: SecretStr + + @property + def token(self): + ... # (re)fetch token + return self._token.get_secret_value() + + + class Settings(BaseSettings, env_nested_delimiter='__'): + auth: BaseAuth + + @classmethod + def settings_customise_sources( + cls, + settings_cls, + init_settings, + env_settings, + dotenv_settings, + file_secret_settings): + return env_settings, init_settings, file_secret_settings + + auth_orig = CustomAuth( + url='https://127.0.0.1', + username='some-username', + password='some-password' + ) + + s = Settings(auth=auth_orig) + assert s.auth is auth_orig + + + env.set('AUTH__URL', 'https://123.4.5.6') + + s = Settings(auth=auth_orig) + assert s.auth is not auth_orig + assert s.auth.username == auth_orig.username + assert s.auth.url == HttpUrl('https://123.4.5.6') + assert s.auth.password is auth_orig.password + + def test_config_file_settings_nornir(env): """ See https://github.com/pydantic/pydantic/pull/341#issuecomment-450378771 From ee9a3f9768ea24572a997c4b5e146f73207c396c Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Mon, 19 Feb 2024 22:00:00 -0700 Subject: [PATCH 4/6] Add assert for type check. --- tests/test_settings.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_settings.py b/tests/test_settings.py index a8af3a8c..6e34e107 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -733,6 +733,7 @@ def settings_customise_sources( s = Settings(auth=auth_orig) assert s.auth is not auth_orig + assert type(s.auth) is CustomAuth assert s.auth.username == auth_orig.username assert s.auth.url == HttpUrl('https://123.4.5.6') assert s.auth.password is auth_orig.password From f0bdc7f405d42c9caf82e79750f035ac5d25ebbf Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Mon, 19 Feb 2024 22:06:48 -0700 Subject: [PATCH 5/6] Whitespace cleanup and stricter asserts. --- tests/test_settings.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_settings.py b/tests/test_settings.py index 6e34e107..03726cc1 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -705,7 +705,6 @@ def token(self): ... # (re)fetch token return self._token.get_secret_value() - class Settings(BaseSettings, env_nested_delimiter='__'): auth: BaseAuth @@ -728,15 +727,15 @@ def settings_customise_sources( s = Settings(auth=auth_orig) assert s.auth is auth_orig - env.set('AUTH__URL', 'https://123.4.5.6') s = Settings(auth=auth_orig) assert s.auth is not auth_orig assert type(s.auth) is CustomAuth - assert s.auth.username == auth_orig.username - assert s.auth.url == HttpUrl('https://123.4.5.6') + assert s.auth.username is auth_orig.username assert s.auth.password is auth_orig.password + assert s.auth.url is not auth_orig.url + assert s.auth.url == HttpUrl('https://123.4.5.6') def test_config_file_settings_nornir(env): From 251ec484bca805bac80e2208e740aa282bd99acc Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Tue, 20 Feb 2024 19:44:32 -0700 Subject: [PATCH 6/6] Lint fix. --- tests/test_settings.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/tests/test_settings.py b/tests/test_settings.py index 03726cc1..f9ca1fbc 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -3,10 +3,10 @@ import os import sys import uuid +from abc import ABC, abstractmethod from datetime import datetime, timezone from pathlib import Path from typing import Any, Callable, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar, Union -from abc import ABC, abstractmethod import pytest from annotated_types import MinLen @@ -710,19 +710,11 @@ class Settings(BaseSettings, env_nested_delimiter='__'): @classmethod def settings_customise_sources( - cls, - settings_cls, - init_settings, - env_settings, - dotenv_settings, - file_secret_settings): + cls, settings_cls, init_settings, env_settings, dotenv_settings, file_secret_settings + ): return env_settings, init_settings, file_secret_settings - auth_orig = CustomAuth( - url='https://127.0.0.1', - username='some-username', - password='some-password' - ) + auth_orig = CustomAuth(url='https://127.0.0.1', username='some-username', password='some-password') s = Settings(auth=auth_orig) assert s.auth is auth_orig