add contextvars and fix mutable default args

This commit is contained in:
yanyongyu
2020-08-25 15:23:10 +08:00
parent d66259da2b
commit c5ea8bc1c3
5 changed files with 112 additions and 33 deletions

View File

@ -6,14 +6,17 @@ import inspect
from functools import wraps
from datetime import datetime
from collections import defaultdict
from contextvars import Context, ContextVar, copy_context
from nonebot.rule import Rule
from nonebot.permission import Permission, USER
from nonebot.typing import Bot, Event, Handler, ArgsParser
from nonebot.typing import Type, List, Dict, Callable, Optional, NoReturn
from nonebot.typing import Type, List, Dict, Union, Callable, Optional, NoReturn
from nonebot.typing import Bot, Event, Handler, Message, ArgsParser, MessageSegment
from nonebot.exception import PausedException, RejectedException, FinishedException
matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list)
current_bot: ContextVar = ContextVar("current_bot")
current_event: ContextVar = ContextVar("current_event")
class Matcher:
@ -51,12 +54,12 @@ class Matcher:
type_: str = "",
rule: Rule = Rule(),
permission: Permission = Permission(),
handlers: list = [],
handlers: Optional[list] = None,
temp: bool = False,
priority: int = 1,
block: bool = False,
*,
default_state: dict = {},
default_state: Optional[dict] = None,
expire_time: Optional[datetime] = None) -> Type["Matcher"]:
"""创建新的 Matcher
@ -69,12 +72,12 @@ class Matcher:
"type": type_,
"rule": rule,
"permission": permission,
"handlers": handlers,
"handlers": handlers or [],
"temp": temp,
"expire_time": expire_time,
"priority": priority,
"block": block,
"_default_state": default_state
"_default_state": default_state or {}
})
matchers[priority].append(NewMatcher)
@ -117,12 +120,12 @@ class Matcher:
def receive(cls) -> Callable[[Handler], Handler]:
"""接收一条新消息并处理"""
async def _handler(bot: Bot, event: Event, state: dict) -> NoReturn:
async def _receive(bot: Bot, event: Event, state: dict) -> NoReturn:
raise PausedException
if cls.handlers:
# 已有前置handlers则接受一条新的消息否则视为接收初始消息
cls.handlers.append(_handler)
cls.handlers.append(_receive)
def _decorator(func: Handler) -> Handler:
if not cls.handlers or cls.handlers[-1] is not func:
@ -144,8 +147,7 @@ class Matcher:
if key not in state:
state["_current_key"] = key
if prompt:
await bot.send_private_msg(user_id=event.user_id,
message=prompt)
await bot.send(event=event, message=prompt)
raise PausedException
async def _key_parser(bot: Bot, event: Event, state: dict):
@ -176,19 +178,42 @@ class Matcher:
return _decorator
@classmethod
def finish(cls) -> NoReturn:
async def finish(
cls,
prompt: Optional[Union[str, Message,
MessageSegment]] = None) -> NoReturn:
bot: Bot = current_bot.get()
event: Event = current_event.get()
if prompt:
await bot.send(event=event, message=prompt)
raise FinishedException
@classmethod
def pause(cls) -> NoReturn:
async def pause(
cls,
prompt: Optional[Union[str, Message,
MessageSegment]] = None) -> NoReturn:
bot: Bot = current_bot.get()
event: Event = current_event.get()
if prompt:
await bot.send(event=event, message=prompt)
raise PausedException
@classmethod
def reject(cls) -> NoReturn:
async def reject(
cls,
prompt: Optional[Union[str, Message,
MessageSegment]] = None) -> NoReturn:
bot: Bot = current_bot.get()
event: Event = current_event.get()
if prompt:
await bot.send(event=event, message=prompt)
raise RejectedException
# 运行handlers
async def run(self, bot: Bot, event: Event, state: dict):
b_t = current_bot.set(bot)
e_t = current_event.set(event)
try:
# Refresh preprocess state
self.state.update(state)
@ -214,7 +239,6 @@ class Matcher:
block=True,
default_state=self.state,
expire_time=datetime.now() + bot.config.session_expire_timeout)
return
except PausedException:
Matcher.new(
self.type,
@ -226,6 +250,8 @@ class Matcher:
block=True,
default_state=self.state,
expire_time=datetime.now() + bot.config.session_expire_timeout)
return
except FinishedException:
return
pass
finally:
current_bot.reset(b_t)
current_event.reset(e_t)