add test cases

This commit is contained in:
yanyongyu
2021-12-20 00:28:02 +08:00
parent ca045b2f73
commit c2c3d5ef4b
17 changed files with 432 additions and 55 deletions

View File

@ -5,6 +5,9 @@ from ._bot import Bot
from nonebot.config import Config
from nonebot.drivers import (
Driver,
Request,
Response,
WebSocket,
ForwardDriver,
ReverseDriver,
HTTPServerSetup,
@ -44,6 +47,16 @@ class Adapter(abc.ABC):
raise TypeError("Current driver does not support websocket server")
self.driver.setup_websocket_server(setup)
async def request(self, setup: Request) -> Response:
if not isinstance(self.driver, ForwardDriver):
raise TypeError("Current driver does not support http client")
return await self.driver.request(setup)
async def websocket(self, setup: Request) -> WebSocket:
if not isinstance(self.driver, ForwardDriver):
raise TypeError("Current driver does not support websocket client")
return await self.driver.websocket(setup)
@abc.abstractmethod
async def _call_api(self, api: str, **data) -> Any:
"""

View File

@ -8,7 +8,6 @@ from nonebot.log import logger
from nonebot.config import Config
from nonebot.exception import MockApiException
from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook
from nonebot.drivers import Driver, HTTPResponse, HTTPConnection
if TYPE_CHECKING:
from ._event import Event

View File

@ -151,6 +151,9 @@ class Dependent(Generic[R]):
) -> Dict[str, Any]:
values: Dict[str, Any] = {}
for param in self.parameterless:
await param._solve(**params)
for field in self.params:
field_info = field.field_info
assert isinstance(field_info, Param), "Params must be subclasses of Param"
@ -168,7 +171,4 @@ class Dependent(Generic[R]):
else:
values[field.name] = value
for param in self.parameterless:
await param._solve(**params)
return values

View File

@ -10,10 +10,14 @@ import asyncio
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Awaitable
from ._model import URL as URL
from nonebot.log import logger
from nonebot.utils import escape_tag
from ._model import Request as Request
from nonebot.config import Env, Config
from ._model import URL, Request, Response, WebSocket, HTTPVersion
from ._model import Response as Response
from ._model import WebSocket as WebSocket
from ._model import HTTPVersion as HTTPVersion
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
if TYPE_CHECKING:
@ -204,11 +208,11 @@ class ForwardDriver(Driver):
"""
@abc.abstractmethod
async def request(self, setup: "Request") -> Any:
async def request(self, setup: Request) -> Response:
raise NotImplementedError
@abc.abstractmethod
async def websocket(self, setup: "Request") -> Any:
async def websocket(self, setup: Request) -> WebSocket:
raise NotImplementedError

View File

@ -193,6 +193,7 @@ class Matcher(metaclass=MatcherMeta):
params.BotParam,
params.EventParam,
params.StateParam,
params.ArgParam,
params.MatcherParam,
params.DefaultParam,
]
@ -443,10 +444,10 @@ class Matcher(metaclass=MatcherMeta):
async def _receive(event: Event, matcher: "Matcher") -> Union[None, NoReturn]:
if matcher.get_receive(id):
return
if matcher.get_target() == RECEIVE_KEY.format(id=id):
if matcher.get_target() == RECEIVE_KEY.format(id=id or ""):
matcher.set_receive(id, event)
return
matcher.set_target(RECEIVE_KEY.format(id=id))
matcher.set_target(RECEIVE_KEY.format(id=id or ""))
raise RejectedException
parameterless = [params.Depends(_receive), *(parameterless or [])]
@ -472,7 +473,6 @@ class Matcher(metaclass=MatcherMeta):
cls,
key: str,
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
args_parser: Optional[T_ArgsParser] = None,
parameterless: Optional[List[Any]] = None,
) -> Callable[[T_Handler], T_Handler]:
"""
@ -495,6 +495,8 @@ class Matcher(metaclass=MatcherMeta):
matcher.set_arg(key, event)
return
matcher.set_target(ARG_KEY.format(key=key))
if prompt is not None:
await matcher.send(prompt)
raise RejectedException
_parameterless = [
@ -517,7 +519,9 @@ class Matcher(metaclass=MatcherMeta):
@classmethod
async def send(
cls, message: Union[str, Message, MessageSegment, MessageTemplate], **kwargs
cls,
message: Union[str, Message, MessageSegment, MessageTemplate],
**kwargs: Any,
) -> Any:
"""
:说明:

View File

@ -58,19 +58,21 @@ EVENT_PCS_PARAMS = [
]
RUN_PREPCS_PARAMS = [
params.DependParam,
params.MatcherParam,
params.BotParam,
params.EventParam,
params.StateParam,
params.ArgParam,
params.MatcherParam,
params.DefaultParam,
]
RUN_POSTPCS_PARAMS = [
params.DependParam,
params.MatcherParam,
params.ExceptionParam,
params.BotParam,
params.EventParam,
params.StateParam,
params.ArgParam,
params.MatcherParam,
params.DefaultParam,
]

View File

@ -1,5 +1,6 @@
import asyncio
import inspect
from typing_extensions import Literal
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
@ -200,7 +201,7 @@ async def _event_message(event: Event) -> Message:
return event.get_message()
def EventMessage() -> Message:
def EventMessage() -> Any:
return Depends(_event_message)
@ -260,7 +261,7 @@ def _command_arg(state=State()) -> Message:
return state[PREFIX_KEY][CMD_ARG_KEY]
def CommandArg() -> Message:
def CommandArg() -> Any:
return Depends(_command_arg, use_cache=False)
@ -332,6 +333,44 @@ def LastReceived(default: Any = None) -> Any:
return Depends(_last_received, use_cache=False)
class ArgInner:
def __init__(
self, key: Optional[str], type: Literal["event", "message", "str"]
) -> None:
self.key = key
self.type = type
def Arg(key: Optional[str] = None) -> Any:
return ArgInner(key, "message")
def ArgEvent(key: Optional[str] = None) -> Any:
return ArgInner(key, "event")
def ArgStr(key: Optional[str] = None) -> Any:
return ArgInner(key, "str")
class ArgParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["ArgParam"]:
if isinstance(param.default, ArgInner):
return cls(Required, key=param.default.key or name, type=param.default.type)
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
event = matcher.get_arg(self.extra["key"])
if self.extra["type"] == "event":
return event
elif self.extra["type"] == "message":
return event.get_message()
else:
return matcher.get_arg_str(self.extra["key"])
class ExceptionParam(Param):
@classmethod
def _check_param(