🎨 format code using black and isort

This commit is contained in:
yanyongyu
2021-11-22 23:21:26 +08:00
parent 602185a34e
commit a98d98cd12
86 changed files with 2893 additions and 2095 deletions

View File

@ -12,8 +12,15 @@ from nonebot.config import Config
from nonebot.typing import overrides
from nonebot.adapters import Bot as BaseBot
from nonebot.exception import ApiNotAvailable
from nonebot.drivers import (Driver, WebSocket, HTTPResponse, ForwardDriver,
ReverseDriver, HTTPConnection, WebSocketSetup)
from nonebot.drivers import (
Driver,
WebSocket,
HTTPResponse,
ForwardDriver,
ReverseDriver,
HTTPConnection,
WebSocketSetup,
)
from .config import Config as MiraiConfig
from .message import MessageChain, MessageSegment
@ -23,16 +30,14 @@ from .utils import Log, process_event, argument_validation, catch_network_error
class SessionManager:
"""Bot会话管理器, 提供API主动调用接口"""
sessions: Dict[int, Tuple[str, httpx.AsyncClient]] = {}
def __init__(self, session_key: str, client: httpx.AsyncClient):
self.session_key, self.client = session_key, client
@catch_network_error
async def post(self,
path: str,
*,
params: Optional[Dict[str, Any]] = None) -> Any:
async def post(self, path: str, *, params: Optional[Dict[str, Any]] = None) -> Any:
"""
:说明:
@ -51,7 +56,7 @@ class SessionManager:
path,
json={
**(params or {}),
'sessionKey': self.session_key,
"sessionKey": self.session_key,
},
timeout=3,
)
@ -59,10 +64,9 @@ class SessionManager:
return response.json()
@catch_network_error
async def request(self,
path: str,
*,
params: Optional[Dict[str, Any]] = None) -> Any:
async def request(
self, path: str, *, params: Optional[Dict[str, Any]] = None
) -> Any:
"""
:说明:
@ -77,7 +81,7 @@ class SessionManager:
path,
params={
**(params or {}),
'sessionKey': self.session_key,
"sessionKey": self.session_key,
},
timeout=3,
)
@ -98,7 +102,7 @@ class SessionManager:
"""
files = {k: v for k, v in params.items() if isinstance(v, BytesIO)}
form = {k: v for k, v in params.items() if k not in files}
form['sessionKey'] = self.session_key
form["sessionKey"] = self.session_key
response = await self.client.post(
path,
data=form,
@ -109,25 +113,25 @@ class SessionManager:
return response.json()
@classmethod
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
auth_key: str) -> "SessionManager":
async def new(
cls, self_id: int, *, host: IPv4Address, port: int, auth_key: str
) -> "SessionManager":
session = cls.get(self_id)
if session is not None:
return session
client = httpx.AsyncClient(base_url=f'http://{host}:{port}',
follow_redirects=True)
response = await client.post('/auth', json={'authKey': auth_key})
client = httpx.AsyncClient(
base_url=f"http://{host}:{port}", follow_redirects=True
)
response = await client.post("/auth", json={"authKey": auth_key})
response.raise_for_status()
auth = response.json()
assert auth['code'] == 0
session_key = auth['session']
response = await client.post('/verify',
json={
'sessionKey': session_key,
'qq': self_id
})
assert response.json()['code'] == 0
assert auth["code"] == 0
session_key = auth["session"]
response = await client.post(
"/verify", json={"sessionKey": session_key, "qq": self_id}
)
assert response.json()["code"] == 0
cls.sessions[self_id] = session_key, client
return cls(session_key, client)
@ -152,7 +156,7 @@ class Bot(BaseBot):
"""
_type = 'mirai'
_type = "mirai"
@property
@overrides(BaseBot)
@ -166,37 +170,42 @@ class Bot(BaseBot):
if api is None:
if isinstance(self.request, WebSocket):
asyncio.create_task(self.request.close(1000))
assert api is not None, 'SessionManager has not been initialized'
assert api is not None, "SessionManager has not been initialized"
return api
@classmethod
@overrides(BaseBot)
async def check_permission(
cls, driver: Driver,
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]:
cls, driver: Driver, request: HTTPConnection
) -> Tuple[Optional[str], HTTPResponse]:
if isinstance(request, WebSocket):
return None, HTTPResponse(
501, b'Websocket connection is not implemented')
self_id: Optional[str] = request.headers.get('bot')
return None, HTTPResponse(501, b"Websocket connection is not implemented")
self_id: Optional[str] = request.headers.get("bot")
if self_id is None:
return None, HTTPResponse(400, b'Header `Bot` is required.')
return None, HTTPResponse(400, b"Header `Bot` is required.")
self_id = str(self_id).strip()
await SessionManager.new(
int(self_id),
host=cls.mirai_config.host, # type: ignore
port=cls.mirai_config.port, #type: ignore
auth_key=cls.mirai_config.auth_key) # type: ignore
return self_id, HTTPResponse(204, b'')
port=cls.mirai_config.port, # type: ignore
auth_key=cls.mirai_config.auth_key, # type: ignore
)
return self_id, HTTPResponse(204, b"")
@classmethod
@overrides(BaseBot)
def register(cls,
driver: Driver,
config: "Config",
qq: Optional[Union[int, List[int]]] = None):
def register(
cls,
driver: Driver,
config: "Config",
qq: Optional[Union[int, List[int]]] = None,
):
cls.mirai_config = MiraiConfig(**config.dict())
if (cls.mirai_config.auth_key and cls.mirai_config.host and
cls.mirai_config.port) is None:
if (
cls.mirai_config.auth_key
and cls.mirai_config.host
and cls.mirai_config.port
) is None:
raise ApiNotAvailable(cls._type)
super().register(driver, config)
@ -209,17 +218,25 @@ class Bot(BaseBot):
self_ids = [qq] if isinstance(qq, int) else qq
async def url_factory(qq: int):
assert cls.mirai_config.host and cls.mirai_config.port and cls.mirai_config.auth_key
assert (
cls.mirai_config.host
and cls.mirai_config.port
and cls.mirai_config.auth_key
)
session = await SessionManager.new(
qq,
host=cls.mirai_config.host,
port=cls.mirai_config.port,
auth_key=cls.mirai_config.auth_key)
auth_key=cls.mirai_config.auth_key,
)
return WebSocketSetup(
adapter=cls._type,
self_id=str(qq),
url=(f'ws://{cls.mirai_config.host}:{cls.mirai_config.port}'
f'/all?sessionKey={session.session_key}'))
url=(
f"ws://{cls.mirai_config.host}:{cls.mirai_config.port}"
f"/all?sessionKey={session.session_key}"
),
)
for self_id in self_ids:
driver.setup_websocket(partial(url_factory, qq=self_id))
@ -234,13 +251,15 @@ class Bot(BaseBot):
try:
await process_event(
bot=self,
event=Event.new({
**json.loads(message),
'self_id': self.self_id,
}),
event=Event.new(
{
**json.loads(message),
"self_id": self.self_id,
}
),
)
except Exception as e:
Log.error(f'Failed to handle message: {message}', e)
Log.error(f"Failed to handle message: {message}", e)
@overrides(BaseBot)
async def _call_api(self, api: str, **data) -> NoReturn:
@ -266,10 +285,12 @@ class Bot(BaseBot):
@overrides(BaseBot)
@argument_validation
async def send(self,
event: Event,
message: Union[MessageChain, MessageSegment, str],
at_sender: bool = False):
async def send(
self,
event: Event,
message: Union[MessageChain, MessageSegment, str],
at_sender: bool = False,
):
"""
:说明:
@ -284,23 +305,24 @@ class Bot(BaseBot):
if not isinstance(message, MessageChain):
message = MessageChain(message)
if isinstance(event, FriendMessage):
return await self.send_friend_message(target=event.sender.id,
message_chain=message)
return await self.send_friend_message(
target=event.sender.id, message_chain=message
)
elif isinstance(event, GroupMessage):
if at_sender:
message = MessageSegment.at(event.sender.id) + message
return await self.send_group_message(group=event.sender.group.id,
message_chain=message)
return await self.send_group_message(
group=event.sender.group.id, message_chain=message
)
elif isinstance(event, TempMessage):
return await self.send_temp_message(qq=event.sender.id,
group=event.sender.group.id,
message_chain=message)
return await self.send_temp_message(
qq=event.sender.id, group=event.sender.group.id, message_chain=message
)
else:
raise ValueError(f'Unsupported event type {event!r}.')
raise ValueError(f"Unsupported event type {event!r}.")
@argument_validation
async def send_friend_message(self, target: int,
message_chain: MessageChain):
async def send_friend_message(self, target: int, message_chain: MessageChain):
"""
:说明:
@ -311,15 +333,13 @@ class Bot(BaseBot):
* ``target: int``: 发送消息目标好友的 QQ 号
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
"""
return await self.api.post('sendFriendMessage',
params={
'target': target,
'messageChain': message_chain.export()
})
return await self.api.post(
"sendFriendMessage",
params={"target": target, "messageChain": message_chain.export()},
)
@argument_validation
async def send_temp_message(self, qq: int, group: int,
message_chain: MessageChain):
async def send_temp_message(self, qq: int, group: int, message_chain: MessageChain):
"""
:说明:
@ -331,18 +351,15 @@ class Bot(BaseBot):
* ``group: int``: 临时会话群号
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
"""
return await self.api.post('sendTempMessage',
params={
'qq': qq,
'group': group,
'messageChain': message_chain.export()
})
return await self.api.post(
"sendTempMessage",
params={"qq": qq, "group": group, "messageChain": message_chain.export()},
)
@argument_validation
async def send_group_message(self,
group: int,
message_chain: MessageChain,
quote: Optional[int] = None):
async def send_group_message(
self, group: int, message_chain: MessageChain, quote: Optional[int] = None
):
"""
:说明:
@ -354,12 +371,14 @@ class Bot(BaseBot):
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
* ``quote: Optional[int]``: 引用一条消息的 message_id 进行回复
"""
return await self.api.post('sendGroupMessage',
params={
'group': group,
'messageChain': message_chain.export(),
'quote': quote
})
return await self.api.post(
"sendGroupMessage",
params={
"group": group,
"messageChain": message_chain.export(),
"quote": quote,
},
)
@argument_validation
async def recall(self, target: int):
@ -372,11 +391,12 @@ class Bot(BaseBot):
* ``target: int``: 需要撤回的消息的message_id
"""
return await self.api.post('recall', params={'target': target})
return await self.api.post("recall", params={"target": target})
@argument_validation
async def send_image_message(self, target: int, qq: int, group: int,
urls: List[str]) -> List[str]:
async def send_image_message(
self, target: int, qq: int, group: int, urls: List[str]
) -> List[str]:
"""
:说明:
@ -396,13 +416,10 @@ class Bot(BaseBot):
- ``List[str]``: 一个包含图片imageId的数组
"""
return await self.api.post('sendImageMessage',
params={
'target': target,
'qq': qq,
'group': group,
'urls': urls
})
return await self.api.post(
"sendImageMessage",
params={"target": target, "qq": qq, "group": group, "urls": urls},
)
@argument_validation
async def upload_image(self, type: str, img: BytesIO):
@ -416,11 +433,7 @@ class Bot(BaseBot):
* ``type: str``: "friend""group""temp"
* ``img: BytesIO``: 图片的BytesIO对象
"""
return await self.api.upload('uploadImage',
params={
'type': type,
'img': img
})
return await self.api.upload("uploadImage", params={"type": type, "img": img})
@argument_validation
async def upload_voice(self, type: str, voice: BytesIO):
@ -434,11 +447,9 @@ class Bot(BaseBot):
* ``type: str``: 当前仅支持 "group"
* ``voice: BytesIO``: 语音的BytesIO对象
"""
return await self.api.upload('uploadVoice',
params={
'type': type,
'voice': voice
})
return await self.api.upload(
"uploadVoice", params={"type": type, "voice": voice}
)
@argument_validation
async def fetch_message(self, count: int = 10):
@ -452,7 +463,7 @@ class Bot(BaseBot):
* ``count: int``: 获取消息和事件的数量
"""
return await self.api.request('fetchMessage', params={'count': count})
return await self.api.request("fetchMessage", params={"count": count})
@argument_validation
async def fetch_latest_message(self, count: int = 10):
@ -466,8 +477,7 @@ class Bot(BaseBot):
* ``count: int``: 获取消息和事件的数量
"""
return await self.api.request('fetchLatestMessage',
params={'count': count})
return await self.api.request("fetchLatestMessage", params={"count": count})
@argument_validation
async def peek_message(self, count: int = 10):
@ -481,7 +491,7 @@ class Bot(BaseBot):
* ``count: int``: 获取消息和事件的数量
"""
return await self.api.request('peekMessage', params={'count': count})
return await self.api.request("peekMessage", params={"count": count})
@argument_validation
async def peek_latest_message(self, count: int = 10):
@ -495,8 +505,7 @@ class Bot(BaseBot):
* ``count: int``: 获取消息和事件的数量
"""
return await self.api.request('peekLatestMessage',
params={'count': count})
return await self.api.request("peekLatestMessage", params={"count": count})
@argument_validation
async def messsage_from_id(self, id: int):
@ -510,7 +519,7 @@ class Bot(BaseBot):
* ``id: int``: 获取消息的message_id
"""
return await self.api.request('messageFromId', params={'id': id})
return await self.api.request("messageFromId", params={"id": id})
@argument_validation
async def count_message(self):
@ -519,7 +528,7 @@ class Bot(BaseBot):
使用此方法获取bot接收并缓存的消息总数注意不包含被删除的
"""
return await self.api.request('countMessage')
return await self.api.request("countMessage")
@argument_validation
async def friend_list(self) -> List[Dict[str, Any]]:
@ -532,7 +541,7 @@ class Bot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的好友列表数据
"""
return await self.api.request('friendList')
return await self.api.request("friendList")
@argument_validation
async def group_list(self) -> List[Dict[str, Any]]:
@ -545,7 +554,7 @@ class Bot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的群列表数据
"""
return await self.api.request('groupList')
return await self.api.request("groupList")
@argument_validation
async def member_list(self, target: int) -> List[Dict[str, Any]]:
@ -562,7 +571,7 @@ class Bot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的群成员列表数据
"""
return await self.api.request('memberList', params={'target': target})
return await self.api.request("memberList", params={"target": target})
@argument_validation
async def mute(self, target: int, member_id: int, time: int):
@ -577,12 +586,9 @@ class Bot(BaseBot):
* ``member_id: int``: 指定群员QQ号
* ``time: int``: 禁言时长单位为秒最多30天
"""
return await self.api.post('mute',
params={
'target': target,
'memberId': member_id,
'time': time
})
return await self.api.post(
"mute", params={"target": target, "memberId": member_id, "time": time}
)
@argument_validation
async def unmute(self, target: int, member_id: int):
@ -596,11 +602,9 @@ class Bot(BaseBot):
* ``target: int``: 指定群的群号
* ``member_id: int``: 指定群员QQ号
"""
return await self.api.post('unmute',
params={
'target': target,
'memberId': member_id
})
return await self.api.post(
"unmute", params={"target": target, "memberId": member_id}
)
@argument_validation
async def kick(self, target: int, member_id: int, msg: str):
@ -615,12 +619,9 @@ class Bot(BaseBot):
* ``member_id: int``: 指定群员QQ号
* ``msg: str``: 信息
"""
return await self.api.post('kick',
params={
'target': target,
'memberId': member_id,
'msg': msg
})
return await self.api.post(
"kick", params={"target": target, "memberId": member_id, "msg": msg}
)
@argument_validation
async def quit(self, target: int):
@ -633,7 +634,7 @@ class Bot(BaseBot):
* ``target: int``: 退出的群号
"""
return await self.api.post('quit', params={'target': target})
return await self.api.post("quit", params={"target": target})
@argument_validation
async def mute_all(self, target: int):
@ -646,7 +647,7 @@ class Bot(BaseBot):
* ``target: int``: 指定群的群号
"""
return await self.api.post('muteAll', params={'target': target})
return await self.api.post("muteAll", params={"target": target})
@argument_validation
async def unmute_all(self, target: int):
@ -659,7 +660,7 @@ class Bot(BaseBot):
* ``target: int``: 指定群的群号
"""
return await self.api.post('unmuteAll', params={'target': target})
return await self.api.post("unmuteAll", params={"target": target})
@argument_validation
async def group_config(self, target: int):
@ -685,7 +686,7 @@ class Bot(BaseBot):
"anonymousChat": true
}
"""
return await self.api.request('groupConfig', params={'target': target})
return await self.api.request("groupConfig", params={"target": target})
@argument_validation
async def modify_group_config(self, target: int, config: Dict[str, Any]):
@ -699,11 +700,9 @@ class Bot(BaseBot):
* ``target: int``: 指定群的群号
* ``config: Dict[str, Any]``: 群设置, 格式见 ``group_config`` 的返回值
"""
return await self.api.post('groupConfig',
params={
'target': target,
'config': config
})
return await self.api.post(
"groupConfig", params={"target": target, "config": config}
)
@argument_validation
async def member_info(self, target: int, member_id: int):
@ -726,15 +725,14 @@ class Bot(BaseBot):
"specialTitle": "群头衔"
}
"""
return await self.api.request('memberInfo',
params={
'target': target,
'memberId': member_id
})
return await self.api.request(
"memberInfo", params={"target": target, "memberId": member_id}
)
@argument_validation
async def modify_member_info(self, target: int, member_id: int,
info: Dict[str, Any]):
async def modify_member_info(
self, target: int, member_id: int, info: Dict[str, Any]
):
"""
:说明:
@ -746,9 +744,6 @@ class Bot(BaseBot):
* ``member_id: int``: 群员QQ号
* ``info: Dict[str, Any]``: 群员资料, 格式见 ``member_info`` 的返回值
"""
return await self.api.post('memberInfo',
params={
'target': target,
'memberId': member_id,
'info': info
})
return await self.api.post(
"memberInfo", params={"target": target, "memberId": member_id, "info": info}
)

