add di functions

This commit is contained in:
yanyongyu
2021-12-14 01:08:48 +08:00
parent e942f4076c
commit 329a1fd226
6 changed files with 147 additions and 74 deletions

View File

@ -1,12 +1,25 @@
import inspect
from typing import Any, List, Type, Callable, Optional, cast
from functools import wraps, partial
from typing import Any, Tuple, Union, TypeVar, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from pydantic.fields import Required, Undefined
from nonebot.adapters import Bot, Event
from nonebot.typing import T_State, T_Handler
from nonebot.adapters import Bot, Event, Message
from nonebot.dependencies import Param, Dependent
from nonebot.consts import (
CMD_KEY,
PREFIX_KEY,
REGEX_DICT,
SHELL_ARGS,
SHELL_ARGV,
CMD_ARG_KEY,
RAW_CMD_KEY,
REGEX_GROUP,
REGEX_MATCHED,
WRAPPER_ASSIGNMENTS,
)
from nonebot.utils import (
CacheDict,
get_name,
@ -18,6 +31,8 @@ from nonebot.utils import (
generic_check_issubclass,
)
T = TypeVar("T")
class DependsInner:
def __init__(
@ -175,12 +190,44 @@ class EventParam(Param):
return event
async def _event_type(event: Event) -> str:
return event.get_type()
def EventType() -> str:
return Depends(_event_type)
async def _event_message(event: Event) -> Message:
return event.get_message()
def EventMessage() -> Message:
return Depends(_event_message)
async def _event_plain_text(event: Event) -> str:
return event.get_plaintext()
def EventPlainText() -> str:
return Depends(_event_plain_text)
async def _event_to_me(event: Event) -> bool:
return event.is_tome()
def EventToMe() -> bool:
return Depends(_event_to_me)
class StateInner:
...
def State() -> Any:
return StateInner()
def State() -> T_State:
return StateInner() # type: ignore
class StateParam(Param):
@ -195,6 +242,30 @@ class StateParam(Param):
return state
def _command(state=State()) -> Message:
return state[PREFIX_KEY][CMD_KEY]
def Command() -> Tuple[str, ...]:
return Depends(_command)
def _raw_command(state=State()) -> Message:
return state[PREFIX_KEY][RAW_CMD_KEY]
def RawCommand() -> str:
return Depends(_raw_command)
def _command_arg(state=State()) -> Message:
return state[PREFIX_KEY][CMD_ARG_KEY]
def CommandArg() -> Message:
return Depends(_command_arg)
class MatcherParam(Param):
@classmethod
def _check_param(
@ -209,6 +280,18 @@ class MatcherParam(Param):
return matcher
def _received(matcher: "Matcher", id: str = "", default: T = None) -> Union[Event, T]:
return matcher.get_receive(id, default)
def Received(id: str = "", default: Any = None) -> Any:
return Depends(
wraps(_received, assigned=WRAPPER_ASSIGNMENTS)(
partial(_received, id=id, default=default)
)
)
class ExceptionParam(Param):
@classmethod
def _check_param(