Feature: 存储 matcher 发送 prompt 的结果 (#3155)

This commit is contained in:
Ju4tCode
2024-12-05 20:55:24 +08:00
committed by GitHub
parent ab8dea5a02
commit 32bc2c314a
8 changed files with 271 additions and 22 deletions

View File

@ -22,6 +22,10 @@ REJECT_TARGET: Literal["_current_target"] = "_current_target"
"""当前 `reject` 目标存储 key"""
REJECT_CACHE_TARGET: Literal["_next_target"] = "_next_target"
"""下一个 `reject` 目标存储 key"""
PAUSE_PROMPT_RESULT_KEY: Literal["_pause_result"] = "_pause_result"
"""`pause` prompt 发送结果存储 key"""
REJECT_PROMPT_RESULT_KEY: Literal["_reject_{key}_result"] = "_reject_{key}_result"
"""`reject` prompt 发送结果存储 key"""
# used by Rule
PREFIX_KEY: Literal["_prefix"] = "_prefix"

View File

@ -27,8 +27,10 @@ from exceptiongroup import BaseExceptionGroup, catch
from nonebot.consts import (
ARG_KEY,
LAST_RECEIVE_KEY,
PAUSE_PROMPT_RESULT_KEY,
RECEIVE_KEY,
REJECT_CACHE_TARGET,
REJECT_PROMPT_RESULT_KEY,
REJECT_TARGET,
)
from nonebot.dependencies import Dependent, Param
@ -560,8 +562,8 @@ class Matcher(metaclass=MatcherMeta):
"""
bot = current_bot.get()
event = current_event.get()
state = current_matcher.get().state
if isinstance(message, MessageTemplate):
state = current_matcher.get().state
_message = message.format(**state)
else:
_message = message
@ -597,8 +599,15 @@ class Matcher(metaclass=MatcherMeta):
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数,
请参考对应 adapter 的 bot 对象 api
"""
try:
matcher = current_matcher.get()
except Exception:
matcher = None
if prompt is not None:
await cls.send(prompt, **kwargs)
result = await cls.send(prompt, **kwargs)
if matcher is not None:
matcher.state[PAUSE_PROMPT_RESULT_KEY] = result
raise PausedException
@classmethod
@ -615,8 +624,19 @@ class Matcher(metaclass=MatcherMeta):
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数,
请参考对应 adapter 的 bot 对象 api
"""
try:
matcher = current_matcher.get()
key = matcher.get_target()
except Exception:
matcher = None
key = None
key = REJECT_PROMPT_RESULT_KEY.format(key=key) if key is not None else None
if prompt is not None:
await cls.send(prompt, **kwargs)
result = await cls.send(prompt, **kwargs)
if key is not None and matcher:
matcher.state[key] = result
raise RejectedException
@classmethod
@ -636,9 +656,12 @@ class Matcher(metaclass=MatcherMeta):
请参考对应 adapter 的 bot 对象 api
"""
matcher = current_matcher.get()
matcher.set_target(ARG_KEY.format(key=key))
arg_key = ARG_KEY.format(key=key)
matcher.set_target(arg_key)
if prompt is not None:
await cls.send(prompt, **kwargs)
result = await cls.send(prompt, **kwargs)
matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=arg_key)] = result
raise RejectedException
@classmethod
@ -658,9 +681,12 @@ class Matcher(metaclass=MatcherMeta):
请参考对应 adapter 的 bot 对象 api
"""
matcher = current_matcher.get()
matcher.set_target(RECEIVE_KEY.format(id=id))
receive_key = RECEIVE_KEY.format(id=id)
matcher.set_target(receive_key)
if prompt is not None:
await cls.send(prompt, **kwargs)
result = await cls.send(prompt, **kwargs)
matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=receive_key)] = result
raise RejectedException
@classmethod

View File

@ -18,6 +18,7 @@ from exceptiongroup import BaseExceptionGroup, catch
from pydantic.fields import FieldInfo as PydanticFieldInfo
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
from nonebot.consts import ARG_KEY, REJECT_PROMPT_RESULT_KEY
from nonebot.dependencies import Dependent, Param
from nonebot.dependencies.utils import check_field_type
from nonebot.exception import SkippedException
@ -39,7 +40,7 @@ from nonebot.utils import (
)
if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
from nonebot.adapters import Bot, Event, Message
from nonebot.matcher import Matcher
@ -522,10 +523,10 @@ class MatcherParam(Param):
class ArgInner:
def __init__(
self, key: Optional[str], type: Literal["message", "str", "plaintext"]
self, key: Optional[str], type: Literal["message", "str", "plaintext", "prompt"]
) -> None:
self.key: Optional[str] = key
self.type: Literal["message", "str", "plaintext"] = type
self.type: Literal["message", "str", "plaintext", "prompt"] = type
def __repr__(self) -> str:
return f"ArgInner(key={self.key!r}, type={self.type!r})"
@ -546,6 +547,11 @@ def ArgPlainText(key: Optional[str] = None) -> str:
return ArgInner(key, "plaintext") # type: ignore
def ArgPromptResult(key: Optional[str] = None) -> Any:
"""`arg` prompt 发送结果"""
return ArgInner(key, "prompt")
class ArgParam(Param):
"""Arg 注入参数
@ -559,7 +565,7 @@ class ArgParam(Param):
self,
*args,
key: str,
type: Literal["message", "str", "plaintext"],
type: Literal["message", "str", "plaintext", "prompt"],
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
@ -584,15 +590,32 @@ class ArgParam(Param):
async def _solve( # pyright: ignore[reportIncompatibleMethodOverride]
self, matcher: "Matcher", **kwargs: Any
) -> Any:
message = matcher.get_arg(self.key)
if message is None:
return message
if self.type == "message":
return message
return self._solve_message(matcher)
elif self.type == "str":
return str(message)
return self._solve_str(matcher)
elif self.type == "plaintext":
return self._solve_plaintext(matcher)
elif self.type == "prompt":
return self._solve_prompt(matcher)
else:
return message.extract_plain_text()
raise ValueError(f"Unknown Arg type: {self.type}")
def _solve_message(self, matcher: "Matcher") -> Optional["Message"]:
return matcher.get_arg(self.key)
def _solve_str(self, matcher: "Matcher") -> Optional[str]:
message = matcher.get_arg(self.key)
return str(message) if message is not None else None
def _solve_plaintext(self, matcher: "Matcher") -> Optional[str]:
message = matcher.get_arg(self.key)
return message.extract_plain_text() if message is not None else None
def _solve_prompt(self, matcher: "Matcher") -> Optional[Any]:
return matcher.state.get(
REJECT_PROMPT_RESULT_KEY.format(key=ARG_KEY.format(key=self.key))
)
class ExceptionParam(Param):