View File

@ -1,7 +1,7 @@
from ipaddress import IPv4Address
from typing import Optional
from ipaddress import IPv4Address
from pydantic import BaseModel, Extra, Field
from pydantic import Extra, Field, BaseModel
class Config(BaseModel):
@ -14,9 +14,10 @@ class Config(BaseModel):
- ``mirai_host``: mirai-api-http 的地址
- ``mirai_port``: mirai-api-http 的端口
"""
auth_key: Optional[str] = Field(None, alias='mirai_auth_key')
host: Optional[IPv4Address] = Field(None, alias='mirai_host')
port: Optional[int] = Field(None, alias='mirai_port')
auth_key: Optional[str] = Field(None, alias="mirai_auth_key")
host: Optional[IPv4Address] = Field(None, alias="mirai_host")
port: Optional[int] = Field(None, alias="mirai_port")
class Config:
extra = Extra.ignore

View File

@ -5,25 +5,56 @@ r"""
部分字段可能与文档在符号上不一致
\:\:\:
"""
from .base import (Event, GroupChatInfo, GroupInfo, PrivateChatInfo,
UserPermission)
from .message import *
from .notice import *
from .message import *
from .request import *
from .base import (
Event,
GroupInfo,
GroupChatInfo,
UserPermission,
PrivateChatInfo,
)
__all__ = [
'Event', 'GroupChatInfo', 'GroupInfo', 'PrivateChatInfo', 'UserPermission',
'MessageSource', 'MessageEvent', 'GroupMessage', 'FriendMessage',
'TempMessage', 'NoticeEvent', 'MuteEvent', 'BotMuteEvent', 'BotUnmuteEvent',
'MemberMuteEvent', 'MemberUnmuteEvent', 'BotJoinGroupEvent',
'BotLeaveEventActive', 'BotLeaveEventKick', 'MemberJoinEvent',
'MemberLeaveEventKick', 'MemberLeaveEventQuit', 'FriendRecallEvent',
'GroupRecallEvent', 'GroupStateChangeEvent', 'GroupNameChangeEvent',
'GroupEntranceAnnouncementChangeEvent', 'GroupMuteAllEvent',
'GroupAllowAnonymousChatEvent', 'GroupAllowConfessTalkEvent',
'GroupAllowMemberInviteEvent', 'MemberStateChangeEvent',
'MemberCardChangeEvent', 'MemberSpecialTitleChangeEvent',
'BotGroupPermissionChangeEvent', 'MemberPermissionChangeEvent',
'RequestEvent', 'NewFriendRequestEvent', 'MemberJoinRequestEvent',
'BotInvitedJoinGroupRequestEvent'
"Event",
"GroupChatInfo",
"GroupInfo",
"PrivateChatInfo",
"UserPermission",
"MessageSource",
"MessageEvent",
"GroupMessage",
"FriendMessage",
"TempMessage",
"NoticeEvent",
"MuteEvent",
"BotMuteEvent",
"BotUnmuteEvent",
"MemberMuteEvent",
"MemberUnmuteEvent",
"BotJoinGroupEvent",
"BotLeaveEventActive",
"BotLeaveEventKick",
"MemberJoinEvent",
"MemberLeaveEventKick",
"MemberLeaveEventQuit",
"FriendRecallEvent",
"GroupRecallEvent",
"GroupStateChangeEvent",
"GroupNameChangeEvent",
"GroupEntranceAnnouncementChangeEvent",
"GroupMuteAllEvent",
"GroupAllowAnonymousChatEvent",
"GroupAllowConfessTalkEvent",
"GroupAllowMemberInviteEvent",
"MemberStateChangeEvent",
"MemberCardChangeEvent",
"MemberSpecialTitleChangeEvent",
"BotGroupPermissionChangeEvent",
"MemberPermissionChangeEvent",
"RequestEvent",
"NewFriendRequestEvent",
"MemberJoinRequestEvent",
"BotInvitedJoinGroupRequestEvent",
]

View File

@ -22,9 +22,10 @@ class UserPermission(str, Enum):
* ``ADMINISTRATOR``: 群管理
* ``MEMBER``: 普通群成员
"""
OWNER = 'OWNER'
ADMINISTRATOR = 'ADMINISTRATOR'
MEMBER = 'MEMBER'
OWNER = "OWNER"
ADMINISTRATOR = "ADMINISTRATOR"
MEMBER = "MEMBER"
class NudgeSubjectKind(str, Enum):
@ -36,8 +37,9 @@ class NudgeSubjectKind(str, Enum):
* ``Group``: 群
* ``Friend``: 好友
"""
Group = 'Group'
Friend = 'Friend'
Group = "Group"
Friend = "Friend"
class GroupInfo(BaseModel):
@ -48,7 +50,7 @@ class GroupInfo(BaseModel):
class GroupChatInfo(BaseModel):
id: int
name: str = Field(alias='memberName')
name: str = Field(alias="memberName")
permission: UserPermission
group: GroupInfo
@ -71,6 +73,7 @@ class Event(BaseEvent):
.. _mirai-api-http 事件类型:
https://github.com/project-mirai/mirai-api-http/blob/master/docs/EventType.md
"""
self_id: int
type: str
@ -79,11 +82,12 @@ class Event(BaseEvent):
"""
此事件类的工厂函数, 能够通过事件数据选择合适的子类进行序列化
"""
type = data['type']
type = data["type"]
def all_subclasses(cls: Type[Event]):
return set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in all_subclasses(c)])
[s for c in cls.__subclasses__() for s in all_subclasses(c)]
)
event_class: Optional[Type[Event]] = None
for subclass in all_subclasses(cls):
@ -99,23 +103,25 @@ class Event(BaseEvent):
return event_class.parse_obj(data)
except ValidationError as e:
logger.info(
f'Failed to parse {data} to class {event_class.__name__}: '
f'{e.errors()!r}. Fallback to parent class.')
f"Failed to parse {data} to class {event_class.__name__}: "
f"{e.errors()!r}. Fallback to parent class."
)
event_class = event_class.__base__ # type: ignore
raise ValueError(f'Failed to serialize {data}.')
raise ValueError(f"Failed to serialize {data}.")
@overrides(BaseEvent)
def get_type(self) -> Literal["message", "notice", "request", "meta_event"]:
from . import meta, notice, message, request
if isinstance(self, message.MessageEvent):
return 'message'
return "message"
elif isinstance(self, notice.NoticeEvent):
return 'notice'
return "notice"
elif isinstance(self, request.RequestEvent):
return 'request'
return "request"
else:
return 'meta_event'
return "meta_event"
@overrides(BaseEvent)
def get_event_name(self) -> str:

View File

@ -1,7 +1,7 @@
from datetime import datetime
from typing import Any, Optional
from pydantic import BaseModel, Field
from pydantic import Field, BaseModel
from nonebot.typing import overrides
@ -16,7 +16,8 @@ class MessageSource(BaseModel):
class MessageEvent(Event):
"""消息事件基类"""
message_chain: MessageChain = Field(alias='messageChain')
message_chain: MessageChain = Field(alias="messageChain")
source: Optional[MessageSource] = None
sender: Any
@ -39,12 +40,13 @@ class MessageEvent(Event):
class GroupMessage(MessageEvent):
"""群消息事件"""
sender: GroupChatInfo
to_me: bool = False
@overrides(MessageEvent)
def get_session_id(self) -> str:
return f'group_{self.sender.group.id}_' + self.get_user_id()
return f"group_{self.sender.group.id}_" + self.get_user_id()
@overrides(MessageEvent)
def get_user_id(self) -> str:
@ -57,6 +59,7 @@ class GroupMessage(MessageEvent):
class FriendMessage(MessageEvent):
"""好友消息事件"""
sender: PrivateChatInfo
@overrides(MessageEvent)
@ -65,7 +68,7 @@ class FriendMessage(MessageEvent):
@overrides(MessageEvent)
def get_session_id(self) -> str:
return 'friend_' + self.get_user_id()
return "friend_" + self.get_user_id()
@overrides(MessageEvent)
def is_tome(self) -> bool:
@ -74,11 +77,12 @@ class FriendMessage(MessageEvent):
class TempMessage(MessageEvent):
"""临时会话消息事件"""
sender: GroupChatInfo
@overrides(MessageEvent)
def get_session_id(self) -> str:
return f'temp_{self.sender.group.id}_' + self.get_user_id()
return f"temp_{self.sender.group.id}_" + self.get_user_id()
@overrides(MessageEvent)
def is_tome(self) -> bool:

View File

@ -3,29 +3,35 @@ from .base import Event
class MetaEvent(Event):
"""元事件基类"""
qq: int
class BotOnlineEvent(MetaEvent):
"""Bot登录成功"""
pass
class BotOfflineEventActive(MetaEvent):
"""Bot主动离线"""
pass
class BotOfflineEventForce(MetaEvent):
"""Bot被挤下线"""
pass
class BotOfflineEventDropped(MetaEvent):
"""Bot被服务器断开或因网络问题而掉线"""
pass
class BotReloginEvent(MetaEvent):
"""Bot主动重新登录"""
pass
pass

View File

@ -2,88 +2,103 @@ from typing import Any, Optional
from pydantic import Field
from .base import Event, GroupChatInfo, GroupInfo, NudgeSubject, UserPermission
from .base import Event, GroupInfo, NudgeSubject, GroupChatInfo, UserPermission
class NoticeEvent(Event):
"""通知事件基类"""
pass
class MuteEvent(NoticeEvent):
"""禁言类事件基类"""
operator: GroupChatInfo
class BotMuteEvent(MuteEvent):
"""Bot被禁言"""
pass
class BotUnmuteEvent(MuteEvent):
"""Bot被取消禁言"""
pass
class MemberMuteEvent(MuteEvent):
"""群成员被禁言事件该成员不是Bot"""
duration_seconds: int = Field(alias='durationSeconds')
duration_seconds: int = Field(alias="durationSeconds")
member: GroupChatInfo
operator: Optional[GroupChatInfo] = None
class MemberUnmuteEvent(MuteEvent):
"""群成员被取消禁言事件该成员不是Bot"""
member: GroupChatInfo
operator: Optional[GroupChatInfo] = None
class BotJoinGroupEvent(NoticeEvent):
"""Bot加入了一个新群"""
group: GroupInfo
class BotLeaveEventActive(BotJoinGroupEvent):
"""Bot主动退出一个群"""
pass
class BotLeaveEventKick(BotJoinGroupEvent):
"""Bot被踢出一个群"""
pass
class MemberJoinEvent(NoticeEvent):
"""新人入群的事件"""
member: GroupChatInfo
class MemberLeaveEventKick(MemberJoinEvent):
"""成员被踢出群该成员不是Bot"""
operator: Optional[GroupChatInfo] = None
class MemberLeaveEventQuit(MemberJoinEvent):
"""成员主动离群该成员不是Bot"""
pass
class FriendRecallEvent(NoticeEvent):
"""好友消息撤回"""
author_id: int = Field(alias='authorId')
message_id: int = Field(alias='messageId')
author_id: int = Field(alias="authorId")
message_id: int = Field(alias="messageId")
time: int
operator: int
class GroupRecallEvent(FriendRecallEvent):
"""群消息撤回"""
group: GroupInfo
operator: Optional[GroupChatInfo] = None
class GroupStateChangeEvent(NoticeEvent):
"""群变化事件基类"""
origin: Any
current: Any
group: GroupInfo
@ -92,73 +107,85 @@ class GroupStateChangeEvent(NoticeEvent):
class GroupNameChangeEvent(GroupStateChangeEvent):
"""某个群名改变"""
origin: str
current: str
class GroupEntranceAnnouncementChangeEvent(GroupStateChangeEvent):
"""某群入群公告改变"""
origin: str
current: str
class GroupMuteAllEvent(GroupStateChangeEvent):
"""全员禁言"""
origin: bool
current: bool
class GroupAllowAnonymousChatEvent(GroupStateChangeEvent):
"""匿名聊天"""
origin: bool
current: bool
class GroupAllowConfessTalkEvent(GroupStateChangeEvent):
"""坦白说"""
origin: bool
current: bool
class GroupAllowMemberInviteEvent(GroupStateChangeEvent):
"""允许群员邀请好友加群"""
origin: bool
current: bool
class MemberStateChangeEvent(NoticeEvent):
"""群成员变化事件基类"""
member: GroupChatInfo
operator: Optional[GroupChatInfo] = None
class MemberCardChangeEvent(MemberStateChangeEvent):
"""群名片改动"""
origin: str
current: str
class MemberSpecialTitleChangeEvent(MemberStateChangeEvent):
"""群头衔改动(只有群主有操作限权)"""
origin: str
current: str
class BotGroupPermissionChangeEvent(MemberStateChangeEvent):
"""Bot在群里的权限被改变"""
origin: UserPermission
current: UserPermission
class MemberPermissionChangeEvent(MemberStateChangeEvent):
"""成员权限改变的事件该成员不是Bot"""
origin: UserPermission
current: UserPermission
class NudgeEvent(NoticeEvent):
"""戳一戳触发事件"""
from_id: int = Field(alias='fromId')
from_id: int = Field(alias="fromId")
target: int
subject: NudgeSubject
action: str

