🎨 format code using black and isort

This commit is contained in:
yanyongyu
2021-11-22 23:21:26 +08:00
parent 602185a34e
commit a98d98cd12
86 changed files with 2893 additions and 2095 deletions

View File

@ -12,8 +12,15 @@ from nonebot.typing import overrides
from nonebot.message import handle_event
from nonebot.adapters import Bot as BaseBot
from nonebot.utils import DataclassEncoder, escape_tag
from nonebot.drivers import (Driver, WebSocket, HTTPRequest, HTTPResponse,
ForwardDriver, HTTPConnection, WebSocketSetup)
from nonebot.drivers import (
Driver,
WebSocket,
HTTPRequest,
HTTPResponse,
ForwardDriver,
HTTPConnection,
WebSocketSetup,
)
from .utils import log, escape
from .config import Config as CQHTTPConfig
@ -49,15 +56,12 @@ async def _check_reply(bot: "Bot", event: "Event"):
return
try:
index = list(map(lambda x: x.type == "reply",
event.message)).index(True)
index = list(map(lambda x: x.type == "reply", event.message)).index(True)
except ValueError:
return
msg_seg = event.message[index]
try:
event.reply = Reply.parse_obj(await
bot.get_msg(message_id=msg_seg.data["id"]
))
event.reply = Reply.parse_obj(await bot.get_msg(message_id=msg_seg.data["id"]))
except Exception as e:
log("WARNING", f"Error when getting message reply info: {repr(e)}", e)
return
@ -68,8 +72,7 @@ async def _check_reply(bot: "Bot", event: "Event"):
if len(event.message) > index and event.message[index].type == "at":
del event.message[index]
if len(event.message) > index and event.message[index].type == "text":
event.message[index].data["text"] = event.message[index].data[
"text"].lstrip()
event.message[index].data["text"] = event.message[index].data["text"].lstrip()
if not event.message[index].data["text"]:
del event.message[index]
if not event.message:
@ -99,23 +102,24 @@ def _check_at_me(bot: "Bot", event: "Event"):
else:
def _is_at_me_seg(segment: MessageSegment):
return segment.type == "at" and str(segment.data.get(
"qq", "")) == str(event.self_id)
return segment.type == "at" and str(segment.data.get("qq", "")) == str(
event.self_id
)
# check the first segment
if _is_at_me_seg(event.message[0]):
event.to_me = True
event.message.pop(0)
if event.message and event.message[0].type == "text":
event.message[0].data["text"] = event.message[0].data[
"text"].lstrip()
event.message[0].data["text"] = event.message[0].data["text"].lstrip()
if not event.message[0].data["text"]:
del event.message[0]
if event.message and _is_at_me_seg(event.message[0]):
event.message.pop(0)
if event.message and event.message[0].type == "text":
event.message[0].data["text"] = event.message[0].data[
"text"].lstrip()
event.message[0].data["text"] = (
event.message[0].data["text"].lstrip()
)
if not event.message[0].data["text"]:
del event.message[0]
@ -123,9 +127,11 @@ def _check_at_me(bot: "Bot", event: "Event"):
# check the last segment
i = -1
last_msg_seg = event.message[i]
if last_msg_seg.type == "text" and \
not last_msg_seg.data["text"].strip() and \
len(event.message) >= 2:
if (
last_msg_seg.type == "text"
and not last_msg_seg.data["text"].strip()
and len(event.message) >= 2
):
i -= 1
last_msg_seg = event.message[i]
@ -161,13 +167,12 @@ def _check_nickname(bot: "Bot", event: "Event"):
if nicknames:
# check if the user is calling me with my nickname
nickname_regex = "|".join(nicknames)
m = re.search(rf"^({nickname_regex})([\s,]*|$)", first_text,
re.IGNORECASE)
m = re.search(rf"^({nickname_regex})([\s,]*|$)", first_text, re.IGNORECASE)
if m:
nickname = m.group(1)
log("DEBUG", f"User is calling me {nickname}")
event.to_me = True
first_msg_seg.data["text"] = first_text[m.end():]
first_msg_seg.data["text"] = first_text[m.end() :]
def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any:
@ -206,8 +211,9 @@ class ResultStore:
@classmethod
def add_result(cls, result: Dict[str, Any]):
if isinstance(result.get("echo"), dict) and \
isinstance(result["echo"].get("seq"), int):
if isinstance(result.get("echo"), dict) and isinstance(
result["echo"].get("seq"), int
):
future = cls._futures.get(result["echo"]["seq"])
if future:
future.set_result(result)
@ -228,6 +234,7 @@ class Bot(BaseBot):
"""
CQHTTP 协议 Bot 适配。继承属性参考 `BaseBot <./#class-basebot>`_ 。
"""
cqhttp_config: CQHTTPConfig
@property
@ -249,22 +256,25 @@ class Bot(BaseBot):
elif isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls:
for self_id, url in cls.cqhttp_config.ws_urls.items():
try:
headers = {
"authorization":
f"Bearer {cls.cqhttp_config.access_token}"
} if cls.cqhttp_config.access_token else {}
headers = (
{"authorization": f"Bearer {cls.cqhttp_config.access_token}"}
if cls.cqhttp_config.access_token
else {}
)
driver.setup_websocket(
WebSocketSetup("cqhttp", self_id, url, headers=headers))
WebSocketSetup("cqhttp", self_id, url, headers=headers)
)
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Bad url {escape_tag(url)} for bot {escape_tag(self_id)} "
"in cqhttp forward websocket</bg #f8bbd0></r>")
"in cqhttp forward websocket</bg #f8bbd0></r>"
)
@classmethod
@overrides(BaseBot)
async def check_permission(
cls, driver: Driver,
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]:
cls, driver: Driver, request: HTTPConnection
) -> Tuple[Optional[str], HTTPResponse]:
"""
:说明:
@ -286,22 +296,26 @@ class Bot(BaseBot):
if not x_signature:
log("WARNING", "Missing Signature Header")
return None, HTTPResponse(401, b"Missing Signature")
sig = hmac.new(secret.encode("utf-8"), request.body,
"sha1").hexdigest()
sig = hmac.new(secret.encode("utf-8"), request.body, "sha1").hexdigest()
if x_signature != "sha1=" + sig:
log("WARNING", "Signature Header is invalid")
return None, HTTPResponse(403, b"Signature is invalid")
access_token = cqhttp_config.access_token
if access_token and access_token != token and isinstance(
request, WebSocket):
if access_token and access_token != token and isinstance(request, WebSocket):
log(
"WARNING", "Authorization Header is invalid"
if token else "Missing Authorization Header")
"WARNING",
"Authorization Header is invalid"
if token
else "Missing Authorization Header",
)
return None, HTTPResponse(
403, b"Authorization Header is invalid"
if token else b"Missing Authorization Header")
return str(x_self_id), HTTPResponse(204, b'')
403,
b"Authorization Header is invalid"
if token
else b"Missing Authorization Header",
)
return str(x_self_id), HTTPResponse(204, b"")
@overrides(BaseBot)
async def handle_message(self, message: bytes):
@ -320,7 +334,7 @@ class Bot(BaseBot):
return
try:
post_type = data['post_type']
post_type = data["post_type"]
detail_type = data.get(f"{post_type}_type")
detail_type = f".{detail_type}" if detail_type else ""
sub_type = data.get("sub_type")
@ -352,17 +366,13 @@ class Bot(BaseBot):
if isinstance(self.request, WebSocket):
seq = ResultStore.get_seq()
json_data = json.dumps(
{
"action": api,
"params": data,
"echo": {
"seq": seq
}
},
cls=DataclassEncoder)
{"action": api, "params": data, "echo": {"seq": seq}},
cls=DataclassEncoder,
)
await self.request.send(json_data)
return _handle_api_result(await ResultStore.fetch(
seq, self.config.api_timeout))
return _handle_api_result(
await ResultStore.fetch(seq, self.config.api_timeout)
)
elif isinstance(self.request, HTTPRequest):
api_root = self.config.api_root.get(self.self_id)
@ -373,22 +383,25 @@ class Bot(BaseBot):
headers = {"Content-Type": "application/json"}
if self.cqhttp_config.access_token is not None:
headers[
"Authorization"] = "Bearer " + self.cqhttp_config.access_token
headers["Authorization"] = "Bearer " + self.cqhttp_config.access_token
try:
async with httpx.AsyncClient(headers=headers,
follow_redirects=True) as client:
async with httpx.AsyncClient(
headers=headers, follow_redirects=True
) as client:
response = await client.post(
api_root + api,
content=json.dumps(data, cls=DataclassEncoder),
timeout=self.config.api_timeout)
timeout=self.config.api_timeout,
)
if 200 <= response.status_code < 300:
result = response.json()
return _handle_api_result(result)
raise NetworkError(f"HTTP request received unexpected "
f"status code: {response.status_code}")
raise NetworkError(
f"HTTP request received unexpected "
f"status code: {response.status_code}"
)
except httpx.InvalidURL:
raise NetworkError("API root url invalid")
except httpx.HTTPError:
@ -418,11 +431,13 @@ class Bot(BaseBot):
return await super().call_api(api, **data)
@overrides(BaseBot)
async def send(self,
event: Event,
message: Union[str, Message, MessageSegment],
at_sender: bool = False,
**kwargs) -> Any:
async def send(
self,
event: Event,
message: Union[str, Message, MessageSegment],
at_sender: bool = False,
**kwargs,
) -> Any:
"""
:说明:
@ -445,8 +460,9 @@ class Bot(BaseBot):
- ``NetworkError``: 网络错误
- ``ActionFailed``: API 调用失败
"""
message = escape(message, escape_comma=False) if isinstance(
message, str) else message
message = (
escape(message, escape_comma=False) if isinstance(message, str) else message
)
msg = message if isinstance(message, Message) else Message(message)
at_sender = at_sender and bool(getattr(event, "user_id", None))

View File

@ -8,7 +8,6 @@ from nonebot.drivers import Driver, WebSocket
from .event import Event
from .message import Message, MessageSegment
def get_auth_bearer(access_token: Optional[str] = ...) -> Optional[str]:
...

View File

@ -1,6 +1,6 @@
from typing import Dict, Optional
from pydantic import Field, BaseModel, AnyUrl
from pydantic import Field, AnyUrl, BaseModel
# priority: alias > origin
@ -14,11 +14,10 @@ class Config(BaseModel):
- ``secret`` / ``cqhttp_secret``: CQHTTP HTTP 上报数据签名口令
- ``ws_urls`` / ``cqhttp_ws_urls``: CQHTTP 正向 Websocket 连接 Bot ID、目标 URL 字典
"""
access_token: Optional[str] = Field(default=None,
alias="cqhttp_access_token")
access_token: Optional[str] = Field(default=None, alias="cqhttp_access_token")
secret: Optional[str] = Field(default=None, alias="cqhttp_secret")
ws_urls: Dict[str, AnyUrl] = Field(default_factory=set,
alias="cqhttp_ws_urls")
ws_urls: Dict[str, AnyUrl] = Field(default_factory=set, alias="cqhttp_ws_urls")
class Config:
extra = "ignore"

View File

