♻️ rewrite dependency injection system

This commit is contained in:
yanyongyu
2021-12-12 18:19:08 +08:00
parent 6b5a5e53eb
commit 66ba25494a
17 changed files with 728 additions and 733 deletions

View File

@ -17,6 +17,7 @@ from typing import (
List,
Type,
Union,
TypeVar,
Callable,
NoReturn,
Optional,
@ -25,9 +26,10 @@ from typing import (
from nonebot import params
from nonebot.rule import Rule
from nonebot.log import logger
from nonebot.handler import Handler
from nonebot.dependencies import DependsWrapper
from nonebot.utils import CacheDict
from nonebot.dependencies import Dependent
from nonebot.permission import USER, Permission
from nonebot.consts import ARG_KEY, ARG_STR_KEY, RECEIVE_KEY, REJECT_TARGET
from nonebot.adapters import (
Bot,
Event,
@ -35,6 +37,14 @@ from nonebot.adapters import (
MessageSegment,
MessageTemplate,
)
from nonebot.typing import (
Any,
T_State,
T_Handler,
T_ArgsParser,
T_TypeUpdater,
T_PermissionUpdater,
)
from nonebot.exception import (
PausedException,
StopPropagation,
@ -42,19 +52,12 @@ from nonebot.exception import (
FinishedException,
RejectedException,
)
from nonebot.typing import (
T_State,
T_Handler,
T_ArgsParser,
T_TypeUpdater,
T_StateFactory,
T_DependencyCache,
T_PermissionUpdater,
)
if TYPE_CHECKING:
from nonebot.plugin import Plugin
T = TypeVar("T")
matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list)
"""
:类型: ``Dict[int, List[Type[Matcher]]]``
@ -63,7 +66,7 @@ matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list)
current_bot: ContextVar[Bot] = ContextVar("current_bot")
current_event: ContextVar[Event] = ContextVar("current_event")
current_state: ContextVar[T_State] = ContextVar("current_state")
current_handler: ContextVar[Handler] = ContextVar("current_handler")
current_handler: ContextVar[Dependent] = ContextVar("current_handler")
class MatcherMeta(type):
@ -131,7 +134,7 @@ class Matcher(metaclass=MatcherMeta):
:类型: ``Permission``
:说明: 事件响应器触发权限
"""
handlers: List[Handler] = []
handlers: List[Dependent[Any]] = []
"""
:类型: ``List[Handler]``
:说明: 事件响应器拥有的事件处理函数列表
@ -163,23 +166,24 @@ class Matcher(metaclass=MatcherMeta):
:说明: 事件响应器默认状态
"""
_default_parser: Optional[T_ArgsParser] = None
_default_parser: Optional[Dependent[None]] = None
"""
:类型: ``Optional[T_ArgsParser]``
:类型: ``Optional[Dependent]``
:说明: 事件响应器默认参数解析函数
"""
_default_type_updater: Optional[T_TypeUpdater] = None
_default_type_updater: Optional[Dependent[str]] = None
"""
:类型: ``Optional[T_TypeUpdater]``
:类型: ``Optional[Dependent]``
:说明: 事件响应器类型更新函数
"""
_default_permission_updater: Optional[T_PermissionUpdater] = None
_default_permission_updater: Optional[Dependent[Permission]] = None
"""
:类型: ``Optional[T_PermissionUpdater]``
:类型: ``Optional[Dependent]``
:说明: 事件响应器权限更新函数
"""
HANDLER_PARAM_TYPES = [
params.DependParam,
params.BotParam,
params.EventParam,
params.StateParam,
@ -207,9 +211,7 @@ class Matcher(metaclass=MatcherMeta):
type_: str = "",
rule: Optional[Rule] = None,
permission: Optional[Permission] = None,
handlers: Optional[
Union[List[T_Handler], List[Handler], List[Union[T_Handler, Handler]]]
] = None,
handlers: Optional[List[Union[T_Handler, Dependent[Any]]]] = None,
temp: bool = False,
priority: int = 1,
block: bool = False,
@ -259,8 +261,10 @@ class Matcher(metaclass=MatcherMeta):
"permission": permission or Permission(),
"handlers": [
handler
if isinstance(handler, Handler)
else Handler(handler, allow_types=cls.HANDLER_PARAM_TYPES)
if isinstance(handler, Dependent)
else Dependent[Any].parse(
call=handler, allow_types=cls.HANDLER_PARAM_TYPES
)
for handler in handlers
]
if handlers
@ -286,7 +290,7 @@ class Matcher(metaclass=MatcherMeta):
bot: Bot,
event: Event,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
) -> bool:
"""
:说明:
@ -314,7 +318,7 @@ class Matcher(metaclass=MatcherMeta):
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
) -> bool:
"""
:说明:
@ -347,7 +351,9 @@ class Matcher(metaclass=MatcherMeta):
* ``func: T_ArgsParser``: 参数解析函数
"""
cls._default_parser = func
cls._default_parser = Dependent[None].parse(
call=func, allow_types=cls.HANDLER_PARAM_TYPES
)
return func
@classmethod
@ -361,7 +367,9 @@ class Matcher(metaclass=MatcherMeta):
* ``func: T_TypeUpdater``: 响应事件类型更新函数
"""
cls._default_type_updater = func
cls._default_type_updater = Dependent[str].parse(
call=func, allow_types=cls.HANDLER_PARAM_TYPES
)
return func
@classmethod
@ -375,22 +383,26 @@ class Matcher(metaclass=MatcherMeta):
* ``func: T_PermissionUpdater``: 会话权限更新函数
"""
cls._default_permission_updater = func
cls._default_permission_updater = Dependent[Permission].parse(
call=func, allow_types=cls.HANDLER_PARAM_TYPES
)
return func
@classmethod
def append_handler(
cls, handler: T_Handler, dependencies: Optional[List[DependsWrapper]] = None
) -> Handler:
handler_ = Handler(
handler, dependencies=dependencies, allow_types=cls.HANDLER_PARAM_TYPES
cls, handler: T_Handler, parameterless: Optional[List[Any]] = None
) -> Dependent[Any]:
handler_ = Dependent[Any].parse(
call=handler,
parameterless=parameterless,
allow_types=cls.HANDLER_PARAM_TYPES,
)
cls.handlers.append(handler_)
return handler_
@classmethod
def handle(
cls, dependencies: Optional[List[DependsWrapper]] = None
cls, parameterless: Optional[List[Any]] = None
) -> Callable[[T_Handler], T_Handler]:
"""
:说明:
@ -399,18 +411,18 @@ class Matcher(metaclass=MatcherMeta):
:参数:
* ``dependencies: Optional[List[DependsWrapper]]``: 非参数类型依赖列表
* ``parameterless: Optional[List[Any]]``: 非参数类型依赖列表
"""
def _decorator(func: T_Handler) -> T_Handler:
cls.append_handler(func, dependencies=dependencies)
cls.append_handler(func, parameterless=parameterless)
return func
return _decorator
@classmethod
def receive(
cls, dependencies: Optional[List[DependsWrapper]] = None
cls, id: str = "", parameterless: Optional[List[Any]] = None
) -> Callable[[T_Handler], T_Handler]:
"""
:说明:
@ -419,28 +431,30 @@ class Matcher(metaclass=MatcherMeta):
:参数:
* ``dependencies: Optional[List[DependsWrapper]]``: 非参数类型依赖列表
* ``parameterless: Optional[List[Any]]``: 非参数类型依赖列表
"""
async def _receive(state: T_State) -> Union[None, NoReturn]:
if state.get(_receive):
async def _receive(event: Event, matcher: "Matcher") -> Union[None, NoReturn]:
if matcher.get_receive(id):
return
state[_receive] = True
del state["_current_key"]
if matcher.get_target() == RECEIVE_KEY.format(id=id):
matcher.set_receive(id, event)
return
matcher.set_target(RECEIVE_KEY.format(id=id))
raise RejectedException
_dependencies = [DependsWrapper(_receive), *(dependencies or [])]
parameterless = [params.Depends(_receive), *(parameterless or [])]
def _decorator(func: T_Handler) -> T_Handler:
if cls.handlers and cls.handlers[-1].call is func:
func_handler = cls.handlers[-1]
for depend in reversed(_dependencies):
func_handler.prepend_dependency(depend)
for depend in reversed(parameterless):
func_handler.prepend_parameterless(depend)
else:
cls.append_handler(
func,
dependencies=_dependencies if cls.handlers else dependencies,
parameterless=parameterless if cls.handlers else parameterless,
)
return func
@ -453,7 +467,7 @@ class Matcher(metaclass=MatcherMeta):
key: str,
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
args_parser: Optional[T_ArgsParser] = None,
dependencies: Optional[List[DependsWrapper]] = None,
parameterless: Optional[List[Any]] = None,
) -> Callable[[T_Handler], T_Handler]:
"""
:说明:
@ -465,51 +479,31 @@ class Matcher(metaclass=MatcherMeta):
* ``key: str``: 参数名
* ``prompt: Optional[Union[str, Message, MessageSegment, MessageFormatter]]``: 在参数不存在时向用户发送的消息
* ``args_parser: Optional[T_ArgsParser]``: 可选参数解析函数,空则使用默认解析函数
* ``dependencies: Optional[List[DependsWrapper]]``: 非参数类型依赖列表
* ``parameterless: Optional[List[Any]]``: 非参数类型依赖列表
"""
async def _key_getter(bot: Bot, event: Event, state: T_State):
if state.get(f"_{key}_prompted"):
async def _key_getter(event: Event, matcher: "Matcher"):
if matcher.get_arg(key):
return
state["_current_key"] = key
state[f"_{key}_prompted"] = True
if key not in state:
if prompt is not None:
if isinstance(prompt, MessageTemplate):
_prompt = prompt.format(**state)
else:
_prompt = prompt
await bot.send(event=event, message=_prompt)
raise RejectedException
else:
state[f"_{key}_parsed"] = True
async def _key_parser(bot: Bot, event: Event, state: T_State):
if key in state and state.get(f"_{key}_parsed"):
if matcher.get_target() == ARG_KEY.format(key=key):
matcher.set_arg(key, event)
return
matcher.set_target(ARG_KEY.format(key=key))
raise RejectedException
parser = args_parser or cls._default_parser
if parser:
await parser(bot, event, state)
else:
state[key] = str(event.get_message())
state[f"_{key}_parsed"] = True
_dependencies = [
DependsWrapper(_key_getter),
DependsWrapper(_key_parser),
*(dependencies or []),
_parameterless = [
params.Depends(_key_getter),
*(parameterless or []),
]
def _decorator(func: T_Handler) -> T_Handler:
if cls.handlers and cls.handlers[-1].call is func:
func_handler = cls.handlers[-1]
for depend in reversed(_dependencies):
func_handler.prepend_dependency(depend)
for depend in reversed(_parameterless):
func_handler.prepend_parameterless(depend)
else:
cls.append_handler(func, dependencies=_dependencies)
cls.append_handler(func, parameterless=_parameterless)
return func
@ -609,8 +603,6 @@ class Matcher(metaclass=MatcherMeta):
bot = current_bot.get()
event = current_event.get()
state = current_state.get()
if "_current_key" in state and f"_{state['_current_key']}_parsed" in state:
del state[f"_{state['_current_key']}_parsed"]
if isinstance(prompt, MessageTemplate):
_prompt = prompt.format(**state)
else:
@ -619,6 +611,28 @@ class Matcher(metaclass=MatcherMeta):
await bot.send(event=event, message=_prompt, **kwargs)
raise RejectedException
def get_receive(self, id: str, default: T = None) -> Union[Event, T]:
return self.state.get(RECEIVE_KEY.format(id=id), default)
def set_receive(self, id: str, event: Event) -> None:
self.state[RECEIVE_KEY.format(id=id)] = event
def get_arg(self, key: str, default: T = None) -> Union[Event, T]:
return self.state.get(ARG_KEY.format(key=key), default)
def get_arg_str(self, key: str, default: T = None) -> Union[str, T]:
return self.state.get(ARG_STR_KEY.format(key=key), default)
def set_arg(self, key: str, event: Event) -> None:
self.state[ARG_KEY.format(key=key)] = event
self.state[ARG_STR_KEY.format(key=key)] = str(event.get_message())
def set_target(self, target: str) -> None:
self.state[REJECT_TARGET] = target
def get_target(self, default: T = None) -> Union[str, T]:
return self.state.get(REJECT_TARGET, default)
def stop_propagation(self):
"""
:说明:
@ -631,13 +645,13 @@ class Matcher(metaclass=MatcherMeta):
updater = self.__class__._default_type_updater
if not updater:
return "message"
return await updater(bot, event, self.state, self.type)
return await updater(bot=bot, event=event, state=self.state, matcher=self)
async def update_permission(self, bot: Bot, event: Event) -> Permission:
updater = self.__class__._default_permission_updater
if not updater:
return USER(event.get_session_id(), perm=self.permission)
return await updater(bot, event, self.state, self.permission)
return await updater(bot=bot, event=event, state=self.state, matcher=self)
async def simple_run(
self,
@ -645,7 +659,7 @@ class Matcher(metaclass=MatcherMeta):
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
):
b_t = current_bot.set(bot)
e_t = current_event.set(event)
@ -664,8 +678,8 @@ class Matcher(metaclass=MatcherMeta):
bot=bot,
event=event,
state=self.state,
_stack=stack,
_dependency_cache=dependency_cache,
stack=stack,
dependency_cache=dependency_cache,
)
except SkippedException as e:
logger.debug(
@ -687,7 +701,7 @@ class Matcher(metaclass=MatcherMeta):
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
):
try:
await self.simple_run(bot, event, state, stack, dependency_cache)