🎨 format code using black and isort

This commit is contained in:
yanyongyu
2021-11-22 23:21:26 +08:00
parent 602185a34e
commit a98d98cd12
86 changed files with 2893 additions and 2095 deletions

View File

@ -12,8 +12,15 @@ from nonebot.config import Config
from nonebot.typing import overrides
from nonebot.adapters import Bot as BaseBot
from nonebot.exception import ApiNotAvailable
from nonebot.drivers import (Driver, WebSocket, HTTPResponse, ForwardDriver,
ReverseDriver, HTTPConnection, WebSocketSetup)
from nonebot.drivers import (
Driver,
WebSocket,
HTTPResponse,
ForwardDriver,
ReverseDriver,
HTTPConnection,
WebSocketSetup,
)
from .config import Config as MiraiConfig
from .message import MessageChain, MessageSegment
@ -23,16 +30,14 @@ from .utils import Log, process_event, argument_validation, catch_network_error
class SessionManager:
"""Bot会话管理器, 提供API主动调用接口"""
sessions: Dict[int, Tuple[str, httpx.AsyncClient]] = {}
def __init__(self, session_key: str, client: httpx.AsyncClient):
self.session_key, self.client = session_key, client
@catch_network_error
async def post(self,
path: str,
*,
params: Optional[Dict[str, Any]] = None) -> Any:
async def post(self, path: str, *, params: Optional[Dict[str, Any]] = None) -> Any:
"""
:说明:
@ -51,7 +56,7 @@ class SessionManager:
path,
json={
**(params or {}),
'sessionKey': self.session_key,
"sessionKey": self.session_key,
},
timeout=3,
)
@ -59,10 +64,9 @@ class SessionManager:
return response.json()
@catch_network_error
async def request(self,
path: str,
*,
params: Optional[Dict[str, Any]] = None) -> Any:
async def request(
self, path: str, *, params: Optional[Dict[str, Any]] = None
) -> Any:
"""
:说明:
@ -77,7 +81,7 @@ class SessionManager:
path,
params={
**(params or {}),
'sessionKey': self.session_key,
"sessionKey": self.session_key,
},
timeout=3,
)
@ -98,7 +102,7 @@ class SessionManager:
"""
files = {k: v for k, v in params.items() if isinstance(v, BytesIO)}
form = {k: v for k, v in params.items() if k not in files}
form['sessionKey'] = self.session_key
form["sessionKey"] = self.session_key
response = await self.client.post(
path,
data=form,
@ -109,25 +113,25 @@ class SessionManager:
return response.json()
@classmethod
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
auth_key: str) -> "SessionManager":
async def new(
cls, self_id: int, *, host: IPv4Address, port: int, auth_key: str
) -> "SessionManager":
session = cls.get(self_id)
if session is not None:
return session
client = httpx.AsyncClient(base_url=f'http://{host}:{port}',
follow_redirects=True)
response = await client.post('/auth', json={'authKey': auth_key})
client = httpx.AsyncClient(
base_url=f"http://{host}:{port}", follow_redirects=True
)
response = await client.post("/auth", json={"authKey": auth_key})
response.raise_for_status()
auth = response.json()
assert auth['code'] == 0
session_key = auth['session']
response = await client.post('/verify',
json={
'sessionKey': session_key,
'qq': self_id
})
assert response.json()['code'] == 0
assert auth["code"] == 0
session_key = auth["session"]
response = await client.post(
"/verify", json={"sessionKey": session_key, "qq": self_id}
)
assert response.json()["code"] == 0
cls.sessions[self_id] = session_key, client
return cls(session_key, client)
@ -152,7 +156,7 @@ class Bot(BaseBot):
"""
_type = 'mirai'
_type = "mirai"
@property
@overrides(BaseBot)
@ -166,37 +170,42 @@ class Bot(BaseBot):
if api is None:
if isinstance(self.request, WebSocket):
asyncio.create_task(self.request.close(1000))
assert api is not None, 'SessionManager has not been initialized'
assert api is not None, "SessionManager has not been initialized"
return api
@classmethod
@overrides(BaseBot)
async def check_permission(
cls, driver: Driver,
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]:
cls, driver: Driver, request: HTTPConnection
) -> Tuple[Optional[str], HTTPResponse]:
if isinstance(request, WebSocket):
return None, HTTPResponse(
501, b'Websocket connection is not implemented')
self_id: Optional[str] = request.headers.get('bot')
return None, HTTPResponse(501, b"Websocket connection is not implemented")
self_id: Optional[str] = request.headers.get("bot")
if self_id is None:
return None, HTTPResponse(400, b'Header `Bot` is required.')
return None, HTTPResponse(400, b"Header `Bot` is required.")
self_id = str(self_id).strip()
await SessionManager.new(
int(self_id),
host=cls.mirai_config.host, # type: ignore
port=cls.mirai_config.port, #type: ignore
auth_key=cls.mirai_config.auth_key) # type: ignore
return self_id, HTTPResponse(204, b'')
port=cls.mirai_config.port, # type: ignore
auth_key=cls.mirai_config.auth_key, # type: ignore
)
return self_id, HTTPResponse(204, b"")
@classmethod
@overrides(BaseBot)
def register(cls,
driver: Driver,
config: "Config",
qq: Optional[Union[int, List[int]]] = None):
def register(
cls,
driver: Driver,
config: "Config",
qq: Optional[Union[int, List[int]]] = None,
):
cls.mirai_config = MiraiConfig(**config.dict())
if (cls.mirai_config.auth_key and cls.mirai_config.host and
cls.mirai_config.port) is None:
if (
cls.mirai_config.auth_key
and cls.mirai_config.host
and cls.mirai_config.port
) is None:
raise ApiNotAvailable(cls._type)
super().register(driver, config)
@ -209,17 +218,25 @@ class Bot(BaseBot):
self_ids = [qq] if isinstance(qq, int) else qq
async def url_factory(qq: int):
assert cls.mirai_config.host and cls.mirai_config.port and cls.mirai_config.auth_key
assert (
cls.mirai_config.host
and cls.mirai_config.port
and cls.mirai_config.auth_key
)
session = await SessionManager.new(
qq,
host=cls.mirai_config.host,
port=cls.mirai_config.port,
auth_key=cls.mirai_config.auth_key)
auth_key=cls.mirai_config.auth_key,
)
return WebSocketSetup(
adapter=cls._type,
self_id=str(qq),
url=(f'ws://{cls.mirai_config.host}:{cls.mirai_config.port}'
f'/all?sessionKey={session.session_key}'))
url=(
f"ws://{cls.mirai_config.host}:{cls.mirai_config.port}"
f"/all?sessionKey={session.session_key}"
),
)
for self_id in self_ids:
driver.setup_websocket(partial(url_factory, qq=self_id))
@ -234,13 +251,15 @@ class Bot(BaseBot):
try:
await process_event(
bot=self,
event=Event.new({
**json.loads(message),
'self_id': self.self_id,
}),
event=Event.new(
{
**json.loads(message),
"self_id": self.self_id,
}
),
)
except Exception as e:
Log.error(f'Failed to handle message: {message}', e)
Log.error(f"Failed to handle message: {message}", e)
@overrides(BaseBot)
async def _call_api(self, api: str, **data) -> NoReturn:
@ -266,10 +285,12 @@ class Bot(BaseBot):
@overrides(BaseBot)
@argument_validation
async def send(self,
event: Event,
message: Union[MessageChain, MessageSegment, str],
at_sender: bool = False):
async def send(
self,
event: Event,
message: Union[MessageChain, MessageSegment, str],
at_sender: bool = False,
):
"""
:说明:
@ -284,23 +305,24 @@ class Bot(BaseBot):
if not isinstance(message, MessageChain):
message = MessageChain(message)
if isinstance(event, FriendMessage):
return await self.send_friend_message(target=event.sender.id,
message_chain=message)
return await self.send_friend_message(
target=event.sender.id, message_chain=message
)
elif isinstance(event, GroupMessage):
if at_sender:
message = MessageSegment.at(event.sender.id) + message
return await self.send_group_message(group=event.sender.group.id,
message_chain=message)
return await self.send_group_message(
group=event.sender.group.id, message_chain=message
)
elif isinstance(event, TempMessage):
return await self.send_temp_message(qq=event.sender.id,
group=event.sender.group.id,
message_chain=message)
return await self.send_temp_message(
qq=event.sender.id, group=event.sender.group.id, message_chain=message
)
else:
raise ValueError(f'Unsupported event type {event!r}.')
raise ValueError(f"Unsupported event type {event!r}.")
@argument_validation
async def send_friend_message(self, target: int,
message_chain: MessageChain):
async def send_friend_message(self, target: int, message_chain: MessageChain):
"""
:说明:
@ -311,15 +333,13 @@ class Bot(BaseBot):
* ``target: int``: 发送消息目标好友的 QQ 号
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
"""
return await self.api.post('sendFriendMessage',
params={
'target': target,
'messageChain': message_chain.export()
})
return await self.api.post(
"sendFriendMessage",
params={"target": target, "messageChain": message_chain.export()},
)
@argument_validation
async def send_temp_message(self, qq: int, group: int,
message_chain: MessageChain):
async def send_temp_message(self, qq: int, group: int, message_chain: MessageChain):
"""
:说明:
@ -331,18 +351,15 @@ class Bot(BaseBot):
* ``group: int``: 临时会话群号
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
"""
return await self.api.post('sendTempMessage',
params={
'qq': qq,
'group': group,
'messageChain': message_chain.export()
})
return await self.api.post(
"sendTempMessage",
params={"qq": qq, "group": group, "messageChain": message_chain.export()},
)
@argument_validation
async def send_group_message(self,
group: int,
message_chain: MessageChain,
quote: Optional[int] = None):
async def send_group_message(
self, group: int, message_chain: MessageChain, quote: Optional[int] = None
):
"""
:说明:
@ -354,12 +371,14 @@ class Bot(BaseBot):
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
* ``quote: Optional[int]``: 引用一条消息的 message_id 进行回复
"""
return await self.api.post('sendGroupMessage',
params={
'group': group,
'messageChain': message_chain.export(),
'quote': quote
})
return await self.api.post(
"sendGroupMessage",
params={
"group": group,
"messageChain": message_chain.export(),
"quote": quote,
},
)
@argument_validation
async def recall(self, target: int):
@ -372,11 +391,12 @@ class Bot(BaseBot):
* ``target: int``: 需要撤回的消息的message_id
"""
return await self.api.post('recall', params={'target': target})
return await self.api.post("recall", params={"target": target})
@argument_validation
async def send_image_message(self, target: int, qq: int, group: int,
urls: List[str]) -> List[str]:
async def send_image_message(
self, target: int, qq: int, group: int, urls: List[str]
) -> List[str]:
"""
:说明:
@ -396,13 +416,10 @@ class Bot(BaseBot):
- ``List[str]``: 一个包含图片imageId的数组
"""
return await self.api.post('sendImageMessage',
params={
'target': target,
'qq': qq,
'group': group,
'urls': urls
})
return await self.api.post(
"sendImageMessage",
params={"target": target, "qq": qq, "group": group, "urls": urls},
)
@argument_validation
async def upload_image(self, type: str, img: BytesIO):
@ -416,11 +433,7 @@ class Bot(BaseBot):
* ``type: str``: "friend""group""temp"
* ``img: BytesIO``: 图片的BytesIO对象
"""
return await self.api.upload('uploadImage',
params={
'type': type,
'img': img
})
return await self.api.upload("uploadImage", params={"type": type, "img": img})
@argument_validation
async def upload_voice(self, type: str, voice: BytesIO):
@ -434,11 +447,9 @@ class Bot(BaseBot):
* ``type: str``: 当前仅支持 "group"
* ``voice: BytesIO``: 语音的BytesIO对象
"""
return await self.api.upload('uploadVoice',
params={
'type': type,
'voice': voice
})
return await self.api.upload(
"uploadVoice", params={"type": type, "voice": voice}
)
@argument_validation
async def fetch_message(self, count: int = 10):
@ -452,7 +463,7 @@ class Bot(BaseBot):
* ``count: int``: 获取消息和事件的数量
"""
return await self.api.request('fetchMessage', params={'count': count})
return await self.api.request("fetchMessage", params={"count": count})
@argument_validation
async def fetch_latest_message(self, count: int = 10):
@ -466,8 +477,7 @@ class Bot(BaseBot):
* ``count: int``: 获取消息和事件的数量
"""
return await self.api.request('fetchLatestMessage',
params={'count': count})
return await self.api.request("fetchLatestMessage", params={"count": count})
@argument_validation
async def peek_message(self, count: int = 10):
@ -481,7 +491,7 @@ class Bot(BaseBot):
* ``count: int``: 获取消息和事件的数量
"""
return await self.api.request('peekMessage', params={'count': count})
return await self.api.request("peekMessage", params={"count": count})
@argument_validation
async def peek_latest_message(self, count: int = 10):
@ -495,8 +505,7 @@ class Bot(BaseBot):
* ``count: int``: 获取消息和事件的数量
"""
return await self.api.request('peekLatestMessage',
params={'count': count})
return await self.api.request("peekLatestMessage", params={"count": count})
@argument_validation
async def messsage_from_id(self, id: int):
@ -510,7 +519,7 @@ class Bot(BaseBot):
* ``id: int``: 获取消息的message_id
"""
return await self.api.request('messageFromId', params={'id': id})
return await self.api.request("messageFromId", params={"id": id})
@argument_validation
async def count_message(self):
@ -519,7 +528,7 @@ class Bot(BaseBot):
使用此方法获取bot接收并缓存的消息总数注意不包含被删除的
"""
return await self.api.request('countMessage')
return await self.api.request("countMessage")
@argument_validation
async def friend_list(self) -> List[Dict[str, Any]]:
@ -532,7 +541,7 @@ class Bot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的好友列表数据
"""
return await self.api.request('friendList')
return await self.api.request("friendList")
@argument_validation
async def group_list(self) -> List[Dict[str, Any]]:
@ -545,7 +554,7 @@ class Bot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的群列表数据
"""
return await self.api.request('groupList')
return await self.api.request("groupList")
@argument_validation
async def member_list(self, target: int) -> List[Dict[str, Any]]:
@ -562,7 +571,7 @@ class Bot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的群成员列表数据
"""
return await self.api.request('memberList', params={'target': target})
return await self.api.request("memberList", params={"target": target})
@argument_validation
async def mute(self, target: int, member_id: int, time: int):
@ -577,12 +586,9 @@ class Bot(BaseBot):
* ``member_id: int``: 指定群员QQ号
* ``time: int``: 禁言时长单位为秒最多30天
"""
return await self.api.post('mute',
params={
'target': target,
'memberId': member_id,
'time': time
})
return await self.api.post(
"mute", params={"target": target, "memberId": member_id, "time": time}
)
@argument_validation
async def unmute(self, target: int, member_id: int):
@ -596,11 +602,9 @@ class Bot(BaseBot):
* ``target: int``: 指定群的群号
* ``member_id: int``: 指定群员QQ号
"""
return await self.api.post('unmute',
params={
'target': target,
'memberId': member_id
})
return await self.api.post(
"unmute", params={"target": target, "memberId": member_id}
)
@argument_validation
async def kick(self, target: int, member_id: int, msg: str):
@ -615,12 +619,9 @@ class Bot(BaseBot):
* ``member_id: int``: 指定群员QQ号
* ``msg: str``: 信息
"""
return await self.api.post('kick',
params={
'target': target,
'memberId': member_id,
'msg': msg
})
return await self.api.post(
"kick", params={"target": target, "memberId": member_id, "msg": msg}
)
@argument_validation
async def quit(self, target: int):
@ -633,7 +634,7 @@ class Bot(BaseBot):
* ``target: int``: 退出的群号
"""
return await self.api.post('quit', params={'target': target})
return await self.api.post("quit", params={"target": target})
@argument_validation
async def mute_all(self, target: int):
@ -646,7 +647,7 @@ class Bot(BaseBot):
* ``target: int``: 指定群的群号
"""
return await self.api.post('muteAll', params={'target': target})
return await self.api.post("muteAll", params={"target": target})
@argument_validation
async def unmute_all(self, target: int):
@ -659,7 +660,7 @@ class Bot(BaseBot):
* ``target: int``: 指定群的群号
"""
return await self.api.post('unmuteAll', params={'target': target})
return await self.api.post("unmuteAll", params={"target": target})
@argument_validation
async def group_config(self, target: int):
@ -685,7 +686,7 @@ class Bot(BaseBot):
"anonymousChat": true
}
"""
return await self.api.request('groupConfig', params={'target': target})
return await self.api.request("groupConfig", params={"target": target})
@argument_validation
async def modify_group_config(self, target: int, config: Dict[str, Any]):
@ -699,11 +700,9 @@ class Bot(BaseBot):
* ``target: int``: 指定群的群号
* ``config: Dict[str, Any]``: 群设置, 格式见 ``group_config`` 的返回值
"""
return await self.api.post('groupConfig',
params={
'target': target,
'config': config
})
return await self.api.post(
"groupConfig", params={"target": target, "config": config}
)
@argument_validation
async def member_info(self, target: int, member_id: int):
@ -726,15 +725,14 @@ class Bot(BaseBot):
"specialTitle": "群头衔"
}
"""
return await self.api.request('memberInfo',
params={
'target': target,
'memberId': member_id
})
return await self.api.request(
"memberInfo", params={"target": target, "memberId": member_id}
)
@argument_validation
async def modify_member_info(self, target: int, member_id: int,
info: Dict[str, Any]):
async def modify_member_info(
self, target: int, member_id: int, info: Dict[str, Any]
):
"""
:说明:
@ -746,9 +744,6 @@ class Bot(BaseBot):
* ``member_id: int``: 群员QQ号
* ``info: Dict[str, Any]``: 群员资料, 格式见 ``member_info`` 的返回值
"""
return await self.api.post('memberInfo',
params={
'target': target,
'memberId': member_id,
'info': info
})
return await self.api.post(
"memberInfo", params={"target": target, "memberId": member_id, "info": info}
)