mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-11-03 16:36:44 +00:00 
			
		
		
		
	⏪ revert call method to http post, add api handle
This commit is contained in:
		@@ -1 +1,3 @@
 | 
			
		||||
from .bot import MiraiBot
 | 
			
		||||
from .bot import MiraiBot
 | 
			
		||||
from .event import *
 | 
			
		||||
from .message import MessageChain, MessageSegment
 | 
			
		||||
 
 | 
			
		||||
@@ -1,128 +1,74 @@
 | 
			
		||||
import asyncio
 | 
			
		||||
import json
 | 
			
		||||
from typing import Any, Dict, Optional, Tuple
 | 
			
		||||
from datetime import datetime, timedelta
 | 
			
		||||
from ipaddress import IPv4Address
 | 
			
		||||
from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set,
 | 
			
		||||
                    TypeVar)
 | 
			
		||||
 | 
			
		||||
import httpx
 | 
			
		||||
import websockets
 | 
			
		||||
 | 
			
		||||
from nonebot.adapters import Bot as BaseBot
 | 
			
		||||
from nonebot.adapters import Event as BaseEvent
 | 
			
		||||
from nonebot.config import Config
 | 
			
		||||
from nonebot.drivers import Driver
 | 
			
		||||
from nonebot.drivers import WebSocket as BaseWebSocket
 | 
			
		||||
from nonebot.drivers import Driver, WebSocket
 | 
			
		||||
from nonebot.exception import RequestDenied
 | 
			
		||||
from nonebot.log import logger
 | 
			
		||||
from nonebot.message import handle_event
 | 
			
		||||
from nonebot.typing import overrides
 | 
			
		||||
 | 
			
		||||
from .config import Config as MiraiConfig
 | 
			
		||||
from .event import Event
 | 
			
		||||
 | 
			
		||||
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
 | 
			
		||||
WebsocketHandler_T = TypeVar('WebsocketHandler_T',
 | 
			
		||||
                             bound=WebsocketHandlerFunction)
 | 
			
		||||
from .event import Event, FriendMessage, TempMessage, GroupMessage
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def _ws_authorization(client: httpx.AsyncClient, *, auth_key: str,
 | 
			
		||||
                            qq: int) -> str:
 | 
			
		||||
class SessionManager:
 | 
			
		||||
    sessions: Dict[int, Tuple[str, datetime, httpx.AsyncClient]] = {}
 | 
			
		||||
    session_expiry: timedelta = timedelta(minutes=15)
 | 
			
		||||
 | 
			
		||||
    async def request(method: str, *, path: str, **kwargs) -> Dict[str, Any]:
 | 
			
		||||
        response = await client.request(method, path, **kwargs)
 | 
			
		||||
    def __init__(self, session_key: str, client: httpx.AsyncClient):
 | 
			
		||||
        self.session_key, self.client = session_key, client
 | 
			
		||||
 | 
			
		||||
    async def post(self, path: str, *, params: Optional[Dict[str, Any]] = None):
 | 
			
		||||
        params = {**(params or {}), 'sessionKey': self.session_key}
 | 
			
		||||
        response = await self.client.post(path, json=params)
 | 
			
		||||
        response.raise_for_status()
 | 
			
		||||
        return response.json()
 | 
			
		||||
 | 
			
		||||
    about = await request('GET', path='/about')
 | 
			
		||||
    logger.opt(colors=True).debug('Mirai API HTTP backend version: '
 | 
			
		||||
                                  f'<g><b>{about["data"]["version"]}</b></g>')
 | 
			
		||||
 | 
			
		||||
    status = await request('POST', path='/auth', json={'authKey': auth_key})
 | 
			
		||||
    assert status['code'] == 0
 | 
			
		||||
    session_key = status['session']
 | 
			
		||||
 | 
			
		||||
    verify = await request('POST',
 | 
			
		||||
                           path='/verify',
 | 
			
		||||
                           json={
 | 
			
		||||
                               'sessionKey': session_key,
 | 
			
		||||
                               'qq': qq
 | 
			
		||||
                           })
 | 
			
		||||
    assert verify['code'] == 0, verify['msg']
 | 
			
		||||
    return session_key
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WebSocket(BaseWebSocket):
 | 
			
		||||
    @classmethod
 | 
			
		||||
    async def new(cls, self_id: int, *, host: IPv4Address, port: int,
 | 
			
		||||
                  auth_key: str):
 | 
			
		||||
        if self_id in cls.sessions:
 | 
			
		||||
            manager = cls.get(self_id)
 | 
			
		||||
            if manager is not None:
 | 
			
		||||
                return manager
 | 
			
		||||
        client = httpx.AsyncClient(base_url=f'http://{host}:{port}')
 | 
			
		||||
        response = await client.post('/auth', json={'authKey': auth_key})
 | 
			
		||||
        response.raise_for_status()
 | 
			
		||||
        auth = response.json()
 | 
			
		||||
        assert auth['code'] == 0
 | 
			
		||||
        session_key = auth['session']
 | 
			
		||||
        response = await client.post('/verify',
 | 
			
		||||
                                     json={
 | 
			
		||||
                                         'sessionKey': session_key,
 | 
			
		||||
                                         'qq': self_id
 | 
			
		||||
                                     })
 | 
			
		||||
        assert response.json()['code'] == 0
 | 
			
		||||
        cls.sessions[self_id] = session_key, datetime.now(), client
 | 
			
		||||
        return cls(session_key, client)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    async def new(cls, *, host: IPv4Address, port: int,
 | 
			
		||||
                  session_key: str) -> "WebSocket":
 | 
			
		||||
        listen_address = httpx.URL(f'ws://{host}:{port}/all',
 | 
			
		||||
                                   params={'sessionKey': session_key})
 | 
			
		||||
        websocket = await websockets.connect(uri=str(listen_address))
 | 
			
		||||
        return cls(websocket)
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseWebSocket)
 | 
			
		||||
    def __init__(self, websocket: websockets.WebSocketClientProtocol):
 | 
			
		||||
        self.event_handlers: Set[WebsocketHandlerFunction] = set()
 | 
			
		||||
        super().__init__(websocket)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @overrides(BaseWebSocket)
 | 
			
		||||
    def websocket(self) -> websockets.WebSocketClientProtocol:
 | 
			
		||||
        return self._websocket
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @overrides(BaseWebSocket)
 | 
			
		||||
    def closed(self) -> bool:
 | 
			
		||||
        return self.websocket.closed
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseWebSocket)
 | 
			
		||||
    async def send(self, data: Dict[str, Any]):
 | 
			
		||||
        return await self.websocket.send(json.dumps(data))
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseWebSocket)
 | 
			
		||||
    async def receive(self) -> Dict[str, Any]:
 | 
			
		||||
        received = await self.websocket.recv()
 | 
			
		||||
        return json.loads(received)
 | 
			
		||||
 | 
			
		||||
    async def _dispatcher(self):
 | 
			
		||||
        while not self.closed:
 | 
			
		||||
            try:
 | 
			
		||||
                data = await self.receive()
 | 
			
		||||
            except websockets.ConnectionClosedOK:
 | 
			
		||||
                logger.debug(f'Websocket connection {self.websocket} closed')
 | 
			
		||||
                break
 | 
			
		||||
            except websockets.ConnectionClosedError:
 | 
			
		||||
                logger.exception(f'Websocket connection {self.websocket} '
 | 
			
		||||
                                 'connection closed abnormally:')
 | 
			
		||||
                break
 | 
			
		||||
            except json.JSONDecodeError as e:
 | 
			
		||||
                logger.exception(f'Websocket client listened {self.websocket} '
 | 
			
		||||
                                 f'failed to decode data: {e}')
 | 
			
		||||
                continue
 | 
			
		||||
            asyncio.gather(*map(lambda f: f(data), self.event_handlers),
 | 
			
		||||
                           return_exceptions=True)
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseWebSocket)
 | 
			
		||||
    async def accept(self):
 | 
			
		||||
        asyncio.create_task(self._dispatcher())
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseWebSocket)
 | 
			
		||||
    async def close(self):
 | 
			
		||||
        await self.websocket.close()
 | 
			
		||||
 | 
			
		||||
    def handle(self, callable: WebsocketHandler_T) -> WebsocketHandler_T:
 | 
			
		||||
        self.event_handlers.add(callable)
 | 
			
		||||
        return callable
 | 
			
		||||
    def get(cls, self_id: int):
 | 
			
		||||
        key, time, client = cls.sessions[self_id]
 | 
			
		||||
        if datetime.now() - time > cls.session_expiry:
 | 
			
		||||
            return None
 | 
			
		||||
        return cls(key, client)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MiraiBot(BaseBot):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, connection_type: str, self_id: str, *,
 | 
			
		||||
                 websocket: WebSocket):
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 connection_type: str,
 | 
			
		||||
                 self_id: str,
 | 
			
		||||
                 *,
 | 
			
		||||
                 websocket: Optional[WebSocket] = None):
 | 
			
		||||
        super().__init__(connection_type, self_id, websocket=websocket)
 | 
			
		||||
        websocket.handle(self.handle_message)
 | 
			
		||||
        self.driver._bot_connect(self)
 | 
			
		||||
        self.api = SessionManager.get(int(self_id))
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
@@ -136,85 +82,44 @@ class MiraiBot(BaseBot):
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
    async def check_permission(cls, driver: "Driver", connection_type: str,
 | 
			
		||||
                               headers: dict, body: Optional[dict]) -> NoReturn:
 | 
			
		||||
        raise RequestDenied(
 | 
			
		||||
            status_code=501,
 | 
			
		||||
            reason=f'Connection {connection_type} not implented')
 | 
			
		||||
                               headers: dict, body: Optional[dict]) -> str:
 | 
			
		||||
        if connection_type == 'ws':
 | 
			
		||||
            raise RequestDenied(
 | 
			
		||||
                status_code=501,
 | 
			
		||||
                reason='Websocket connection is not implemented')
 | 
			
		||||
        self_id: Optional[str] = headers.get('bot')
 | 
			
		||||
        if self_id is None:
 | 
			
		||||
            raise RequestDenied(status_code=400,
 | 
			
		||||
                                reason='Header `Bot` is required.')
 | 
			
		||||
        self_id = str(self_id).strip()
 | 
			
		||||
        await SessionManager.new(
 | 
			
		||||
            int(self_id),
 | 
			
		||||
            host=cls.mirai_config.host,  # type: ignore
 | 
			
		||||
            port=cls.mirai_config.port,  #type: ignore
 | 
			
		||||
            auth_key=cls.mirai_config.auth_key)  # type: ignore
 | 
			
		||||
        return self_id
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
    def register(cls, driver: "Driver", config: "Config", qq: int):
 | 
			
		||||
    def register(cls, driver: "Driver", config: "Config"):
 | 
			
		||||
        cls.mirai_config = MiraiConfig(**config.dict())
 | 
			
		||||
        cls.active = True
 | 
			
		||||
        assert cls.mirai_config.auth_key is not None
 | 
			
		||||
        assert cls.mirai_config.host is not None
 | 
			
		||||
        assert cls.mirai_config.port is not None
 | 
			
		||||
        super().register(driver, config)
 | 
			
		||||
 | 
			
		||||
        async def _bot_connection():
 | 
			
		||||
            async with httpx.AsyncClient(
 | 
			
		||||
                    base_url=
 | 
			
		||||
                    f'http://{cls.mirai_config.host}:{cls.mirai_config.port}'
 | 
			
		||||
            ) as client:
 | 
			
		||||
                session_key = await _ws_authorization(
 | 
			
		||||
                    client,
 | 
			
		||||
                    auth_key=cls.mirai_config.auth_key,  # type: ignore
 | 
			
		||||
                    qq=qq)  # type: ignore
 | 
			
		||||
 | 
			
		||||
            websocket = await WebSocket.new(
 | 
			
		||||
                host=cls.mirai_config.host,  # type: ignore
 | 
			
		||||
                port=cls.mirai_config.port,  # type: ignore
 | 
			
		||||
                session_key=session_key)
 | 
			
		||||
            bot = cls(connection_type='forward_ws',
 | 
			
		||||
                      self_id=str(qq),
 | 
			
		||||
                      websocket=websocket)
 | 
			
		||||
            websocket.handle(bot.handle_message)
 | 
			
		||||
            driver._clients[str(qq)] = bot
 | 
			
		||||
            await websocket.accept()
 | 
			
		||||
 | 
			
		||||
        async def _connection_ensure():
 | 
			
		||||
            if str(qq) not in driver._clients:
 | 
			
		||||
                await _bot_connection()
 | 
			
		||||
            elif not driver._clients[str(qq)].alive:
 | 
			
		||||
                driver._clients.pop(str(qq), None)
 | 
			
		||||
                await _bot_connection()
 | 
			
		||||
 | 
			
		||||
        @driver.on_startup
 | 
			
		||||
        async def _startup():
 | 
			
		||||
 | 
			
		||||
            async def _checker():
 | 
			
		||||
                while cls.active:
 | 
			
		||||
                    try:
 | 
			
		||||
                        await _connection_ensure()
 | 
			
		||||
                    except Exception as e:
 | 
			
		||||
                        logger.opt(colors=True).warning(
 | 
			
		||||
                            'Failed to create mirai connection to '
 | 
			
		||||
                            f'<y>{qq}</y>, reason: <r>{e}</r>. '
 | 
			
		||||
                            'Will retry after 3 seconds')
 | 
			
		||||
                    await asyncio.sleep(3)
 | 
			
		||||
 | 
			
		||||
            asyncio.create_task(_checker())
 | 
			
		||||
 | 
			
		||||
        @driver.on_shutdown
 | 
			
		||||
        async def _shutdown():
 | 
			
		||||
            cls.active = False
 | 
			
		||||
            bot = driver._clients.pop(str(qq), None)
 | 
			
		||||
            if bot is None:
 | 
			
		||||
                return
 | 
			
		||||
            await bot.websocket.close()  #type:ignore
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
    async def handle_message(self, message: dict):
 | 
			
		||||
        event = Event.new(message)
 | 
			
		||||
        await handle_event(self, event)
 | 
			
		||||
        await handle_event(bot=self,
 | 
			
		||||
                           event=Event.new({
 | 
			
		||||
                               **message,
 | 
			
		||||
                               'self_id': self.self_id,
 | 
			
		||||
                           }))
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
    async def call_api(self, api: str, **data):
 | 
			
		||||
        return super().call_api(api, **data)
 | 
			
		||||
        return await self.api.post('/' + api, params=data)
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
    async def send(self, event: "BaseEvent", message: str, **kwargs):
 | 
			
		||||
        return super().send(event, message, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def __del__(self):
 | 
			
		||||
        self.driver._bot_disconnect(self)
 | 
			
		||||
        pass
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										220
									
								
								nonebot/adapters/mirai/bot_ws.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										220
									
								
								nonebot/adapters/mirai/bot_ws.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,220 @@
 | 
			
		||||
import asyncio
 | 
			
		||||
import json
 | 
			
		||||
from ipaddress import IPv4Address
 | 
			
		||||
from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set,
 | 
			
		||||
                    TypeVar)
 | 
			
		||||
 | 
			
		||||
