mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-06-14 16:17:43 +00:00
👽 ✨ Add forward driver support for mirai-api-http adapter
This commit is contained in:
parent
cda1ad093f
commit
358528b495
@ -28,6 +28,9 @@ Mirai-API-HTTP 的适配器以 `AGPLv3许可`_ 单独开源
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from .bot import Bot
|
from .bot import Bot
|
||||||
from .bot_ws import WebsocketBot
|
|
||||||
from .event import *
|
from .event import *
|
||||||
from .message import MessageChain, MessageSegment
|
from .message import MessageChain, MessageSegment
|
||||||
|
"""
|
||||||
|
``WebsocketBot``现在已经和``Bot``合并, 并已经被弃用, 请直接使用``Bot``
|
||||||
|
"""
|
||||||
|
WebsocketBot = Bot
|
@ -1,15 +1,17 @@
|
|||||||
|
import json
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from ipaddress import IPv4Address
|
from ipaddress import IPv4Address
|
||||||
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
|
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from nonebot.config import Config
|
|
||||||
from nonebot.typing import overrides
|
|
||||||
from nonebot.adapters import Bot as BaseBot
|
from nonebot.adapters import Bot as BaseBot
|
||||||
|
from nonebot.config import Config
|
||||||
|
from nonebot.drivers import Driver, ReverseDriver, HTTPConnection, HTTPResponse, WebSocket, ForwardDriver, WebSocketSetup
|
||||||
from nonebot.exception import ApiNotAvailable
|
from nonebot.exception import ApiNotAvailable
|
||||||
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse, WebSocket
|
from nonebot.typing import overrides
|
||||||
|
|
||||||
from .config import Config as MiraiConfig
|
from .config import Config as MiraiConfig
|
||||||
from .event import Event, FriendMessage, GroupMessage, TempMessage
|
from .event import Event, FriendMessage, GroupMessage, TempMessage
|
||||||
@ -152,15 +154,12 @@ class Bot(BaseBot):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_type = 'mirai'
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@overrides(BaseBot)
|
@overrides(BaseBot)
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
return "mirai"
|
return self._type
|
||||||
|
|
||||||
@property
|
|
||||||
def alive(self) -> bool:
|
|
||||||
assert isinstance(self.request, WebSocket)
|
|
||||||
return not self.request.closed
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def api(self) -> SessionManager:
|
def api(self) -> SessionManager:
|
||||||
@ -190,21 +189,50 @@ class Bot(BaseBot):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@overrides(BaseBot)
|
@overrides(BaseBot)
|
||||||
def register(cls, driver: Driver, config: "Config"):
|
def register(cls,
|
||||||
|
driver: Driver,
|
||||||
|
config: "Config",
|
||||||
|
qq: Optional[int] = None):
|
||||||
cls.mirai_config = MiraiConfig(**config.dict())
|
cls.mirai_config = MiraiConfig(**config.dict())
|
||||||
if (cls.mirai_config.auth_key and cls.mirai_config.host and
|
if (cls.mirai_config.auth_key and cls.mirai_config.host and
|
||||||
cls.mirai_config.port) is None:
|
cls.mirai_config.port) is None:
|
||||||
raise ApiNotAvailable('mirai')
|
raise ApiNotAvailable(cls._type)
|
||||||
|
|
||||||
super().register(driver, config)
|
super().register(driver, config)
|
||||||
|
|
||||||
|
if not isinstance(driver, ForwardDriver) and qq:
|
||||||
|
logger.warning(
|
||||||
|
f"Current driver {cls.config.driver} don't support forward connections"
|
||||||
|
)
|
||||||
|
elif isinstance(driver, ForwardDriver) and qq:
|
||||||
|
|
||||||
|
async def url_factory():
|
||||||
|
assert cls.mirai_config.host and cls.mirai_config.port and cls.mirai_config.auth_key
|
||||||
|
session = await SessionManager.new(
|
||||||
|
qq, # type: ignore
|
||||||
|
host=cls.mirai_config.host,
|
||||||
|
port=cls.mirai_config.port,
|
||||||
|
auth_key=cls.mirai_config.auth_key)
|
||||||
|
return WebSocketSetup(
|
||||||
|
adapter=cls._type,
|
||||||
|
self_id=str(qq),
|
||||||
|
url=(f'ws://{cls.mirai_config.host}:{cls.mirai_config.port}'
|
||||||
|
f'/all?sessionKey={session.session_key}'))
|
||||||
|
|
||||||
|
driver.setup_websocket(url_factory)
|
||||||
|
elif isinstance(driver, ReverseDriver):
|
||||||
|
logger.debug(
|
||||||
|
'Param "qq" does not set for mirai adapter, use http post instead'
|
||||||
|
)
|
||||||
|
|
||||||
@overrides(BaseBot)
|
@overrides(BaseBot)
|
||||||
async def handle_message(self, message: dict):
|
async def handle_message(self, message: bytes):
|
||||||
Log.debug(f'received message {message}')
|
Log.debug(f'received message {message}')
|
||||||
try:
|
try:
|
||||||
await process_event(
|
await process_event(
|
||||||
bot=self,
|
bot=self,
|
||||||
event=Event.new({
|
event=Event.new({
|
||||||
**message,
|
**json.loads(message),
|
||||||
'self_id': self.self_id,
|
'self_id': self.self_id,
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
@ -1,202 +0,0 @@
|
|||||||
import json
|
|
||||||
import asyncio
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from ipaddress import IPv4Address
|
|
||||||
from typing import Any, Set, Dict, Tuple, TypeVar, Optional, Callable, Coroutine
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import websockets
|
|
||||||
|
|
||||||
from nonebot.log import logger
|
|
||||||
from nonebot.config import Config
|
|
||||||
from nonebot.typing import overrides
|
|
||||||
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse, WebSocket as BaseWebSocket
|
|
||||||
|
|
||||||
from .bot import SessionManager, Bot
|
|
||||||
|
|
||||||
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
|
|
||||||
WebsocketHandler_T = TypeVar('WebsocketHandler_T',
|
|
||||||
bound=WebsocketHandlerFunction)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class WebSocket(BaseWebSocket):
|
|
||||||
websocket: websockets.WebSocketClientProtocol = None # type: ignore
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def new(cls, *, host: IPv4Address, port: int,
|
|
||||||
session_key: str) -> "WebSocket":
|
|
||||||
listen_address = httpx.URL(f'ws://{host}:{port}/all',
|
|
||||||
params={'sessionKey': session_key})
|
|
||||||
websocket = await websockets.connect(uri=str(listen_address))
|
|
||||||
await (await websocket.ping())
|
|
||||||
return cls("1.1",
|
|
||||||
listen_address.scheme,
|
|
||||||
listen_address.path,
|
|
||||||
listen_address.query,
|
|
||||||
websocket=websocket)
|
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
|
||||||
def __init__(self,
|
|
||||||
http_version: str,
|
|
||||||
scheme: str,
|
|
||||||
path: str,
|
|
||||||
query_string: bytes = b"",
|
|
||||||
headers: Dict[str, str] = None,
|
|
||||||
websocket: websockets.WebSocketClientProtocol = None):
|
|
||||||
self.event_handlers: Set[WebsocketHandlerFunction] = set()
|
|
||||||
self.websocket: websockets.WebSocketClientProtocol = websocket # type: ignore
|
|
||||||
super(WebSocket, self).__init__(http_version=http_version,
|
|
||||||
scheme=scheme,
|
|
||||||
path=path,
|
|
||||||
query_string=query_string,
|
|
||||||
headers=headers or {})
|
|
||||||
|
|
||||||
@property
|
|
||||||
@overrides(BaseWebSocket)
|
|
||||||
def closed(self) -> bool:
|
|
||||||
return self.websocket.closed
|
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
|
||||||
async def send(self, data: str):
|
|
||||||
return await self.websocket.send(data)
|
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
|
||||||
async def send_bytes(self, data: str):
|
|
||||||
return await self.websocket.send(data)
|
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
|
||||||
async def receive(self) -> str:
|
|
||||||
return await self.websocket.recv() # type: ignore
|
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
|
||||||
async def receive_bytes(self) -> bytes:
|
|
||||||
return await self.websocket.recv() # type: ignore
|
|
||||||
|
|
||||||
async def _dispatcher(self):
|
|
||||||
while not self.closed:
|
|
||||||
try:
|
|
||||||
data = await self.receive()
|
|
||||||
except websockets.ConnectionClosedOK:
|
|
||||||
logger.debug(f'Websocket connection {self.websocket} closed')
|
|
||||||
break
|
|
||||||
except websockets.ConnectionClosedError:
|
|
||||||
logger.exception(f'Websocket connection {self.websocket} '
|
|
||||||
'connection closed abnormally:')
|
|
||||||
break
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.exception(f'Websocket client listened {self.websocket} '
|
|
||||||
f'failed to decode data: {e}')
|
|
||||||
continue
|
|
||||||
asyncio.gather(
|
|
||||||
*map(lambda f: f(data), self.event_handlers), #type: ignore
|
|
||||||
return_exceptions=True)
|
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
|
||||||
async def accept(self):
|
|
||||||
asyncio.create_task(self._dispatcher())
|
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
|
||||||
async def close(self):
|
|
||||||
await self.websocket.close()
|
|
||||||
|
|
||||||
def handle(self, callable: WebsocketHandler_T) -> WebsocketHandler_T:
|
|
||||||
self.event_handlers.add(callable)
|
|
||||||
return callable
|
|
||||||
|
|
||||||
|
|
||||||
class WebsocketBot(Bot):
|
|
||||||
"""
|
|
||||||
mirai-api-http 正向 Websocket 协议 Bot 适配。
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
@overrides(Bot)
|
|
||||||
def type(self) -> str:
|
|
||||||
return "mirai-ws"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def alive(self) -> bool:
|
|
||||||
assert isinstance(self.request, WebSocket)
|
|
||||||
return not self.request.closed
|
|
||||||
|
|
||||||
@property
|
|
||||||
def api(self) -> SessionManager:
|
|
||||||
api = SessionManager.get(self_id=int(self.self_id), check_expire=False)
|
|
||||||
assert api is not None, 'SessionManager has not been initialized'
|
|
||||||
return api
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@overrides(Bot)
|
|
||||||
async def check_permission(
|
|
||||||
cls, driver: Driver,
|
|
||||||
request: HTTPConnection) -> Tuple[None, HTTPResponse]:
|
|
||||||
return None, HTTPResponse(501, b'Connection not implented')
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@overrides(Bot)
|
|
||||||
def register(cls, driver: Driver, config: "Config", qq: int):
|
|
||||||
"""
|
|
||||||
:说明:
|
|
||||||
|
|
||||||
注册该Adapter
|
|
||||||
|
|
||||||
:参数:
|
|
||||||
|
|
||||||
* ``driver: Driver``: 程序所使用的``Driver``
|
|
||||||
* ``config: Config``: 程序配置对象
|
|
||||||
* ``qq: int``: 要使用的Bot的QQ号 **注意: 在使用正向Websocket时必须指定该值!**
|
|
||||||
"""
|
|
||||||
super().register(driver, config)
|
|
||||||
cls.active = True
|
|
||||||
|
|
||||||
async def _bot_connection():
|
|
||||||
session: SessionManager = await SessionManager.new(
|
|
||||||
qq,
|
|
||||||
host=cls.mirai_config.host, # type: ignore
|
|
||||||
port=cls.mirai_config.port, # type: ignore
|
|
||||||
auth_key=cls.mirai_config.auth_key # type: ignore
|
|
||||||
)
|
|
||||||
websocket = await WebSocket.new(
|
|
||||||
host=cls.mirai_config.host, # type: ignore
|
|
||||||
port=cls.mirai_config.port, # type: ignore
|
|
||||||
session_key=session.session_key)
|
|
||||||
bot = cls(self_id=str(qq), request=websocket)
|
|
||||||
websocket.handle(bot.handle_message)
|
|
||||||
await websocket.accept()
|
|
||||||
return bot
|
|
||||||
|
|
||||||
async def _connection_ensure():
|
|
||||||
self_id = str(qq)
|
|
||||||
if self_id not in driver._clients:
|
|
||||||
bot = await _bot_connection()
|
|
||||||
driver._bot_connect(bot)
|
|
||||||
else:
|
|
||||||
bot = driver._clients[self_id]
|
|
||||||
if not bot.alive:
|
|
||||||
driver._bot_disconnect(bot)
|
|
||||||
return
|
|
||||||
|
|
||||||
@driver.on_startup
|
|
||||||
async def _startup():
|
|
||||||
|
|
||||||
async def _checker():
|
|
||||||
while cls.active:
|
|
||||||
try:
|
|
||||||
await _connection_ensure()
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True).warning(
|
|
||||||
'Failed to create mirai connection to '
|
|
||||||
f'<y>{qq}</y>, reason: <r>{e}</r>. '
|
|
||||||
'Will retry after 3 seconds')
|
|
||||||
await asyncio.sleep(3)
|
|
||||||
|
|
||||||
asyncio.create_task(_checker())
|
|
||||||
|
|
||||||
@driver.on_shutdown
|
|
||||||
async def _shutdown():
|
|
||||||
cls.active = False
|
|
||||||
bot = driver._clients.pop(str(qq), None)
|
|
||||||
if bot is None:
|
|
||||||
return
|
|
||||||
await bot.websocket.close() #type:ignore
|
|
Loading…
x
Reference in New Issue
Block a user