@ -5,12 +5,13 @@ from typing import TYPE_CHECKING, List, Type, Optional
from pydantic import BaseModel
from pygtrie import StringTrie
from .message import Message
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from .exception import NoLogException
from nonebot.adapters import Event as BaseEvent
from .message import Message
from .exception import NoLogException
if TYPE_CHECKING:
from .bot import Bot
@ -22,6 +23,7 @@ class Event(BaseEvent):
.. _CQHTTP 文档:
https://github.com/howmanybots/onebot/blob/master/README.md
"""
__event__ = ""
time: int
self_id: int
@ -118,6 +120,7 @@ class Status(BaseModel):
# Message Events
class MessageEvent(Event):
"""消息事件"""
__event__ = "message"
post_type: Literal["message"]
sub_type: str
@ -144,8 +147,9 @@ class MessageEvent(Event):
@overrides(Event)
def get_event_name(self) -> str:
sub_type = getattr(self, "sub_type", None)
return f"{self.post_type}.{self.message_type}" + (f".{sub_type}"
if sub_type else "")
return f"{self.post_type}.{self.message_type}" + (
f".{sub_type}" if sub_type else ""
)
@overrides(Event)
def get_message(self) -> Message:
@ -170,20 +174,29 @@ class MessageEvent(Event):
class PrivateMessageEvent(MessageEvent):
"""私聊消息"""
__event__ = "message.private"
message_type: Literal["private"]
@overrides(Event)
def get_event_description(self) -> str:
return (f'Message {self.message_id} from {self.user_id} "' + "".join(
map(
lambda x: escape_tag(str(x))
if x.is_text() else f"<le>{escape_tag(str(x))}</le>",
self.message)) + '"')
return (
f'Message {self.message_id} from {self.user_id} "'
+ "".join(
map(
lambda x: escape_tag(str(x))
if x.is_text()
else f"<le>{escape_tag(str(x))}</le>",
self.message,
)
)
+ '"'
)
class GroupMessageEvent(MessageEvent):
"""群消息"""
__event__ = "message.group"
message_type: Literal["group"]
group_id: int
@ -196,8 +209,13 @@ class GroupMessageEvent(MessageEvent):
+ "".join(
map(
lambda x: escape_tag(str(x))
if x.is_text() else f"<le>{escape_tag(str(x))}</le>",
self.message)) + '"')
if x.is_text()
else f"<le>{escape_tag(str(x))}</le>",
self.message,
)
)
+ '"'
)
@overrides(MessageEvent)
def get_session_id(self) -> str:
@ -207,6 +225,7 @@ class GroupMessageEvent(MessageEvent):
# Notice Events
class NoticeEvent(Event):
"""通知事件"""
__event__ = "notice"
post_type: Literal["notice"]
notice_type: str
@ -214,12 +233,14 @@ class NoticeEvent(Event):
@overrides(Event)
def get_event_name(self) -> str:
sub_type = getattr(self, "sub_type", None)
return f"{self.post_type}.{self.notice_type}" + (f".{sub_type}"
if sub_type else "")
return f"{self.post_type}.{self.notice_type}" + (
f".{sub_type}" if sub_type else ""
)
class GroupUploadNoticeEvent(NoticeEvent):
"""群文件上传事件"""
__event__ = "notice.group_upload"
notice_type: Literal["group_upload"]
user_id: int
@ -237,6 +258,7 @@ class GroupUploadNoticeEvent(NoticeEvent):
class GroupAdminNoticeEvent(NoticeEvent):
"""群管理员变动"""
__event__ = "notice.group_admin"
notice_type: Literal["group_admin"]
sub_type: str
@ -258,6 +280,7 @@ class GroupAdminNoticeEvent(NoticeEvent):
class GroupDecreaseNoticeEvent(NoticeEvent):
"""群成员减少事件"""
__event__ = "notice.group_decrease"
notice_type: Literal["group_decrease"]
sub_type: str
@ -280,6 +303,7 @@ class GroupDecreaseNoticeEvent(NoticeEvent):
class GroupIncreaseNoticeEvent(NoticeEvent):
"""群成员增加事件"""
__event__ = "notice.group_increase"
notice_type: Literal["group_increase"]
sub_type: str
@ -302,6 +326,7 @@ class GroupIncreaseNoticeEvent(NoticeEvent):
class GroupBanNoticeEvent(NoticeEvent):
"""群禁言事件"""
__event__ = "notice.group_ban"
notice_type: Literal["group_ban"]
sub_type: str
@ -325,6 +350,7 @@ class GroupBanNoticeEvent(NoticeEvent):
class FriendAddNoticeEvent(NoticeEvent):
"""好友添加事件"""
__event__ = "notice.friend_add"
notice_type: Literal["friend_add"]
user_id: int
@ -340,6 +366,7 @@ class FriendAddNoticeEvent(NoticeEvent):
class GroupRecallNoticeEvent(NoticeEvent):
"""群消息撤回事件"""
__event__ = "notice.group_recall"
notice_type: Literal["group_recall"]
user_id: int
@ -362,6 +389,7 @@ class GroupRecallNoticeEvent(NoticeEvent):
class FriendRecallNoticeEvent(NoticeEvent):
"""好友消息撤回事件"""
__event__ = "notice.friend_recall"
notice_type: Literal["friend_recall"]
user_id: int
@ -378,6 +406,7 @@ class FriendRecallNoticeEvent(NoticeEvent):
class NotifyEvent(NoticeEvent):
"""提醒事件"""
__event__ = "notice.notify"
notice_type: Literal["notify"]
sub_type: str
@ -395,6 +424,7 @@ class NotifyEvent(NoticeEvent):
class PokeNotifyEvent(NotifyEvent):
"""戳一戳提醒事件"""
__event__ = "notice.notify.poke"
sub_type: Literal["poke"]
target_id: int
@ -413,6 +443,7 @@ class PokeNotifyEvent(NotifyEvent):
class LuckyKingNotifyEvent(NotifyEvent):
"""群红包运气王提醒事件"""
__event__ = "notice.notify.lucky_king"
sub_type: Literal["lucky_king"]
target_id: int
@ -432,6 +463,7 @@ class LuckyKingNotifyEvent(NotifyEvent):
class HonorNotifyEvent(NotifyEvent):
"""群荣誉变更提醒事件"""
__event__ = "notice.notify.honor"
sub_type: Literal["honor"]
honor_type: str
@ -444,6 +476,7 @@ class HonorNotifyEvent(NotifyEvent):
# Request Events
class RequestEvent(Event):
"""请求事件"""
__event__ = "request"
post_type: Literal["request"]
request_type: str
@ -451,12 +484,14 @@ class RequestEvent(Event):
@overrides(Event)
def get_event_name(self) -> str:
sub_type = getattr(self, "sub_type", None)
return f"{self.post_type}.{self.request_type}" + (f".{sub_type}"
if sub_type else "")
return f"{self.post_type}.{self.request_type}" + (
f".{sub_type}" if sub_type else ""
)
class FriendRequestEvent(RequestEvent):
"""加好友请求事件"""
__event__ = "request.friend"
request_type: Literal["friend"]
user_id: int
@ -472,9 +507,9 @@ class FriendRequestEvent(RequestEvent):
return str(self.user_id)
async def approve(self, bot: "Bot", remark: str = ""):
return await bot.set_friend_add_request(flag=self.flag,
approve=True,
remark=remark)
return await bot.set_friend_add_request(
flag=self.flag, approve=True, remark=remark
)
async def reject(self, bot: "Bot"):
return await bot.set_friend_add_request(flag=self.flag, approve=False)
@ -482,6 +517,7 @@ class FriendRequestEvent(RequestEvent):
class GroupRequestEvent(RequestEvent):
"""加群请求/邀请事件"""
__event__ = "request.group"
request_type: Literal["group"]
sub_type: str
@ -499,20 +535,20 @@ class GroupRequestEvent(RequestEvent):
return f"group_{self.group_id}_{self.user_id}"
async def approve(self, bot: "Bot"):
return await bot.set_group_add_request(flag=self.flag,
sub_type=self.sub_type,
approve=True)
return await bot.set_group_add_request(
flag=self.flag, sub_type=self.sub_type, approve=True
)
async def reject(self, bot: "Bot", reason: str = ""):
return await bot.set_group_add_request(flag=self.flag,
sub_type=self.sub_type,
approve=False,
reason=reason)
return await bot.set_group_add_request(
flag=self.flag, sub_type=self.sub_type, approve=False, reason=reason
)
# Meta Events
class MetaEvent(Event):
"""元事件"""
__event__ = "meta_event"
post_type: Literal["meta_event"]
meta_event_type: str
@ -520,8 +556,9 @@ class MetaEvent(Event):
@overrides(Event)
def get_event_name(self) -> str:
sub_type = getattr(self, "sub_type", None)
return f"{self.post_type}.{self.meta_event_type}" + (f".{sub_type}" if
sub_type else "")
return f"{self.post_type}.{self.meta_event_type}" + (
f".{sub_type}" if sub_type else ""
)
@overrides(Event)
def get_log_string(self) -> str:
@ -530,6 +567,7 @@ class MetaEvent(Event):
class LifecycleMetaEvent(MetaEvent):
"""生命周期元事件"""
__event__ = "meta_event.lifecycle"
meta_event_type: Literal["lifecycle"]
sub_type: str
@ -537,6 +575,7 @@ class LifecycleMetaEvent(MetaEvent):
class HeartbeatMetaEvent(MetaEvent):
"""心跳元事件"""
__event__ = "meta_event.heartbeat"
meta_event_type: Literal["heartbeat"]
status: Status
@ -567,12 +606,28 @@ def get_event_model(event_name) -> List[Type[Event]]:
__all__ = [
"Event", "MessageEvent", "PrivateMessageEvent", "GroupMessageEvent",
"NoticeEvent", "GroupUploadNoticeEvent", "GroupAdminNoticeEvent",
"GroupDecreaseNoticeEvent", "GroupIncreaseNoticeEvent",
"GroupBanNoticeEvent", "FriendAddNoticeEvent", "GroupRecallNoticeEvent",
"FriendRecallNoticeEvent", "NotifyEvent", "PokeNotifyEvent",
"LuckyKingNotifyEvent", "HonorNotifyEvent", "RequestEvent",
"FriendRequestEvent", "GroupRequestEvent", "MetaEvent",
"LifecycleMetaEvent", "HeartbeatMetaEvent", "get_event_model"
"Event",
"MessageEvent",
"PrivateMessageEvent",
"GroupMessageEvent",
"NoticeEvent",
"GroupUploadNoticeEvent",
"GroupAdminNoticeEvent",
"GroupDecreaseNoticeEvent",
"GroupIncreaseNoticeEvent",
"GroupBanNoticeEvent",
"FriendAddNoticeEvent",
"GroupRecallNoticeEvent",
"FriendRecallNoticeEvent",
"NotifyEvent",
"PokeNotifyEvent",
"LuckyKingNotifyEvent",
"HonorNotifyEvent",
"RequestEvent",
"FriendRequestEvent",
"GroupRequestEvent",
"MetaEvent",
"LifecycleMetaEvent",
"HeartbeatMetaEvent",
"get_event_model",
]

View File

@ -8,7 +8,6 @@ from nonebot.exception import ApiNotAvailable as BaseApiNotAvailable
class CQHTTPAdapterException(AdapterException):
def __init__(self):
super().__init__("cqhttp")
@ -33,8 +32,11 @@ class ActionFailed(BaseActionFailed, CQHTTPAdapterException):
self.info = kwargs
def __repr__(self):
return f"<ActionFailed " + ", ".join(
f"{k}={v}" for k, v in self.info.items()) + ">"
return (
f"<ActionFailed "
+ ", ".join(f"{k}={v}" for k, v in self.info.items())
+ ">"
)
def __str__(self):
return self.__repr__()

View File

@ -5,10 +5,11 @@ from base64 import b64encode
from typing import Any, Type, Tuple, Union, Mapping, Iterable, Optional, cast
from nonebot.typing import overrides
from .utils import log, _b2s, escape, unescape
from nonebot.adapters import Message as BaseMessage
from nonebot.adapters import MessageSegment as BaseMessageSegment
from .utils import log, _b2s, escape, unescape
class MessageSegment(BaseMessageSegment["Message"]):
"""
@ -27,23 +28,24 @@ class MessageSegment(BaseMessageSegment["Message"]):
# process special types
if type_ == "text":
return escape(
data.get("text", ""), # type: ignore
escape_comma=False)
return escape(data.get("text", ""), escape_comma=False) # type: ignore
params = ",".join(
[f"{k}={escape(str(v))}" for k, v in data.items() if v is not None])
[f"{k}={escape(str(v))}" for k, v in data.items() if v is not None]
)
return f"[CQ:{type_}{',' if params else ''}{params}]"
@overrides(BaseMessageSegment)
def __add__(self, other) -> "Message":
return Message(self) + (MessageSegment.text(other) if isinstance(
other, str) else other)
return Message(self) + (
MessageSegment.text(other) if isinstance(other, str) else other
)
@overrides(BaseMessageSegment)
def __radd__(self, other) -> "Message":
return (MessageSegment.text(other)
if isinstance(other, str) else Message(other)) + self
return (
MessageSegment.text(other) if isinstance(other, str) else Message(other)
) + self
@overrides(BaseMessageSegment)
def is_text(self) -> bool:
@ -83,11 +85,13 @@ class MessageSegment(BaseMessageSegment["Message"]):
return MessageSegment("forward", {"id": id_})
@staticmethod
def image(file: Union[str, bytes, BytesIO, Path],
type_: Optional[str] = None,
cache: bool = True,
proxy: bool = True,
timeout: Optional[int] = None) -> "MessageSegment":
def image(
file: Union[str, bytes, BytesIO, Path],
type_: Optional[str] = None,
cache: bool = True,
proxy: bool = True,
timeout: Optional[int] = None,
) -> "MessageSegment":
if isinstance(file, BytesIO):
file = file.getvalue()
if isinstance(file, bytes):
@ -95,74 +99,85 @@ class MessageSegment(BaseMessageSegment["Message"]):
elif isinstance(file, Path):
file = f"file:///{file.resolve()}"
return MessageSegment(
"image", {
"image",
{
"file": file,
"type": type_,
"cache": _b2s(cache),
"proxy": _b2s(proxy),
"timeout": timeout
})
"timeout": timeout,
},
)
@staticmethod
def json(data: str) -> "MessageSegment":
return MessageSegment("json", {"data": data})
@staticmethod
def location(latitude: float,
longitude: float,
title: Optional[str] = None,
content: Optional[str] = None) -> "MessageSegment":
def location(
latitude: float,
longitude: float,
title: Optional[str] = None,
content: Optional[str] = None,
) -> "MessageSegment":
return MessageSegment(
"location", {
"location",
{
"lat": str(latitude),
"lon": str(longitude),
"title": title,
"content": content
})
"content": content,
},
)
@staticmethod
def music(type_: str, id_: int) -> "MessageSegment":
return MessageSegment("music", {"type": type_, "id": id_})
@staticmethod
def music_custom(url: str,
audio: str,
title: str,
content: Optional[str] = None,
img_url: Optional[str] = None) -> "MessageSegment":
def music_custom(
url: str,
audio: str,
title: str,
content: Optional[str] = None,
img_url: Optional[str] = None,
) -> "MessageSegment":
return MessageSegment(
"music", {
"music",
{
"type": "custom",
"url": url,
"audio": audio,
"title": title,
"content": content,
"image": img_url
})
"image": img_url,
},
)
@staticmethod
def node(id_: int) -> "MessageSegment":
return MessageSegment("node", {"id": str(id_)})
@staticmethod
def node_custom(user_id: int, nickname: str,
content: Union[str, "Message"]) -> "MessageSegment":
return MessageSegment("node", {
"user_id": str(user_id),
"nickname": nickname,
"content": content
})
def node_custom(
user_id: int, nickname: str, content: Union[str, "Message"]
) -> "MessageSegment":
return MessageSegment(
"node", {"user_id": str(user_id), "nickname": nickname, "content": content}
)
@staticmethod
def poke(type_: str, id_: str) -> "MessageSegment":
return MessageSegment("poke", {"type": type_, "id": id_})
@staticmethod
def record(file: Union[str, bytes, BytesIO, Path],
magic: Optional[bool] = None,
cache: Optional[bool] = None,
proxy: Optional[bool] = None,
timeout: Optional[int] = None) -> "MessageSegment":
def record(
file: Union[str, bytes, BytesIO, Path],
magic: Optional[bool] = None,
cache: Optional[bool] = None,
proxy: Optional[bool] = None,
timeout: Optional[int] = None,
) -> "MessageSegment":
if isinstance(file, BytesIO):
file = file.getvalue()
if isinstance(file, bytes):
@ -170,13 +185,15 @@ class MessageSegment(BaseMessageSegment["Message"]):
elif isinstance(file, Path):
file = f"file:///{file.resolve()}"
return MessageSegment(
"record", {
"record",
{
"file": file,
"magic": _b2s(magic),
"cache": _b2s(cache),
"proxy": _b2s(proxy),
"timeout": timeout
})
"timeout": timeout,
},
)
@staticmethod
def reply(id_: int) -> "MessageSegment":
@ -191,26 +208,27 @@ class MessageSegment(BaseMessageSegment["Message"]):
return MessageSegment("shake", {})
@staticmethod
def share(url: str = "",
title: str = "",
content: Optional[str] = None,
image: Optional[str] = None) -> "MessageSegment":
return MessageSegment("share", {
"url": url,
"title": title,
"content": content,
"image": image
})
def share(
url: str = "",
title: str = "",
content: Optional[str] = None,
image: Optional[str] = None,
) -> "MessageSegment":
return MessageSegment(
"share", {"url": url, "title": title, "content": content, "image": image}
)
@staticmethod
def text(text: str) -> "MessageSegment":
return MessageSegment("text", {"text": text})
@staticmethod
def video(file: Union[str, bytes, BytesIO, Path],
cache: Optional[bool] = None,
proxy: Optional[bool] = None,
timeout: Optional[int] = None) -> "MessageSegment":
def video(
file: Union[str, bytes, BytesIO, Path],
cache: Optional[bool] = None,
proxy: Optional[bool] = None,
timeout: Optional[int] = None,
) -> "MessageSegment":
if isinstance(file, BytesIO):
file = file.getvalue()
if isinstance(file, bytes):
@ -218,12 +236,14 @@ class MessageSegment(BaseMessageSegment["Message"]):
elif isinstance(file, Path):
file = f"file:///{file.resolve()}"
return MessageSegment(
"video", {
"video",
{
"file": file,
"cache": _b2s(cache),
"proxy": _b2s(proxy),
"timeout": timeout
})
"timeout": timeout,
},
)
@staticmethod
def xml(data: str) -> "MessageSegment":
@ -241,22 +261,22 @@ class Message(BaseMessage[MessageSegment]):
return MessageSegment
@overrides(BaseMessage)
def __add__(self, other: Union[str, Mapping,
Iterable[Mapping]]) -> "Message":
def __add__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> "Message":
return super(Message, self).__add__(
MessageSegment.text(other) if isinstance(other, str) else other)
MessageSegment.text(other) if isinstance(other, str) else other
)
@overrides(BaseMessage)
def __radd__(self, other: Union[str, Mapping,
Iterable[Mapping]]) -> "Message":
def __radd__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> "Message":
return super(Message, self).__radd__(
MessageSegment.text(other) if isinstance(other, str) else other)
MessageSegment.text(other) if isinstance(other, str) else other
)
@staticmethod
@overrides(BaseMessage)
def _construct(
msg: Union[str, Mapping,
Iterable[Mapping]]) -> Iterable[MessageSegment]:
msg: Union[str, Mapping, Iterable[Mapping]]
) -> Iterable[MessageSegment]:
if isinstance(msg, Mapping):
msg = cast(Mapping[str, Any], msg)
yield MessageSegment(msg["type"], msg.get("data") or {})
@ -270,14 +290,15 @@ class Message(BaseMessage[MessageSegment]):
def _iter_message(msg: str) -> Iterable[Tuple[str, str]]:
text_begin = 0
for cqcode in re.finditer(
r"\[CQ:(?P<type>[a-zA-Z0-9-_.]+)"
r"(?P<params>"
r"(?:,[a-zA-Z0-9-_.]+=[^,\]]+)*"
r"),?\]", msg):
yield "text", msg[text_begin:cqcode.pos + cqcode.start()]
r"\[CQ:(?P<type>[a-zA-Z0-9-_.]+)"
r"(?P<params>"
r"(?:,[a-zA-Z0-9-_.]+=[^,\]]+)*"
r"),?\]",
msg,
):
yield "text", msg[text_begin : cqcode.pos + cqcode.start()]
text_begin = cqcode.pos + cqcode.end()
yield cqcode.group("type"), cqcode.group("params").lstrip(
",")
yield cqcode.group("type"), cqcode.group("params").lstrip(",")
yield "text", msg[text_begin:]
for type_, data in _iter_message(msg):
@ -287,10 +308,11 @@ class Message(BaseMessage[MessageSegment]):
yield MessageSegment(type_, {"text": unescape(data)})
else:
data = {
k: unescape(v) for k, v in map(
k: unescape(v)
for k, v in map(
lambda x: x.split("=", maxsplit=1),
filter(lambda x: x, (
x.lstrip() for x in data.split(","))))
filter(lambda x: x, (x.lstrip() for x in data.split(","))),
)
}
yield MessageSegment(type_, data)