import httpx
 | 
			
		||||
import websockets
 | 
			
		||||
 | 
			
		||||
from nonebot.adapters import Bot as BaseBot
 | 
			
		||||
from nonebot.adapters import Event as BaseEvent
 | 
			
		||||
from nonebot.config import Config
 | 
			
		||||
from nonebot.drivers import Driver
 | 
			
		||||
from nonebot.drivers import WebSocket as BaseWebSocket
 | 
			
		||||
from nonebot.exception import RequestDenied
 | 
			
		||||
from nonebot.log import logger
 | 
			
		||||
from nonebot.message import handle_event
 | 
			
		||||
from nonebot.typing import overrides
 | 
			
		||||
 | 
			
		||||
from .config import Config as MiraiConfig
 | 
			
		||||
from .event import Event
 | 
			
		||||
 | 
			
		||||
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
 | 
			
		||||
WebsocketHandler_T = TypeVar('WebsocketHandler_T',
 | 
			
		||||
                             bound=WebsocketHandlerFunction)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def _ws_authorization(client: httpx.AsyncClient, *, auth_key: str,
 | 
			
		||||
                            qq: int) -> str:
 | 
			
		||||
 | 
			
		||||
    async def request(method: str, *, path: str, **kwargs) -> Dict[str, Any]:
 | 
			
		||||
        response = await client.request(method, path, **kwargs)
 | 
			
		||||
        response.raise_for_status()
 | 
			
		||||
        return response.json()
 | 
			
		||||
 | 
			
		||||
    about = await request('GET', path='/about')
 | 
			
		||||
    logger.opt(colors=True).debug('Mirai API HTTP backend version: '
 | 
			
		||||
                                  f'<g><b>{about["data"]["version"]}</b></g>')
 | 
			
		||||
 | 
			
		||||
    status = await request('POST', path='/auth', json={'authKey': auth_key})
 | 
			
		||||
    assert status['code'] == 0
 | 
			
		||||
    session_key = status['session']
 | 
			
		||||
 | 
			
		||||
    verify = await request('POST',
 | 
			
		||||
                           path='/verify',
 | 
			
		||||
                           json={
 | 
			
		||||
                               'sessionKey': session_key,
 | 
			
		||||
                               'qq': qq
 | 
			
		||||
                           })
 | 
			
		||||
    assert verify['code'] == 0, verify['msg']
 | 
			
		||||
    return session_key
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WebSocket(BaseWebSocket):
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    async def new(cls, *, host: IPv4Address, port: int,
 | 
			
		||||
                  session_key: str) -> "WebSocket":
 | 
			
		||||
        listen_address = httpx.URL(f'ws://{host}:{port}/all',
 | 
			
		||||
                                   params={'sessionKey': session_key})
 | 
			
		||||
        websocket = await websockets.connect(uri=str(listen_address))
 | 
			
		||||
        return cls(websocket)
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseWebSocket)
 | 
			
		||||
    def __init__(self, websocket: websockets.WebSocketClientProtocol):
 | 
			
		||||
        self.event_handlers: Set[WebsocketHandlerFunction] = set()
 | 
			
		||||
        super().__init__(websocket)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @overrides(BaseWebSocket)
 | 
			
		||||
    def websocket(self) -> websockets.WebSocketClientProtocol:
 | 
			
		||||
        return self._websocket
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @overrides(BaseWebSocket)
 | 
			
		||||
    def closed(self) -> bool:
 | 
			
		||||
        return self.websocket.closed
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseWebSocket)
 | 
			
		||||
    async def send(self, data: Dict[str, Any]):
 | 
			
		||||
        return await self.websocket.send(json.dumps(data))
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseWebSocket)
 | 
			
		||||
    async def receive(self) -> Dict[str, Any]:
 | 
			
		||||
        received = await self.websocket.recv()
 | 
			
		||||
        return json.loads(received)
 | 
			
		||||
 | 
			
		||||
    async def _dispatcher(self):
 | 
			
		||||
        while not self.closed:
 | 
			
		||||
            try:
 | 
			
		||||
                data = await self.receive()
 | 
			
		||||
            except websockets.ConnectionClosedOK:
 | 
			
		||||
                logger.debug(f'Websocket connection {self.websocket} closed')
 | 
			
		||||
                break
 | 
			
		||||
            except websockets.ConnectionClosedError:
 | 
			
		||||
                logger.exception(f'Websocket connection {self.websocket} '
 | 
			
		||||
                                 'connection closed abnormally:')
 | 
			
		||||
                break
 | 
			
		||||
            except json.JSONDecodeError as e:
 | 
			
		||||
                logger.exception(f'Websocket client listened {self.websocket} '
 | 
			
		||||
                                 f'failed to decode data: {e}')
 | 
			
		||||
                continue
 | 
			
		||||
            asyncio.gather(*map(lambda f: f(data), self.event_handlers),
 | 
			
		||||
                           return_exceptions=True)
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseWebSocket)
 | 
			
		||||
    async def accept(self):
 | 
			
		||||
        asyncio.create_task(self._dispatcher())
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseWebSocket)
 | 
			
		||||
    async def close(self):
 | 
			
		||||
        await self.websocket.close()
 | 
			
		||||
 | 
			
		||||
    def handle(self, callable: WebsocketHandler_T) -> WebsocketHandler_T:
 | 
			
		||||
        self.event_handlers.add(callable)
 | 
			
		||||
        return callable
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MiraiWebsocketBot(BaseBot):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, connection_type: str, self_id: str, *,
 | 
			
		||||
                 websocket: WebSocket):
 | 
			
		||||
        super().__init__(connection_type, self_id, websocket=websocket)
 | 
			
		||||
        websocket.handle(self.handle_message)
 | 
			
		||||
        self.driver._bot_connect(self)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
    def type(self) -> str:
 | 
			
		||||
        return "mirai"
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def alive(self) -> bool:
 | 
			
		||||
        return not self.websocket.closed
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
    async def check_permission(cls, driver: "Driver", connection_type: str,
 | 
			
		||||
                               headers: dict, body: Optional[dict]) -> NoReturn:
 | 
			
		||||
        raise RequestDenied(
 | 
			
		||||
            status_code=501,
 | 
			
		||||
            reason=f'Connection {connection_type} not implented')
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
    def register(cls, driver: "Driver", config: "Config", qq: int):
 | 
			
		||||
        cls.mirai_config = MiraiConfig(**config.dict())
 | 
			
		||||
        cls.active = True
 | 
			
		||||
        assert cls.mirai_config.auth_key is not None
 | 
			
		||||
        assert cls.mirai_config.host is not None
 | 
			
		||||
        assert cls.mirai_config.port is not None
 | 
			
		||||
        super().register(driver, config)
 | 
			
		||||
 | 
			
		||||
        async def _bot_connection():
 | 
			
		||||
            async with httpx.AsyncClient(
 | 
			
		||||
                    base_url=
 | 
			
		||||
                    f'http://{cls.mirai_config.host}:{cls.mirai_config.port}'
 | 
			
		||||
            ) as client:
 | 
			
		||||
                session_key = await _ws_authorization(
 | 
			
		||||
                    client,
 | 
			
		||||
                    auth_key=cls.mirai_config.auth_key,  # type: ignore
 | 
			
		||||
                    qq=qq)  # type: ignore
 | 
			
		||||
 | 
			
		||||
            websocket = await WebSocket.new(
 | 
			
		||||
                host=cls.mirai_config.host,  # type: ignore
 | 
			
		||||
                port=cls.mirai_config.port,  # type: ignore
 | 
			
		||||
                session_key=session_key)
 | 
			
		||||
            bot = cls(connection_type='forward_ws',
 | 
			
		||||
                      self_id=str(qq),
 | 
			
		||||
                      websocket=websocket)
 | 
			
		||||
            websocket.handle(bot.handle_message)
 | 
			
		||||
            driver._clients[str(qq)] = bot
 | 
			
		||||
            await websocket.accept()
 | 
			
		||||
 | 
			
		||||
        async def _connection_ensure():
 | 
			
		||||
            if str(qq) not in driver._clients:
 | 
			
		||||
                await _bot_connection()
 | 
			
		||||
            elif not driver._clients[str(qq)].alive:
 | 
			
		||||
                driver._clients.pop(str(qq), None)
 | 
			
		||||
                await _bot_connection()
 | 
			
		||||
 | 
			
		||||
        @driver.on_startup
 | 
			
		||||
        async def _startup():
 | 
			
		||||
 | 
			
		||||
            async def _checker():
 | 
			
		||||
                while cls.active:
 | 
			
		||||
                    try:
 | 
			
		||||
                        await _connection_ensure()
 | 
			
		||||
                    except Exception as e:
 | 
			
		||||
                        logger.opt(colors=True).warning(
 | 
			
		||||
                            'Failed to create mirai connection to '
 | 
			
		||||
                            f'<y>{qq}</y>, reason: <r>{e}</r>. '
 | 
			
		||||
                            'Will retry after 3 seconds')
 | 
			
		||||
                    await asyncio.sleep(3)
 | 
			
		||||
 | 
			
		||||
            asyncio.create_task(_checker())
 | 
			
		||||
 | 
			
		||||
        @driver.on_shutdown
 | 
			
		||||
        async def _shutdown():
 | 
			
		||||
            cls.active = False
 | 
			
		||||
            bot = driver._clients.pop(str(qq), None)
 | 
			
		||||
            if bot is None:
 | 
			
		||||
                return
 | 
			
		||||
            await bot.websocket.close()  #type:ignore
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
    async def handle_message(self, message: dict):
 | 
			
		||||
        event = Event.new(message)
 | 
			
		||||
        await handle_event(self, event)
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
    async def call_api(self, api: str, **data):
 | 
			
		||||
        return super().call_api(api, **data)
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
    async def send(self, event: "BaseEvent", message: str, **kwargs):
 | 
			
		||||
        return super().send(event, message, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def __del__(self):
 | 
			
		||||
        self.driver._bot_disconnect(self)
 | 
			
		||||
@@ -37,6 +37,7 @@ class PrivateSenderInfo(BaseModel):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Event(BaseEvent):
 | 
			
		||||
    self_id: int
 | 
			
		||||
    type: str
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user