Feature: 使用自定义配置加载替代 pydantic-settings (#2521)

Co-authored-by: uy/sun <hmy0119@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ju4tCode
2024-01-15 13:24:55 +08:00
committed by GitHub
parent 6c532f5926
commit 1153c5ff17
6 changed files with 520 additions and 180 deletions

View File

@ -12,76 +12,257 @@ FrontMatter:
"""
import os
import abc
from pathlib import Path
from datetime import timedelta
from ipaddress import IPv4Address
from typing import TYPE_CHECKING, Any, Set, Dict, Tuple, Union, Mapping, Optional
from pydantic.utils import deep_update
from pydantic.fields import Undefined, UndefinedType
from pydantic import Extra, Field, BaseSettings, IPvAnyAddress
from pydantic.env_settings import (
DotenvType,
SettingsError,
EnvSettingsSource,
InitSettingsSource,
SettingsSourceCallable,
from typing_extensions import TypeAlias, get_origin
from typing import (
TYPE_CHECKING,
Any,
Set,
Dict,
List,
Type,
Tuple,
Union,
Mapping,
ClassVar,
Optional,
)
from dotenv import dotenv_values
from pydantic.typing import is_union
from pydantic.utils import deep_update
from pydantic.fields import Undefined, ModelField, UndefinedType
from pydantic import Extra, Field, BaseModel, BaseConfig, JsonWrapper, IPvAnyAddress
from nonebot.log import logger
DOTENV_TYPE: TypeAlias = Union[
Path, str, List[Union[Path, str]], Tuple[Union[Path, str], ...]
]
class CustomEnvSettings(EnvSettingsSource):
def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
ENV_FILE_SENTINEL = Path("")
class SettingsError(ValueError):
...
class BaseSettingsSource(abc.ABC):
def __init__(self, settings_cls: Type["BaseSettings"]) -> None:
self.settings_cls = settings_cls
@property
def config(self) -> Type["SettingsConfig"]:
return self.settings_cls.__config__
@abc.abstractmethod
def __call__(self) -> Dict[str, Any]:
raise NotImplementedError
class InitSettingsSource(BaseSettingsSource):
__slots__ = ("init_kwargs",)
def __init__(
self, settings_cls: Type["BaseSettings"], init_kwargs: Dict[str, Any]
) -> None:
self.init_kwargs = init_kwargs
super().__init__(settings_cls)
def __call__(self) -> Dict[str, Any]:
return self.init_kwargs
def __repr__(self) -> str:
return f"InitSettingsSource(init_kwargs={self.init_kwargs!r})"
class DotEnvSettingsSource(BaseSettingsSource):
def __init__(
self,
settings_cls: Type["BaseSettings"],
env_file: Optional[DOTENV_TYPE] = ENV_FILE_SENTINEL,
env_file_encoding: Optional[str] = None,
case_sensitive: Optional[bool] = None,
env_nested_delimiter: Optional[str] = None,
) -> None:
super().__init__(settings_cls)
self.env_file = (
env_file if env_file is not ENV_FILE_SENTINEL else self.config.env_file
)
self.env_file_encoding = (
env_file_encoding
if env_file_encoding is not None
else self.config.env_file_encoding
)
self.case_sensitive = (
case_sensitive if case_sensitive is not None else self.config.case_sensitive
)
self.env_nested_delimiter = (
env_nested_delimiter
if env_nested_delimiter is not None
else self.config.env_nested_delimiter
)
def _apply_case_sensitive(self, var_name: str) -> str:
return var_name if self.case_sensitive else var_name.lower()
def _field_is_complex(self, field: ModelField) -> Tuple[bool, bool]:
try:
if isinstance(field.annotation, type) and issubclass(
field.annotation, JsonWrapper
):
return True, False
except TypeError:
pass
if field.is_complex():
return True, False
elif (
is_union(get_origin(field.type_))
and field.sub_fields
and any(f.is_complex() for f in field.sub_fields)
):
return True, True
return False, False
def _parse_env_vars(
self, env_vars: Mapping[str, Optional[str]]
) -> Dict[str, Optional[str]]:
return {
self._apply_case_sensitive(key): value for key, value in env_vars.items()
}
def _read_env_file(self, file_path: Path) -> Dict[str, Optional[str]]:
file_vars = dotenv_values(file_path, encoding=self.env_file_encoding)
return self._parse_env_vars(file_vars)
def _read_env_files(self) -> Dict[str, Optional[str]]:
env_files = self.env_file
if env_files is None:
return {}
if isinstance(env_files, (str, os.PathLike)):
env_files = [env_files]
dotenv_vars: Dict[str, Optional[str]] = {}
for env_file in env_files:
env_path = Path(env_file).expanduser()
if env_path.is_file():
dotenv_vars.update(self._read_env_file(env_path))
return dotenv_vars
def _next_field(
self, field: Optional[ModelField], key: str
) -> Optional[ModelField]:
if not field or is_union(get_origin(field.annotation)):
return None
elif (
field.annotation
and isinstance(
(fields := getattr(field.annotation, "__fields__", None)), dict
)
and (field := fields.get(key))
):
return field
return None
def _explode_env_vars(
self,
field: ModelField,
env_vars: Dict[str, Optional[str]],
env_file_vars: Dict[str, Optional[str]],
) -> Dict[str, Any]:
if self.env_nested_delimiter is None:
return {}
prefix = f"{field.name}{self.env_nested_delimiter}"
result: Dict[str, Any] = {}
for env_name, env_val in env_vars.items():
if not env_name.startswith(prefix):
continue
# delete from file vars when used
if env_name in env_file_vars:
del env_file_vars[env_name]
_, *keys, last_key = env_name.split(self.env_nested_delimiter)
env_var = result
target_field: Optional[ModelField] = field
for key in keys:
target_field = self._next_field(target_field, key)
env_var = env_var.setdefault(key, {})
target_field = self._next_field(target_field, last_key)
if target_field and env_val:
is_complex, allow_parse_failure = self._field_is_complex(target_field)
if is_complex:
try:
env_val = self.settings_cls.__config__.json_loads(env_val)
except ValueError as e:
if not allow_parse_failure:
raise SettingsError(
f'error parsing env var "{env_name}"'
) from e
env_var[last_key] = env_val
return result
def __call__(self) -> Dict[str, Any]:
"""从环境变量和 dotenv 配置文件中读取配置项。"""
d: Dict[str, Any] = {}
if settings.__config__.case_sensitive:
env_vars: Mapping[str, Optional[str]] = os.environ # pragma: no cover
else:
env_vars = {k.lower(): v for k, v in os.environ.items()}
env_file_vars = self._read_env_files(settings.__config__.case_sensitive)
env_vars = self._parse_env_vars(os.environ)
env_file_vars = self._read_env_files()
env_vars = {**env_file_vars, **env_vars}
for field in settings.__fields__.values():
env_val: Union[str, None, UndefinedType] = Undefined
for env_name in field.field_info.extra["env_names"]:
env_val = env_vars.get(env_name, Undefined)
if env_name in env_file_vars:
del env_file_vars[env_name]
if env_val is not Undefined:
break
for field in self.settings_cls.__fields__.values():
field_key = field.name
env_name = self._apply_case_sensitive(field_key)
# try get values from env vars
env_val = env_vars.get(env_name, Undefined)
# delete from file vars when used
if env_name in env_file_vars:
del env_file_vars[env_name]
is_complex, allow_parse_failure = self.field_is_complex(field)
is_complex, allow_parse_failure = self._field_is_complex(field)
if is_complex:
if isinstance(env_val, UndefinedType):
# field is complex but no value found so far, try explode_env_vars
if env_val_built := self.explode_env_vars(field, env_vars):
d[field.alias] = env_val_built
if env_val_built := self._explode_env_vars(
field, env_vars, env_file_vars
):
d[field_key] = env_val_built
elif env_val is None:
d[field.alias] = env_val
d[field_key] = env_val
else:
# field is complex and there's a value
# decode that as JSON, then add explode_env_vars
try:
env_val = settings.__config__.parse_env_var(field.name, env_val)
env_val = self.settings_cls.__config__.json_loads(env_val)
except ValueError as e:
if not allow_parse_failure:
raise SettingsError(
f'error parsing env var "{env_name}"' # type: ignore
f'error parsing env var "{env_name}"'
) from e
if isinstance(env_val, dict):
d[field.alias] = deep_update(
env_val, self.explode_env_vars(field, env_vars)
# field value is a dict
# try explode_env_vars to find more sub-values
d[field_key] = deep_update(
env_val,
self._explode_env_vars(field, env_vars, env_file_vars),
)
else:
d[field.alias] = env_val
d[field_key] = env_val
elif not isinstance(env_val, UndefinedType):
# simplest case, field is not complex
# we only need to add the value if it was found
d[field.alias] = env_val
d[field_key] = env_val
# remain user custom config
for env_name in env_file_vars:
@ -89,7 +270,7 @@ class CustomEnvSettings(EnvSettingsSource):
if env_val and (val_striped := env_val.strip()):
# there's a value, decode that as JSON
try:
env_val = settings.__config__.parse_env_var(env_name, val_striped)
env_val = self.settings_cls.__config__.json_loads(val_striped)
except ValueError:
logger.trace(
"Error while parsing JSON for "
@ -113,38 +294,58 @@ class CustomEnvSettings(EnvSettingsSource):
return d
class BaseConfig(BaseSettings):
class SettingsConfig(BaseConfig):
extra = Extra.allow
env_file: Optional[DOTENV_TYPE] = None
env_file_encoding: str = "utf-8"
case_sensitive: bool = False
env_nested_delimiter: Optional[str] = "__"
class BaseSettings(BaseModel):
if TYPE_CHECKING:
__config__: ClassVar[Type[SettingsConfig]]
# dummy getattr for pylance checking, actually not used
def __getattr__(self, name: str) -> Any: # pragma: no cover
return self.__dict__.get(name)
class Config:
extra = Extra.allow
env_nested_delimiter = "__"
Config = SettingsConfig
@classmethod
def customise_sources(
cls,
init_settings: InitSettingsSource,
env_settings: EnvSettingsSource,
file_secret_settings: SettingsSourceCallable,
) -> Tuple[SettingsSourceCallable, ...]:
common_config = init_settings.init_kwargs.pop("_common_config", {})
return (
init_settings,
CustomEnvSettings(
env_settings.env_file,
env_settings.env_file_encoding,
env_settings.env_nested_delimiter,
env_settings.env_prefix_len,
),
InitSettingsSource(common_config),
file_secret_settings,
def __init__(
__settings_self__, # pyright: ignore[reportSelfClsParameterName]
_env_file: Optional[DOTENV_TYPE] = ENV_FILE_SENTINEL,
_env_file_encoding: Optional[str] = None,
_env_nested_delimiter: Optional[str] = None,
**values: Any,
) -> None:
super().__init__(
**__settings_self__._settings_build_values(
values,
env_file=_env_file,
env_file_encoding=_env_file_encoding,
env_nested_delimiter=_env_nested_delimiter,
)
)
def _settings_build_values(
self,
init_kwargs: Dict[str, Any],
env_file: Optional[DOTENV_TYPE] = None,
env_file_encoding: Optional[str] = None,
env_nested_delimiter: Optional[str] = None,
) -> Dict[str, Any]:
init_settings = InitSettingsSource(self.__class__, init_kwargs=init_kwargs)
env_settings = DotEnvSettingsSource(
self.__class__,
env_file=env_file,
env_file_encoding=env_file_encoding,
env_nested_delimiter=env_nested_delimiter,
)
return deep_update(env_settings(), init_settings())
class Env(BaseConfig):
class Env(BaseSettings):
"""运行环境配置。大小写不敏感。
将会从 **环境变量** > **dotenv 配置文件** 的优先级读取环境信息。
@ -160,7 +361,7 @@ class Env(BaseConfig):
env_file = ".env"
class Config(BaseConfig):
class Config(BaseSettings):
"""NoneBot 主要配置。大小写不敏感。
除了 NoneBot 的配置项外,还可以自行添加配置项到 `.env.{environment}` 文件中。
@ -169,7 +370,7 @@ class Config(BaseConfig):
配置方法参考: [配置](https://nonebot.dev/docs/appendices/config)
"""
_env_file: DotenvType = ".env", ".env.prod"
_env_file: Optional[DOTENV_TYPE] = ".env", ".env.prod"
# nonebot configs
driver: str = "~fastapi"
@ -259,6 +460,10 @@ class Config(BaseConfig):
__autodoc__ = {
"CustomEnvSettings": False,
"BaseConfig": False,
"SettingsError": False,
"BaseSettingsSource": False,
"InitSettingsSource": False,
"DotEnvSettingsSource": False,
"SettingsConfig": False,
"BaseSettings": False,
}