View File

@ -1,5 +1,6 @@
from nonebot.adapters import Event
from nonebot.permission import Permission
from .event import GroupMessageEvent, PrivateMessageEvent
@ -42,8 +43,7 @@ async def _group(event: Event) -> bool:
async def _group_member(event: Event) -> bool:
return isinstance(event,
GroupMessageEvent) and event.sender.role == "member"
return isinstance(event, GroupMessageEvent) and event.sender.role == "member"
async def _group_admin(event: Event) -> bool:
@ -76,6 +76,12 @@ GROUP_OWNER = Permission(_group_owner)
"""
__all__ = [
"PRIVATE", "PRIVATE_FRIEND", "PRIVATE_GROUP", "PRIVATE_OTHER", "GROUP",
"GROUP_MEMBER", "GROUP_ADMIN", "GROUP_OWNER"
"PRIVATE",
"PRIVATE_FRIEND",
"PRIVATE_GROUP",
"PRIVATE_OTHER",
"GROUP",
"GROUP_MEMBER",
"GROUP_ADMIN",
"GROUP_OWNER",
]

View File

@ -16,9 +16,7 @@ def escape(s: str, *, escape_comma: bool = True) -> str:
* ``s: str``: 需要转义的字符串
* ``escape_comma: bool``: 是否转义逗号(``,``)。
"""
s = s.replace("&", "&amp;") \
.replace("[", "&#91;") \
.replace("]", "&#93;")
s = s.replace("&", "&amp;").replace("[", "&#91;").replace("]", "&#93;")
if escape_comma:
s = s.replace(",", "&#44;")
return s
@ -34,10 +32,12 @@ def unescape(s: str) -> str:
* ``s: str``: 需要转义的字符串
"""
return s.replace("&#44;", ",") \
.replace("&#91;", "[") \
.replace("&#93;", "]") \
return (
s.replace("&#44;", ",")
.replace("&#91;", "[")
.replace("&#93;", "]")
.replace("&amp;", "&")
)
def _b2s(b: Optional[bool]) -> Optional[str]:

View File

@ -34,6 +34,21 @@ nonebot2 = { path = "../../", develop = true }
# url = "https://mirrors.aliyun.com/pypi/simple/"
# default = true
[tool.black]
line-length = 88
target-version = ["py37", "py38", "py39"]
include = '\.pyi?$'
extend-exclude = '''
'''
[tool.isort]
profile = "black"
line_length = 80
length_sort = true
skip_gitignore = true
force_sort_within_sections = true
extra_standard_library = ["typing_extensions"]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

View File

@ -16,10 +16,18 @@ from nonebot.drivers import Driver, HTTPRequest, HTTPResponse, HTTPConnection
from .config import Config as DingConfig
from .utils import log, calc_hmac_base64
from .message import Message, MessageSegment
from .exception import (ActionFailed, NetworkError, SessionExpired,
ApiNotAvailable)
from .event import (MessageEvent, ConversationType, GroupMessageEvent,
PrivateMessageEvent)
from .exception import (
ActionFailed,
NetworkError,
SessionExpired,
ApiNotAvailable,
)
from .event import (
MessageEvent,
ConversationType,
GroupMessageEvent,
PrivateMessageEvent,
)
if TYPE_CHECKING:
from nonebot.config import Config
@ -31,6 +39,7 @@ class Bot(BaseBot):
"""
钉钉 协议 Bot 适配。继承属性参考 `BaseBot <./#class-basebot>`_ 。
"""
ding_config: DingConfig
@property
@ -48,8 +57,8 @@ class Bot(BaseBot):
@classmethod
@overrides(BaseBot)
async def check_permission(
cls, driver: Driver,
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]:
cls, driver: Driver, request: HTTPConnection
) -> Tuple[Optional[str], HTTPResponse]:
"""
:说明:
@ -61,7 +70,8 @@ class Bot(BaseBot):
# 检查连接方式
if not isinstance(request, HTTPRequest):
return None, HTTPResponse(
405, b"Unsupported connection type, available type: `http`")
405, b"Unsupported connection type, available type: `http`"
)
# 检查 timestamp
if not timestamp:
@ -74,13 +84,15 @@ class Bot(BaseBot):
log("WARNING", "Missing Signature Header")
return None, HTTPResponse(400, b"Missing `sign` Header")
sign_base64 = calc_hmac_base64(str(timestamp), secret)
if sign != sign_base64.decode('utf-8'):
if sign != sign_base64.decode("utf-8"):
log("WARNING", "Signature Header is invalid")
return None, HTTPResponse(403, b"Signature is invalid")
else:
log("WARNING", "Ding signature check ignored!")
return (json.loads(request.body.decode())["chatbotUserId"],
HTTPResponse(204, b''))
return (
json.loads(request.body.decode())["chatbotUserId"],
HTTPResponse(204, b""),
)
@overrides(BaseBot)
async def handle_message(self, message: bytes):
@ -111,10 +123,9 @@ class Bot(BaseBot):
return
@overrides(BaseBot)
async def _call_api(self,
api: str,
event: Optional[MessageEvent] = None,
**data) -> Any:
async def _call_api(
self, api: str, event: Optional[MessageEvent] = None, **data
) -> Any:
if not isinstance(self.request, HTTPRequest):
log("ERROR", "Only support http connection.")
return
@ -138,7 +149,8 @@ class Bot(BaseBot):
if event:
# 确保 sessionWebhook 没有过期
if int(datetime.now().timestamp()) > int(
event.sessionWebhookExpiredTime / 1000):
event.sessionWebhookExpiredTime / 1000
):
raise SessionExpired
webhook = event.sessionWebhook
@ -150,32 +162,37 @@ class Bot(BaseBot):
if not message:
raise ValueError("Message not found")
try:
async with httpx.AsyncClient(headers=headers,
follow_redirects=True) as client:
response = await client.post(webhook,
params=params,
json=message._produce(),
timeout=self.config.api_timeout)
async with httpx.AsyncClient(
headers=headers, follow_redirects=True
) as client:
response = await client.post(
webhook,
params=params,
json=message._produce(),
timeout=self.config.api_timeout,
)
if 200 <= response.status_code < 300:
result = response.json()
if isinstance(result, dict):
if result.get("errcode") != 0:
raise ActionFailed(errcode=result.get("errcode"),
errmsg=result.get("errmsg"))
raise ActionFailed(
errcode=result.get("errcode"), errmsg=result.get("errmsg")
)
return result
raise NetworkError(f"HTTP request received unexpected "
f"status code: {response.status_code}")
raise NetworkError(
f"HTTP request received unexpected "
f"status code: {response.status_code}"
)
except httpx.InvalidURL:
raise NetworkError("API root url invalid")
except httpx.HTTPError:
raise NetworkError("HTTP request failed")
@overrides(BaseBot)
async def call_api(self,
api: str,
event: Optional[MessageEvent] = None,
**data) -> Any:
async def call_api(
self, api: str, event: Optional[MessageEvent] = None, **data
) -> Any:
"""
:说明:
@ -199,13 +216,15 @@ class Bot(BaseBot):
return await super().call_api(api, event=event, **data)
@overrides(BaseBot)
async def send(self,
event: MessageEvent,
message: Union[str, "Message", "MessageSegment"],
at_sender: bool = False,
webhook: Optional[str] = None,
secret: Optional[str] = None,
**kwargs) -> Any:
async def send(
self,
event: MessageEvent,
message: Union[str, "Message", "MessageSegment"],
at_sender: bool = False,
webhook: Optional[str] = None,
secret: Optional[str] = None,
**kwargs,
) -> Any:
"""
:说明:
@ -241,9 +260,11 @@ class Bot(BaseBot):
params.update(kwargs)
if at_sender and event.conversationType != ConversationType.private:
params[
"message"] = f"@{event.senderId} " + msg + MessageSegment.atDingtalkIds(
event.senderId)
params["message"] = (
f"@{event.senderId} "
+ msg
+ MessageSegment.atDingtalkIds(event.senderId)
)
else:
params["message"] = msg

View File

@ -12,6 +12,7 @@ class Config(BaseModel):
- ``access_token`` / ``ding_access_token``: 钉钉令牌
- ``secret`` / ``ding_secret``: 钉钉 HTTP 上报数据签名口令
"""
secret: Optional[str] = Field(default=None, alias="ding_secret")
access_token: Optional[str] = Field(default=None, alias="ding_access_token")

View File

@ -69,6 +69,7 @@ class ConversationType(str, Enum):
class MessageEvent(Event):
"""消息事件"""
msgtype: str
text: TextMessage
msgId: str
@ -88,11 +89,10 @@ class MessageEvent(Event):
def gen_message(cls, values: dict):
assert "msgtype" in values, "msgtype must be specified"
# 其实目前钉钉机器人只能接收到 text 类型的消息
assert values[
"msgtype"] in values, f"{values['msgtype']} must be specified"
content = values[values['msgtype']]['content']
assert values["msgtype"] in values, f"{values['msgtype']} must be specified"
content = values[values["msgtype"]]["content"]
# 如果是被 @,第一个字符将会为空格,移除特殊情况
if content[0] == ' ':
if content[0] == " ":
content = content[1:]
values["message"] = content
return values
@ -128,6 +128,7 @@ class MessageEvent(Event):
class PrivateMessageEvent(MessageEvent):
"""私聊消息事件"""
chatbotCorpId: str
senderStaffId: Optional[str]
conversationType: ConversationType = ConversationType.private
@ -135,6 +136,7 @@ class PrivateMessageEvent(MessageEvent):
class GroupMessageEvent(MessageEvent):
"""群消息事件"""
atUsers: List[AtUsersItem]
conversationType: ConversationType = ConversationType.group
conversationTitle: str

View File

@ -1,9 +1,9 @@
from typing import Optional
from nonebot.exception import (AdapterException, ActionFailed as
BaseActionFailed, ApiNotAvailable as
BaseApiNotAvailable, NetworkError as
BaseNetworkError)
from nonebot.exception import AdapterException
from nonebot.exception import ActionFailed as BaseActionFailed
from nonebot.exception import NetworkError as BaseNetworkError
from nonebot.exception import ApiNotAvailable as BaseApiNotAvailable
class DingAdapterException(AdapterException):
@ -29,15 +29,13 @@ class ActionFailed(BaseActionFailed, DingAdapterException):
* ``errmsg: Optional[str]``: 错误信息
"""
def __init__(self,
errcode: Optional[int] = None,
errmsg: Optional[str] = None):
def __init__(self, errcode: Optional[int] = None, errmsg: Optional[str] = None):
super().__init__()
self.errcode = errcode
self.errmsg = errmsg
def __repr__(self):
return f"<ApiError errcode={self.errcode} errmsg=\"{self.errmsg}\">"
return f'<ApiError errcode={self.errcode} errmsg="{self.errmsg}">'
def __str__(self):
return self.__repr__()

View File

