mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-27 08:11:38 +00:00
⚗️ add call_api hook
This commit is contained in:
@ -6,20 +6,30 @@
|
||||
"""
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
from copy import copy
|
||||
from functools import reduce, partial
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Union, TypeVar, Mapping, Optional, Callable, Iterable, Iterator, Awaitable, TYPE_CHECKING
|
||||
from typing import (Any, Set, Dict, Union, TypeVar, Mapping, Optional, Iterable,
|
||||
Protocol, Awaitable, TYPE_CHECKING)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.utils import DataclassEncoder
|
||||
from nonebot.typing import T_CallingAPIHook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.config import Config
|
||||
from nonebot.drivers import Driver, WebSocket
|
||||
|
||||
|
||||
class _ApiCall(Protocol):
|
||||
|
||||
def __call__(self, **kwargs: Any) -> Awaitable[Any]:
|
||||
...
|
||||
|
||||
|
||||
class Bot(abc.ABC):
|
||||
"""
|
||||
Bot 基类。用于处理上报消息,并提供 API 调用接口。
|
||||
@ -29,6 +39,11 @@ class Bot(abc.ABC):
|
||||
"""Driver 对象"""
|
||||
config: "Config"
|
||||
"""Config 配置对象"""
|
||||
_call_api_hook: Set[T_CallingAPIHook] = set()
|
||||
"""
|
||||
:类型: ``Set[T_CallingAPIHook]``
|
||||
:说明: call_api 时执行的函数
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self,
|
||||
@ -50,7 +65,7 @@ class Bot(abc.ABC):
|
||||
self.websocket = websocket
|
||||
"""Websocket 连接对象"""
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Awaitable[Any]]:
|
||||
def __getattr__(self, name: str) -> _ApiCall:
|
||||
return partial(self.call_api, name)
|
||||
|
||||
@property
|
||||
@ -109,7 +124,20 @@ class Bot(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def call_api(self, api: str, **data) -> Any:
|
||||
async def _call_api(self, api: str, **data) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
``adapter`` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
|
||||
|
||||
:参数:
|
||||
|
||||
* ``api: str``: API 名称
|
||||
* ``**data``: API 数据
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def call_api(self, api: str, **data: Any) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -118,6 +146,7 @@ class Bot(abc.ABC):
|
||||
:参数:
|
||||
|
||||
* ``api: str``: API 名称
|
||||
* ``self_id: Optional[str]``: 指定调用 API 的机器人
|
||||
* ``**data``: API 数据
|
||||
|
||||
:示例:
|
||||
@ -127,7 +156,23 @@ class Bot(abc.ABC):
|
||||
await bot.call_api("send_msg", message="hello world")
|
||||
await bot.send_msg(message="hello world")
|
||||
"""
|
||||
raise NotImplementedError
|
||||
coros = list(map(lambda x: x(api, data), self._call_api_hook))
|
||||
if coros:
|
||||
try:
|
||||
logger.debug("Running CallingAPI hooks...")
|
||||
await asyncio.gather(*coros)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
|
||||
"Running cancelled!</bg #f8bbd0></r>")
|
||||
|
||||
if "self_id" in data:
|
||||
self_id = data.pop("self_id")
|
||||
if self_id:
|
||||
bot = self.driver.bots[str(self_id)]
|
||||
return await bot._call_api(api, **data)
|
||||
|
||||
return await self._call_api(api, **data)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def send(self, event: "Event", message: Union[str, "Message",
|
||||
@ -146,6 +191,11 @@ class Bot(abc.ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook:
|
||||
cls._call_api_hook.add(func)
|
||||
return func
|
||||
|
||||
|
||||
T_Message = TypeVar("T_Message", bound="Message")
|
||||
T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment")
|
||||
|
Reference in New Issue
Block a user