View File

@ -19,9 +19,12 @@ from nonebot.consts import (
ENDSWITH_KEY,
FULLMATCH_KEY,
KEYWORD_KEY,
PAUSE_PROMPT_RESULT_KEY,
PREFIX_KEY,
RAW_CMD_KEY,
RECEIVE_KEY,
REGEX_MATCHED,
REJECT_PROMPT_RESULT_KEY,
SHELL_ARGS,
SHELL_ARGV,
STARTSWITH_KEY,
@ -29,6 +32,7 @@ from nonebot.consts import (
from nonebot.internal.params import Arg as Arg
from nonebot.internal.params import ArgParam as ArgParam
from nonebot.internal.params import ArgPlainText as ArgPlainText
from nonebot.internal.params import ArgPromptResult as ArgPromptResult
from nonebot.internal.params import ArgStr as ArgStr
from nonebot.internal.params import BotParam as BotParam
from nonebot.internal.params import DefaultParam as DefaultParam
@ -252,6 +256,26 @@ def LastReceived(default: Any = None) -> Any:
return Depends(_last_received, use_cache=False)
def ReceivePromptResult(id: Optional[str] = None) -> Any:
"""`receive` prompt 发送结果"""
def _receive_prompt_result(matcher: "Matcher") -> Any:
return matcher.state.get(
REJECT_PROMPT_RESULT_KEY.format(key=RECEIVE_KEY.format(id=id))
)
return Depends(_receive_prompt_result, use_cache=False)
def PausePromptResult() -> Any:
"""`pause` prompt 发送结果"""
def _pause_prompt_result(matcher: "Matcher") -> Any:
return matcher.state.get(PAUSE_PROMPT_RESULT_KEY)
return Depends(_pause_prompt_result, use_cache=False)
__autodoc__ = {
"Arg": True,
"ArgStr": True,
@ -265,4 +289,5 @@ __autodoc__ = {
"DefaultParam": True,
"MatcherParam": True,
"ExceptionParam": True,
"ArgPromptResult": True,
}