add state factory support #113

This commit is contained in:
yanyongyu
2020-12-20 11:59:23 +08:00
parent 168cc3801a
commit 1b00fe7921
4 changed files with 239 additions and 164 deletions

View File

@ -15,7 +15,7 @@ from typing import Type, List, Dict, Union, Callable, Optional, NoReturn, TYPE_C
from nonebot.rule import Rule
from nonebot.log import logger
from nonebot.permission import Permission, USER
from nonebot.typing import T_State, T_Handler, T_ArgsParser
from nonebot.typing import T_State, T_StateFactory, T_Handler, T_ArgsParser
from nonebot.exception import PausedException, RejectedException, FinishedException
if TYPE_CHECKING:
@ -95,6 +95,11 @@ class Matcher(metaclass=MatcherMeta):
:类型: ``T_State``
:说明: 事件响应器默认状态
"""
_default_state_factory: Optional[T_StateFactory] = None
"""
:类型: ``Optional[T_State]``
:说明: 事件响应器默认工厂函数
"""
_default_parser: Optional[T_ArgsParser] = None
"""
@ -126,6 +131,7 @@ class Matcher(metaclass=MatcherMeta):
*,
module: Optional[str] = None,
default_state: Optional[T_State] = None,
default_state_factory: Optional[T_StateFactory] = None,
expire_time: Optional[datetime] = None) -> Type["Matcher"]:
"""
:说明:
@ -143,6 +149,7 @@ class Matcher(metaclass=MatcherMeta):
* ``block: bool``: 是否阻止事件向更低优先级的响应器传播
* ``module: Optional[str]``: 事件响应器所在模块名称
* ``default_state: Optional[T_State]``: 默认状态 ``state``
* ``default_state_factory: Optional[T_StateFactory]``: 默认状态 ``state`` 的工厂函数
* ``expire_time: Optional[datetime]``: 事件响应器最终有效时间点,过时即被删除
:返回:
@ -161,7 +168,8 @@ class Matcher(metaclass=MatcherMeta):
"expire_time": expire_time,
"priority": priority,
"block": block,
"_default_state": default_state or {}
"_default_state": default_state or {},
"_default_state_factory": default_state_factory
})
matchers[priority].append(NewMatcher)
@ -452,11 +460,13 @@ class Matcher(metaclass=MatcherMeta):
e_t = current_event.set(event)
try:
# Refresh preprocess state
self.state.update(state)
state_ = await self._default_state_factory(
bot, event) if self._default_state_factory else self.state
state_.update(state)
for _ in range(len(self.handlers)):
handler = self.handlers.pop(0)
await self.run_handler(handler, bot, event, self.state)
await self.run_handler(handler, bot, event, state_)
except RejectedException:
self.handlers.insert(0, handler) # type: ignore