🐛 fix cannot reject preset arg

This commit is contained in:
yanyongyu 2021-12-24 14:09:43 +08:00
parent 17f3c8fd09
commit 6643f951ef
4 changed files with 49 additions and 10 deletions

View File

@ -3,6 +3,7 @@ RECEIVE_KEY = "_receive_{id}"
LAST_RECEIVE_KEY = "_last_receive" LAST_RECEIVE_KEY = "_last_receive"
ARG_KEY = "{key}" ARG_KEY = "{key}"
REJECT_TARGET = "_current_target" REJECT_TARGET = "_current_target"
REJECT_CACHE_TARGET = "_next_target"
# used by Rule # used by Rule
PREFIX_KEY = "_prefix" PREFIX_KEY = "_prefix"

View File

@ -28,7 +28,6 @@ from nonebot.rule import Rule
from nonebot.log import logger from nonebot.log import logger
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.permission import USER, Permission from nonebot.permission import USER, Permission
from nonebot.consts import ARG_KEY, RECEIVE_KEY, REJECT_TARGET, LAST_RECEIVE_KEY
from nonebot.adapters import ( from nonebot.adapters import (
Bot, Bot,
Event, Event,
@ -36,6 +35,13 @@ from nonebot.adapters import (
MessageSegment, MessageSegment,
MessageTemplate, MessageTemplate,
) )
from nonebot.consts import (
ARG_KEY,
RECEIVE_KEY,
REJECT_TARGET,
LAST_RECEIVE_KEY,
REJECT_CACHE_TARGET,
)
from nonebot.exception import ( from nonebot.exception import (
PausedException, PausedException,
StopPropagation, StopPropagation,
@ -432,12 +438,12 @@ class Matcher(metaclass=MatcherMeta):
""" """
async def _receive(event: Event, matcher: "Matcher") -> Union[None, NoReturn]: async def _receive(event: Event, matcher: "Matcher") -> Union[None, NoReturn]:
matcher.set_target(RECEIVE_KEY.format(id=id))
if matcher.get_target() == RECEIVE_KEY.format(id=id): if matcher.get_target() == RECEIVE_KEY.format(id=id):
matcher.set_receive(id, event) matcher.set_receive(id, event)
return return
if matcher.get_receive(id): if matcher.get_receive(id):
return return
matcher.set_target(RECEIVE_KEY.format(id=id))
raise RejectedException raise RejectedException
_parameterless = [params.Depends(_receive), *(parameterless or [])] _parameterless = [params.Depends(_receive), *(parameterless or [])]
@ -476,12 +482,13 @@ class Matcher(metaclass=MatcherMeta):
""" """
async def _key_getter(event: Event, matcher: "Matcher"): async def _key_getter(event: Event, matcher: "Matcher"):
print(key, matcher.state)
matcher.set_target(ARG_KEY.format(key=key))
if matcher.get_target() == ARG_KEY.format(key=key): if matcher.get_target() == ARG_KEY.format(key=key):
matcher.set_arg(key, event.get_message()) matcher.set_arg(key, event.get_message())
return return
if matcher.get_arg(key): if matcher.get_arg(key):
return return
matcher.set_target(ARG_KEY.format(key=key))
if prompt is not None: if prompt is not None:
await matcher.send(prompt) await matcher.send(prompt)
raise RejectedException raise RejectedException
@ -654,8 +661,11 @@ class Matcher(metaclass=MatcherMeta):
def set_arg(self, key: str, message: Message) -> None: def set_arg(self, key: str, message: Message) -> None:
self.state[ARG_KEY.format(key=key)] = message self.state[ARG_KEY.format(key=key)] = message
def set_target(self, target: str) -> None: def set_target(self, target: str, cache: bool = True) -> None:
self.state[REJECT_TARGET] = target if cache:
self.state[REJECT_CACHE_TARGET] = target
else:
self.state[REJECT_TARGET] = target
def get_target(self, default: T = None) -> Union[str, T]: def get_target(self, default: T = None) -> Union[str, T]:
return self.state.get(REJECT_TARGET, default) return self.state.get(REJECT_TARGET, default)
@ -680,6 +690,11 @@ class Matcher(metaclass=MatcherMeta):
return USER(event.get_session_id(), perm=self.permission) return USER(event.get_session_id(), perm=self.permission)
return await updater(bot=bot, event=event, state=self.state, matcher=self) return await updater(bot=bot, event=event, state=self.state, matcher=self)
async def resolve_reject(self):
handler = current_handler.get()
self.handlers.insert(0, handler)
self.state[REJECT_TARGET] = self.state[REJECT_CACHE_TARGET]
async def simple_run( async def simple_run(
self, self,
bot: Bot, bot: Bot,
@ -734,9 +749,7 @@ class Matcher(metaclass=MatcherMeta):
await self.simple_run(bot, event, state, stack, dependency_cache) await self.simple_run(bot, event, state, stack, dependency_cache)
except RejectedException: except RejectedException:
handler = current_handler.get() await self.resolve_reject()
self.handlers.insert(0, handler)
type_ = await self.update_type(bot, event) type_ = await self.update_type(bot, event)
permission = await self.update_permission(bot, event) permission = await self.update_permission(bot, event)

View File

@ -1,6 +1,7 @@
from nonebot import on_message from nonebot import on_message
from nonebot.adapters import Event from nonebot.matcher import Matcher
from nonebot.params import ArgStr, Received, LastReceived from nonebot.adapters import Event, Message
from nonebot.params import ArgStr, Received, EventMessage, LastReceived
test_handle = on_message() test_handle = on_message()
@ -54,3 +55,19 @@ async def combine(a: str = ArgStr(), b: str = ArgStr(), r: Event = Received()):
assert a == "text_next" assert a == "text_next"
assert b == "text_next" assert b == "text_next"
assert str(r.get_message()) == "text_next" assert str(r.get_message()) == "text_next"
test_preset = on_message()
@test_preset.handle()
async def preset(matcher: Matcher, message: Message = EventMessage()):
matcher.set_arg("a", message)
@test_preset.got("a")
async def reject_preset(a: str = ArgStr()):
if a == "text":
await test_preset.reject_arg("a")
assert a == "text_next"

View File

@ -9,6 +9,7 @@ async def test_matcher(app: App, load_plugin):
from plugins.matcher import ( from plugins.matcher import (
test_got, test_got,
test_handle, test_handle,
test_preset,
test_combine, test_combine,
test_receive, test_receive,
) )
@ -58,3 +59,10 @@ async def test_matcher(app: App, load_plugin):
ctx.receive_event(bot, event_next) ctx.receive_event(bot, event_next)
ctx.should_rejected() ctx.should_rejected()
ctx.receive_event(bot, event_next) ctx.receive_event(bot, event_next)
assert len(test_preset.handlers) == 2
async with app.test_matcher(test_preset) as ctx:
bot = ctx.create_bot()
ctx.receive_event(bot, event)
ctx.should_rejected()
ctx.receive_event(bot, event_next)