mirror of
https://github.com/nonebot/nonebot2.git
synced 2026-01-15 00:02:14 +00:00
✨ Feature: 允许插件从环境变量中读取配置项并支持 alias (#3673)
Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com>
This commit is contained in:
@@ -51,7 +51,7 @@ class SettingsError(ValueError): ...
|
||||
|
||||
|
||||
class BaseSettingsSource(abc.ABC):
|
||||
def __init__(self, settings_cls: type["BaseSettings"]) -> None:
|
||||
def __init__(self, settings_cls: type[BaseModel]) -> None:
|
||||
self.settings_cls = settings_cls
|
||||
|
||||
@property
|
||||
@@ -67,7 +67,7 @@ class InitSettingsSource(BaseSettingsSource):
|
||||
__slots__ = ("init_kwargs",)
|
||||
|
||||
def __init__(
|
||||
self, settings_cls: type["BaseSettings"], init_kwargs: dict[str, Any]
|
||||
self, settings_cls: type[BaseModel], init_kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
self.init_kwargs = init_kwargs
|
||||
super().__init__(settings_cls)
|
||||
@@ -82,33 +82,17 @@ class InitSettingsSource(BaseSettingsSource):
|
||||
class DotEnvSettingsSource(BaseSettingsSource):
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type["BaseSettings"],
|
||||
env_file: Optional[DOTENV_TYPE] = ENV_FILE_SENTINEL,
|
||||
env_file_encoding: Optional[str] = None,
|
||||
case_sensitive: Optional[bool] = None,
|
||||
settings_cls: type[BaseModel],
|
||||
env_file: Optional[DOTENV_TYPE],
|
||||
env_file_encoding: str,
|
||||
case_sensitive: Optional[bool] = False,
|
||||
env_nested_delimiter: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__(settings_cls)
|
||||
self.env_file = (
|
||||
env_file
|
||||
if env_file is not ENV_FILE_SENTINEL
|
||||
else self.config.get("env_file", (".env",))
|
||||
)
|
||||
self.env_file_encoding = (
|
||||
env_file_encoding
|
||||
if env_file_encoding is not None
|
||||
else self.config.get("env_file_encoding", "utf-8")
|
||||
)
|
||||
self.case_sensitive = (
|
||||
case_sensitive
|
||||
if case_sensitive is not None
|
||||
else self.config.get("case_sensitive", False)
|
||||
)
|
||||
self.env_nested_delimiter = (
|
||||
env_nested_delimiter
|
||||
if env_nested_delimiter is not None
|
||||
else self.config.get("env_nested_delimiter", None)
|
||||
)
|
||||
self.env_file = env_file
|
||||
self.env_file_encoding = env_file_encoding
|
||||
self.case_sensitive = case_sensitive
|
||||
self.env_nested_delimiter = env_nested_delimiter
|
||||
|
||||
def _apply_case_sensitive(self, var_name: str) -> str:
|
||||
return var_name if self.case_sensitive else var_name.lower()
|
||||
@@ -212,12 +196,33 @@ class DotEnvSettingsSource(BaseSettingsSource):
|
||||
for field in model_fields(self.settings_cls):
|
||||
field_name = field.name
|
||||
env_name = self._apply_case_sensitive(field_name)
|
||||
alias_name = field.field_info.alias
|
||||
alias_env_name = (
|
||||
None if alias_name is None else self._apply_case_sensitive(alias_name)
|
||||
)
|
||||
|
||||
# pydantic use alias name to validate if exist
|
||||
if alias_name is not None:
|
||||
field_name = alias_name
|
||||
|
||||
# try get values from env vars
|
||||
env_val = env_vars.get(env_name, PydanticUndefined)
|
||||
alias_env_val = (
|
||||
PydanticUndefined
|
||||
if alias_env_name is None
|
||||
else env_vars.get(alias_env_name, PydanticUndefined)
|
||||
)
|
||||
# alias env value has higher priority
|
||||
env_val = (
|
||||
env_val
|
||||
if isinstance(alias_env_val, PydanticUndefinedType)
|
||||
else alias_env_val
|
||||
)
|
||||
# delete from file vars when used
|
||||
if env_name in env_file_vars:
|
||||
del env_file_vars[env_name]
|
||||
if alias_env_name is not None and alias_env_name in env_file_vars:
|
||||
del env_file_vars[alias_env_name]
|
||||
|
||||
is_complex, allow_parse_failure = self._field_is_complex(field)
|
||||
if is_complex:
|
||||
@@ -331,25 +336,48 @@ class BaseSettings(BaseModel):
|
||||
_env_nested_delimiter: Optional[str] = None,
|
||||
**values: Any,
|
||||
) -> None:
|
||||
settings_config = model_config(__settings_self__.__class__)
|
||||
env_file = (
|
||||
_env_file
|
||||
if _env_file is not ENV_FILE_SENTINEL
|
||||
else settings_config.get("env_file", (".env",))
|
||||
)
|
||||
env_file_encoding = (
|
||||
_env_file_encoding
|
||||
if _env_file_encoding is not None
|
||||
else settings_config.get("env_file_encoding", "utf-8")
|
||||
)
|
||||
env_nested_delimiter = (
|
||||
_env_nested_delimiter
|
||||
if _env_nested_delimiter is not None
|
||||
else settings_config.get("env_nested_delimiter", None)
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
**__settings_self__._settings_build_values(
|
||||
__settings_self__.__class__,
|
||||
values,
|
||||
env_file=_env_file,
|
||||
env_file_encoding=_env_file_encoding,
|
||||
env_nested_delimiter=_env_nested_delimiter,
|
||||
env_file=env_file,
|
||||
env_file_encoding=env_file_encoding,
|
||||
env_nested_delimiter=env_nested_delimiter,
|
||||
)
|
||||
)
|
||||
|
||||
__settings_self__._env_file = env_file
|
||||
__settings_self__._env_file_encoding = env_file_encoding
|
||||
__settings_self__._env_nested_delimiter = env_nested_delimiter
|
||||
|
||||
@staticmethod
|
||||
def _settings_build_values(
|
||||
self,
|
||||
settings_cls: type[BaseModel],
|
||||
init_kwargs: dict[str, Any],
|
||||
env_file: Optional[DOTENV_TYPE] = None,
|
||||
env_file_encoding: Optional[str] = None,
|
||||
env_nested_delimiter: Optional[str] = None,
|
||||
env_file: Optional[DOTENV_TYPE],
|
||||
env_file_encoding: str,
|
||||
env_nested_delimiter: Optional[str],
|
||||
) -> dict[str, Any]:
|
||||
init_settings = InitSettingsSource(self.__class__, init_kwargs=init_kwargs)
|
||||
init_settings = InitSettingsSource(settings_cls, init_kwargs=init_kwargs)
|
||||
env_settings = DotEnvSettingsSource(
|
||||
self.__class__,
|
||||
settings_cls,
|
||||
env_file=env_file,
|
||||
env_file_encoding=env_file_encoding,
|
||||
env_nested_delimiter=env_nested_delimiter,
|
||||
|
||||
@@ -47,6 +47,7 @@ from pydantic import BaseModel
|
||||
|
||||
from nonebot import get_driver
|
||||
from nonebot.compat import model_dump, type_validate_python
|
||||
from nonebot.config import BaseSettings
|
||||
|
||||
C = TypeVar("C", bound=BaseModel)
|
||||
|
||||
@@ -172,7 +173,17 @@ def get_available_plugin_names() -> set[str]:
|
||||
|
||||
def get_plugin_config(config: type[C]) -> C:
|
||||
"""从全局配置获取当前插件需要的配置项。"""
|
||||
return type_validate_python(config, model_dump(get_driver().config))
|
||||
global_config = get_driver().config
|
||||
return type_validate_python(
|
||||
config,
|
||||
BaseSettings._settings_build_values(
|
||||
config,
|
||||
model_dump(global_config),
|
||||
env_file=global_config._env_file,
|
||||
env_file_encoding=global_config._env_file_encoding,
|
||||
env_nested_delimiter=global_config._env_nested_delimiter,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
from .load import inherit_supported_adapters as inherit_supported_adapters
|
||||
|
||||
Reference in New Issue
Block a user