View File

@ -1,7 +1,7 @@
from typing import TYPE_CHECKING
from typing_extensions import Literal
from pydantic import Field
from typing_extensions import Literal
from .base import Event
@ -11,15 +11,17 @@ if TYPE_CHECKING:
class RequestEvent(Event):
"""请求事件基类"""
event_id: int = Field(alias='eventId')
event_id: int = Field(alias="eventId")
message: str
nick: str
class NewFriendRequestEvent(RequestEvent):
"""添加好友申请"""
from_id: int = Field(alias='fromId')
group_id: int = Field(0, alias='groupId')
from_id: int = Field(alias="fromId")
group_id: int = Field(0, alias="groupId")
async def approve(self, bot: "Bot"):
"""
@ -31,19 +33,18 @@ class NewFriendRequestEvent(RequestEvent):
* ``bot: Bot``: 当前的 ``Bot`` 对象
"""
return await bot.api.post('/resp/newFriendRequestEvent',
params={
'eventId': self.event_id,
'groupId': self.group_id,
'fromId': self.from_id,
'operate': 0,
'message': ''
})
return await bot.api.post(
"/resp/newFriendRequestEvent",
params={
"eventId": self.event_id,
"groupId": self.group_id,
"fromId": self.from_id,
"operate": 0,
"message": "",
},
)
async def reject(self,
bot: "Bot",
operate: Literal[1, 2] = 1,
message: str = ''):
async def reject(self, bot: "Bot", operate: Literal[1, 2] = 1, message: str = ""):
"""
:说明:
@ -60,21 +61,24 @@ class NewFriendRequestEvent(RequestEvent):
* ``message: str``: 回复的信息
"""
assert operate > 0
return await bot.api.post('/resp/newFriendRequestEvent',
params={
'eventId': self.event_id,
'groupId': self.group_id,
'fromId': self.from_id,
'operate': operate,
'message': message
})
return await bot.api.post(
"/resp/newFriendRequestEvent",
params={
"eventId": self.event_id,
"groupId": self.group_id,
"fromId": self.from_id,
"operate": operate,
"message": message,
},
)
class MemberJoinRequestEvent(RequestEvent):
"""用户入群申请Bot需要有管理员权限"""
from_id: int = Field(alias='fromId')
group_id: int = Field(alias='groupId')
group_name: str = Field(alias='groupName')
from_id: int = Field(alias="fromId")
group_id: int = Field(alias="groupId")
group_name: str = Field(alias="groupName")
async def approve(self, bot: "Bot"):
"""
@ -86,19 +90,20 @@ class MemberJoinRequestEvent(RequestEvent):
* ``bot: Bot``: 当前的 ``Bot`` 对象
"""
return await bot.api.post('/resp/memberJoinRequestEvent',
params={
'eventId': self.event_id,
'groupId': self.group_id,
'fromId': self.from_id,
'operate': 0,
'message': ''
})
return await bot.api.post(
"/resp/memberJoinRequestEvent",
params={
"eventId": self.event_id,
"groupId": self.group_id,
"fromId": self.from_id,
"operate": 0,
"message": "",
},
)
async def reject(self,
bot: "Bot",
operate: Literal[1, 2, 3, 4] = 1,
message: str = ''):
async def reject(
self, bot: "Bot", operate: Literal[1, 2, 3, 4] = 1, message: str = ""
):
"""
:说明:
@ -117,21 +122,24 @@ class MemberJoinRequestEvent(RequestEvent):
* ``message: str``: 回复的信息
"""
assert operate > 0
return await bot.api.post('/resp/memberJoinRequestEvent',
params={
'eventId': self.event_id,
'groupId': self.group_id,
'fromId': self.from_id,
'operate': operate,
'message': message
})
return await bot.api.post(
"/resp/memberJoinRequestEvent",
params={
"eventId": self.event_id,
"groupId": self.group_id,
"fromId": self.from_id,
"operate": operate,
"message": message,
},
)
class BotInvitedJoinGroupRequestEvent(RequestEvent):
"""Bot被邀请入群申请"""
from_id: int = Field(alias='fromId')
group_id: int = Field(alias='groupId')
group_name: str = Field(alias='groupName')
from_id: int = Field(alias="fromId")
group_id: int = Field(alias="groupId")
group_name: str = Field(alias="groupName")
async def approve(self, bot: "Bot"):
"""
@ -143,14 +151,16 @@ class BotInvitedJoinGroupRequestEvent(RequestEvent):
* ``bot: Bot``: 当前的 ``Bot`` 对象
"""
return await bot.api.post('/resp/botInvitedJoinGroupRequestEvent',
params={
'eventId': self.event_id,
'groupId': self.group_id,
'fromId': self.from_id,
'operate': 0,
'message': ''
})
return await bot.api.post(
"/resp/botInvitedJoinGroupRequestEvent",
params={
"eventId": self.event_id,
"groupId": self.group_id,
"fromId": self.from_id,
"operate": 0,
"message": "",
},
)
async def reject(self, bot: "Bot", message: str = ""):
"""
@ -163,11 +173,13 @@ class BotInvitedJoinGroupRequestEvent(RequestEvent):
* ``bot: Bot``: 当前的 ``Bot`` 对象
* ``message: str``: 邀请消息
"""
return await bot.api.post('/resp/botInvitedJoinGroupRequestEvent',
params={
'eventId': self.event_id,
'groupId': self.group_id,
'fromId': self.from_id,
'operate': 1,
'message': message
})
return await bot.api.post(
"/resp/botInvitedJoinGroupRequestEvent",
params={
"eventId": self.event_id,
"groupId": self.group_id,
"fromId": self.from_id,
"operate": 1,
"message": message,
},
)

