Feature: 存储 matcher 发送 prompt 的结果 (#3155)

This commit is contained in:
Ju4tCode
2024-12-05 20:55:24 +08:00
committed by GitHub
parent ab8dea5a02
commit 32bc2c314a
8 changed files with 271 additions and 22 deletions

View File

@ -18,6 +18,7 @@ from exceptiongroup import BaseExceptionGroup, catch
from pydantic.fields import FieldInfo as PydanticFieldInfo
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
from nonebot.consts import ARG_KEY, REJECT_PROMPT_RESULT_KEY
from nonebot.dependencies import Dependent, Param
from nonebot.dependencies.utils import check_field_type
from nonebot.exception import SkippedException
@ -39,7 +40,7 @@ from nonebot.utils import (
)
if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
from nonebot.adapters import Bot, Event, Message
from nonebot.matcher import Matcher
@ -522,10 +523,10 @@ class MatcherParam(Param):
class ArgInner:
def __init__(
self, key: Optional[str], type: Literal["message", "str", "plaintext"]
self, key: Optional[str], type: Literal["message", "str", "plaintext", "prompt"]
) -> None:
self.key: Optional[str] = key
self.type: Literal["message", "str", "plaintext"] = type
self.type: Literal["message", "str", "plaintext", "prompt"] = type
def __repr__(self) -> str:
return f"ArgInner(key={self.key!r}, type={self.type!r})"
@ -546,6 +547,11 @@ def ArgPlainText(key: Optional[str] = None) -> str:
return ArgInner(key, "plaintext") # type: ignore
def ArgPromptResult(key: Optional[str] = None) -> Any:
"""`arg` prompt 发送结果"""
return ArgInner(key, "prompt")
class ArgParam(Param):
"""Arg 注入参数
@ -559,7 +565,7 @@ class ArgParam(Param):
self,
*args,
key: str,
type: Literal["message", "str", "plaintext"],
type: Literal["message", "str", "plaintext", "prompt"],
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
@ -584,15 +590,32 @@ class ArgParam(Param):
async def _solve( # pyright: ignore[reportIncompatibleMethodOverride]
self, matcher: "Matcher", **kwargs: Any
) -> Any:
message = matcher.get_arg(self.key)
if message is None:
return message
if self.type == "message":
return message
return self._solve_message(matcher)
elif self.type == "str":
return str(message)
return self._solve_str(matcher)
elif self.type == "plaintext":
return self._solve_plaintext(matcher)
elif self.type == "prompt":
return self._solve_prompt(matcher)
else:
return message.extract_plain_text()
raise ValueError(f"Unknown Arg type: {self.type}")
def _solve_message(self, matcher: "Matcher") -> Optional["Message"]:
return matcher.get_arg(self.key)
def _solve_str(self, matcher: "Matcher") -> Optional[str]:
message = matcher.get_arg(self.key)
return str(message) if message is not None else None
def _solve_plaintext(self, matcher: "Matcher") -> Optional[str]:
message = matcher.get_arg(self.key)
return message.extract_plain_text() if message is not None else None
def _solve_prompt(self, matcher: "Matcher") -> Optional[Any]:
return matcher.state.get(
REJECT_PROMPT_RESULT_KEY.format(key=ARG_KEY.format(key=self.key))
)
class ExceptionParam(Param):