mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-09-07 04:26:45 +00:00
✅ add test cases
This commit is contained in:
@ -5,6 +5,9 @@ from ._bot import Bot
|
||||
from nonebot.config import Config
|
||||
from nonebot.drivers import (
|
||||
Driver,
|
||||
Request,
|
||||
Response,
|
||||
WebSocket,
|
||||
ForwardDriver,
|
||||
ReverseDriver,
|
||||
HTTPServerSetup,
|
||||
@ -44,6 +47,16 @@ class Adapter(abc.ABC):
|
||||
raise TypeError("Current driver does not support websocket server")
|
||||
self.driver.setup_websocket_server(setup)
|
||||
|
||||
async def request(self, setup: Request) -> Response:
|
||||
if not isinstance(self.driver, ForwardDriver):
|
||||
raise TypeError("Current driver does not support http client")
|
||||
return await self.driver.request(setup)
|
||||
|
||||
async def websocket(self, setup: Request) -> WebSocket:
|
||||
if not isinstance(self.driver, ForwardDriver):
|
||||
raise TypeError("Current driver does not support websocket client")
|
||||
return await self.driver.websocket(setup)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _call_api(self, api: str, **data) -> Any:
|
||||
"""
|
||||
|
@ -8,7 +8,6 @@ from nonebot.log import logger
|
||||
from nonebot.config import Config
|
||||
from nonebot.exception import MockApiException
|
||||
from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook
|
||||
from nonebot.drivers import Driver, HTTPResponse, HTTPConnection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._event import Event
|
||||
|
@ -151,6 +151,9 @@ class Dependent(Generic[R]):
|
||||
) -> Dict[str, Any]:
|
||||
values: Dict[str, Any] = {}
|
||||
|
||||
for param in self.parameterless:
|
||||
await param._solve(**params)
|
||||
|
||||
for field in self.params:
|
||||
field_info = field.field_info
|
||||
assert isinstance(field_info, Param), "Params must be subclasses of Param"
|
||||
@ -168,7 +171,4 @@ class Dependent(Generic[R]):
|
||||
else:
|
||||
values[field.name] = value
|
||||
|
||||
for param in self.parameterless:
|
||||
await param._solve(**params)
|
||||
|
||||
return values
|
||||
|
@ -10,10 +10,14 @@ import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Awaitable
|
||||
|
||||
from ._model import URL as URL
|
||||
from nonebot.log import logger
|
||||
from nonebot.utils import escape_tag
|
||||
from ._model import Request as Request
|
||||
from nonebot.config import Env, Config
|
||||
from ._model import URL, Request, Response, WebSocket, HTTPVersion
|
||||
from ._model import Response as Response
|
||||
from ._model import WebSocket as WebSocket
|
||||
from ._model import HTTPVersion as HTTPVersion
|
||||
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -204,11 +208,11 @@ class ForwardDriver(Driver):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def request(self, setup: "Request") -> Any:
|
||||
async def request(self, setup: Request) -> Response:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def websocket(self, setup: "Request") -> Any:
|
||||
async def websocket(self, setup: Request) -> WebSocket:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@ -193,6 +193,7 @@ class Matcher(metaclass=MatcherMeta):
|
||||
params.BotParam,
|
||||
params.EventParam,
|
||||
params.StateParam,
|
||||
params.ArgParam,
|
||||
params.MatcherParam,
|
||||
params.DefaultParam,
|
||||
]
|
||||
@ -443,10 +444,10 @@ class Matcher(metaclass=MatcherMeta):
|
||||
async def _receive(event: Event, matcher: "Matcher") -> Union[None, NoReturn]:
|
||||
if matcher.get_receive(id):
|
||||
return
|
||||
if matcher.get_target() == RECEIVE_KEY.format(id=id):
|
||||
if matcher.get_target() == RECEIVE_KEY.format(id=id or ""):
|
||||
matcher.set_receive(id, event)
|
||||
return
|
||||
matcher.set_target(RECEIVE_KEY.format(id=id))
|
||||
matcher.set_target(RECEIVE_KEY.format(id=id or ""))
|
||||
raise RejectedException
|
||||
|
||||
parameterless = [params.Depends(_receive), *(parameterless or [])]
|
||||
@ -472,7 +473,6 @@ class Matcher(metaclass=MatcherMeta):
|
||||
cls,
|
||||
key: str,
|
||||
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
|
||||
args_parser: Optional[T_ArgsParser] = None,
|
||||
parameterless: Optional[List[Any]] = None,
|
||||
) -> Callable[[T_Handler], T_Handler]:
|
||||
"""
|
||||
@ -495,6 +495,8 @@ class Matcher(metaclass=MatcherMeta):
|
||||
matcher.set_arg(key, event)
|
||||
return
|
||||
matcher.set_target(ARG_KEY.format(key=key))
|
||||
if prompt is not None:
|
||||
await matcher.send(prompt)
|
||||
raise RejectedException
|
||||
|
||||
_parameterless = [
|
||||
@ -517,7 +519,9 @@ class Matcher(metaclass=MatcherMeta):
|
||||
|
||||
@classmethod
|
||||
async def send(
|
||||
cls, message: Union[str, Message, MessageSegment, MessageTemplate], **kwargs
|
||||
cls,
|
||||
message: Union[str, Message, MessageSegment, MessageTemplate],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
@ -58,19 +58,21 @@ EVENT_PCS_PARAMS = [
|
||||
]
|
||||
RUN_PREPCS_PARAMS = [
|
||||
params.DependParam,
|
||||
params.MatcherParam,
|
||||
params.BotParam,
|
||||
params.EventParam,
|
||||
params.StateParam,
|
||||
params.ArgParam,
|
||||
params.MatcherParam,
|
||||
params.DefaultParam,
|
||||
]
|
||||
RUN_POSTPCS_PARAMS = [
|
||||
params.DependParam,
|
||||
params.MatcherParam,
|
||||
params.ExceptionParam,
|
||||
params.BotParam,
|
||||
params.EventParam,
|
||||
params.StateParam,
|
||||
params.ArgParam,
|
||||
params.MatcherParam,
|
||||
params.DefaultParam,
|
||||
]
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing_extensions import Literal
|
||||
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
|
||||
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
||||
|
||||
@ -200,7 +201,7 @@ async def _event_message(event: Event) -> Message:
|
||||
return event.get_message()
|
||||
|
||||
|
||||
def EventMessage() -> Message:
|
||||
def EventMessage() -> Any:
|
||||
return Depends(_event_message)
|
||||
|
||||
|
||||
@ -260,7 +261,7 @@ def _command_arg(state=State()) -> Message:
|
||||
return state[PREFIX_KEY][CMD_ARG_KEY]
|
||||
|
||||
|
||||
def CommandArg() -> Message:
|
||||
def CommandArg() -> Any:
|
||||
return Depends(_command_arg, use_cache=False)
|
||||
|
||||
|
||||
@ -332,6 +333,44 @@ def LastReceived(default: Any = None) -> Any:
|
||||
return Depends(_last_received, use_cache=False)
|
||||
|
||||
|
||||
class ArgInner:
|
||||
def __init__(
|
||||
self, key: Optional[str], type: Literal["event", "message", "str"]
|
||||
) -> None:
|
||||
self.key = key
|
||||
self.type = type
|
||||
|
||||
|
||||
def Arg(key: Optional[str] = None) -> Any:
|
||||
return ArgInner(key, "message")
|
||||
|
||||
|
||||
def ArgEvent(key: Optional[str] = None) -> Any:
|
||||
return ArgInner(key, "event")
|
||||
|
||||
|
||||
def ArgStr(key: Optional[str] = None) -> Any:
|
||||
return ArgInner(key, "str")
|
||||
|
||||
|
||||
class ArgParam(Param):
|
||||
@classmethod
|
||||
def _check_param(
|
||||
cls, dependent: Dependent, name: str, param: inspect.Parameter
|
||||
) -> Optional["ArgParam"]:
|
||||
if isinstance(param.default, ArgInner):
|
||||
return cls(Required, key=param.default.key or name, type=param.default.type)
|
||||
|
||||
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
||||
event = matcher.get_arg(self.extra["key"])
|
||||
if self.extra["type"] == "event":
|
||||
return event
|
||||
elif self.extra["type"] == "message":
|
||||
return event.get_message()
|
||||
else:
|
||||
return matcher.get_arg_str(self.extra["key"])
|
||||
|
||||
|
||||
class ExceptionParam(Param):
|
||||
@classmethod
|
||||
def _check_param(
|
||||
|
Reference in New Issue
Block a user