mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-29 01:01:24 +00:00
🚧 add support of reverse post and forward ws for mirai adapter
This commit is contained in:
@ -1,3 +1,4 @@
|
|||||||
from .bot import MiraiBot
|
from .bot import MiraiBot
|
||||||
|
from .bot_ws import MiraiWebsocketBot
|
||||||
from .event import *
|
from .event import *
|
||||||
from .message import MessageChain, MessageSegment
|
from .message import MessageChain, MessageSegment
|
||||||
|
@ -1,19 +1,19 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from ipaddress import IPv4Address
|
from ipaddress import IPv4Address
|
||||||
from typing import Any, Dict, List, NoReturn, Optional, Tuple
|
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from nonebot.adapters import Bot as BaseBot
|
from nonebot.adapters import Bot as BaseBot
|
||||||
from nonebot.adapters import Event as BaseEvent
|
|
||||||
from nonebot.config import Config
|
from nonebot.config import Config
|
||||||
from nonebot.drivers import Driver, WebSocket
|
from nonebot.drivers import Driver, WebSocket
|
||||||
from nonebot.exception import RequestDenied
|
|
||||||
from nonebot.exception import ActionFailed as BaseActionFailed
|
from nonebot.exception import ActionFailed as BaseActionFailed
|
||||||
|
from nonebot.exception import RequestDenied
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.message import handle_event
|
from nonebot.message import handle_event
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
|
from nonebot.utils import escape_tag
|
||||||
|
|
||||||
from .config import Config as MiraiConfig
|
from .config import Config as MiraiConfig
|
||||||
from .event import Event, FriendMessage, GroupMessage, TempMessage
|
from .event import Event, FriendMessage, GroupMessage, TempMessage
|
||||||
@ -41,7 +41,8 @@ class SessionManager:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
|
def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
code = data.get('code', 0)
|
code = data.get('code', 0)
|
||||||
logger.debug(f'Mirai API returned data: {data}')
|
logger.opt(colors=True).debug('Mirai API returned data: '
|
||||||
|
f'<y>{escape_tag(str(data))}</y>')
|
||||||
if code != 0:
|
if code != 0:
|
||||||
raise ActionFailed(code, message=data['msg'])
|
raise ActionFailed(code, message=data['msg'])
|
||||||
return data
|
return data
|
||||||
@ -85,10 +86,10 @@ class SessionManager:
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
|
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
|
||||||
auth_key: str):
|
auth_key: str):
|
||||||
if self_id in cls.sessions:
|
session = cls.get(self_id)
|
||||||
manager = cls.get(self_id)
|
if session is not None:
|
||||||
if manager is not None:
|
return session
|
||||||
return manager
|
|
||||||
client = httpx.AsyncClient(base_url=f'http://{host}:{port}')
|
client = httpx.AsyncClient(base_url=f'http://{host}:{port}')
|
||||||
response = await client.post('/auth', json={'authKey': auth_key})
|
response = await client.post('/auth', json={'authKey': auth_key})
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -102,10 +103,13 @@ class SessionManager:
|
|||||||
})
|
})
|
||||||
assert response.json()['code'] == 0
|
assert response.json()['code'] == 0
|
||||||
cls.sessions[self_id] = session_key, datetime.now(), client
|
cls.sessions[self_id] = session_key, datetime.now(), client
|
||||||
|
|
||||||
return cls(session_key, client)
|
return cls(session_key, client)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, self_id: int):
|
def get(cls, self_id: int):
|
||||||
|
if self_id not in cls.sessions:
|
||||||
|
return None
|
||||||
key, time, client = cls.sessions[self_id]
|
key, time, client = cls.sessions[self_id]
|
||||||
if datetime.now() - time > cls.session_expiry:
|
if datetime.now() - time > cls.session_expiry:
|
||||||
return None
|
return None
|
||||||
@ -114,6 +118,7 @@ class SessionManager:
|
|||||||
|
|
||||||
class MiraiBot(BaseBot):
|
class MiraiBot(BaseBot):
|
||||||
|
|
||||||
|
@overrides(BaseBot)
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
connection_type: str,
|
connection_type: str,
|
||||||
self_id: str,
|
self_id: str,
|
||||||
@ -179,17 +184,20 @@ class MiraiBot(BaseBot):
|
|||||||
@overrides(BaseBot)
|
@overrides(BaseBot)
|
||||||
async def send(self,
|
async def send(self,
|
||||||
event: Event,
|
event: Event,
|
||||||
message: MessageChain,
|
message: Union[MessageChain, MessageSegment, str],
|
||||||
at_sender: bool = False,
|
at_sender: bool = False):
|
||||||
**kwargs):
|
if isinstance(message, MessageSegment):
|
||||||
|
message = MessageChain(message)
|
||||||
|
elif isinstance(message, str):
|
||||||
|
message = MessageChain(MessageSegment.plain(message))
|
||||||
if isinstance(event, FriendMessage):
|
if isinstance(event, FriendMessage):
|
||||||
return await self.send_friend_message(target=event.sender.id,
|
return await self.send_friend_message(target=event.sender.id,
|
||||||
message_chain=message)
|
message_chain=message)
|
||||||
elif isinstance(event, GroupMessage):
|
elif isinstance(event, GroupMessage):
|
||||||
return await self.send_group_message(
|
if at_sender:
|
||||||
group=event.sender.group.id,
|
message = MessageSegment.at(event.sender.id) + message
|
||||||
message_chain=message if not at_sender else
|
return await self.send_group_message(group=event.sender.group.id,
|
||||||
(MessageSegment.at(target=event.sender.id) + message))
|
message_chain=message)
|
||||||
elif isinstance(event, TempMessage):
|
elif isinstance(event, TempMessage):
|
||||||
return await self.send_temp_message(qq=event.sender.id,
|
return await self.send_temp_message(qq=event.sender.id,
|
||||||
group=event.sender.group.id,
|
group=event.sender.group.id,
|
||||||
|
@ -7,50 +7,21 @@ from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set,
|
|||||||
import httpx
|
import httpx
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
from nonebot.adapters import Bot as BaseBot
|
|
||||||
from nonebot.adapters import Event as BaseEvent
|
|
||||||
from nonebot.config import Config
|
from nonebot.config import Config
|
||||||
from nonebot.drivers import Driver
|
from nonebot.drivers import Driver
|
||||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||||
from nonebot.exception import RequestDenied
|
from nonebot.exception import RequestDenied
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.message import handle_event
|
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
|
|
||||||
|
from .bot import MiraiBot, SessionManager
|
||||||
from .config import Config as MiraiConfig
|
from .config import Config as MiraiConfig
|
||||||
from .event import Event
|
|
||||||
|
|
||||||
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
|
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
|
||||||
WebsocketHandler_T = TypeVar('WebsocketHandler_T',
|
WebsocketHandler_T = TypeVar('WebsocketHandler_T',
|
||||||
bound=WebsocketHandlerFunction)
|
bound=WebsocketHandlerFunction)
|
||||||
|
|
||||||
|
|
||||||
async def _ws_authorization(client: httpx.AsyncClient, *, auth_key: str,
|
|
||||||
qq: int) -> str:
|
|
||||||
|
|
||||||
async def request(method: str, *, path: str, **kwargs) -> Dict[str, Any]:
|
|
||||||
response = await client.request(method, path, **kwargs)
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
about = await request('GET', path='/about')
|
|
||||||
logger.opt(colors=True).debug('Mirai API HTTP backend version: '
|
|
||||||
f'<g><b>{about["data"]["version"]}</b></g>')
|
|
||||||
|
|
||||||
status = await request('POST', path='/auth', json={'authKey': auth_key})
|
|
||||||
assert status['code'] == 0
|
|
||||||
session_key = status['session']
|
|
||||||
|
|
||||||
verify = await request('POST',
|
|
||||||
path='/verify',
|
|
||||||
json={
|
|
||||||
'sessionKey': session_key,
|
|
||||||
'qq': qq
|
|
||||||
})
|
|
||||||
assert verify['code'] == 0, verify['msg']
|
|
||||||
return session_key
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocket(BaseWebSocket):
|
class WebSocket(BaseWebSocket):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -59,6 +30,7 @@ class WebSocket(BaseWebSocket):
|
|||||||
listen_address = httpx.URL(f'ws://{host}:{port}/all',
|
listen_address = httpx.URL(f'ws://{host}:{port}/all',
|
||||||
params={'sessionKey': session_key})
|
params={'sessionKey': session_key})
|
||||||
websocket = await websockets.connect(uri=str(listen_address))
|
websocket = await websockets.connect(uri=str(listen_address))
|
||||||
|
await (await websocket.ping())
|
||||||
return cls(websocket)
|
return cls(websocket)
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
@ -116,25 +88,24 @@ class WebSocket(BaseWebSocket):
|
|||||||
return callable
|
return callable
|
||||||
|
|
||||||
|
|
||||||
class MiraiWebsocketBot(BaseBot):
|
class MiraiWebsocketBot(MiraiBot):
|
||||||
|
|
||||||
|
@overrides(MiraiBot)
|
||||||
def __init__(self, connection_type: str, self_id: str, *,
|
def __init__(self, connection_type: str, self_id: str, *,
|
||||||
websocket: WebSocket):
|
websocket: WebSocket):
|
||||||
super().__init__(connection_type, self_id, websocket=websocket)
|
super().__init__(connection_type, self_id, websocket=websocket)
|
||||||
websocket.handle(self.handle_message)
|
|
||||||
self.driver._bot_connect(self)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@overrides(BaseBot)
|
@overrides(MiraiBot)
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
return "mirai"
|
return "mirai-ws"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def alive(self) -> bool:
|
def alive(self) -> bool:
|
||||||
return not self.websocket.closed
|
return not self.websocket.closed
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@overrides(BaseBot)
|
@overrides(MiraiBot)
|
||||||
async def check_permission(cls, driver: "Driver", connection_type: str,
|
async def check_permission(cls, driver: "Driver", connection_type: str,
|
||||||
headers: dict, body: Optional[dict]) -> NoReturn:
|
headers: dict, body: Optional[dict]) -> NoReturn:
|
||||||
raise RequestDenied(
|
raise RequestDenied(
|
||||||
@ -142,7 +113,7 @@ class MiraiWebsocketBot(BaseBot):
|
|||||||
reason=f'Connection {connection_type} not implented')
|
reason=f'Connection {connection_type} not implented')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@overrides(BaseBot)
|
@overrides(MiraiBot)
|
||||||
def register(cls, driver: "Driver", config: "Config", qq: int):
|
def register(cls, driver: "Driver", config: "Config", qq: int):
|
||||||
cls.mirai_config = MiraiConfig(**config.dict())
|
cls.mirai_config = MiraiConfig(**config.dict())
|
||||||
cls.active = True
|
cls.active = True
|
||||||
@ -152,32 +123,33 @@ class MiraiWebsocketBot(BaseBot):
|
|||||||
super().register(driver, config)
|
super().register(driver, config)
|
||||||
|
|
||||||
async def _bot_connection():
|
async def _bot_connection():
|
||||||
async with httpx.AsyncClient(
|
session: SessionManager = await SessionManager.new(
|
||||||
base_url=
|
qq,
|
||||||
f'http://{cls.mirai_config.host}:{cls.mirai_config.port}'
|
host=cls.mirai_config.host, # type: ignore
|
||||||
) as client:
|
port=cls.mirai_config.port, # type: ignore
|
||||||
session_key = await _ws_authorization(
|
auth_key=cls.mirai_config.auth_key # type: ignore
|
||||||
client,
|
)
|
||||||
auth_key=cls.mirai_config.auth_key, # type: ignore
|
|
||||||
qq=qq) # type: ignore
|
|
||||||
|
|
||||||
websocket = await WebSocket.new(
|
websocket = await WebSocket.new(
|
||||||
host=cls.mirai_config.host, # type: ignore
|
host=cls.mirai_config.host, # type: ignore
|
||||||
port=cls.mirai_config.port, # type: ignore
|
port=cls.mirai_config.port, # type: ignore
|
||||||
session_key=session_key)
|
session_key=session.session_key)
|
||||||
bot = cls(connection_type='forward_ws',
|
bot = cls(connection_type='forward_ws',
|
||||||
self_id=str(qq),
|
self_id=str(qq),
|
||||||
websocket=websocket)
|
websocket=websocket)
|
||||||
websocket.handle(bot.handle_message)
|
websocket.handle(bot.handle_message)
|
||||||
driver._clients[str(qq)] = bot
|
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
|
return bot
|
||||||
|
|
||||||
async def _connection_ensure():
|
async def _connection_ensure():
|
||||||
if str(qq) not in driver._clients:
|
self_id = str(qq)
|
||||||
await _bot_connection()
|
if self_id not in driver._clients:
|
||||||
elif not driver._clients[str(qq)].alive:
|
bot = await _bot_connection()
|
||||||
driver._clients.pop(str(qq), None)
|
driver._bot_connect(bot)
|
||||||
await _bot_connection()
|
else:
|
||||||
|
bot = driver._clients[self_id]
|
||||||
|
if not bot.alive:
|
||||||
|
driver._bot_disconnect(bot)
|
||||||
|
return
|
||||||
|
|
||||||
@driver.on_startup
|
@driver.on_startup
|
||||||
async def _startup():
|
async def _startup():
|
||||||
@ -202,19 +174,3 @@ class MiraiWebsocketBot(BaseBot):
|
|||||||
if bot is None:
|
if bot is None:
|
||||||
return
|
return
|
||||||
await bot.websocket.close() #type:ignore
|
await bot.websocket.close() #type:ignore
|
||||||
|
|
||||||
@overrides(BaseBot)
|
|
||||||
async def handle_message(self, message: dict):
|
|
||||||
event = Event.new(message)
|
|
||||||
await handle_event(self, event)
|
|
||||||
|
|
||||||
@overrides(BaseBot)
|
|
||||||
async def call_api(self, api: str, **data):
|
|
||||||
return super().call_api(api, **data)
|
|
||||||
|
|
||||||
@overrides(BaseBot)
|
|
||||||
async def send(self, event: "BaseEvent", message: str, **kwargs):
|
|
||||||
return super().send(event, message, **kwargs)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
self.driver._bot_disconnect(self)
|
|
||||||
|
@ -86,7 +86,7 @@ class Event(BaseEvent):
|
|||||||
|
|
||||||
@overrides(BaseEvent)
|
@overrides(BaseEvent)
|
||||||
def get_event_description(self) -> str:
|
def get_event_description(self) -> str:
|
||||||
return str(self.dict())
|
return str(self.normalize_dict())
|
||||||
|
|
||||||
@overrides(BaseEvent)
|
@overrides(BaseEvent)
|
||||||
def get_message(self) -> BaseMessage:
|
def get_message(self) -> BaseMessage:
|
||||||
|
@ -135,10 +135,11 @@ class MessageSegment(BaseMessageSegment):
|
|||||||
return cls(type=MessageType.POKE, name=name)
|
return cls(type=MessageType.POKE, name=name)
|
||||||
|
|
||||||
|
|
||||||
class MessageChain(BaseMessage):
|
class MessageChain(BaseMessage): #type:List[MessageSegment]
|
||||||
|
|
||||||
@overrides(BaseMessage)
|
@overrides(BaseMessage)
|
||||||
def __init__(self, message: Union[List[Dict[str, Any]], MessageSegment],
|
def __init__(self, message: Union[List[Dict[str, Any]],
|
||||||
|
Iterable[MessageSegment], MessageSegment],
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if isinstance(message, MessageSegment):
|
if isinstance(message, MessageSegment):
|
||||||
@ -152,15 +153,16 @@ class MessageChain(BaseMessage):
|
|||||||
|
|
||||||
@overrides(BaseMessage)
|
@overrides(BaseMessage)
|
||||||
def _construct(
|
def _construct(
|
||||||
self, message: Iterable[Union[Dict[str, Any], MessageSegment]]
|
self, message: Union[List[Dict[str, Any]], Iterable[MessageSegment]]
|
||||||
) -> List[MessageSegment]:
|
) -> List[MessageSegment]:
|
||||||
if isinstance(message, str):
|
if isinstance(message, str):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"String operation is not supported in mirai adapter")
|
"String operation is not supported in mirai adapter")
|
||||||
return [
|
return [
|
||||||
*map(
|
*map(
|
||||||
lambda segment: segment if isinstance(segment, MessageSegment)
|
lambda x: x
|
||||||
else MessageSegment(**segment), message)
|
if isinstance(x, MessageSegment) else MessageSegment(**x),
|
||||||
|
message)
|
||||||
]
|
]
|
||||||
|
|
||||||
def export(self) -> List[Dict[str, Any]]:
|
def export(self) -> List[Dict[str, Any]]:
|
||||||
|
Reference in New Issue
Block a user