View File

@ -1,28 +1,29 @@
from enum import Enum
from typing import Any, List, Dict, Type, Iterable, Optional, Union
from typing import Any, Dict, List, Type, Union, Iterable, Optional
from pydantic import validate_arguments
from nonebot.typing import overrides
from nonebot.adapters import Message as BaseMessage
from nonebot.adapters import MessageSegment as BaseMessageSegment
from nonebot.typing import overrides
class MessageType(str, Enum):
"""消息类型枚举类"""
SOURCE = 'Source'
QUOTE = 'Quote'
AT = 'At'
AT_ALL = 'AtAll'
FACE = 'Face'
PLAIN = 'Plain'
IMAGE = 'Image'
FLASH_IMAGE = 'FlashImage'
VOICE = 'Voice'
XML = 'Xml'
JSON = 'Json'
APP = 'App'
POKE = 'Poke'
SOURCE = "Source"
QUOTE = "Quote"
AT = "At"
AT_ALL = "AtAll"
FACE = "Face"
PLAIN = "Plain"
IMAGE = "Image"
FLASH_IMAGE = "FlashImage"
VOICE = "Voice"
XML = "Xml"
JSON = "Json"
APP = "App"
POKE = "Poke"
class MessageSegment(BaseMessageSegment["MessageChain"]):
@ -43,21 +44,24 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
@validate_arguments
@overrides(BaseMessageSegment)
def __init__(self, type: MessageType, **data: Any):
super().__init__(type=type,
data={k: v for k, v in data.items() if v is not None})
super().__init__(
type=type, data={k: v for k, v in data.items() if v is not None}
)
@overrides(BaseMessageSegment)
def __str__(self) -> str:
return self.data['text'] if self.is_text() else repr(self)
return self.data["text"] if self.is_text() else repr(self)
def __repr__(self) -> str:
return '[mirai:%s]' % ','.join([
self.type.value,
*map(
lambda s: '%s=%r' % s,
self.data.items(),
),
])
return "[mirai:%s]" % ",".join(
[
self.type.value,
*map(
lambda s: "%s=%r" % s,
self.data.items(),
),
]
)
@overrides(BaseMessageSegment)
def is_text(self) -> bool:
@ -65,15 +69,21 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
def as_dict(self) -> Dict[str, Any]:
"""导出可以被正常json序列化的结构体"""
return {'type': self.type.value, **self.data}
return {"type": self.type.value, **self.data}
@classmethod
def source(cls, id: int, time: int):
return cls(type=MessageType.SOURCE, id=id, time=time)
@classmethod
def quote(cls, id: int, group_id: int, sender_id: int, target_id: int,
origin: "MessageChain"):
def quote(
cls,
id: int,
group_id: int,
sender_id: int,
target_id: int,
origin: "MessageChain",
):
"""
:说明:
@ -87,12 +97,14 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
* ``target_id: int``: 被引用回复的原消息的接收者者的QQ号或群号
* ``origin: MessageChain``: 被引用回复的原消息的消息链对象
"""
return cls(type=MessageType.QUOTE,
id=id,
groupId=group_id,
senderId=sender_id,
targetId=target_id,
origin=origin.export())
return cls(
type=MessageType.QUOTE,
id=id,
groupId=group_id,
senderId=sender_id,
targetId=target_id,
origin=origin.export(),
)
@classmethod
def at(cls, target: int):
@ -144,10 +156,12 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
return cls(type=MessageType.PLAIN, text=text)
@classmethod
def image(cls,
image_id: Optional[str] = None,
url: Optional[str] = None,
path: Optional[str] = None):
def image(
cls,
image_id: Optional[str] = None,
url: Optional[str] = None,
path: Optional[str] = None,
):
"""
:说明:
@ -162,10 +176,12 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
return cls(type=MessageType.IMAGE, imageId=image_id, url=url, path=path)
@classmethod
def flash_image(cls,
image_id: Optional[str] = None,
url: Optional[str] = None,
path: Optional[str] = None):
def flash_image(
cls,
image_id: Optional[str] = None,
url: Optional[str] = None,
path: Optional[str] = None,
):
"""
:说明:
@ -175,16 +191,15 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
同 ``image``
"""
return cls(type=MessageType.FLASH_IMAGE,
imageId=image_id,
url=url,
path=path)
return cls(type=MessageType.FLASH_IMAGE, imageId=image_id, url=url, path=path)
@classmethod
def voice(cls,
voice_id: Optional[str] = None,
url: Optional[str] = None,
path: Optional[str] = None):
def voice(
cls,
voice_id: Optional[str] = None,
url: Optional[str] = None,
path: Optional[str] = None,
):
"""
:说明:
@ -196,10 +211,7 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
* ``url: Optional[str]``: 语音的URL发送时可作网络语音的链接
* ``path: Optional[str]``: 语音的路径,发送本地语音
"""
return cls(type=MessageType.FLASH_IMAGE,
imageId=voice_id,
url=url,
path=path)
return cls(type=MessageType.FLASH_IMAGE, imageId=voice_id, url=url, path=path)
@classmethod
def xml(cls, xml: str):
@ -282,16 +294,14 @@ class MessageChain(BaseMessage[MessageSegment]):
return [MessageSegment.plain(text=message)]
return [
*map(
lambda x: x
if isinstance(x, MessageSegment) else MessageSegment(**x),
message)
lambda x: x if isinstance(x, MessageSegment) else MessageSegment(**x),
message,
)
]
def export(self) -> List[Dict[str, Any]]:
"""导出为可以被正常json序列化的数组"""
return [
*map(lambda segment: segment.as_dict(), self.copy()) # type: ignore
]
return [*map(lambda segment: segment.as_dict(), self.copy())] # type: ignore
def extract_first(self, *type: MessageType) -> Optional[MessageSegment]:
"""
@ -311,4 +321,4 @@ class MessageChain(BaseMessage[MessageSegment]):
return None
def __repr__(self) -> str:
return f'<{self.__class__.__name__} {[*self.copy()]}>'
return f"<{self.__class__.__name__} {[*self.copy()]}>"

