Feature: 存储 matcher 发送 prompt 的结果 (#3155)

This commit is contained in:
Ju4tCode
2024-12-05 20:55:24 +08:00
committed by GitHub
parent ab8dea5a02
commit 32bc2c314a
8 changed files with 271 additions and 22 deletions

View File

@ -1,7 +1,7 @@
from typing import Annotated
from typing import Annotated, Any
from nonebot.adapters import Message
from nonebot.params import Arg, ArgPlainText, ArgStr
from nonebot.params import Arg, ArgPlainText, ArgPromptResult, ArgStr
async def arg(key: Message = Arg()) -> Message:
@ -28,6 +28,10 @@ async def annotated_arg_plain_text(key: Annotated[str, ArgPlainText()]) -> str:
return key
async def annotated_arg_prompt_result(key: Annotated[Any, ArgPromptResult()]) -> Any:
return key
# test dependency priority
async def annotated_prior_arg(key: Annotated[str, ArgStr("foo")] = ArgPlainText()):
return key

View File

@ -1,8 +1,13 @@
from typing import TypeVar, Union
from typing import Any, TypeVar, Union
from nonebot.adapters import Event
from nonebot.matcher import Matcher
from nonebot.params import LastReceived, Received
from nonebot.params import (
LastReceived,
PausePromptResult,
Received,
ReceivePromptResult,
)
async def matcher(m: Matcher) -> Matcher:
@ -59,3 +64,11 @@ async def receive(e: Event = Received("test")) -> Event:
async def last_receive(e: Event = LastReceived()) -> Event:
return e
async def receive_prompt_result(result: Any = ReceivePromptResult("test")) -> Any:
return result
async def pause_prompt_result(result: Any = PausePromptResult()) -> Any:
return result

View File

@ -1,3 +1,4 @@
from contextlib import suppress
import re
from exceptiongroup import BaseExceptionGroup
@ -5,6 +6,7 @@ from nonebug import App
import pytest
from nonebot.consts import (
ARG_KEY,
CMD_ARG_KEY,
CMD_KEY,
CMD_START_KEY,
@ -14,13 +16,14 @@ from nonebot.consts import (
KEYWORD_KEY,
PREFIX_KEY,
RAW_CMD_KEY,
RECEIVE_KEY,
REGEX_MATCHED,
SHELL_ARGS,
SHELL_ARGV,
STARTSWITH_KEY,
)
from nonebot.dependencies import Dependent
from nonebot.exception import TypeMisMatch
from nonebot.exception import PausedException, RejectedException, TypeMisMatch
from nonebot.matcher import Matcher
from nonebot.params import (
ArgParam,
@ -469,8 +472,10 @@ async def test_matcher(app: App):
matcher,
not_legacy_matcher,
not_matcher,
pause_prompt_result,
postpone_matcher,
receive,
receive_prompt_result,
sub_matcher,
union_matcher,
)
@ -538,12 +543,42 @@ async def test_matcher(app: App):
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(event_next)
fake_matcher.set_target(RECEIVE_KEY.format(id="test"), cache=False)
async with app.test_api() as ctx:
bot = ctx.create_bot()
ctx.should_call_send(event, "test", result=True, bot=bot)
with fake_matcher.ensure_context(bot, event):
with suppress(RejectedException):
await fake_matcher.reject("test")
async with app.test_dependent(
receive_prompt_result, allow_types=[MatcherParam, DependParam]
) as ctx:
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(True)
async with app.test_api() as ctx:
bot = ctx.create_bot()
ctx.should_call_send(event, "test", result=False, bot=bot)
with fake_matcher.ensure_context(bot, event):
fake_matcher.set_target("test")
with suppress(PausedException):
await fake_matcher.pause("test")
async with app.test_dependent(
pause_prompt_result, allow_types=[MatcherParam, DependParam]
) as ctx:
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(False)
@pytest.mark.anyio
async def test_arg(app: App):
from plugins.param.param_arg import (
annotated_arg,
annotated_arg_plain_text,
annotated_arg_prompt_result,
annotated_arg_str,
annotated_multi_arg,
annotated_prior_arg,
@ -553,6 +588,7 @@ async def test_arg(app: App):
)
matcher = Matcher()
event = make_fake_event()()
message = FakeMessage("text")
matcher.set_arg("key", message)
@ -582,6 +618,21 @@ async def test_arg(app: App):
ctx.pass_params(matcher=matcher)
ctx.should_return(message.extract_plain_text())
matcher.set_target(ARG_KEY.format(key="key"), cache=False)
async with app.test_api() as ctx:
bot = ctx.create_bot()
ctx.should_call_send(event, "test", result="arg", bot=bot)
with matcher.ensure_context(bot, event):
with suppress(RejectedException):
await matcher.reject("test")
async with app.test_dependent(
annotated_arg_prompt_result, allow_types=[ArgParam]
) as ctx:
ctx.pass_params(matcher=matcher)
ctx.should_return("arg")
async with app.test_dependent(annotated_multi_arg, allow_types=[ArgParam]) as ctx:
ctx.pass_params(matcher=matcher)
ctx.should_return(message.extract_plain_text())