diff --git a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/__init__.py b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/__init__.py
index 5adc7a16..68c35ca4 100644
--- a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/__init__.py
+++ b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/__init__.py
@@ -28,6 +28,9 @@ Mirai-API-HTTP 的适配器以 `AGPLv3许可`_ 单独开源
"""
from .bot import Bot
-from .bot_ws import WebsocketBot
from .event import *
from .message import MessageChain, MessageSegment
+"""
+``WebsocketBot``现在已经和``Bot``合并, 并已经被弃用, 请直接使用``Bot``
+"""
+WebsocketBot = Bot
\ No newline at end of file
diff --git a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py
index 4b10d446..de96a29e 100644
--- a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py
+++ b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py
@@ -1,15 +1,17 @@
+import json
from datetime import datetime, timedelta
from io import BytesIO
from ipaddress import IPv4Address
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
import httpx
+from loguru import logger
-from nonebot.config import Config
-from nonebot.typing import overrides
from nonebot.adapters import Bot as BaseBot
+from nonebot.config import Config
+from nonebot.drivers import Driver, ReverseDriver, HTTPConnection, HTTPResponse, WebSocket, ForwardDriver, WebSocketSetup
from nonebot.exception import ApiNotAvailable
-from nonebot.drivers import Driver, HTTPConnection, HTTPResponse, WebSocket
+from nonebot.typing import overrides
from .config import Config as MiraiConfig
from .event import Event, FriendMessage, GroupMessage, TempMessage
@@ -152,15 +154,12 @@ class Bot(BaseBot):
"""
+ _type = 'mirai'
+
@property
@overrides(BaseBot)
def type(self) -> str:
- return "mirai"
-
- @property
- def alive(self) -> bool:
- assert isinstance(self.request, WebSocket)
- return not self.request.closed
+ return self._type
@property
def api(self) -> SessionManager:
@@ -190,21 +189,50 @@ class Bot(BaseBot):
@classmethod
@overrides(BaseBot)
- def register(cls, driver: Driver, config: "Config"):
+ def register(cls,
+ driver: Driver,
+ config: "Config",
+ qq: Optional[int] = None):
cls.mirai_config = MiraiConfig(**config.dict())
if (cls.mirai_config.auth_key and cls.mirai_config.host and
cls.mirai_config.port) is None:
- raise ApiNotAvailable('mirai')
+ raise ApiNotAvailable(cls._type)
+
super().register(driver, config)
+ if not isinstance(driver, ForwardDriver) and qq:
+ logger.warning(
+ f"Current driver {cls.config.driver} don't support forward connections"
+ )
+ elif isinstance(driver, ForwardDriver) and qq:
+
+ async def url_factory():
+ assert cls.mirai_config.host and cls.mirai_config.port and cls.mirai_config.auth_key
+ session = await SessionManager.new(
+ qq, # type: ignore
+ host=cls.mirai_config.host,
+ port=cls.mirai_config.port,
+ auth_key=cls.mirai_config.auth_key)
+ return WebSocketSetup(
+ adapter=cls._type,
+ self_id=str(qq),
+ url=(f'ws://{cls.mirai_config.host}:{cls.mirai_config.port}'
+ f'/all?sessionKey={session.session_key}'))
+
+ driver.setup_websocket(url_factory)
+ elif isinstance(driver, ReverseDriver):
+ logger.debug(
+ 'Param "qq" does not set for mirai adapter, use http post instead'
+ )
+
@overrides(BaseBot)
- async def handle_message(self, message: dict):
+ async def handle_message(self, message: bytes):
Log.debug(f'received message {message}')
try:
await process_event(
bot=self,
event=Event.new({
- **message,
+ **json.loads(message),
'self_id': self.self_id,
}),
)
diff --git a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot_ws.py b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot_ws.py
deleted file mode 100644
index 29fc12bf..00000000
--- a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot_ws.py
+++ /dev/null
@@ -1,202 +0,0 @@
-import json
-import asyncio
-from dataclasses import dataclass
-from ipaddress import IPv4Address
-from typing import Any, Set, Dict, Tuple, TypeVar, Optional, Callable, Coroutine
-
-import httpx
-import websockets
-
-from nonebot.log import logger
-from nonebot.config import Config
-from nonebot.typing import overrides
-from nonebot.drivers import Driver, HTTPConnection, HTTPResponse, WebSocket as BaseWebSocket
-
-from .bot import SessionManager, Bot
-
-WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
-WebsocketHandler_T = TypeVar('WebsocketHandler_T',
- bound=WebsocketHandlerFunction)
-
-
-@dataclass
-class WebSocket(BaseWebSocket):
- websocket: websockets.WebSocketClientProtocol = None # type: ignore
-
- @classmethod
- async def new(cls, *, host: IPv4Address, port: int,
- session_key: str) -> "WebSocket":
- listen_address = httpx.URL(f'ws://{host}:{port}/all',
- params={'sessionKey': session_key})
- websocket = await websockets.connect(uri=str(listen_address))
- await (await websocket.ping())
- return cls("1.1",
- listen_address.scheme,
- listen_address.path,
- listen_address.query,
- websocket=websocket)
-
- @overrides(BaseWebSocket)
- def __init__(self,
- http_version: str,
- scheme: str,
- path: str,
- query_string: bytes = b"",
- headers: Dict[str, str] = None,
- websocket: websockets.WebSocketClientProtocol = None):
- self.event_handlers: Set[WebsocketHandlerFunction] = set()
- self.websocket: websockets.WebSocketClientProtocol = websocket # type: ignore
- super(WebSocket, self).__init__(http_version=http_version,
- scheme=scheme,
- path=path,
- query_string=query_string,
- headers=headers or {})
-
- @property
- @overrides(BaseWebSocket)
- def closed(self) -> bool:
- return self.websocket.closed
-
- @overrides(BaseWebSocket)
- async def send(self, data: str):
- return await self.websocket.send(data)
-
- @overrides(BaseWebSocket)
- async def send_bytes(self, data: str):
- return await self.websocket.send(data)
-
- @overrides(BaseWebSocket)
- async def receive(self) -> str:
- return await self.websocket.recv() # type: ignore
-
- @overrides(BaseWebSocket)
- async def receive_bytes(self) -> bytes:
- return await self.websocket.recv() # type: ignore
-
- async def _dispatcher(self):
- while not self.closed:
- try:
- data = await self.receive()
- except websockets.ConnectionClosedOK:
- logger.debug(f'Websocket connection {self.websocket} closed')
- break
- except websockets.ConnectionClosedError:
- logger.exception(f'Websocket connection {self.websocket} '
- 'connection closed abnormally:')
- break
- except json.JSONDecodeError as e:
- logger.exception(f'Websocket client listened {self.websocket} '
- f'failed to decode data: {e}')
- continue
- asyncio.gather(
- *map(lambda f: f(data), self.event_handlers), #type: ignore
- return_exceptions=True)
-
- @overrides(BaseWebSocket)
- async def accept(self):
- asyncio.create_task(self._dispatcher())
-
- @overrides(BaseWebSocket)
- async def close(self):
- await self.websocket.close()
-
- def handle(self, callable: WebsocketHandler_T) -> WebsocketHandler_T:
- self.event_handlers.add(callable)
- return callable
-
-
-class WebsocketBot(Bot):
- """
- mirai-api-http 正向 Websocket 协议 Bot 适配。
- """
-
- @property
- @overrides(Bot)
- def type(self) -> str:
- return "mirai-ws"
-
- @property
- def alive(self) -> bool:
- assert isinstance(self.request, WebSocket)
- return not self.request.closed
-
- @property
- def api(self) -> SessionManager:
- api = SessionManager.get(self_id=int(self.self_id), check_expire=False)
- assert api is not None, 'SessionManager has not been initialized'
- return api
-
- @classmethod
- @overrides(Bot)
- async def check_permission(
- cls, driver: Driver,
- request: HTTPConnection) -> Tuple[None, HTTPResponse]:
- return None, HTTPResponse(501, b'Connection not implented')
-
- @classmethod
- @overrides(Bot)
- def register(cls, driver: Driver, config: "Config", qq: int):
- """
- :说明:
-
- 注册该Adapter
-
- :参数:
-
- * ``driver: Driver``: 程序所使用的``Driver``
- * ``config: Config``: 程序配置对象
- * ``qq: int``: 要使用的Bot的QQ号 **注意: 在使用正向Websocket时必须指定该值!**
- """
- super().register(driver, config)
- cls.active = True
-
- async def _bot_connection():
- session: SessionManager = await SessionManager.new(
- qq,
- host=cls.mirai_config.host, # type: ignore
- port=cls.mirai_config.port, # type: ignore
- auth_key=cls.mirai_config.auth_key # type: ignore
- )
- websocket = await WebSocket.new(
- host=cls.mirai_config.host, # type: ignore
- port=cls.mirai_config.port, # type: ignore
- session_key=session.session_key)
- bot = cls(self_id=str(qq), request=websocket)
- websocket.handle(bot.handle_message)
- await websocket.accept()
- return bot
-
- async def _connection_ensure():
- self_id = str(qq)
- if self_id not in driver._clients:
- bot = await _bot_connection()
- driver._bot_connect(bot)
- else:
- bot = driver._clients[self_id]
- if not bot.alive:
- driver._bot_disconnect(bot)
- return
-
- @driver.on_startup
- async def _startup():
-
- async def _checker():
- while cls.active:
- try:
- await _connection_ensure()
- except Exception as e:
- logger.opt(colors=True).warning(
- 'Failed to create mirai connection to '
- f'{qq}, reason: {e}. '
- 'Will retry after 3 seconds')
- await asyncio.sleep(3)
-
- asyncio.create_task(_checker())
-
- @driver.on_shutdown
- async def _shutdown():
- cls.active = False
- bot = driver._clients.pop(str(qq), None)
- if bot is None:
- return
- await bot.websocket.close() #type:ignore