diff --git a/docs/api/adapters/README.md b/docs/api/adapters/README.md
index 05fe3b49..bd4aab2a 100644
--- a/docs/api/adapters/README.md
+++ b/docs/api/adapters/README.md
@@ -27,6 +27,21 @@ Driver 对象
Config 配置对象
+### `_call_api_hook`
+
+
+* **类型**
+
+ `Set[T_CallingAPIHook]`
+
+
+
+* **说明**
+
+ call_api 时执行的函数
+
+
+
### _abstract_ `__init__(connection_type, self_id, *, websocket=None)`
@@ -127,7 +142,26 @@ Adapter 类型
-### _abstract async_ `call_api(api, **data)`
+### _abstract async_ `_call_api(api, **data)`
+
+
+* **说明**
+
+ `adapter` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
+
+
+
+* **参数**
+
+
+ * `api: str`: API 名称
+
+
+ * `**data`: API 数据
+
+
+
+### _async_ `call_api(api, **data)`
* **说明**
@@ -142,6 +176,9 @@ Adapter 类型
* `api: str`: API 名称
+ * `self_id: Optional[str]`: 指定调用 API 的机器人
+
+
* `**data`: API 数据
diff --git a/docs/api/adapters/ding.md b/docs/api/adapters/ding.md
index 7cfe5932..2c531a7b 100644
--- a/docs/api/adapters/ding.md
+++ b/docs/api/adapters/ding.md
@@ -129,6 +129,9 @@ sidebarDepth: 0
* `api: str`: API 名称
+ * `event: Optional[MessageEvent]`: Event 对象
+
+
* `**data: Any`: API 参数
diff --git a/nonebot/adapters/_base.py b/nonebot/adapters/_base.py
index 8f4c2898..d36407a9 100644
--- a/nonebot/adapters/_base.py
+++ b/nonebot/adapters/_base.py
@@ -6,20 +6,30 @@
"""
import abc
+import asyncio
from copy import copy
from functools import reduce, partial
from dataclasses import dataclass, field
-from typing import Any, Dict, Union, TypeVar, Mapping, Optional, Callable, Iterable, Iterator, Awaitable, TYPE_CHECKING
+from typing import (Any, Set, Dict, Union, TypeVar, Mapping, Optional, Iterable,
+ Protocol, Awaitable, TYPE_CHECKING)
from pydantic import BaseModel
+from nonebot.log import logger
from nonebot.utils import DataclassEncoder
+from nonebot.typing import T_CallingAPIHook
if TYPE_CHECKING:
from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket
+class _ApiCall(Protocol):
+
+ def __call__(self, **kwargs: Any) -> Awaitable[Any]:
+ ...
+
+
class Bot(abc.ABC):
"""
Bot 基类。用于处理上报消息,并提供 API 调用接口。
@@ -29,6 +39,11 @@ class Bot(abc.ABC):
"""Driver 对象"""
config: "Config"
"""Config 配置对象"""
+ _call_api_hook: Set[T_CallingAPIHook] = set()
+ """
+ :类型: ``Set[T_CallingAPIHook]``
+ :说明: call_api 时执行的函数
+ """
@abc.abstractmethod
def __init__(self,
@@ -50,7 +65,7 @@ class Bot(abc.ABC):
self.websocket = websocket
"""Websocket 连接对象"""
- def __getattr__(self, name: str) -> Callable[..., Awaitable[Any]]:
+ def __getattr__(self, name: str) -> _ApiCall:
return partial(self.call_api, name)
@property
@@ -109,7 +124,20 @@ class Bot(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
- async def call_api(self, api: str, **data) -> Any:
+ async def _call_api(self, api: str, **data) -> Any:
+ """
+ :说明:
+
+ ``adapter`` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
+
+ :参数:
+
+ * ``api: str``: API 名称
+ * ``**data``: API 数据
+ """
+ raise NotImplementedError
+
+ async def call_api(self, api: str, **data: Any) -> Any:
"""
:说明:
@@ -118,6 +146,7 @@ class Bot(abc.ABC):
:参数:
* ``api: str``: API 名称
+ * ``self_id: Optional[str]``: 指定调用 API 的机器人
* ``**data``: API 数据
:示例:
@@ -127,7 +156,23 @@ class Bot(abc.ABC):
await bot.call_api("send_msg", message="hello world")
await bot.send_msg(message="hello world")
"""
- raise NotImplementedError
+ coros = list(map(lambda x: x(api, data), self._call_api_hook))
+ if coros:
+ try:
+ logger.debug("Running CallingAPI hooks...")
+ await asyncio.gather(*coros)
+ except Exception as e:
+ logger.opt(colors=True, exception=e).error(
+ "Error when running CallingAPI hook. "
+ "Running cancelled!")
+
+ if "self_id" in data:
+ self_id = data.pop("self_id")
+ if self_id:
+ bot = self.driver.bots[str(self_id)]
+ return await bot._call_api(api, **data)
+
+ return await self._call_api(api, **data)
@abc.abstractmethod
async def send(self, event: "Event", message: Union[str, "Message",
@@ -146,6 +191,11 @@ class Bot(abc.ABC):
"""
raise NotImplementedError
+ @classmethod
+ def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook:
+ cls._call_api_hook.add(func)
+ return func
+
T_Message = TypeVar("T_Message", bound="Message")
T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment")
diff --git a/nonebot/typing.py b/nonebot/typing.py
index dd2f24c5..c1dc008a 100644
--- a/nonebot/typing.py
+++ b/nonebot/typing.py
@@ -71,6 +71,7 @@ T_WebSocketDisconnectionHook = Callable[["Bot"], Awaitable[None]]
WebSocket 连接断开时执行的函数
"""
+T_CallingAPIHook = Callable[[str, Dict[str, Any]], Awaitable[None]]
T_EventPreProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]]
"""
diff --git a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py
index ca477559..20e5015d 100644
--- a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py
+++ b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py
@@ -328,32 +328,7 @@ class Bot(BaseBot):
)
@overrides(BaseBot)
- async def call_api(self, api: str, **data) -> Any:
- """
- :说明:
-
- 调用 CQHTTP 协议 API
-
- :参数:
-
- * ``api: str``: API 名称
- * ``**data: Any``: API 参数
-
- :返回:
-
- - ``Any``: API 调用返回数据
-
- :异常:
-
- - ``NetworkError``: 网络错误
- - ``ActionFailed``: API 调用失败
- """
- if "self_id" in data:
- self_id = data.pop("self_id")
- if self_id:
- bot = self.driver.bots[str(self_id)]
- return await bot.call_api(api, **data)
-
+ async def _call_api(self, api: str, **data) -> Any:
log("DEBUG", f"Calling API {api}")
if self.connection_type == "websocket":
seq = ResultStore.get_seq()
@@ -396,6 +371,29 @@ class Bot(BaseBot):
except httpx.HTTPError:
raise NetworkError("HTTP request failed")
+ @overrides(BaseBot)
+ async def call_api(self, api: str, **data) -> Any:
+ """
+ :说明:
+
+ 调用 CQHTTP 协议 API
+
+ :参数:
+
+ * ``api: str``: API 名称
+ * ``**data: Any``: API 参数
+
+ :返回:
+
+ - ``Any``: API 调用返回数据
+
+ :异常:
+
+ - ``NetworkError``: 网络错误
+ - ``ActionFailed``: API 调用失败
+ """
+ return super().call_api(api, **data)
+
@overrides(BaseBot)
async def send(self,
event: Event,
diff --git a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.pyi b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.pyi
index 7ba09f8a..ad8d459c 100644
--- a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.pyi
+++ b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.pyi
@@ -68,7 +68,8 @@ class Bot(BaseBot):
async def handle_message(self, message: dict):
...
- async def call_api(self, api: str, **data) -> Any:
+ async def call_api(self, api: str, *, self_id: Optional[str],
+ **data) -> Any:
...
async def send(self, event: Event, message: Union[str, Message,
diff --git a/packages/nonebot-adapter-ding/nonebot/adapters/ding/bot.py b/packages/nonebot-adapter-ding/nonebot/adapters/ding/bot.py
index 08175ce4..410515bb 100644
--- a/packages/nonebot-adapter-ding/nonebot/adapters/ding/bot.py
+++ b/packages/nonebot-adapter-ding/nonebot/adapters/ding/bot.py
@@ -109,37 +109,13 @@ class Bot(BaseBot):
return
@overrides(BaseBot)
- async def call_api(self,
- api: str,
- event: Optional[MessageEvent] = None,
- **data) -> Any:
- """
- :说明:
-
- 调用 钉钉 协议 API
-
- :参数:
-
- * ``api: str``: API 名称
- * ``**data: Any``: API 参数
-
- :返回:
-
- - ``Any``: API 调用返回数据
-
- :异常:
-
- - ``NetworkError``: 网络错误
- - ``ActionFailed``: API 调用失败
- """
+ async def _call_api(self,
+ api: str,
+ event: Optional[MessageEvent] = None,
+ **data) -> Any:
if self.connection_type != "http":
log("ERROR", "Only support http connection.")
return
- if "self_id" in data:
- self_id = data.pop("self_id")
- if self_id:
- bot = self.driver.bots[str(self_id)]
- return await bot.call_api(api, **data)
log("DEBUG", f"Calling API {api}")
params = {}
@@ -192,6 +168,33 @@ class Bot(BaseBot):
except httpx.HTTPError:
raise NetworkError("HTTP request failed")
+ @overrides(BaseBot)
+ async def call_api(self,
+ api: str,
+ event: Optional[MessageEvent] = None,
+ **data) -> Any:
+ """
+ :说明:
+
+ 调用 钉钉 协议 API
+
+ :参数:
+
+ * ``api: str``: API 名称
+ * ``event: Optional[MessageEvent]``: Event 对象
+ * ``**data: Any``: API 参数
+
+ :返回:
+
+ - ``Any``: API 调用返回数据
+
+ :异常:
+
+ - ``NetworkError``: 网络错误
+ - ``ActionFailed``: API 调用失败
+ """
+ return super().call_api(api, event=event, **data)
+
@overrides(BaseBot)
async def send(self,
event: MessageEvent,
diff --git a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py
index 1b598ebf..ebce2d74 100644
--- a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py
+++ b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py
@@ -218,6 +218,10 @@ class Bot(BaseBot):
except Exception as e:
Log.error(f'Failed to handle message: {message}', e)
+ @overrides(BaseBot)
+ async def _call_api(self, api: str, **data) -> NoReturn:
+ raise NotImplementedError
+
@overrides(BaseBot)
async def call_api(self, api: str, **data) -> NoReturn:
"""