@ -77,10 +77,9 @@ class MessageSegment(BaseMessageSegment["Message"]):
def code(code_language: str, code: str) -> "Message":
"""发送 code 消息段"""
message = MessageSegment.text(code)
message += MessageSegment.extension({
"text_type": "code_snippet",
"code_language": code_language
})
message += MessageSegment.extension(
{"text_type": "code_snippet", "code_language": code_language}
)
return message
@staticmethod
@ -95,16 +94,19 @@ class MessageSegment(BaseMessageSegment["Message"]):
)
@staticmethod
def actionCardSingleBtn(title: str, text: str, singleTitle: str,
singleURL) -> "MessageSegment":
def actionCardSingleBtn(
title: str, text: str, singleTitle: str, singleURL
) -> "MessageSegment":
"""发送 ``actionCardSingleBtn`` 类型消息"""
return MessageSegment(
"actionCard", {
"actionCard",
{
"title": title,
"text": text,
"singleTitle": singleTitle,
"singleURL": singleURL
})
"singleURL": singleURL,
},
)
@staticmethod
def actionCardMultiBtns(
@ -112,7 +114,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
text: str,
btns: list,
hideAvatar: bool = False,
btnOrientation: str = '1',
btnOrientation: str = "1",
) -> "MessageSegment":
"""
发送 ``actionCardMultiBtn`` 类型消息
@ -123,13 +125,15 @@ class MessageSegment(BaseMessageSegment["Message"]):
* ``btns``: ``[{ "title": title, "actionURL": actionURL }, ...]``
"""
return MessageSegment(
"actionCard", {
"actionCard",
{
"title": title,
"text": text,
"hideAvatar": "1" if hideAvatar else "0",
"btnOrientation": btnOrientation,
"btns": btns
})
"btns": btns,
},
)
@staticmethod
def feedCard(links: list) -> "MessageSegment":
@ -144,7 +148,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
@staticmethod
def raw(data) -> "MessageSegment":
return MessageSegment('raw', data)
return MessageSegment("raw", data)
def to_dict(self) -> Dict[str, Any]:
# 让用户可以直接发送原始的消息格式
@ -171,8 +175,8 @@ class Message(BaseMessage[MessageSegment]):
@staticmethod
@overrides(BaseMessage)
def _construct(
msg: Union[str, Mapping,
Iterable[Mapping]]) -> Iterable[MessageSegment]:
msg: Union[str, Mapping, Iterable[Mapping]]
) -> Iterable[MessageSegment]:
if isinstance(msg, Mapping):
msg = cast(Mapping[str, Any], msg)
yield MessageSegment(msg["type"], msg.get("data") or {})
@ -187,10 +191,11 @@ class Message(BaseMessage[MessageSegment]):
segment: MessageSegment
for segment in self:
# text 可以和 text 合并
if segment.type == "text" and data.get("msgtype") == 'text':
if segment.type == "text" and data.get("msgtype") == "text":
data.setdefault("text", {})
data["text"]["content"] = data["text"].setdefault(
"content", "") + segment.data["content"]
data["text"]["content"] = (
data["text"].setdefault("content", "") + segment.data["content"]
)
else:
data.update(segment.to_dict())
return data

View File

@ -8,10 +8,10 @@ log = logger_wrapper("DING")
def calc_hmac_base64(timestamp: str, secret: str):
secret_enc = secret.encode('utf-8')
string_to_sign = '{}\n{}'.format(timestamp, secret)
string_to_sign_enc = string_to_sign.encode('utf-8')
hmac_code = hmac.new(secret_enc,
string_to_sign_enc,
digestmod=hashlib.sha256).digest()
secret_enc = secret.encode("utf-8")
string_to_sign = "{}\n{}".format(timestamp, secret)
string_to_sign_enc = string_to_sign.encode("utf-8")
hmac_code = hmac.new(
secret_enc, string_to_sign_enc, digestmod=hashlib.sha256
).digest()
return base64.b64encode(hmac_code)

View File

@ -34,6 +34,21 @@ nonebot2 = { path = "../../", develop = true }
# url = "https://mirrors.aliyun.com/pypi/simple/"
# default = true
[tool.black]
line-length = 88
target-version = ["py37", "py38", "py39"]
include = '\.pyi?$'
extend-exclude = '''
'''
[tool.isort]
profile = "black"
line_length = 80
length_sort = true
skip_gitignore = true
force_sort_within_sections = true
extra_standard_library = ["typing_extensions"]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

View File

@ -7,7 +7,7 @@ aiocache_logger.setLevel(logging.DEBUG)
aiocache_logger.handlers.clear()
aiocache_logger.addHandler(LoguruHandler())
from .bot import Bot as Bot
from .event import *
from .bot import Bot as Bot
from .message import Message as Message
from .message import MessageSegment as MessageSegment

View File

@ -1,24 +1,39 @@
import re
import json
from typing import (TYPE_CHECKING, Any, Dict, Tuple, Union, Iterable, Optional,
AsyncIterable, cast)
from typing import (
TYPE_CHECKING,
Any,
Dict,
Tuple,
Union,
Iterable,
Optional,
AsyncIterable,
cast,
)
import httpx
from aiocache import Cache, cached
from aiocache.serializers import PickleSerializer
from nonebot.log import logger
from .utils import AESCipher, log
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.message import handle_event
from .config import Config as FeishuConfig
from nonebot.adapters import Bot as BaseBot
from nonebot.drivers import Driver, HTTPRequest, HTTPResponse
from .utils import AESCipher, log
from .config import Config as FeishuConfig
from .message import Message, MessageSegment, MessageSerializer
from .exception import ActionFailed, NetworkError, ApiNotAvailable
from .event import (Event, MessageEvent, GroupMessageEvent, PrivateMessageEvent,
get_event_model)
from .event import (
Event,
MessageEvent,
GroupMessageEvent,
PrivateMessageEvent,
get_event_model,
)
if TYPE_CHECKING:
from nonebot.config import Config
@ -47,8 +62,10 @@ def _check_at_me(bot: "Bot", event: "Event"):
event.to_me = True
for index, segment in enumerate(message):
if segment.type == "at" and segment.data.get(
"user_name") in bot.config.nickname:
if (
segment.type == "at"
and segment.data.get("user_name") in bot.config.nickname
):
event.to_me = True
del event.event.message.content[index]
return
@ -57,7 +74,8 @@ def _check_at_me(bot: "Bot", event: "Event"):
if mention["name"] in bot.config.nickname:
event.to_me = True
segment.data["text"] = segment.data["text"].replace(
f"@{mention['name']}", "")
f"@{mention['name']}", ""
)
segment.data["text"] = segment.data["text"].lstrip()
break
else:
@ -92,18 +110,18 @@ def _check_nickname(bot: "Bot", event: "Event"):
if nicknames:
# check if the user is calling me with my nickname
nickname_regex = "|".join(nicknames)
m = re.search(rf"^({nickname_regex})([\s,]*|$)", first_text,
re.IGNORECASE)
m = re.search(rf"^({nickname_regex})([\s,]*|$)", first_text, re.IGNORECASE)
if m:
nickname = m.group(1)
log("DEBUG", f"User is calling me {nickname}")
event.to_me = True
first_msg_seg.data["text"] = first_text[m.end():]
first_msg_seg.data["text"] = first_text[m.end() :]
def _handle_api_result(
result: Union[Optional[Dict[str, Any]], str, bytes, Iterable[bytes],
AsyncIterable[bytes]]
result: Union[
Optional[Dict[str, Any]], str, bytes, Iterable[bytes], AsyncIterable[bytes]
]
) -> Any:
"""
:说明:
@ -155,13 +173,13 @@ class Bot(BaseBot):
@classmethod
@overrides(BaseBot)
async def check_permission(
cls, driver: Driver, request: HTTPRequest
cls, driver: Driver, request: HTTPRequest
) -> Tuple[Optional[str], Optional[HTTPResponse]]:
if not isinstance(request, HTTPRequest):
log("WARNING",
"Unsupported connection type, available type: `http`")
log("WARNING", "Unsupported connection type, available type: `http`")
return None, HTTPResponse(
405, b"Unsupported connection type, available type: `http`")
405, b"Unsupported connection type, available type: `http`"
)
encrypt_key = cls.feishu_config.encrypt_key
if encrypt_key:
@ -174,16 +192,13 @@ class Bot(BaseBot):
challenge = data.get("challenge")
if challenge:
return data.get("token"), HTTPResponse(
200,
json.dumps({
"challenge": challenge
}).encode())
200, json.dumps({"challenge": challenge}).encode()
)
schema = data.get("schema")
if not schema:
return None, HTTPResponse(
400,
b"Missing `schema` in POST body, only accept event of version 2.0"
400, b"Missing `schema` in POST body, only accept event of version 2.0"
)
headers = data.get("header")
@ -196,15 +211,13 @@ class Bot(BaseBot):
if not token:
log("WARNING", "Missing `verification token` in POST body")
return None, HTTPResponse(
400, b"Missing `verification token` in POST body")
return None, HTTPResponse(400, b"Missing `verification token` in POST body")
else:
if token != cls.feishu_config.verification_token:
log("WARNING", "Verification token check failed")
return None, HTTPResponse(403,
b"Verification token check failed")
return None, HTTPResponse(403, b"Verification token check failed")
return app_id, HTTPResponse(200, b'')
return app_id, HTTPResponse(200, b"")
async def handle_message(self, message: bytes):
"""
@ -245,28 +258,32 @@ class Bot(BaseBot):
def _construct_url(self, path: str) -> str:
return self.api_root + path
@cached(ttl=60 * 60,
cache=Cache.MEMORY,
key="_feishu_tenant_access_token",
serializer=PickleSerializer())
@cached(
ttl=60 * 60,
cache=Cache.MEMORY,
key="_feishu_tenant_access_token",
serializer=PickleSerializer(),
)
async def _fetch_tenant_access_token(self) -> str:
try:
async with httpx.AsyncClient(follow_redirects=True) as client:
response = await client.post(
self._construct_url(
"auth/v3/tenant_access_token/internal/"),
self._construct_url("auth/v3/tenant_access_token/internal/"),
json={
"app_id": self.feishu_config.app_id,
"app_secret": self.feishu_config.app_secret
"app_secret": self.feishu_config.app_secret,
},
timeout=self.config.api_timeout)
timeout=self.config.api_timeout,
)
if 200 <= response.status_code < 300:
result = response.json()
return result["tenant_access_token"]
else:
raise NetworkError(f"HTTP request received unexpected "
f"status code: {response.status_code}")
raise NetworkError(
f"HTTP request received unexpected "
f"status code: {response.status_code}"
)
except httpx.InvalidURL:
raise NetworkError("API root url invalid")
except httpx.HTTPError:
@ -280,30 +297,37 @@ class Bot(BaseBot):
raise ApiNotAvailable
headers = {}
self.feishu_config.tenant_access_token = await self._fetch_tenant_access_token(
self.feishu_config.tenant_access_token = (
await self._fetch_tenant_access_token()
)
headers["Authorization"] = (
"Bearer " + self.feishu_config.tenant_access_token
)
headers[
"Authorization"] = "Bearer " + self.feishu_config.tenant_access_token
try:
async with httpx.AsyncClient(timeout=self.config.api_timeout,
follow_redirects=True) as client:
async with httpx.AsyncClient(
timeout=self.config.api_timeout, follow_redirects=True
) as client:
response = await client.send(
httpx.Request(data["method"],
self.api_root + api,
json=data.get("body", {}),
params=data.get("query", {}),
headers=headers))
httpx.Request(
data["method"],
self.api_root + api,
json=data.get("body", {}),
params=data.get("query", {}),
headers=headers,
)
)
if 200 <= response.status_code < 300:
if response.headers["content-type"].startswith(
"application/json"):
if response.headers["content-type"].startswith("application/json"):
result = response.json()
else:
result = response.content
return _handle_api_result(result)
raise NetworkError(f"HTTP request received unexpected "
f"status code: {response.status_code} "
f"response body: {response.text}")
raise NetworkError(
f"HTTP request received unexpected "
f"status code: {response.status_code} "
f"response body: {response.text}"
)
except httpx.InvalidURL:
raise NetworkError("API root url invalid")
except httpx.HTTPError:
@ -333,11 +357,13 @@ class Bot(BaseBot):
return await super().call_api(api, **data)
@overrides(BaseBot)
async def send(self,
event: Event,
message: Union[str, Message, MessageSegment],
at_sender: bool = False,
**kwargs) -> Any:
async def send(
self,
event: Event,
message: Union[str, Message, MessageSegment],
at_sender: bool = False,
**kwargs,
) -> Any:
msg = message if isinstance(message, Message) else Message(message)
if isinstance(event, GroupMessageEvent):
@ -346,7 +372,8 @@ class Bot(BaseBot):
receive_id, receive_id_type = event.get_user_id(), "open_id"
else:
raise ValueError(
"Cannot guess `receive_id` and `receive_id_type` to reply!")
"Cannot guess `receive_id` and `receive_id_type` to reply!"
)
at_sender = at_sender and bool(event.get_user_id())
@ -357,14 +384,12 @@ class Bot(BaseBot):
params = {
"method": "POST",
"query": {
"receive_id_type": receive_id_type
},
"query": {"receive_id_type": receive_id_type},
"body": {
"receive_id": receive_id,
"content": content,
"msg_type": msg_type
}
"msg_type": msg_type,
},
}
return await self.call_api(f"im/v1/messages", **params)

View File

