⚗️ add call_api hook

This commit is contained in:
yanyongyu
2021-03-31 16:51:09 +08:00
parent 4e7592de98
commit 8f99b01fb5
8 changed files with 157 additions and 60 deletions

View File

@ -6,20 +6,30 @@
"""
import abc
import asyncio
from copy import copy
from functools import reduce, partial
from dataclasses import dataclass, field
from typing import Any, Dict, Union, TypeVar, Mapping, Optional, Callable, Iterable, Iterator, Awaitable, TYPE_CHECKING
from typing import (Any, Set, Dict, Union, TypeVar, Mapping, Optional, Iterable,
Protocol, Awaitable, TYPE_CHECKING)
from pydantic import BaseModel
from nonebot.log import logger
from nonebot.utils import DataclassEncoder
from nonebot.typing import T_CallingAPIHook
if TYPE_CHECKING:
from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket
class _ApiCall(Protocol):
def __call__(self, **kwargs: Any) -> Awaitable[Any]:
...
class Bot(abc.ABC):
"""
Bot 基类。用于处理上报消息,并提供 API 调用接口。
@ -29,6 +39,11 @@ class Bot(abc.ABC):
"""Driver 对象"""
config: "Config"
"""Config 配置对象"""
_call_api_hook: Set[T_CallingAPIHook] = set()
"""
:类型: ``Set[T_CallingAPIHook]``
:说明: call_api 时执行的函数
"""
@abc.abstractmethod
def __init__(self,
@ -50,7 +65,7 @@ class Bot(abc.ABC):
self.websocket = websocket
"""Websocket 连接对象"""
def __getattr__(self, name: str) -> Callable[..., Awaitable[Any]]:
def __getattr__(self, name: str) -> _ApiCall:
return partial(self.call_api, name)
@property
@ -109,7 +124,20 @@ class Bot(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
async def call_api(self, api: str, **data) -> Any:
async def _call_api(self, api: str, **data) -> Any:
"""
:说明:
``adapter`` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
:参数:
* ``api: str``: API 名称
* ``**data``: API 数据
"""
raise NotImplementedError
async def call_api(self, api: str, **data: Any) -> Any:
"""
:说明:
@ -118,6 +146,7 @@ class Bot(abc.ABC):
:参数:
* ``api: str``: API 名称
* ``self_id: Optional[str]``: 指定调用 API 的机器人
* ``**data``: API 数据
:示例:
@ -127,7 +156,23 @@ class Bot(abc.ABC):
await bot.call_api("send_msg", message="hello world")
await bot.send_msg(message="hello world")
"""
raise NotImplementedError
coros = list(map(lambda x: x(api, data), self._call_api_hook))
if coros:
try:
logger.debug("Running CallingAPI hooks...")
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
"Running cancelled!</bg #f8bbd0></r>")
if "self_id" in data:
self_id = data.pop("self_id")
if self_id:
bot = self.driver.bots[str(self_id)]
return await bot._call_api(api, **data)
return await self._call_api(api, **data)
@abc.abstractmethod
async def send(self, event: "Event", message: Union[str, "Message",
@ -146,6 +191,11 @@ class Bot(abc.ABC):
"""
raise NotImplementedError
@classmethod
def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook:
cls._call_api_hook.add(func)
return func
T_Message = TypeVar("T_Message", bound="Message")
T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment")