添加依赖注入支持,重构函数调用上下文,优化插件加载机制

This commit is contained in:
2024-12-17 13:25:30 +08:00
parent a0f657b239
commit a2c4fb220e
4 changed files with 132 additions and 42 deletions

View File

@ -3,11 +3,13 @@ from typing import Any
from nonebot import logger
from nonebot.adapters import Bot, Event
from nonebot.matcher import Matcher
from nonebot.permission import Permission
from nonebot.rule import Rule
from nonebot.typing import T_State
from ..typing import ASYNC_FUNCTION_CALL_FUNC, F
from .models import SessionContext, SessionContextDepends
from .utils import async_wrap, is_coroutine_callable
_caller_data: dict[str, "Caller"] = {}
@ -19,10 +21,15 @@ class Caller:
self._description = description
self.func: ASYNC_FUNCTION_CALL_FUNC | None = None
self._parameters: dict[str, Any] = {}
"""依赖注入的参数"""
self.bot: Bot | None = None
self.event: Event | None = None
self.state: T_State | None = None
"""声明参数"""
self.di: SessionContextDepends = SessionContextDepends()
"""依赖注入的参数信息"""
self.default: dict[str, Any] = {}
"""默认值"""
self.ctx: SessionContext | None = None
self._permission: Permission | None = None
self._rule: Rule | None = None
@ -36,14 +43,20 @@ class Caller:
return self
async def pre_check(self) -> tuple[bool, str]:
if self.bot is None or self.event is None:
if self.ctx is None:
return False, "上下文为空"
if self.ctx.bot is None or self.ctx.event is None:
return False, "Context is None"
if self._permission and not await self._permission(self.bot, self.event):
if self._permission and not await self._permission(
self.ctx.bot, self.ctx.event
):
return False, "告诉用户 Permission Denied 权限不足"
if self.state is None:
if self.ctx.state is None:
return False, "State is None"
if self._rule and not await self._rule(self.bot, self.event, self.state):
if self._rule and not await self._rule(
self.ctx.bot, self.ctx.event, self.ctx.state
):
return False, "告诉用户 Rule Denied 规则不匹配"
return True, ""
@ -86,6 +99,35 @@ class Caller:
self._name = f"{module_name}-{func.__name__}"
_caller_data[self._name] = self
# 检查函数签名,确定依赖注入参数
sig = inspect.signature(func)
for name, param in sig.parameters.items():
if issubclass(param.annotation, Event) or isinstance(
param.annotation, Event
):
self.di.event = name
if issubclass(param.annotation, Caller) or isinstance(
param.annotation, Caller
):
self.di.caller = name
if issubclass(param.annotation, Bot) or isinstance(param.annotation, Bot):
self.di.bot = name
if issubclass(param.annotation, Matcher) or isinstance(
param.annotation, Matcher
):
self.di.matcher = name
if param.annotation == T_State:
self.di.state = name
# 检查默认值情况
for name, param in sig.parameters.items():
if param.default is not inspect.Parameter.empty:
self.default[name] = param.default
if is_coroutine_callable(func):
self.func = func # type: ignore
else:
@ -126,11 +168,30 @@ class Caller:
},
}
def set_event(self, event: Event):
self.event = event
def set_ctx(self, ctx: SessionContext) -> None:
"""设置依赖注入上下文
def set_bot(self, bot: Bot):
self.bot = bot
Args:
ctx (SessionContext): 依赖注入上下文
"""
ctx.caller = self
self.ctx = ctx
for type_name, arg_name in self.di.model_dump().items():
if arg_name:
self.default[arg_name] = ctx.__getattribute__(type_name)
def with_ctx(self, ctx: SessionContext) -> "Caller":
"""设置依赖注入上下文
Args:
ctx (SessionContext): 依赖注入上下文
Returns:
Caller: Caller对象
"""
self.set_ctx(ctx)
return self
async def call(self, *args: Any, **kwargs: Any) -> Any:
"""调用函数
@ -145,28 +206,11 @@ class Caller:
if self.func is None:
raise ValueError("未注册函数对象")
sig = inspect.signature(self.func)
for name, param in sig.parameters.items():
if issubclass(param.annotation, Event) or isinstance(
param.annotation, Event
):
kwargs[name] = self.event
if issubclass(param.annotation, Caller) or isinstance(
param.annotation, Caller
):
kwargs[name] = self
if issubclass(param.annotation, Bot) or isinstance(param.annotation, Bot):
kwargs[name] = self.bot
if param.annotation == T_State:
kwargs[name] = self.state
# 检查形参是否有默认值或传入若没有则用parameters中的默认值填充
for name, param in sig.parameters.items():
# 检查形参是否有默认值或传入若没有则用default中的默认值填充
for name, value in self.default.items():
if name not in kwargs:
kwargs[name] = self._parameters.get(name, param.default)
kwargs[name] = value
return await self.func(*args, **kwargs)