@ -17,13 +17,16 @@ class Config(BaseModel):
- ``is_lark`` / ``feishu_is_lark``: 是否使用Lark飞书海外版默认为 false
"""
app_id: Optional[str] = Field(default=None, alias="feishu_app_id")
app_secret: Optional[str] = Field(default=None, alias="feishu_app_secret")
encrypt_key: Optional[str] = Field(default=None, alias="feishu_encrypt_key")
verification_token: Optional[str] = Field(default=None,
alias="feishu_verification_token")
verification_token: Optional[str] = Field(
default=None, alias="feishu_verification_token"
)
tenant_access_token: Optional[str] = Field(
default=None, alias="feishu_tenant_access_token")
default=None, alias="feishu_tenant_access_token"
)
is_lark: Optional[str] = Field(default=False, alias="feishu_is_lark")
class Config:

View File

@ -1,12 +1,12 @@
import inspect
import json
from typing import Any, Dict, List, Literal, Optional, Type
import inspect
from typing import Any, Dict, List, Type, Literal, Optional
from pydantic import BaseModel, Field, root_validator
from pygtrie import StringTrie
from pydantic import Field, BaseModel, root_validator
from nonebot.adapters import Event as BaseEvent
from nonebot.typing import overrides
from nonebot.adapters import Event as BaseEvent
from .message import Message, MessageDeserializer

View File

@ -1,13 +1,12 @@
from typing import Optional
from nonebot.exception import ActionFailed as BaseActionFailed
from nonebot.exception import AdapterException
from nonebot.exception import ApiNotAvailable as BaseApiNotAvailable
from nonebot.exception import ActionFailed as BaseActionFailed
from nonebot.exception import NetworkError as BaseNetworkError
from nonebot.exception import ApiNotAvailable as BaseApiNotAvailable
class FeishuAdapterException(AdapterException):
def __init__(self):
super().__init__("feishu")
@ -28,8 +27,11 @@ class ActionFailed(BaseActionFailed, FeishuAdapterException):
self.info = kwargs
def __repr__(self):
return f"<ActionFailed " + ", ".join(
f"{k}={v}" for k, v in self.info.items()) + ">"
return (
f"<ActionFailed "
+ ", ".join(f"{k}={v}" for k, v in self.info.items())
+ ">"
)
def __str__(self):
return self.__repr__()

View File

@ -1,8 +1,18 @@
import json
import itertools
from dataclasses import dataclass
from typing import (Any, Dict, List, Type, Tuple, Union, Mapping, Iterable,
Optional, cast)
from typing import (
Any,
Dict,
List,
Type,
Tuple,
Union,
Mapping,
Iterable,
Optional,
cast,
)
from nonebot.typing import overrides
from nonebot.adapters import Message as BaseMessage
@ -34,7 +44,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
"share_user": "[个人名片]",
"system": "[系统消息]",
"location": "[位置]",
"video_chat": "[视频通话]"
"video_chat": "[视频通话]",
}
def __str__(self) -> str:
@ -47,24 +57,26 @@ class MessageSegment(BaseMessageSegment["Message"]):
@overrides(BaseMessageSegment)
def __add__(self, other) -> "Message":
return Message(self) + (MessageSegment.text(other) if isinstance(
other, str) else other)
return Message(self) + (
MessageSegment.text(other) if isinstance(other, str) else other
)
@overrides(BaseMessageSegment)
def __radd__(self, other) -> "Message":
return (MessageSegment.text(other)
if isinstance(other, str) else Message(other)) + self
return (
MessageSegment.text(other) if isinstance(other, str) else Message(other)
) + self
@overrides(BaseMessageSegment)
def is_text(self) -> bool:
return self.type == "text"
#接收消息
# 接收消息
@staticmethod
def at(user_id: str) -> "MessageSegment":
return MessageSegment("at", {"user_id": user_id})
#发送消息
# 发送消息
@staticmethod
def text(text: str) -> "MessageSegment":
return MessageSegment("text", {"text": text})
@ -79,10 +91,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
@staticmethod
def interactive(title: str, elements: list) -> "MessageSegment":
return MessageSegment("interactive", {
"title": title,
"elements": elements
})
return MessageSegment("interactive", {"title": title, "elements": elements})
@staticmethod
def share_chat(chat_id: str) -> "MessageSegment":
@ -94,28 +103,25 @@ class MessageSegment(BaseMessageSegment["Message"]):
@staticmethod
def audio(file_key: str, duration: int) -> "MessageSegment":
return MessageSegment("audio", {
"file_key": file_key,
"duration": duration
})
return MessageSegment("audio", {"file_key": file_key, "duration": duration})
@staticmethod
def media(file_key: str, image_key: str, file_name: str,
duration: int) -> "MessageSegment":
def media(
file_key: str, image_key: str, file_name: str, duration: int
) -> "MessageSegment":
return MessageSegment(
"media", {
"media",
{
"file_key": file_key,
"image_key": image_key,
"file_name": file_name,
"duration": duration
})
"duration": duration,
},
)
@staticmethod
def file(file_key: str, file_name: str) -> "MessageSegment":
return MessageSegment("file", {
"file_key": file_key,
"file_name": file_name
})
return MessageSegment("file", {"file_key": file_key, "file_name": file_name})
@staticmethod
def sticker(file_key) -> "MessageSegment":
@ -133,22 +139,22 @@ class Message(BaseMessage[MessageSegment]):
return MessageSegment
@overrides(BaseMessage)
def __add__(self, other: Union[str, Mapping,
Iterable[Mapping]]) -> "Message":
def __add__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> "Message":
return super(Message, self).__add__(
MessageSegment.text(other) if isinstance(other, str) else other)
MessageSegment.text(other) if isinstance(other, str) else other
)
@overrides(BaseMessage)
def __radd__(self, other: Union[str, Mapping,
Iterable[Mapping]]) -> "Message":
def __radd__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> "Message":
return super(Message, self).__radd__(
MessageSegment.text(other) if isinstance(other, str) else other)
MessageSegment.text(other) if isinstance(other, str) else other
)
@staticmethod
@overrides(BaseMessage)
def _construct(
msg: Union[str, Mapping,
Iterable[Mapping]]) -> Iterable[MessageSegment]:
msg: Union[str, Mapping, Iterable[Mapping]]
) -> Iterable[MessageSegment]:
if isinstance(msg, Mapping):
msg = cast(Mapping[str, Any], msg)
yield MessageSegment(msg["type"], msg.get("data") or {})
@ -169,7 +175,8 @@ class Message(BaseMessage[MessageSegment]):
for i, seg in enumerate(self):
if seg.type == "text" and i != 0 and msg[-1].type == "text":
msg[-1] = MessageSegment(
"text", {"text": msg[-1].data["text"] + seg.data["text"]})
"text", {"text": msg[-1].data["text"] + seg.data["text"]}
)
else:
msg.append(seg)
return Message(msg)
@ -184,6 +191,7 @@ class MessageSerializer:
"""
飞书 协议 Message 序列化器。
"""
message: Message
def serialize(self) -> Tuple[str, str]:
@ -198,10 +206,12 @@ class MessageSerializer:
else:
if last_segment_type == "image":
msg["content"].append([])
msg["content"][-1].append({
"tag": segment.type if segment.type != "image" else "img",
**segment.data
})
msg["content"][-1].append(
{
"tag": segment.type if segment.type != "image" else "img",
**segment.data,
}
)
last_segment_type = segment.type
return "post", json.dumps({"zh_cn": {**msg}})
@ -214,6 +224,7 @@ class MessageDeserializer:
"""
飞书 协议 Message 反序列化器。
"""
type: str
data: Dict[str, Any]
mentions: Optional[List[dict]]
@ -227,14 +238,13 @@ class MessageDeserializer:
if self.type == "post":
msg = Message()
if self.data["title"] != "":
msg += MessageSegment("text", {'text': self.data["title"]})
msg += MessageSegment("text", {"text": self.data["title"]})
for seg in itertools.chain(*self.data["content"]):
tag = seg.pop("tag")
if tag == "at":
seg["user_name"] = dict_mention[seg["user_id"]]["name"]
seg["user_id"] = dict_mention[
seg["user_id"]]["id"]["open_id"]
seg["user_id"] = dict_mention[seg["user_id"]]["id"]["open_id"]
msg += MessageSegment(tag if tag != "img" else "image", seg)
@ -242,7 +252,8 @@ class MessageDeserializer:
elif self.type == "text":
for key, mention in dict_mention.items():
self.data["text"] = self.data["text"].replace(
key, f"@{mention['name']}")
key, f"@{mention['name']}"
)
self.data["mentions"] = dict_mention
return Message(MessageSegment(self.type, self.data))

View File

@ -9,27 +9,26 @@ log = logger_wrapper("FEISHU")
class AESCipher(object):
def __init__(self, key):
self.block_size = AES.block_size
self.key = hashlib.sha256(AESCipher.str_to_bytes(key)).digest()
@staticmethod
def str_to_bytes(data):
u_type = type(b"".decode('utf8'))
u_type = type(b"".decode("utf8"))
if isinstance(data, u_type):
return data.encode('utf8')
return data.encode("utf8")
return data
@staticmethod
def _unpad(s):
return s[:-ord(s[len(s) - 1:])]
return s[: -ord(s[len(s) - 1 :])]
def decrypt(self, enc):
iv = enc[:AES.block_size]
iv = enc[: AES.block_size]
cipher = AES.new(self.key, AES.MODE_CBC, iv)
return self._unpad(cipher.decrypt(enc[AES.block_size:]))
return self._unpad(cipher.decrypt(enc[AES.block_size :]))
def decrypt_string(self, enc):
enc = base64.b64decode(enc)
return self.decrypt(enc).decode('utf8')
return self.decrypt(enc).decode("utf8")

View File

@ -36,6 +36,21 @@ nonebot2 = { path = "../../", develop = true }
# url = "https://mirrors.aliyun.com/pypi/simple/"
# default = true
[tool.black]
line-length = 88
target-version = ["py37", "py38", "py39"]
include = '\.pyi?$'
extend-exclude = '''
'''
[tool.isort]
profile = "black"
line_length = 80
length_sort = true
skip_gitignore = true
force_sort_within_sections = true
extra_standard_library = ["typing_extensions"]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

View File

@ -12,8 +12,15 @@ from nonebot.config import Config
from nonebot.typing import overrides
from nonebot.adapters import Bot as BaseBot
from nonebot.exception import ApiNotAvailable
from nonebot.drivers import (Driver, WebSocket, HTTPResponse, ForwardDriver,
ReverseDriver, HTTPConnection, WebSocketSetup)
from nonebot.drivers import (
Driver,
WebSocket,
HTTPResponse,
ForwardDriver,
ReverseDriver,
HTTPConnection,
WebSocketSetup,
)
from .config import Config as MiraiConfig
from .message import MessageChain, MessageSegment
@ -23,16 +30,14 @@ from .utils import Log, process_event, argument_validation, catch_network_error
class SessionManager:
"""Bot会话管理器, 提供API主动调用接口"""
sessions: Dict[int, Tuple[str, httpx.AsyncClient]] = {}
def __init__(self, session_key: str, client: httpx.AsyncClient):
self.session_key, self.client = session_key, client
@catch_network_error
async def post(self,
path: str,
*,
params: Optional[Dict[str, Any]] = None) -> Any:
async def post(self, path: str, *, params: Optional[Dict[str, Any]] = None) -> Any:
"""
:说明:
@ -51,7 +56,7 @@ class SessionManager:
path,
json={
**(params or {}),
'sessionKey': self.session_key,
"sessionKey": self.session_key,
},
timeout=3,
)
@ -59,10 +64,9 @@ class SessionManager:
return response.json()
@catch_network_error
async def request(self,
path: str,
*,
params: Optional[Dict[str, Any]] = None) -> Any:
async def request(
self, path: str, *, params: Optional[Dict[str, Any]] = None
) -> Any:
"""
:说明:
@ -77,7 +81,7 @@ class SessionManager:
path,
params={
**(params or {}),
'sessionKey': self.session_key,
"sessionKey": self.session_key,
},
timeout=3,
)
@ -98,7 +102,7 @@ class SessionManager:
"""
files = {k: v for k, v in params.items() if isinstance(v, BytesIO)}
form = {k: v for k, v in params.items() if k not in files}
form['sessionKey'] = self.session_key
form["sessionKey"] = self.session_key
response = await self.client.post(
path,
data=form,
@ -109,25 +113,25 @@ class SessionManager:
return response.json()
@classmethod
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
auth_key: str) -> "SessionManager":
async def new(
cls, self_id: int, *, host: IPv4Address, port: int, auth_key: str
) -> "SessionManager":
session = cls.get(self_id)
if session is not None:
return session
client = httpx.AsyncClient(base_url=f'http://{host}:{port}',
follow_redirects=True)
response = await client.post('/auth', json={'authKey': auth_key})
client = httpx.AsyncClient(
base_url=f"http://{host}:{port}", follow_redirects=True
)
response = await client.post("/auth", json={"authKey": auth_key})
response.raise_for_status()
auth = response.json()
assert auth['code'] == 0
session_key = auth['session']
response = await client.post('/verify',
json={
'sessionKey': session_key,
'qq': self_id
})
assert response.json()['code'] == 0
assert auth["code"] == 0
session_key = auth["session"]
response = await client.post(
"/verify", json={"sessionKey": session_key, "qq": self_id}
)
assert response.json()["code"] == 0
cls.sessions[self_id] = session_key, client
return cls(session_key, client)
@ -152,7 +156,7 @@ class Bot(BaseBot):
"""
_type = 'mirai'
_type = "mirai"
@property
@overrides(BaseBot)
@ -166,37 +170,42 @@ class Bot(BaseBot):
if api is None:
if isinstance(self.request, WebSocket):
asyncio.create_task(self.request.close(1000))
assert api is not None, 'SessionManager has not been initialized'
assert api is not None, "SessionManager has not been initialized"
return api
@classmethod
@overrides(BaseBot)
async def check_permission(
cls, driver: Driver,
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]:
cls, driver: Driver, request: HTTPConnection
) -> Tuple[Optional[str], HTTPResponse]:
if isinstance(request, WebSocket):
return None, HTTPResponse(
501, b'Websocket connection is not implemented')
self_id: Optional[str] = request.headers.get('bot')
return None, HTTPResponse(501, b"Websocket connection is not implemented")
self_id: Optional[str] = request.headers.get("bot")
if self_id is None:
return None, HTTPResponse(400, b'Header `Bot` is required.')
return None, HTTPResponse(400, b"Header `Bot` is required.")
self_id = str(self_id).strip()
await SessionManager.new(
int(self_id),
host=cls.mirai_config.host, # type: ignore
port=cls.mirai_config.port, #type: ignore
auth_key=cls.mirai_config.auth_key) # type: ignore
return self_id, HTTPResponse(204, b'')
port=cls.mirai_config.port, # type: ignore
auth_key=cls.mirai_config.auth_key, # type: ignore
)
return self_id, HTTPResponse(204, b"")
@classmethod
@overrides(BaseBot)
def register(cls,
driver: Driver,
config: "Config",
qq: Optional[Union[int, List[int]]] = None):
def register(
cls,
driver: Driver,
config: "Config",
qq: Optional[Union[int, List[int]]] = None,
):
cls.mirai_config = MiraiConfig(**config.dict())
if (cls.mirai_config.auth_key and cls.mirai_config.host and
cls.mirai_config.port) is None:
if (
cls.mirai_config.auth_key
and cls.mirai_config.host
and cls.mirai_config.port
) is None:
raise ApiNotAvailable(cls._type)
super().register(driver, config)
@ -209,17 +218,25 @@ class Bot(BaseBot):
self_ids = [qq] if isinstance(qq, int) else qq
async def url_factory(qq: int):
assert cls.mirai_config.host and cls.mirai_config.port and cls.mirai_config.auth_key
assert (
cls.mirai_config.host
and cls.mirai_config.port
and cls.mirai_config.auth_key
)
session = await SessionManager.new(
qq,
host=cls.mirai_config.host,
port=cls.mirai_config.port,
auth_key=cls.mirai_config.auth_key)
auth_key=cls.mirai_config.auth_key,
)
return WebSocketSetup(
adapter=cls._type,
self_id=str(qq),
url=(f'ws://{cls.mirai_config.host}:{cls.mirai_config.port}'
f'/all?sessionKey={session.session_key}'))
url=(
f"ws://{cls.mirai_config.host}:{cls.mirai_config.port}"
f"/all?sessionKey={session.session_key}"
),
)
for self_id in self_ids:
driver.setup_websocket(partial(url_factory, qq=self_id))
@ -234,13 +251,15 @@ class Bot(BaseBot):
try:
await process_event(
bot=self,
event=Event.new({
**json.loads(message),
'self_id': self.self_id,
}),
event=Event.new(
{
**json.loads(message),
"self_id": self.self_id,
}
),
)
except Exception as e:
Log.error(f'Failed to handle message: {message}', e)
Log.error(f"Failed to handle message: {message}", e)
@overrides(BaseBot)
async def _call_api(self, api: str, **data) -> NoReturn:
@ -266,10 +285,12 @@ class Bot(BaseBot):
@overrides(BaseBot)
@argument_validation
async def send(self,
event: Event,
message: Union[MessageChain, MessageSegment, str],
at_sender: bool = False):
async def send(
self,
event: Event,
message: Union[MessageChain, MessageSegment, str],
at_sender: bool = False,
):
"""
:说明:
@ -284,23 +305,24 @@ class Bot(BaseBot):
if not isinstance(message, MessageChain):
message = MessageChain(message)
if isinstance(event, FriendMessage):
return await self.send_friend_message(target=event.sender.id,
message_chain=message)
return await self.send_friend_message(
target=event.sender.id, message_chain=message
)
elif isinstance(event, GroupMessage):
if at_sender:
message = MessageSegment.at(event.sender.id) + message
return await self.send_group_message(group=event.sender.group.id,
message_chain=message)
return await self.send_group_message(
group=event.sender.group.id, message_chain=message
)
elif isinstance(event, TempMessage):
return await self.send_temp_message(qq=event.sender.id,
group=event.sender.group.id,
message_chain=message)
return await self.send_temp_message(
qq=event.sender.id, group=event.sender.group.id, message_chain=message
)
else:
raise ValueError(f'Unsupported event type {event!r}.')
raise ValueError(f"Unsupported event type {event!r}.")
@argument_validation
async def send_friend_message(self, target: int,
message_chain: MessageChain):
async def send_friend_message(self, target: int, message_chain: MessageChain):
"""
:说明:
@ -311,15 +333,13 @@ class Bot(BaseBot):
* ``target: int``: 发送消息目标好友的 QQ 号
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
"""
return await self.api.post('sendFriendMessage',
params={
'target': target,
'messageChain': message_chain.export()
})
return await self.api.post(
"sendFriendMessage",
params={"target": target, "messageChain": message_chain.export()},
)
@argument_validation
async def send_temp_message(self, qq: int, group: int,
message_chain: MessageChain):
async def send_temp_message(self, qq: int, group: int, message_chain: MessageChain):
"""
:说明:
@ -331,18 +351,15 @@ class Bot(BaseBot):
* ``group: int``: 临时会话群号
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
"""
return await self.api.post('sendTempMessage',
params={
'qq': qq,
'group': group,
'messageChain': message_chain.export()
})
return await self.api.post(
"sendTempMessage",
params={"qq": qq, "group": group, "messageChain": message_chain.export()},
)
@argument_validation
async def send_group_message(self,
group: int,
message_chain: MessageChain,
quote: Optional[int] = None):
async def send_group_message(
self, group: int, message_chain: MessageChain, quote: Optional[int] = None
):
"""
:说明:
@ -354,12 +371,14 @@ class Bot(BaseBot):
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
* ``quote: Optional[int]``: 引用一条消息的 message_id 进行回复
"""
return await self.api.post('sendGroupMessage',
params={
'group': group,
'messageChain': message_chain.export(),
'quote': quote
})
return await self.api.post(
"sendGroupMessage",
params={
"group": group,
"messageChain": message_chain.export(),
"quote": quote,
},
)
@argument_validation
async def recall(self, target: int):
@ -372,11 +391,12 @@ class Bot(BaseBot):
* ``target: int``: 需要撤回的消息的message_id
"""
return await self.api.post('recall', params={'target': target})
return await self.api.post("recall", params={"target": target})
@argument_validation
async def send_image_message(self, target: int, qq: int, group: int,
urls: List[str]) -> List[str]:
async def send_image_message(
self, target: int, qq: int, group: int, urls: List[str]
) -> List[str]:
"""
:说明:
@ -396,13 +416,10 @@ class Bot(BaseBot):
- ``List[str]``: 一个包含图片imageId的数组
"""
return await self.api.post('sendImageMessage',
params={
'target': target,
'qq': qq,
'group': group,
'urls': urls
})
return await self.api.post(
"sendImageMessage",
params={"target": target, "qq": qq, "group": group, "urls": urls},
)
@argument_validation
async def upload_image(self, type: str, img: BytesIO):
@ -416,11 +433,7 @@ class Bot(BaseBot):
* ``type: str``: "friend""group""temp"
* ``img: BytesIO``: 图片的BytesIO对象
"""
return await self.api.upload('uploadImage',
params={
'type': type,
'img': img
})
return await self.api.upload("uploadImage", params={"type": type, "img": img})
@argument_validation
async def upload_voice(self, type: str, voice: BytesIO):
@ -434,11 +447,9 @@ class Bot(BaseBot):
* ``type: str``: 当前仅支持 "group"
* ``voice: BytesIO``: 语音的BytesIO对象
"""
return await self.api.upload('uploadVoice',
params={
'type': type,
'voice': voice
})
return await self.api.upload(
"uploadVoice", params={"type": type, "voice": voice}
)
@argument_validation
async def fetch_message(self, count: int = 10):
@ -452,7 +463,7 @@ class Bot(BaseBot):
* ``count: int``: 获取消息和事件的数量
"""
return await self.api.request('fetchMessage', params={'count': count})
return await self.api.request("fetchMessage", params={"count": count})
@argument_validation
async def fetch_latest_message(self, count: int = 10):
@ -466,8 +477,7 @@ class Bot(BaseBot):
* ``count: int``: 获取消息和事件的数量
"""
return await self.api.request('fetchLatestMessage',
params={'count': count})
return await self.api.request("fetchLatestMessage", params={"count": count})
@argument_validation
async def peek_message(self, count: int = 10):
@ -481,7 +491,7 @@ class Bot(BaseBot):
* ``count: int``: 获取消息和事件的数量
"""
return await self.api.request('peekMessage', params={'count': count})
return await self.api.request("peekMessage", params={"count": count})
@argument_validation
async def peek_latest_message(self, count: int = 10):
@ -495,8 +505,7 @@ class Bot(BaseBot):
* ``count: int``: 获取消息和事件的数量
"""
return await self.api.request('peekLatestMessage',
params={'count': count})
return await self.api.request("peekLatestMessage", params={"count": count})
@argument_validation
async def messsage_from_id(self, id: int):
@ -510,7 +519,7 @@ class Bot(BaseBot):
* ``id: int``: 获取消息的message_id
"""
return await self.api.request('messageFromId', params={'id': id})
return await self.api.request("messageFromId", params={"id": id})
@argument_validation
async def count_message(self):
@ -519,7 +528,7 @@ class Bot(BaseBot):
使用此方法获取bot接收并缓存的消息总数注意不包含被删除的
"""
return await self.api.request('countMessage')
return await self.api.request("countMessage")
@argument_validation
async def friend_list(self) -> List[Dict[str, Any]]:
@ -532,7 +541,7 @@ class Bot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的好友列表数据
"""
return await self.api.request('friendList')
return await self.api.request("friendList")
@argument_validation
async def group_list(self) -> List[Dict[str, Any]]:
@ -545,7 +554,7 @@ class Bot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的群列表数据
"""
return await self.api.request('groupList')
return await self.api.request("groupList")
@argument_validation
async def member_list(self, target: int) -> List[Dict[str, Any]]:
@ -562,7 +571,7 @@ class Bot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的群成员列表数据
"""
return await self.api.request('memberList', params={'target': target})
return await self.api.request("memberList", params={"target": target})
@argument_validation
async def mute(self, target: int, member_id: int, time: int):
@ -577,12 +586,9 @@ class Bot(BaseBot):
* ``member_id: int``: 指定群员QQ号
* ``time: int``: 禁言时长单位为秒最多30天
"""
return await self.api.post('mute',
params={
'target': target,
'memberId': member_id,
'time': time
})
return await self.api.post(
"mute", params={"target": target, "memberId": member_id, "time": time}
)
@argument_validation
async def unmute(self, target: int, member_id: int):
@ -596,11 +602,9 @@ class Bot(BaseBot):
* ``target: int``: 指定群的群号
* ``member_id: int``: 指定群员QQ号
"""
return await self.api.post('unmute',
params={
'target': target,
'memberId': member_id
})
return await self.api.post(
"unmute", params={"target": target, "memberId": member_id}
)
@argument_validation
async def kick(self, target: int, member_id: int, msg: str):
@ -615,12 +619,9 @@ class Bot(BaseBot):
* ``member_id: int``: 指定群员QQ号
* ``msg: str``: 信息
"""
return await self.api.post('kick',
params={
'target': target,
'memberId': member_id,
'msg': msg
})
return await self.api.post(
"kick", params={"target": target, "memberId": member_id, "msg": msg}
)
@argument_validation
async def quit(self, target: int):
@ -633,7 +634,7 @@ class Bot(BaseBot):
* ``target: int``: 退出的群号
"""
return await self.api.post('quit', params={'target': target})
return await self.api.post("quit", params={"target": target})
@argument_validation
async def mute_all(self, target: int):
@ -646,7 +647,7 @@ class Bot(BaseBot):
* ``target: int``: 指定群的群号
"""
return await self.api.post('muteAll', params={'target': target})
return await self.api.post("muteAll", params={"target": target})
@argument_validation
async def unmute_all(self, target: int):
@ -659,7 +660,7 @@ class Bot(BaseBot):
* ``target: int``: 指定群的群号
"""
return await self.api.post('unmuteAll', params={'target': target})
return await self.api.post("unmuteAll", params={"target": target})
@argument_validation
async def group_config(self, target: int):
@ -685,7 +686,7 @@ class Bot(BaseBot):
"anonymousChat": true
}
"""
return await self.api.request('groupConfig', params={'target': target})
return await self.api.request("groupConfig", params={"target": target})
@argument_validation
async def modify_group_config(self, target: int, config: Dict[str, Any]):
@ -699,11 +700,9 @@ class Bot(BaseBot):
* ``target: int``: 指定群的群号
* ``config: Dict[str, Any]``: 群设置, 格式见 ``group_config`` 的返回值
"""
return await self.api.post('groupConfig',
params={
'target': target,
'config': config
})
return await self.api.post(
"groupConfig", params={"target": target, "config": config}
)
@argument_validation
async def member_info(self, target: int, member_id: int):
@ -726,15 +725,14 @@ class Bot(BaseBot):
"specialTitle": "群头衔"
}
"""
return await self.api.request('memberInfo',
params={
'target': target,
'memberId': member_id
})
return await self.api.request(
"memberInfo", params={"target": target, "memberId": member_id}
)
@argument_validation
async def modify_member_info(self, target: int, member_id: int,
info: Dict[str, Any]):
async def modify_member_info(
self, target: int, member_id: int, info: Dict[str, Any]
):
"""
:说明:
@ -746,9 +744,6 @@ class Bot(BaseBot):
* ``member_id: int``: 群员QQ号
* ``info: Dict[str, Any]``: 群员资料, 格式见 ``member_info`` 的返回值
"""
return await self.api.post('memberInfo',
params={
'target': target,
'memberId': member_id,
'info': info
})
return await self.api.post(
"memberInfo", params={"target": target, "memberId": member_id, "info": info}
)

View File

@ -1,7 +1,7 @@
from ipaddress import IPv4Address
from typing import Optional
from ipaddress import IPv4Address
from pydantic import BaseModel, Extra, Field
from pydantic import Extra, Field, BaseModel
class Config(BaseModel):
@ -14,9 +14,10 @@ class Config(BaseModel):
- ``mirai_host``: mirai-api-http 的地址
- ``mirai_port``: mirai-api-http 的端口
"""
auth_key: Optional[str] = Field(None, alias='mirai_auth_key')
host: Optional[IPv4Address] = Field(None, alias='mirai_host')
port: Optional[int] = Field(None, alias='mirai_port')
auth_key: Optional[str] = Field(None, alias="mirai_auth_key")
host: Optional[IPv4Address] = Field(None, alias="mirai_host")
port: Optional[int] = Field(None, alias="mirai_port")
class Config:
extra = Extra.ignore

