add test cases

This commit is contained in:
yanyongyu
2021-12-20 00:28:02 +08:00
parent ca045b2f73
commit c2c3d5ef4b
17 changed files with 432 additions and 55 deletions

View File

@ -1,5 +1,6 @@
import asyncio
import inspect
from typing_extensions import Literal
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
@ -200,7 +201,7 @@ async def _event_message(event: Event) -> Message:
return event.get_message()
def EventMessage() -> Message:
def EventMessage() -> Any:
return Depends(_event_message)
@ -260,7 +261,7 @@ def _command_arg(state=State()) -> Message:
return state[PREFIX_KEY][CMD_ARG_KEY]
def CommandArg() -> Message:
def CommandArg() -> Any:
return Depends(_command_arg, use_cache=False)
@ -332,6 +333,44 @@ def LastReceived(default: Any = None) -> Any:
return Depends(_last_received, use_cache=False)
class ArgInner:
def __init__(
self, key: Optional[str], type: Literal["event", "message", "str"]
) -> None:
self.key = key
self.type = type
def Arg(key: Optional[str] = None) -> Any:
return ArgInner(key, "message")
def ArgEvent(key: Optional[str] = None) -> Any:
return ArgInner(key, "event")
def ArgStr(key: Optional[str] = None) -> Any:
return ArgInner(key, "str")
class ArgParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["ArgParam"]:
if isinstance(param.default, ArgInner):
return cls(Required, key=param.default.key or name, type=param.default.type)
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
event = matcher.get_arg(self.extra["key"])
if self.extra["type"] == "event":
return event
elif self.extra["type"] == "message":
return event.get_message()
else:
return matcher.get_arg_str(self.extra["key"])
class ExceptionParam(Param):
@classmethod
def _check_param(