View File

@ -1,17 +1,17 @@
import re
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar, Callable, Optional, Coroutine
import httpx
from pydantic import Extra, ValidationError, validate_arguments
import nonebot.exception as exception
from nonebot.log import logger
import nonebot.exception as exception
from nonebot.message import handle_event
from nonebot.utils import escape_tag, logger_wrapper
from .event import Event, GroupMessage, MessageEvent, MessageSource
from .message import MessageType, MessageSegment
from .event import Event, GroupMessage, MessageEvent, MessageSource
if TYPE_CHECKING:
from .bot import Bot
@ -21,28 +21,27 @@ _AnyCallable = TypeVar("_AnyCallable", bound=Callable)
class Log:
@staticmethod
def log(level: str, message: str, exception: Optional[Exception] = None):
logger = logger_wrapper('MIRAI')
message = '<e>' + escape_tag(message) + '</e>'
logger = logger_wrapper("MIRAI")
message = "<e>" + escape_tag(message) + "</e>"
logger(level=level.upper(), message=message, exception=exception)
@classmethod
def info(cls, message: Any):
cls.log('INFO', str(message))
cls.log("INFO", str(message))
@classmethod
def debug(cls, message: Any):
cls.log('DEBUG', str(message))
cls.log("DEBUG", str(message))
@classmethod
def warn(cls, message: Any):
cls.log('WARNING', str(message))
cls.log("WARNING", str(message))
@classmethod
def error(cls, message: Any, exception: Optional[Exception] = None):
cls.log('ERROR', str(message), exception=exception)
cls.log("ERROR", str(message), exception=exception)
class ActionFailed(exception.ActionFailed):
@ -53,12 +52,13 @@ class ActionFailed(exception.ActionFailed):
"""
def __init__(self, **kwargs):
super().__init__('mirai')
super().__init__("mirai")
self.data = kwargs.copy()
def __repr__(self):
return self.__class__.__name__ + '(%s)' % ', '.join(
map(lambda m: '%s=%r' % m, self.data.items()))
return self.__class__.__name__ + "(%s)" % ", ".join(
map(lambda m: "%s=%r" % m, self.data.items())
)
class InvalidArgument(exception.AdapterException):
@ -69,7 +69,7 @@ class InvalidArgument(exception.AdapterException):
"""
def __init__(self, **kwargs):
super().__init__('mirai')
super().__init__("mirai")
def catch_network_error(function: _AsyncCallable) -> _AsyncCallable:
@ -90,11 +90,12 @@ def catch_network_error(function: _AsyncCallable) -> _AsyncCallable:
try:
data = await function(*args, **kwargs)
except httpx.HTTPError:
raise exception.NetworkError('mirai')
logger.opt(colors=True).debug('<b>Mirai API returned data:</b> '
f'<y>{escape_tag(str(data))}</y>')
raise exception.NetworkError("mirai")
logger.opt(colors=True).debug(
"<b>Mirai API returned data:</b> " f"<y>{escape_tag(str(data))}</y>"
)
if isinstance(data, dict):
if data.get('code', 0) != 0:
if data.get("code", 0) != 0:
raise ActionFailed(**data)
return data
@ -109,10 +110,9 @@ def argument_validation(function: _AnyCallable) -> _AnyCallable:
会在参数出错时释放 ``InvalidArgument`` 异常
"""
function = validate_arguments(config={
'arbitrary_types_allowed': True,
'extra': Extra.forbid
})(function)
function = validate_arguments(
config={"arbitrary_types_allowed": True, "extra": Extra.forbid}
)(function)
@wraps(function)
def wrapper(*args, **kwargs):
@ -134,12 +134,12 @@ def process_source(bot: "Bot", event: MessageEvent) -> MessageEvent:
def process_at(bot: "Bot", event: GroupMessage) -> GroupMessage:
at = event.message_chain.extract_first(MessageType.AT)
if at is not None:
if at.data['target'] == event.self_id:
if at.data["target"] == event.self_id:
event.to_me = True
else:
event.message_chain.insert(0, at)
if not event.message_chain:
event.message_chain.append(MessageSegment.plain(''))
event.message_chain.append(MessageSegment.plain(""))
return event
@ -147,13 +147,13 @@ def process_nick(bot: "Bot", event: GroupMessage) -> GroupMessage:
plain = event.message_chain.extract_first(MessageType.PLAIN)
if plain is not None:
text = str(plain)
nick_regex = '|'.join(filter(lambda x: x, bot.config.nickname))
nick_regex = "|".join(filter(lambda x: x, bot.config.nickname))
matched = re.search(rf"^({nick_regex})([\s,]*|$)", text, re.IGNORECASE)
if matched is not None:
event.to_me = True
nickname = matched.group(1)
Log.info(f'User is calling me {nickname}')
plain.data['text'] = text[matched.end():]
Log.info(f"User is calling me {nickname}")
plain.data["text"] = text[matched.end() :]
event.message_chain.insert(0, plain)
return event
@ -161,7 +161,7 @@ def process_nick(bot: "Bot", event: GroupMessage) -> GroupMessage:
def process_reply(bot: "Bot", event: GroupMessage) -> GroupMessage:
reply = event.message_chain.extract_first(MessageType.QUOTE)
if reply is not None:
if reply.data['senderId'] == event.self_id:
if reply.data["senderId"] == event.self_id:
event.to_me = True
else:
event.message_chain.insert(0, reply)

View File

@ -34,6 +34,21 @@ nonebot2 = { path = "../../", develop = true }
# url = "https://mirrors.aliyun.com/pypi/simple/"
# default = true
[tool.black]
line-length = 88
target-version = ["py37", "py38", "py39"]
include = '\.pyi?$'
extend-exclude = '''
'''
[tool.isort]
profile = "black"
line_length = 80
length_sort = true
skip_gitignore = true
force_sort_within_sections = true
extra_standard_library = ["typing_extensions"]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"