View File

@ -5,25 +5,56 @@ r"""
部分字段可能与文档在符号上不一致
\:\:\:
"""
from .base import (Event, GroupChatInfo, GroupInfo, PrivateChatInfo,
UserPermission)
from .message import *
from .notice import *
from .message import *
from .request import *
from .base import (
Event,
GroupInfo,
GroupChatInfo,
UserPermission,
PrivateChatInfo,
)
__all__ = [
'Event', 'GroupChatInfo', 'GroupInfo', 'PrivateChatInfo', 'UserPermission',
'MessageSource', 'MessageEvent', 'GroupMessage', 'FriendMessage',
'TempMessage', 'NoticeEvent', 'MuteEvent', 'BotMuteEvent', 'BotUnmuteEvent',
'MemberMuteEvent', 'MemberUnmuteEvent', 'BotJoinGroupEvent',
'BotLeaveEventActive', 'BotLeaveEventKick', 'MemberJoinEvent',
'MemberLeaveEventKick', 'MemberLeaveEventQuit', 'FriendRecallEvent',
'GroupRecallEvent', 'GroupStateChangeEvent', 'GroupNameChangeEvent',
'GroupEntranceAnnouncementChangeEvent', 'GroupMuteAllEvent',
'GroupAllowAnonymousChatEvent', 'GroupAllowConfessTalkEvent',
'GroupAllowMemberInviteEvent', 'MemberStateChangeEvent',
'MemberCardChangeEvent', 'MemberSpecialTitleChangeEvent',
'BotGroupPermissionChangeEvent', 'MemberPermissionChangeEvent',
'RequestEvent', 'NewFriendRequestEvent', 'MemberJoinRequestEvent',
'BotInvitedJoinGroupRequestEvent'
"Event",
"GroupChatInfo",
"GroupInfo",
"PrivateChatInfo",
"UserPermission",
"MessageSource",
"MessageEvent",
"GroupMessage",
"FriendMessage",
"TempMessage",
"NoticeEvent",
"MuteEvent",
"BotMuteEvent",
"BotUnmuteEvent",
"MemberMuteEvent",
"MemberUnmuteEvent",
"BotJoinGroupEvent",
"BotLeaveEventActive",
"BotLeaveEventKick",
"MemberJoinEvent",
"MemberLeaveEventKick",
"MemberLeaveEventQuit",
"FriendRecallEvent",
"GroupRecallEvent",
"GroupStateChangeEvent",
"GroupNameChangeEvent",
"GroupEntranceAnnouncementChangeEvent",
"GroupMuteAllEvent",
"GroupAllowAnonymousChatEvent",
"GroupAllowConfessTalkEvent",
"GroupAllowMemberInviteEvent",
"MemberStateChangeEvent",
"MemberCardChangeEvent",
"MemberSpecialTitleChangeEvent",
"BotGroupPermissionChangeEvent",
"MemberPermissionChangeEvent",
"RequestEvent",
"NewFriendRequestEvent",
"MemberJoinRequestEvent",
"BotInvitedJoinGroupRequestEvent",
]

