🐛 update mirai adapter

This commit is contained in:
yanyongyu
2021-06-18 01:23:13 +08:00
parent a4c6d834ff
commit cd12718dcb
9 changed files with 56 additions and 44 deletions

View File

@ -30,12 +30,27 @@ class WebSocket(BaseWebSocket):
params={'sessionKey': session_key})
websocket = await websockets.connect(uri=str(listen_address))
await (await websocket.ping())
return cls(websocket)
return cls("1.1",
listen_address.scheme,
listen_address.path,
listen_address.query,
websocket=websocket)
@overrides(BaseWebSocket)
def __init__(self, websocket: websockets.WebSocketClientProtocol):
def __init__(self,
http_version: str,
scheme: str,
path: str,
query_string: bytes = b"",
headers: Dict[str, str] = None,
websocket: websockets.WebSocketClientProtocol = None):
self.event_handlers: Set[WebsocketHandlerFunction] = set()
super().__init__(websocket)
self.websocket: websockets.WebSocketClientProtocol = websocket # type: ignore
super(WebSocket, self).__init__(http_version=http_version,
scheme=scheme,
path=path,
query_string=query_string,
headers=headers or {})
@property
@overrides(BaseWebSocket)
@ -146,9 +161,7 @@ class WebsocketBot(Bot):
host=cls.mirai_config.host, # type: ignore
port=cls.mirai_config.port, # type: ignore
session_key=session.session_key)
bot = cls(connection_type='forward_ws',
self_id=str(qq),
websocket=websocket)
bot = cls(self_id=str(qq), request=websocket)
websocket.handle(bot.handle_message)
await websocket.accept()
return bot

View File

@ -1,5 +1,5 @@
"""
\:\:\: warning
r"""
\:\:\: warning
事件中为了使代码更加整洁, 我们采用了与PEP8相符的命名规则取代Mirai原有的驼峰命名
部分字段可能与文档在符号上不一致

View File

@ -14,12 +14,12 @@ from nonebot.typing import overrides
class UserPermission(str, Enum):
"""
:说明:
用户权限枚举类
* ``OWNER``: 群主
* ``ADMINISTRATOR``: 群管理
* ``MEMBER``: 普通群成员
用户权限枚举类
* ``OWNER``: 群主
* ``ADMINISTRATOR``: 群管理
* ``MEMBER``: 普通群成员
"""
OWNER = 'OWNER'
ADMINISTRATOR = 'ADMINISTRATOR'

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, List, Dict, Type, Iterable, Optional, Union
from pydantic import validate_arguments
@ -25,7 +25,7 @@ class MessageType(str, Enum):
POKE = 'Poke'
class MessageSegment(BaseMessageSegment):
class MessageSegment(BaseMessageSegment["MessageChain"]):
"""
Mirai-API-HTTP 协议 MessageSegment 适配。具体方法参考 `mirai-api-http 消息类型`_
@ -36,9 +36,13 @@ class MessageSegment(BaseMessageSegment):
type: MessageType
data: Dict[str, Any]
@overrides(BaseMessageSegment)
@classmethod
def get_message_class(cls) -> Type["MessageChain"]:
return MessageChain
@validate_arguments
def __init__(self, type: MessageType, **data):
@overrides(BaseMessageSegment)
def __init__(self, type: MessageType, **data: Any):
super().__init__(type=type,
data={k: v for k, v in data.items() if v is not None})
@ -55,14 +59,6 @@ class MessageSegment(BaseMessageSegment):
),
])
@overrides(BaseMessageSegment)
def __add__(self, other) -> "MessageChain":
return MessageChain(self) + other
@overrides(BaseMessageSegment)
def __radd__(self, other) -> "MessageChain":
return MessageChain(other) + self
@overrides(BaseMessageSegment)
def is_text(self) -> bool:
return self.type == MessageType.PLAIN
@ -273,6 +269,11 @@ class MessageChain(BaseMessage[MessageSegment]):
由于Mirai协议的Message实现较为特殊, 故使用MessageChain命名
"""
@classmethod
@overrides(BaseMessage)
def get_segment_class(cls) -> Type[MessageSegment]:
return MessageSegment
@overrides(BaseMessage)
def __init__(self, message: Union[List[Dict[str,
Any]], Iterable[MessageSegment],

View File

@ -73,7 +73,7 @@ class InvalidArgument(exception.AdapterException):
def catch_network_error(function: _AsyncCallable) -> _AsyncCallable:
"""
r"""
:说明:
捕捉函数抛出的httpx网络异常并释放 ``NetworkError`` 异常
@ -170,7 +170,6 @@ def process_reply(bot: "Bot", event: GroupMessage) -> GroupMessage:
async def process_event(bot: "Bot", event: Event) -> None:
if isinstance(event, MessageEvent):
event.message_chain.reduce()
Log.debug(event.message_chain)
event = process_source(bot, event)
if isinstance(event, GroupMessage):