View File

@ -22,9 +22,10 @@ class UserPermission(str, Enum):
* ``ADMINISTRATOR``: 群管理
* ``MEMBER``: 普通群成员
"""
OWNER = 'OWNER'
ADMINISTRATOR = 'ADMINISTRATOR'
MEMBER = 'MEMBER'
OWNER = "OWNER"
ADMINISTRATOR = "ADMINISTRATOR"
MEMBER = "MEMBER"
class NudgeSubjectKind(str, Enum):
@ -36,8 +37,9 @@ class NudgeSubjectKind(str, Enum):
* ``Group``: 群
* ``Friend``: 好友
"""
Group = 'Group'
Friend = 'Friend'
Group = "Group"
Friend = "Friend"
class GroupInfo(BaseModel):
@ -48,7 +50,7 @@ class GroupInfo(BaseModel):
class GroupChatInfo(BaseModel):
id: int
name: str = Field(alias='memberName')
name: str = Field(alias="memberName")
permission: UserPermission
group: GroupInfo
@ -71,6 +73,7 @@ class Event(BaseEvent):
.. _mirai-api-http 事件类型:
https://github.com/project-mirai/mirai-api-http/blob/master/docs/EventType.md
"""
self_id: int
type: str
@ -79,11 +82,12 @@ class Event(BaseEvent):
"""
此事件类的工厂函数, 能够通过事件数据选择合适的子类进行序列化
"""
type = data['type']
type = data["type"]
def all_subclasses(cls: Type[Event]):
return set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in all_subclasses(c)])
[s for c in cls.__subclasses__() for s in all_subclasses(c)]
)
event_class: Optional[Type[Event]] = None
for subclass in all_subclasses(cls):
@ -99,23 +103,25 @@ class Event(BaseEvent):
return event_class.parse_obj(data)
except ValidationError as e:
logger.info(
f'Failed to parse {data} to class {event_class.__name__}: '
f'{e.errors()!r}. Fallback to parent class.')
f"Failed to parse {data} to class {event_class.__name__}: "
f"{e.errors()!r}. Fallback to parent class."
)
event_class = event_class.__base__ # type: ignore
raise ValueError(f'Failed to serialize {data}.')
raise ValueError(f"Failed to serialize {data}.")
@overrides(BaseEvent)
def get_type(self) -> Literal["message", "notice", "request", "meta_event"]:
from . import meta, notice, message, request
if isinstance(self, message.MessageEvent):
return 'message'
return "message"
elif isinstance(self, notice.NoticeEvent):
return 'notice'
return "notice"
elif isinstance(self, request.RequestEvent):
return 'request'
return "request"
else:
return 'meta_event'
return "meta_event"
@overrides(BaseEvent)
def get_event_name(self) -> str:

View File

@ -1,7 +1,7 @@
from datetime import datetime
from typing import Any, Optional
from pydantic import BaseModel, Field
from pydantic import Field, BaseModel
from nonebot.typing import overrides
@ -16,7 +16,8 @@ class MessageSource(BaseModel):
class MessageEvent(Event):
"""消息事件基类"""
message_chain: MessageChain = Field(alias='messageChain')
message_chain: MessageChain = Field(alias="messageChain")
source: Optional[MessageSource] = None
sender: Any
@ -39,12 +40,13 @@ class MessageEvent(Event):
class GroupMessage(MessageEvent):
"""群消息事件"""
sender: GroupChatInfo
to_me: bool = False
@overrides(MessageEvent)
def get_session_id(self) -> str:
return f'group_{self.sender.group.id}_' + self.get_user_id()
return f"group_{self.sender.group.id}_" + self.get_user_id()
@overrides(MessageEvent)
def get_user_id(self) -> str:
@ -57,6 +59,7 @@ class GroupMessage(MessageEvent):
class FriendMessage(MessageEvent):
"""好友消息事件"""
sender: PrivateChatInfo
@overrides(MessageEvent)
@ -65,7 +68,7 @@ class FriendMessage(MessageEvent):
@overrides(MessageEvent)
def get_session_id(self) -> str:
return 'friend_' + self.get_user_id()
return "friend_" + self.get_user_id()
@overrides(MessageEvent)
def is_tome(self) -> bool:
@ -74,11 +77,12 @@ class FriendMessage(MessageEvent):
class TempMessage(MessageEvent):
"""临时会话消息事件"""
sender: GroupChatInfo
@overrides(MessageEvent)
def get_session_id(self) -> str:
return f'temp_{self.sender.group.id}_' + self.get_user_id()
return f"temp_{self.sender.group.id}_" + self.get_user_id()
@overrides(MessageEvent)
def is_tome(self) -> bool:

View File

@ -3,29 +3,35 @@ from .base import Event
class MetaEvent(Event):
"""元事件基类"""
qq: int
class BotOnlineEvent(MetaEvent):
"""Bot登录成功"""
pass
class BotOfflineEventActive(MetaEvent):
"""Bot主动离线"""
pass
class BotOfflineEventForce(MetaEvent):
"""Bot被挤下线"""
pass
class BotOfflineEventDropped(MetaEvent):
"""Bot被服务器断开或因网络问题而掉线"""
pass
class BotReloginEvent(MetaEvent):
"""Bot主动重新登录"""
pass
pass

View File

@ -2,88 +2,103 @@ from typing import Any, Optional
from pydantic import Field
from .base import Event, GroupChatInfo, GroupInfo, NudgeSubject, UserPermission
from .base import Event, GroupInfo, NudgeSubject, GroupChatInfo, UserPermission
class NoticeEvent(Event):
"""通知事件基类"""
pass
class MuteEvent(NoticeEvent):
"""禁言类事件基类"""
operator: GroupChatInfo
class BotMuteEvent(MuteEvent):
"""Bot被禁言"""
pass
class BotUnmuteEvent(MuteEvent):
"""Bot被取消禁言"""
pass
class MemberMuteEvent(MuteEvent):
"""群成员被禁言事件该成员不是Bot"""
duration_seconds: int = Field(alias='durationSeconds')
duration_seconds: int = Field(alias="durationSeconds")
member: GroupChatInfo
operator: Optional[GroupChatInfo] = None
class MemberUnmuteEvent(MuteEvent):
"""群成员被取消禁言事件该成员不是Bot"""
member: GroupChatInfo
operator: Optional[GroupChatInfo] = None
class BotJoinGroupEvent(NoticeEvent):
"""Bot加入了一个新群"""
group: GroupInfo
class BotLeaveEventActive(BotJoinGroupEvent):
"""Bot主动退出一个群"""
pass
class BotLeaveEventKick(BotJoinGroupEvent):
"""Bot被踢出一个群"""
pass
class MemberJoinEvent(NoticeEvent):
"""新人入群的事件"""
member: GroupChatInfo
class MemberLeaveEventKick(MemberJoinEvent):
"""成员被踢出群该成员不是Bot"""
operator: Optional[GroupChatInfo] = None
class MemberLeaveEventQuit(MemberJoinEvent):
"""成员主动离群该成员不是Bot"""
pass
class FriendRecallEvent(NoticeEvent):
"""好友消息撤回"""
author_id: int = Field(alias='authorId')
message_id: int = Field(alias='messageId')
author_id: int = Field(alias="authorId")
message_id: int = Field(alias="messageId")
time: int
operator: int
class GroupRecallEvent(FriendRecallEvent):
"""群消息撤回"""
group: GroupInfo
operator: Optional[GroupChatInfo] = None
class GroupStateChangeEvent(NoticeEvent):
"""群变化事件基类"""
origin: Any
current: Any
group: GroupInfo
@ -92,73 +107,85 @@ class GroupStateChangeEvent(NoticeEvent):
class GroupNameChangeEvent(GroupStateChangeEvent):
"""某个群名改变"""
origin: str
current: str
class GroupEntranceAnnouncementChangeEvent(GroupStateChangeEvent):
"""某群入群公告改变"""
origin: str
current: str
class GroupMuteAllEvent(GroupStateChangeEvent):
"""全员禁言"""
origin: bool
current: bool
class GroupAllowAnonymousChatEvent(GroupStateChangeEvent):
"""匿名聊天"""
origin: bool
current: bool
class GroupAllowConfessTalkEvent(GroupStateChangeEvent):
"""坦白说"""
origin: bool
current: bool
class GroupAllowMemberInviteEvent(GroupStateChangeEvent):
"""允许群员邀请好友加群"""
origin: bool
current: bool
class MemberStateChangeEvent(NoticeEvent):
"""群成员变化事件基类"""
member: GroupChatInfo
operator: Optional[GroupChatInfo] = None
class MemberCardChangeEvent(MemberStateChangeEvent):
"""群名片改动"""
origin: str
current: str
class MemberSpecialTitleChangeEvent(MemberStateChangeEvent):
"""群头衔改动(只有群主有操作限权)"""
origin: str
current: str
class BotGroupPermissionChangeEvent(MemberStateChangeEvent):
"""Bot在群里的权限被改变"""
origin: UserPermission
current: UserPermission
class MemberPermissionChangeEvent(MemberStateChangeEvent):
"""成员权限改变的事件该成员不是Bot"""
origin: UserPermission
current: UserPermission
class NudgeEvent(NoticeEvent):
"""戳一戳触发事件"""
from_id: int = Field(alias='fromId')
from_id: int = Field(alias="fromId")
target: int
subject: NudgeSubject
action: str

View File

@ -1,7 +1,7 @@
from typing import TYPE_CHECKING
from typing_extensions import Literal
from pydantic import Field
from typing_extensions import Literal
from .base import Event
@ -11,15 +11,17 @@ if TYPE_CHECKING:
class RequestEvent(Event):
"""请求事件基类"""
event_id: int = Field(alias='eventId')
event_id: int = Field(alias="eventId")
message: str
nick: str
class NewFriendRequestEvent(RequestEvent):
"""添加好友申请"""
from_id: int = Field(alias='fromId')
group_id: int = Field(0, alias='groupId')
from_id: int = Field(alias="fromId")
group_id: int = Field(0, alias="groupId")
async def approve(self, bot: "Bot"):
"""
@ -31,19 +33,18 @@ class NewFriendRequestEvent(RequestEvent):
* ``bot: Bot``: 当前的 ``Bot`` 对象
"""
return await bot.api.post('/resp/newFriendRequestEvent',
params={
'eventId': self.event_id,
'groupId': self.group_id,
'fromId': self.from_id,
'operate': 0,
'message': ''
})
return await bot.api.post(
"/resp/newFriendRequestEvent",
params={
"eventId": self.event_id,
"groupId": self.group_id,
"fromId": self.from_id,
"operate": 0,
"message": "",
},
)
async def reject(self,
bot: "Bot",
operate: Literal[1, 2] = 1,
message: str = ''):
async def reject(self, bot: "Bot", operate: Literal[1, 2] = 1, message: str = ""):
"""
:说明:
@ -60,21 +61,24 @@ class NewFriendRequestEvent(RequestEvent):
* ``message: str``: 回复的信息
"""
assert operate > 0
return await bot.api.post('/resp/newFriendRequestEvent',
params={
'eventId': self.event_id,
'groupId': self.group_id,
'fromId': self.from_id,
'operate': operate,
'message': message
})
return await bot.api.post(
"/resp/newFriendRequestEvent",
params={
"eventId": self.event_id,
"groupId": self.group_id,
"fromId": self.from_id,
"operate": operate,
"message": message,
},
)
class MemberJoinRequestEvent(RequestEvent):
"""用户入群申请Bot需要有管理员权限"""
from_id: int = Field(alias='fromId')
group_id: int = Field(alias='groupId')
group_name: str = Field(alias='groupName')
from_id: int = Field(alias="fromId")
group_id: int = Field(alias="groupId")
group_name: str = Field(alias="groupName")
async def approve(self, bot: "Bot"):
"""
@ -86,19 +90,20 @@ class MemberJoinRequestEvent(RequestEvent):
* ``bot: Bot``: 当前的 ``Bot`` 对象
"""
return await bot.api.post('/resp/memberJoinRequestEvent',
params={
'eventId': self.event_id,
'groupId': self.group_id,
'fromId': self.from_id,
'operate': 0,
'message': ''
})
return await bot.api.post(
"/resp/memberJoinRequestEvent",
params={
"eventId": self.event_id,
"groupId": self.group_id,
"fromId": self.from_id,
"operate": 0,
"message": "",
},
)
async def reject(self,
bot: "Bot",
operate: Literal[1, 2, 3, 4] = 1,
message: str = ''):
async def reject(
self, bot: "Bot", operate: Literal[1, 2, 3, 4] = 1, message: str = ""
):
"""
:说明:
@ -117,21 +122,24 @@ class MemberJoinRequestEvent(RequestEvent):
* ``message: str``: 回复的信息
"""
assert operate > 0
return await bot.api.post('/resp/memberJoinRequestEvent',
params={
'eventId': self.event_id,
'groupId': self.group_id,
'fromId': self.from_id,
'operate': operate,
'message': message
})
return await bot.api.post(
"/resp/memberJoinRequestEvent",
params={
"eventId": self.event_id,
"groupId": self.group_id,
"fromId": self.from_id,
"operate": operate,
"message": message,
},
)
class BotInvitedJoinGroupRequestEvent(RequestEvent):
"""Bot被邀请入群申请"""
from_id: int = Field(alias='fromId')
group_id: int = Field(alias='groupId')
group_name: str = Field(alias='groupName')
from_id: int = Field(alias="fromId")
group_id: int = Field(alias="groupId")
group_name: str = Field(alias="groupName")
async def approve(self, bot: "Bot"):
"""
@ -143,14 +151,16 @@ class BotInvitedJoinGroupRequestEvent(RequestEvent):
* ``bot: Bot``: 当前的 ``Bot`` 对象
"""
return await bot.api.post('/resp/botInvitedJoinGroupRequestEvent',
params={
'eventId': self.event_id,
'groupId': self.group_id,
'fromId': self.from_id,
'operate': 0,
'message': ''
})
return await bot.api.post(
"/resp/botInvitedJoinGroupRequestEvent",
params={
"eventId": self.event_id,
"groupId": self.group_id,
"fromId": self.from_id,
"operate": 0,
"message": "",
},
)
async def reject(self, bot: "Bot", message: str = ""):
"""
@ -163,11 +173,13 @@ class BotInvitedJoinGroupRequestEvent(RequestEvent):
* ``bot: Bot``: 当前的 ``Bot`` 对象
* ``message: str``: 邀请消息
"""
return await bot.api.post('/resp/botInvitedJoinGroupRequestEvent',
params={
'eventId': self.event_id,
'groupId': self.group_id,
'fromId': self.from_id,
'operate': 1,
'message': message
})
return await bot.api.post(
"/resp/botInvitedJoinGroupRequestEvent",
params={
"eventId": self.event_id,
"groupId": self.group_id,
"fromId": self.from_id,
"operate": 1,
"message": message,
},
)

View File

@ -1,28 +1,29 @@
from enum import Enum
from typing import Any, List, Dict, Type, Iterable, Optional, Union
from typing import Any, Dict, List, Type, Union, Iterable, Optional
from pydantic import validate_arguments
from nonebot.typing import overrides
from nonebot.adapters import Message as BaseMessage
from nonebot.adapters import MessageSegment as BaseMessageSegment
from nonebot.typing import overrides
class MessageType(str, Enum):
"""消息类型枚举类"""
SOURCE = 'Source'
QUOTE = 'Quote'
AT = 'At'
AT_ALL = 'AtAll'
FACE = 'Face'
PLAIN = 'Plain'
IMAGE = 'Image'
FLASH_IMAGE = 'FlashImage'
VOICE = 'Voice'
XML = 'Xml'
JSON = 'Json'
APP = 'App'
POKE = 'Poke'
SOURCE = "Source"
QUOTE = "Quote"
AT = "At"
AT_ALL = "AtAll"
FACE = "Face"
PLAIN = "Plain"
IMAGE = "Image"
FLASH_IMAGE = "FlashImage"
VOICE = "Voice"
XML = "Xml"
JSON = "Json"
APP = "App"
POKE = "Poke"
class MessageSegment(BaseMessageSegment["MessageChain"]):
@ -43,21 +44,24 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
@validate_arguments
@overrides(BaseMessageSegment)
def __init__(self, type: MessageType, **data: Any):
super().__init__(type=type,
data={k: v for k, v in data.items() if v is not None})
super().__init__(
type=type, data={k: v for k, v in data.items() if v is not None}
)
@overrides(BaseMessageSegment)
def __str__(self) -> str:
return self.data['text'] if self.is_text() else repr(self)
return self.data["text"] if self.is_text() else repr(self)
def __repr__(self) -> str:
return '[mirai:%s]' % ','.join([
self.type.value,
*map(
lambda s: '%s=%r' % s,
self.data.items(),
),
])
return "[mirai:%s]" % ",".join(
[
self.type.value,
*map(
lambda s: "%s=%r" % s,
self.data.items(),
),
]
)
@overrides(BaseMessageSegment)
def is_text(self) -> bool:
@ -65,15 +69,21 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
def as_dict(self) -> Dict[str, Any]:
"""导出可以被正常json序列化的结构体"""
return {'type': self.type.value, **self.data}
return {"type": self.type.value, **self.data}
@classmethod
def source(cls, id: int, time: int):
return cls(type=MessageType.SOURCE, id=id, time=time)
@classmethod
def quote(cls, id: int, group_id: int, sender_id: int, target_id: int,
origin: "MessageChain"):
def quote(
cls,
id: int,
group_id: int,
sender_id: int,
target_id: int,
origin: "MessageChain",
):
"""
:说明:
@ -87,12 +97,14 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
* ``target_id: int``: 被引用回复的原消息的接收者者的QQ号或群号
* ``origin: MessageChain``: 被引用回复的原消息的消息链对象
"""
return cls(type=MessageType.QUOTE,
id=id,
groupId=group_id,
senderId=sender_id,
targetId=target_id,
origin=origin.export())
return cls(
type=MessageType.QUOTE,
id=id,
groupId=group_id,
senderId=sender_id,
targetId=target_id,
origin=origin.export(),
)
@classmethod
def at(cls, target: int):
@ -144,10 +156,12 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
return cls(type=MessageType.PLAIN, text=text)
@classmethod
def image(cls,
image_id: Optional[str] = None,
url: Optional[str] = None,
path: Optional[str] = None):
def image(
cls,
image_id: Optional[str] = None,
url: Optional[str] = None,
path: Optional[str] = None,
):
"""
:说明:
@ -162,10 +176,12 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
return cls(type=MessageType.IMAGE, imageId=image_id, url=url, path=path)
@classmethod
def flash_image(cls,
image_id: Optional[str] = None,
url: Optional[str] = None,
path: Optional[str] = None):
def flash_image(
cls,
image_id: Optional[str] = None,
url: Optional[str] = None,
path: Optional[str] = None,
):
"""
:说明:
@ -175,16 +191,15 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
同 ``image``
"""
return cls(type=MessageType.FLASH_IMAGE,
imageId=image_id,
url=url,
path=path)
return cls(type=MessageType.FLASH_IMAGE, imageId=image_id, url=url, path=path)
@classmethod
def voice(cls,
voice_id: Optional[str] = None,
url: Optional[str] = None,
path: Optional[str] = None):
def voice(
cls,
voice_id: Optional[str] = None,
url: Optional[str] = None,
path: Optional[str] = None,
):
"""
:说明:
@ -196,10 +211,7 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
* ``url: Optional[str]``: 语音的URL发送时可作网络语音的链接
* ``path: Optional[str]``: 语音的路径,发送本地语音
"""
return cls(type=MessageType.FLASH_IMAGE,
imageId=voice_id,
url=url,
path=path)
return cls(type=MessageType.FLASH_IMAGE, imageId=voice_id, url=url, path=path)
@classmethod
def xml(cls, xml: str):
@ -282,16 +294,14 @@ class MessageChain(BaseMessage[MessageSegment]):
return [MessageSegment.plain(text=message)]
return [
*map(
lambda x: x
if isinstance(x, MessageSegment) else MessageSegment(**x),
message)
lambda x: x if isinstance(x, MessageSegment) else MessageSegment(**x),
message,
)
]
def export(self) -> List[Dict[str, Any]]:
"""导出为可以被正常json序列化的数组"""
return [
*map(lambda segment: segment.as_dict(), self.copy()) # type: ignore
]
return [*map(lambda segment: segment.as_dict(), self.copy())] # type: ignore
def extract_first(self, *type: MessageType) -> Optional[MessageSegment]:
"""
@ -311,4 +321,4 @@ class MessageChain(BaseMessage[MessageSegment]):
return None
def __repr__(self) -> str:
return f'<{self.__class__.__name__} {[*self.copy()]}>'
return f"<{self.__class__.__name__} {[*self.copy()]}>"

View File

@ -1,17 +1,17 @@
import re
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar, Callable, Optional, Coroutine
import httpx
from pydantic import Extra, ValidationError, validate_arguments
import nonebot.exception as exception
from nonebot.log import logger
import nonebot.exception as exception
from nonebot.message import handle_event
from nonebot.utils import escape_tag, logger_wrapper
from .event import Event, GroupMessage, MessageEvent, MessageSource
from .message import MessageType, MessageSegment
from .event import Event, GroupMessage, MessageEvent, MessageSource
if TYPE_CHECKING:
from .bot import Bot
@ -21,28 +21,27 @@ _AnyCallable = TypeVar("_AnyCallable", bound=Callable)
class Log:
@staticmethod
def log(level: str, message: str, exception: Optional[Exception] = None):
logger = logger_wrapper('MIRAI')
message = '<e>' + escape_tag(message) + '</e>'
logger = logger_wrapper("MIRAI")
message = "<e>" + escape_tag(message) + "</e>"
logger(level=level.upper(), message=message, exception=exception)
@classmethod
def info(cls, message: Any):
cls.log('INFO', str(message))
cls.log("INFO", str(message))
@classmethod
def debug(cls, message: Any):
cls.log('DEBUG', str(message))
cls.log("DEBUG", str(message))
@classmethod
def warn(cls, message: Any):
cls.log('WARNING', str(message))
cls.log("WARNING", str(message))
@classmethod
def error(cls, message: Any, exception: Optional[Exception] = None):
cls.log('ERROR', str(message), exception=exception)
cls.log("ERROR", str(message), exception=exception)
class ActionFailed(exception.ActionFailed):
@ -53,12 +52,13 @@ class ActionFailed(exception.ActionFailed):
"""
def __init__(self, **kwargs):
super().__init__('mirai')
super().__init__("mirai")
self.data = kwargs.copy()
def __repr__(self):
return self.__class__.__name__ + '(%s)' % ', '.join(
map(lambda m: '%s=%r' % m, self.data.items()))
return self.__class__.__name__ + "(%s)" % ", ".join(
map(lambda m: "%s=%r" % m, self.data.items())
)
class InvalidArgument(exception.AdapterException):
@ -69,7 +69,7 @@ class InvalidArgument(exception.AdapterException):
"""
def __init__(self, **kwargs):
super().__init__('mirai')
super().__init__("mirai")
def catch_network_error(function: _AsyncCallable) -> _AsyncCallable:
@ -90,11 +90,12 @@ def catch_network_error(function: _AsyncCallable) -> _AsyncCallable:
try:
data = await function(*args, **kwargs)
except httpx.HTTPError:
raise exception.NetworkError('mirai')
logger.opt(colors=True).debug('<b>Mirai API returned data:</b> '
f'<y>{escape_tag(str(data))}</y>')
raise exception.NetworkError("mirai")
logger.opt(colors=True).debug(
"<b>Mirai API returned data:</b> " f"<y>{escape_tag(str(data))}</y>"
)
if isinstance(data, dict):
if data.get('code', 0) != 0:
if data.get("code", 0) != 0:
raise ActionFailed(**data)
return data
@ -109,10 +110,9 @@ def argument_validation(function: _AnyCallable) -> _AnyCallable:
会在参数出错时释放 ``InvalidArgument`` 异常
"""
function = validate_arguments(config={
'arbitrary_types_allowed': True,
'extra': Extra.forbid
})(function)
function = validate_arguments(
config={"arbitrary_types_allowed": True, "extra": Extra.forbid}
)(function)
@wraps(function)
def wrapper(*args, **kwargs):
@ -134,12 +134,12 @@ def process_source(bot: "Bot", event: MessageEvent) -> MessageEvent:
def process_at(bot: "Bot", event: GroupMessage) -> GroupMessage:
at = event.message_chain.extract_first(MessageType.AT)
if at is not None:
if at.data['target'] == event.self_id:
if at.data["target"] == event.self_id:
event.to_me = True
else:
event.message_chain.insert(0, at)
if not event.message_chain:
event.message_chain.append(MessageSegment.plain(''))
event.message_chain.append(MessageSegment.plain(""))
return event
@ -147,13 +147,13 @@ def process_nick(bot: "Bot", event: GroupMessage) -> GroupMessage:
plain = event.message_chain.extract_first(MessageType.PLAIN)
if plain is not None:
text = str(plain)
nick_regex = '|'.join(filter(lambda x: x, bot.config.nickname))
nick_regex = "|".join(filter(lambda x: x, bot.config.nickname))
matched = re.search(rf"^({nick_regex})([\s,]*|$)", text, re.IGNORECASE)
if matched is not None:
event.to_me = True
nickname = matched.group(1)
Log.info(f'User is calling me {nickname}')
plain.data['text'] = text[matched.end():]
Log.info(f"User is calling me {nickname}")
plain.data["text"] = text[matched.end() :]
event.message_chain.insert(0, plain)
return event
@ -161,7 +161,7 @@ def process_nick(bot: "Bot", event: GroupMessage) -> GroupMessage:
def process_reply(bot: "Bot", event: GroupMessage) -> GroupMessage:
reply = event.message_chain.extract_first(MessageType.QUOTE)
if reply is not None:
if reply.data['senderId'] == event.self_id:
if reply.data["senderId"] == event.self_id:
event.to_me = True
else:
event.message_chain.insert(0, reply)

View File

@ -34,6 +34,21 @@ nonebot2 = { path = "../../", develop = true }
# url = "https://mirrors.aliyun.com/pypi/simple/"
# default = true
[tool.black]
line-length = 88
target-version = ["py37", "py38", "py39"]
include = '\.pyi?$'
extend-exclude = '''
'''
[tool.isort]
profile = "black"
line_length = 80
length_sort = true
skip_gitignore = true
force_sort_within_sections = true
extra_standard_library = ["typing_extensions"]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

View File

@ -7,8 +7,7 @@ from nonebot.log import logger
def init():
driver = nonebot.get_driver()
try:
_module = importlib.import_module(
f"nonebot_plugin_docs.drivers.{driver.type}")
_module = importlib.import_module(f"nonebot_plugin_docs.drivers.{driver.type}")
except ImportError:
logger.warning(f"Driver {driver.type} not supported")
return
@ -18,8 +17,9 @@ def init():
port = driver.config.port
if host in ["0.0.0.0", "127.0.0.1"]:
host = "localhost"
logger.opt(colors=True).info(f"Nonebot docs will be running at: "
f"<b><u>http://{host}:{port}/docs/</u></b>")
logger.opt(colors=True).info(
f"Nonebot docs will be running at: " f"<b><u>http://{host}:{port}/docs/</u></b>"
)
init()

View File

@ -9,6 +9,4 @@ def register_route(driver: Driver):
static_path = str((Path(__file__).parent / ".." / "dist").resolve())
app.mount("/docs",
StaticFiles(directory=static_path, html=True),
name="docs")
app.mount("/docs", StaticFiles(directory=static_path, html=True), name="docs")

View File

@ -18,6 +18,21 @@ nonebot2 = "^2.0.0-alpha.1"
[tool.poetry.dev-dependencies]
[tool.black]
line-length = 88
target-version = ["py37", "py38", "py39"]
include = '\.pyi?$'
extend-exclude = '''
'''
[tool.isort]
profile = "black"
line_length = 80
length_sort = true
skip_gitignore = true
force_sort_within_sections = true
extra_standard_library = ["typing_extensions"]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"