mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-16 11:00:54 +00:00
🎨 format code using black and isort
This commit is contained in:
@ -12,8 +12,15 @@ from nonebot.typing import overrides
|
||||
from nonebot.message import handle_event
|
||||
from nonebot.adapters import Bot as BaseBot
|
||||
from nonebot.utils import DataclassEncoder, escape_tag
|
||||
from nonebot.drivers import (Driver, WebSocket, HTTPRequest, HTTPResponse,
|
||||
ForwardDriver, HTTPConnection, WebSocketSetup)
|
||||
from nonebot.drivers import (
|
||||
Driver,
|
||||
WebSocket,
|
||||
HTTPRequest,
|
||||
HTTPResponse,
|
||||
ForwardDriver,
|
||||
HTTPConnection,
|
||||
WebSocketSetup,
|
||||
)
|
||||
|
||||
from .utils import log, escape
|
||||
from .config import Config as CQHTTPConfig
|
||||
@ -49,15 +56,12 @@ async def _check_reply(bot: "Bot", event: "Event"):
|
||||
return
|
||||
|
||||
try:
|
||||
index = list(map(lambda x: x.type == "reply",
|
||||
event.message)).index(True)
|
||||
index = list(map(lambda x: x.type == "reply", event.message)).index(True)
|
||||
except ValueError:
|
||||
return
|
||||
msg_seg = event.message[index]
|
||||
try:
|
||||
event.reply = Reply.parse_obj(await
|
||||
bot.get_msg(message_id=msg_seg.data["id"]
|
||||
))
|
||||
event.reply = Reply.parse_obj(await bot.get_msg(message_id=msg_seg.data["id"]))
|
||||
except Exception as e:
|
||||
log("WARNING", f"Error when getting message reply info: {repr(e)}", e)
|
||||
return
|
||||
@ -68,8 +72,7 @@ async def _check_reply(bot: "Bot", event: "Event"):
|
||||
if len(event.message) > index and event.message[index].type == "at":
|
||||
del event.message[index]
|
||||
if len(event.message) > index and event.message[index].type == "text":
|
||||
event.message[index].data["text"] = event.message[index].data[
|
||||
"text"].lstrip()
|
||||
event.message[index].data["text"] = event.message[index].data["text"].lstrip()
|
||||
if not event.message[index].data["text"]:
|
||||
del event.message[index]
|
||||
if not event.message:
|
||||
@ -99,23 +102,24 @@ def _check_at_me(bot: "Bot", event: "Event"):
|
||||
else:
|
||||
|
||||
def _is_at_me_seg(segment: MessageSegment):
|
||||
return segment.type == "at" and str(segment.data.get(
|
||||
"qq", "")) == str(event.self_id)
|
||||
return segment.type == "at" and str(segment.data.get("qq", "")) == str(
|
||||
event.self_id
|
||||
)
|
||||
|
||||
# check the first segment
|
||||
if _is_at_me_seg(event.message[0]):
|
||||
event.to_me = True
|
||||
event.message.pop(0)
|
||||
if event.message and event.message[0].type == "text":
|
||||
event.message[0].data["text"] = event.message[0].data[
|
||||
"text"].lstrip()
|
||||
event.message[0].data["text"] = event.message[0].data["text"].lstrip()
|
||||
if not event.message[0].data["text"]:
|
||||
del event.message[0]
|
||||
if event.message and _is_at_me_seg(event.message[0]):
|
||||
event.message.pop(0)
|
||||
if event.message and event.message[0].type == "text":
|
||||
event.message[0].data["text"] = event.message[0].data[
|
||||
"text"].lstrip()
|
||||
event.message[0].data["text"] = (
|
||||
event.message[0].data["text"].lstrip()
|
||||
)
|
||||
if not event.message[0].data["text"]:
|
||||
del event.message[0]
|
||||
|
||||
@ -123,9 +127,11 @@ def _check_at_me(bot: "Bot", event: "Event"):
|
||||
# check the last segment
|
||||
i = -1
|
||||
last_msg_seg = event.message[i]
|
||||
if last_msg_seg.type == "text" and \
|
||||
not last_msg_seg.data["text"].strip() and \
|
||||
len(event.message) >= 2:
|
||||
if (
|
||||
last_msg_seg.type == "text"
|
||||
and not last_msg_seg.data["text"].strip()
|
||||
and len(event.message) >= 2
|
||||
):
|
||||
i -= 1
|
||||
last_msg_seg = event.message[i]
|
||||
|
||||
@ -161,13 +167,12 @@ def _check_nickname(bot: "Bot", event: "Event"):
|
||||
if nicknames:
|
||||
# check if the user is calling me with my nickname
|
||||
nickname_regex = "|".join(nicknames)
|
||||
m = re.search(rf"^({nickname_regex})([\s,,]*|$)", first_text,
|
||||
re.IGNORECASE)
|
||||
m = re.search(rf"^({nickname_regex})([\s,,]*|$)", first_text, re.IGNORECASE)
|
||||
if m:
|
||||
nickname = m.group(1)
|
||||
log("DEBUG", f"User is calling me {nickname}")
|
||||
event.to_me = True
|
||||
first_msg_seg.data["text"] = first_text[m.end():]
|
||||
first_msg_seg.data["text"] = first_text[m.end() :]
|
||||
|
||||
|
||||
def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any:
|
||||
@ -206,8 +211,9 @@ class ResultStore:
|
||||
|
||||
@classmethod
|
||||
def add_result(cls, result: Dict[str, Any]):
|
||||
if isinstance(result.get("echo"), dict) and \
|
||||
isinstance(result["echo"].get("seq"), int):
|
||||
if isinstance(result.get("echo"), dict) and isinstance(
|
||||
result["echo"].get("seq"), int
|
||||
):
|
||||
future = cls._futures.get(result["echo"]["seq"])
|
||||
if future:
|
||||
future.set_result(result)
|
||||
@ -228,6 +234,7 @@ class Bot(BaseBot):
|
||||
"""
|
||||
CQHTTP 协议 Bot 适配。继承属性参考 `BaseBot <./#class-basebot>`_ 。
|
||||
"""
|
||||
|
||||
cqhttp_config: CQHTTPConfig
|
||||
|
||||
@property
|
||||
@ -249,22 +256,25 @@ class Bot(BaseBot):
|
||||
elif isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls:
|
||||
for self_id, url in cls.cqhttp_config.ws_urls.items():
|
||||
try:
|
||||
headers = {
|
||||
"authorization":
|
||||
f"Bearer {cls.cqhttp_config.access_token}"
|
||||
} if cls.cqhttp_config.access_token else {}
|
||||
headers = (
|
||||
{"authorization": f"Bearer {cls.cqhttp_config.access_token}"}
|
||||
if cls.cqhttp_config.access_token
|
||||
else {}
|
||||
)
|
||||
driver.setup_websocket(
|
||||
WebSocketSetup("cqhttp", self_id, url, headers=headers))
|
||||
WebSocketSetup("cqhttp", self_id, url, headers=headers)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
f"<r><bg #f8bbd0>Bad url {escape_tag(url)} for bot {escape_tag(self_id)} "
|
||||
"in cqhttp forward websocket</bg #f8bbd0></r>")
|
||||
"in cqhttp forward websocket</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@overrides(BaseBot)
|
||||
async def check_permission(
|
||||
cls, driver: Driver,
|
||||
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]:
|
||||
cls, driver: Driver, request: HTTPConnection
|
||||
) -> Tuple[Optional[str], HTTPResponse]:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -286,22 +296,26 @@ class Bot(BaseBot):
|
||||
if not x_signature:
|
||||
log("WARNING", "Missing Signature Header")
|
||||
return None, HTTPResponse(401, b"Missing Signature")
|
||||
sig = hmac.new(secret.encode("utf-8"), request.body,
|
||||
"sha1").hexdigest()
|
||||
sig = hmac.new(secret.encode("utf-8"), request.body, "sha1").hexdigest()
|
||||
if x_signature != "sha1=" + sig:
|
||||
log("WARNING", "Signature Header is invalid")
|
||||
return None, HTTPResponse(403, b"Signature is invalid")
|
||||
|
||||
access_token = cqhttp_config.access_token
|
||||
if access_token and access_token != token and isinstance(
|
||||
request, WebSocket):
|
||||
if access_token and access_token != token and isinstance(request, WebSocket):
|
||||
log(
|
||||
"WARNING", "Authorization Header is invalid"
|
||||
if token else "Missing Authorization Header")
|
||||
"WARNING",
|
||||
"Authorization Header is invalid"
|
||||
if token
|
||||
else "Missing Authorization Header",
|
||||
)
|
||||
return None, HTTPResponse(
|
||||
403, b"Authorization Header is invalid"
|
||||
if token else b"Missing Authorization Header")
|
||||
return str(x_self_id), HTTPResponse(204, b'')
|
||||
403,
|
||||
b"Authorization Header is invalid"
|
||||
if token
|
||||
else b"Missing Authorization Header",
|
||||
)
|
||||
return str(x_self_id), HTTPResponse(204, b"")
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def handle_message(self, message: bytes):
|
||||
@ -320,7 +334,7 @@ class Bot(BaseBot):
|
||||
return
|
||||
|
||||
try:
|
||||
post_type = data['post_type']
|
||||
post_type = data["post_type"]
|
||||
detail_type = data.get(f"{post_type}_type")
|
||||
detail_type = f".{detail_type}" if detail_type else ""
|
||||
sub_type = data.get("sub_type")
|
||||
@ -352,17 +366,13 @@ class Bot(BaseBot):
|
||||
if isinstance(self.request, WebSocket):
|
||||
seq = ResultStore.get_seq()
|
||||
json_data = json.dumps(
|
||||
{
|
||||
"action": api,
|
||||
"params": data,
|
||||
"echo": {
|
||||
"seq": seq
|
||||
}
|
||||
},
|
||||
cls=DataclassEncoder)
|
||||
{"action": api, "params": data, "echo": {"seq": seq}},
|
||||
cls=DataclassEncoder,
|
||||
)
|
||||
await self.request.send(json_data)
|
||||
return _handle_api_result(await ResultStore.fetch(
|
||||
seq, self.config.api_timeout))
|
||||
return _handle_api_result(
|
||||
await ResultStore.fetch(seq, self.config.api_timeout)
|
||||
)
|
||||
|
||||
elif isinstance(self.request, HTTPRequest):
|
||||
api_root = self.config.api_root.get(self.self_id)
|
||||
@ -373,22 +383,25 @@ class Bot(BaseBot):
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.cqhttp_config.access_token is not None:
|
||||
headers[
|
||||
"Authorization"] = "Bearer " + self.cqhttp_config.access_token
|
||||
headers["Authorization"] = "Bearer " + self.cqhttp_config.access_token
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(headers=headers,
|
||||
follow_redirects=True) as client:
|
||||
async with httpx.AsyncClient(
|
||||
headers=headers, follow_redirects=True
|
||||
) as client:
|
||||
response = await client.post(
|
||||
api_root + api,
|
||||
content=json.dumps(data, cls=DataclassEncoder),
|
||||
timeout=self.config.api_timeout)
|
||||
timeout=self.config.api_timeout,
|
||||
)
|
||||
|
||||
if 200 <= response.status_code < 300:
|
||||
result = response.json()
|
||||
return _handle_api_result(result)
|
||||
raise NetworkError(f"HTTP request received unexpected "
|
||||
f"status code: {response.status_code}")
|
||||
raise NetworkError(
|
||||
f"HTTP request received unexpected "
|
||||
f"status code: {response.status_code}"
|
||||
)
|
||||
except httpx.InvalidURL:
|
||||
raise NetworkError("API root url invalid")
|
||||
except httpx.HTTPError:
|
||||
@ -418,11 +431,13 @@ class Bot(BaseBot):
|
||||
return await super().call_api(api, **data)
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def send(self,
|
||||
event: Event,
|
||||
message: Union[str, Message, MessageSegment],
|
||||
at_sender: bool = False,
|
||||
**kwargs) -> Any:
|
||||
async def send(
|
||||
self,
|
||||
event: Event,
|
||||
message: Union[str, Message, MessageSegment],
|
||||
at_sender: bool = False,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -445,8 +460,9 @@ class Bot(BaseBot):
|
||||
- ``NetworkError``: 网络错误
|
||||
- ``ActionFailed``: API 调用失败
|
||||
"""
|
||||
message = escape(message, escape_comma=False) if isinstance(
|
||||
message, str) else message
|
||||
message = (
|
||||
escape(message, escape_comma=False) if isinstance(message, str) else message
|
||||
)
|
||||
msg = message if isinstance(message, Message) else Message(message)
|
||||
|
||||
at_sender = at_sender and bool(getattr(event, "user_id", None))
|
||||
|
@ -8,7 +8,6 @@ from nonebot.drivers import Driver, WebSocket
|
||||
from .event import Event
|
||||
from .message import Message, MessageSegment
|
||||
|
||||
|
||||
def get_auth_bearer(access_token: Optional[str] = ...) -> Optional[str]:
|
||||
...
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field, BaseModel, AnyUrl
|
||||
from pydantic import Field, AnyUrl, BaseModel
|
||||
|
||||
|
||||
# priority: alias > origin
|
||||
@ -14,11 +14,10 @@ class Config(BaseModel):
|
||||
- ``secret`` / ``cqhttp_secret``: CQHTTP HTTP 上报数据签名口令
|
||||
- ``ws_urls`` / ``cqhttp_ws_urls``: CQHTTP 正向 Websocket 连接 Bot ID、目标 URL 字典
|
||||
"""
|
||||
access_token: Optional[str] = Field(default=None,
|
||||
alias="cqhttp_access_token")
|
||||
|
||||
access_token: Optional[str] = Field(default=None, alias="cqhttp_access_token")
|
||||
secret: Optional[str] = Field(default=None, alias="cqhttp_secret")
|
||||
ws_urls: Dict[str, AnyUrl] = Field(default_factory=set,
|
||||
alias="cqhttp_ws_urls")
|
||||
ws_urls: Dict[str, AnyUrl] = Field(default_factory=set, alias="cqhttp_ws_urls")
|
||||
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
|
@ -5,12 +5,13 @@ from typing import TYPE_CHECKING, List, Type, Optional
|
||||
from pydantic import BaseModel
|
||||
from pygtrie import StringTrie
|
||||
|
||||
from .message import Message
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.utils import escape_tag
|
||||
from .exception import NoLogException
|
||||
from nonebot.adapters import Event as BaseEvent
|
||||
|
||||
from .message import Message
|
||||
from .exception import NoLogException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .bot import Bot
|
||||
|
||||
@ -22,6 +23,7 @@ class Event(BaseEvent):
|
||||
.. _CQHTTP 文档:
|
||||
https://github.com/howmanybots/onebot/blob/master/README.md
|
||||
"""
|
||||
|
||||
__event__ = ""
|
||||
time: int
|
||||
self_id: int
|
||||
@ -118,6 +120,7 @@ class Status(BaseModel):
|
||||
# Message Events
|
||||
class MessageEvent(Event):
|
||||
"""消息事件"""
|
||||
|
||||
__event__ = "message"
|
||||
post_type: Literal["message"]
|
||||
sub_type: str
|
||||
@ -144,8 +147,9 @@ class MessageEvent(Event):
|
||||
@overrides(Event)
|
||||
def get_event_name(self) -> str:
|
||||
sub_type = getattr(self, "sub_type", None)
|
||||
return f"{self.post_type}.{self.message_type}" + (f".{sub_type}"
|
||||
if sub_type else "")
|
||||
return f"{self.post_type}.{self.message_type}" + (
|
||||
f".{sub_type}" if sub_type else ""
|
||||
)
|
||||
|
||||
@overrides(Event)
|
||||
def get_message(self) -> Message:
|
||||
@ -170,20 +174,29 @@ class MessageEvent(Event):
|
||||
|
||||
class PrivateMessageEvent(MessageEvent):
|
||||
"""私聊消息"""
|
||||
|
||||
__event__ = "message.private"
|
||||
message_type: Literal["private"]
|
||||
|
||||
@overrides(Event)
|
||||
def get_event_description(self) -> str:
|
||||
return (f'Message {self.message_id} from {self.user_id} "' + "".join(
|
||||
map(
|
||||
lambda x: escape_tag(str(x))
|
||||
if x.is_text() else f"<le>{escape_tag(str(x))}</le>",
|
||||
self.message)) + '"')
|
||||
return (
|
||||
f'Message {self.message_id} from {self.user_id} "'
|
||||
+ "".join(
|
||||
map(
|
||||
lambda x: escape_tag(str(x))
|
||||
if x.is_text()
|
||||
else f"<le>{escape_tag(str(x))}</le>",
|
||||
self.message,
|
||||
)
|
||||
)
|
||||
+ '"'
|
||||
)
|
||||
|
||||
|
||||
class GroupMessageEvent(MessageEvent):
|
||||
"""群消息"""
|
||||
|
||||
__event__ = "message.group"
|
||||
message_type: Literal["group"]
|
||||
group_id: int
|
||||
@ -196,8 +209,13 @@ class GroupMessageEvent(MessageEvent):
|
||||
+ "".join(
|
||||
map(
|
||||
lambda x: escape_tag(str(x))
|
||||
if x.is_text() else f"<le>{escape_tag(str(x))}</le>",
|
||||
self.message)) + '"')
|
||||
if x.is_text()
|
||||
else f"<le>{escape_tag(str(x))}</le>",
|
||||
self.message,
|
||||
)
|
||||
)
|
||||
+ '"'
|
||||
)
|
||||
|
||||
@overrides(MessageEvent)
|
||||
def get_session_id(self) -> str:
|
||||
@ -207,6 +225,7 @@ class GroupMessageEvent(MessageEvent):
|
||||
# Notice Events
|
||||
class NoticeEvent(Event):
|
||||
"""通知事件"""
|
||||
|
||||
__event__ = "notice"
|
||||
post_type: Literal["notice"]
|
||||
notice_type: str
|
||||
@ -214,12 +233,14 @@ class NoticeEvent(Event):
|
||||
@overrides(Event)
|
||||
def get_event_name(self) -> str:
|
||||
sub_type = getattr(self, "sub_type", None)
|
||||
return f"{self.post_type}.{self.notice_type}" + (f".{sub_type}"
|
||||
if sub_type else "")
|
||||
return f"{self.post_type}.{self.notice_type}" + (
|
||||
f".{sub_type}" if sub_type else ""
|
||||
)
|
||||
|
||||
|
||||
class GroupUploadNoticeEvent(NoticeEvent):
|
||||
"""群文件上传事件"""
|
||||
|
||||
__event__ = "notice.group_upload"
|
||||
notice_type: Literal["group_upload"]
|
||||
user_id: int
|
||||
@ -237,6 +258,7 @@ class GroupUploadNoticeEvent(NoticeEvent):
|
||||
|
||||
class GroupAdminNoticeEvent(NoticeEvent):
|
||||
"""群管理员变动"""
|
||||
|
||||
__event__ = "notice.group_admin"
|
||||
notice_type: Literal["group_admin"]
|
||||
sub_type: str
|
||||
@ -258,6 +280,7 @@ class GroupAdminNoticeEvent(NoticeEvent):
|
||||
|
||||
class GroupDecreaseNoticeEvent(NoticeEvent):
|
||||
"""群成员减少事件"""
|
||||
|
||||
__event__ = "notice.group_decrease"
|
||||
notice_type: Literal["group_decrease"]
|
||||
sub_type: str
|
||||
@ -280,6 +303,7 @@ class GroupDecreaseNoticeEvent(NoticeEvent):
|
||||
|
||||
class GroupIncreaseNoticeEvent(NoticeEvent):
|
||||
"""群成员增加事件"""
|
||||
|
||||
__event__ = "notice.group_increase"
|
||||
notice_type: Literal["group_increase"]
|
||||
sub_type: str
|
||||
@ -302,6 +326,7 @@ class GroupIncreaseNoticeEvent(NoticeEvent):
|
||||
|
||||
class GroupBanNoticeEvent(NoticeEvent):
|
||||
"""群禁言事件"""
|
||||
|
||||
__event__ = "notice.group_ban"
|
||||
notice_type: Literal["group_ban"]
|
||||
sub_type: str
|
||||
@ -325,6 +350,7 @@ class GroupBanNoticeEvent(NoticeEvent):
|
||||
|
||||
class FriendAddNoticeEvent(NoticeEvent):
|
||||
"""好友添加事件"""
|
||||
|
||||
__event__ = "notice.friend_add"
|
||||
notice_type: Literal["friend_add"]
|
||||
user_id: int
|
||||
@ -340,6 +366,7 @@ class FriendAddNoticeEvent(NoticeEvent):
|
||||
|
||||
class GroupRecallNoticeEvent(NoticeEvent):
|
||||
"""群消息撤回事件"""
|
||||
|
||||
__event__ = "notice.group_recall"
|
||||
notice_type: Literal["group_recall"]
|
||||
user_id: int
|
||||
@ -362,6 +389,7 @@ class GroupRecallNoticeEvent(NoticeEvent):
|
||||
|
||||
class FriendRecallNoticeEvent(NoticeEvent):
|
||||
"""好友消息撤回事件"""
|
||||
|
||||
__event__ = "notice.friend_recall"
|
||||
notice_type: Literal["friend_recall"]
|
||||
user_id: int
|
||||
@ -378,6 +406,7 @@ class FriendRecallNoticeEvent(NoticeEvent):
|
||||
|
||||
class NotifyEvent(NoticeEvent):
|
||||
"""提醒事件"""
|
||||
|
||||
__event__ = "notice.notify"
|
||||
notice_type: Literal["notify"]
|
||||
sub_type: str
|
||||
@ -395,6 +424,7 @@ class NotifyEvent(NoticeEvent):
|
||||
|
||||
class PokeNotifyEvent(NotifyEvent):
|
||||
"""戳一戳提醒事件"""
|
||||
|
||||
__event__ = "notice.notify.poke"
|
||||
sub_type: Literal["poke"]
|
||||
target_id: int
|
||||
@ -413,6 +443,7 @@ class PokeNotifyEvent(NotifyEvent):
|
||||
|
||||
class LuckyKingNotifyEvent(NotifyEvent):
|
||||
"""群红包运气王提醒事件"""
|
||||
|
||||
__event__ = "notice.notify.lucky_king"
|
||||
sub_type: Literal["lucky_king"]
|
||||
target_id: int
|
||||
@ -432,6 +463,7 @@ class LuckyKingNotifyEvent(NotifyEvent):
|
||||
|
||||
class HonorNotifyEvent(NotifyEvent):
|
||||
"""群荣誉变更提醒事件"""
|
||||
|
||||
__event__ = "notice.notify.honor"
|
||||
sub_type: Literal["honor"]
|
||||
honor_type: str
|
||||
@ -444,6 +476,7 @@ class HonorNotifyEvent(NotifyEvent):
|
||||
# Request Events
|
||||
class RequestEvent(Event):
|
||||
"""请求事件"""
|
||||
|
||||
__event__ = "request"
|
||||
post_type: Literal["request"]
|
||||
request_type: str
|
||||
@ -451,12 +484,14 @@ class RequestEvent(Event):
|
||||
@overrides(Event)
|
||||
def get_event_name(self) -> str:
|
||||
sub_type = getattr(self, "sub_type", None)
|
||||
return f"{self.post_type}.{self.request_type}" + (f".{sub_type}"
|
||||
if sub_type else "")
|
||||
return f"{self.post_type}.{self.request_type}" + (
|
||||
f".{sub_type}" if sub_type else ""
|
||||
)
|
||||
|
||||
|
||||
class FriendRequestEvent(RequestEvent):
|
||||
"""加好友请求事件"""
|
||||
|
||||
__event__ = "request.friend"
|
||||
request_type: Literal["friend"]
|
||||
user_id: int
|
||||
@ -472,9 +507,9 @@ class FriendRequestEvent(RequestEvent):
|
||||
return str(self.user_id)
|
||||
|
||||
async def approve(self, bot: "Bot", remark: str = ""):
|
||||
return await bot.set_friend_add_request(flag=self.flag,
|
||||
approve=True,
|
||||
remark=remark)
|
||||
return await bot.set_friend_add_request(
|
||||
flag=self.flag, approve=True, remark=remark
|
||||
)
|
||||
|
||||
async def reject(self, bot: "Bot"):
|
||||
return await bot.set_friend_add_request(flag=self.flag, approve=False)
|
||||
@ -482,6 +517,7 @@ class FriendRequestEvent(RequestEvent):
|
||||
|
||||
class GroupRequestEvent(RequestEvent):
|
||||
"""加群请求/邀请事件"""
|
||||
|
||||
__event__ = "request.group"
|
||||
request_type: Literal["group"]
|
||||
sub_type: str
|
||||
@ -499,20 +535,20 @@ class GroupRequestEvent(RequestEvent):
|
||||
return f"group_{self.group_id}_{self.user_id}"
|
||||
|
||||
async def approve(self, bot: "Bot"):
|
||||
return await bot.set_group_add_request(flag=self.flag,
|
||||
sub_type=self.sub_type,
|
||||
approve=True)
|
||||
return await bot.set_group_add_request(
|
||||
flag=self.flag, sub_type=self.sub_type, approve=True
|
||||
)
|
||||
|
||||
async def reject(self, bot: "Bot", reason: str = ""):
|
||||
return await bot.set_group_add_request(flag=self.flag,
|
||||
sub_type=self.sub_type,
|
||||
approve=False,
|
||||
reason=reason)
|
||||
return await bot.set_group_add_request(
|
||||
flag=self.flag, sub_type=self.sub_type, approve=False, reason=reason
|
||||
)
|
||||
|
||||
|
||||
# Meta Events
|
||||
class MetaEvent(Event):
|
||||
"""元事件"""
|
||||
|
||||
__event__ = "meta_event"
|
||||
post_type: Literal["meta_event"]
|
||||
meta_event_type: str
|
||||
@ -520,8 +556,9 @@ class MetaEvent(Event):
|
||||
@overrides(Event)
|
||||
def get_event_name(self) -> str:
|
||||
sub_type = getattr(self, "sub_type", None)
|
||||
return f"{self.post_type}.{self.meta_event_type}" + (f".{sub_type}" if
|
||||
sub_type else "")
|
||||
return f"{self.post_type}.{self.meta_event_type}" + (
|
||||
f".{sub_type}" if sub_type else ""
|
||||
)
|
||||
|
||||
@overrides(Event)
|
||||
def get_log_string(self) -> str:
|
||||
@ -530,6 +567,7 @@ class MetaEvent(Event):
|
||||
|
||||
class LifecycleMetaEvent(MetaEvent):
|
||||
"""生命周期元事件"""
|
||||
|
||||
__event__ = "meta_event.lifecycle"
|
||||
meta_event_type: Literal["lifecycle"]
|
||||
sub_type: str
|
||||
@ -537,6 +575,7 @@ class LifecycleMetaEvent(MetaEvent):
|
||||
|
||||
class HeartbeatMetaEvent(MetaEvent):
|
||||
"""心跳元事件"""
|
||||
|
||||
__event__ = "meta_event.heartbeat"
|
||||
meta_event_type: Literal["heartbeat"]
|
||||
status: Status
|
||||
@ -567,12 +606,28 @@ def get_event_model(event_name) -> List[Type[Event]]:
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Event", "MessageEvent", "PrivateMessageEvent", "GroupMessageEvent",
|
||||
"NoticeEvent", "GroupUploadNoticeEvent", "GroupAdminNoticeEvent",
|
||||
"GroupDecreaseNoticeEvent", "GroupIncreaseNoticeEvent",
|
||||
"GroupBanNoticeEvent", "FriendAddNoticeEvent", "GroupRecallNoticeEvent",
|
||||
"FriendRecallNoticeEvent", "NotifyEvent", "PokeNotifyEvent",
|
||||
"LuckyKingNotifyEvent", "HonorNotifyEvent", "RequestEvent",
|
||||
"FriendRequestEvent", "GroupRequestEvent", "MetaEvent",
|
||||
"LifecycleMetaEvent", "HeartbeatMetaEvent", "get_event_model"
|
||||
"Event",
|
||||
"MessageEvent",
|
||||
"PrivateMessageEvent",
|
||||
"GroupMessageEvent",
|
||||
"NoticeEvent",
|
||||
"GroupUploadNoticeEvent",
|
||||
"GroupAdminNoticeEvent",
|
||||
"GroupDecreaseNoticeEvent",
|
||||
"GroupIncreaseNoticeEvent",
|
||||
"GroupBanNoticeEvent",
|
||||
"FriendAddNoticeEvent",
|
||||
"GroupRecallNoticeEvent",
|
||||
"FriendRecallNoticeEvent",
|
||||
"NotifyEvent",
|
||||
"PokeNotifyEvent",
|
||||
"LuckyKingNotifyEvent",
|
||||
"HonorNotifyEvent",
|
||||
"RequestEvent",
|
||||
"FriendRequestEvent",
|
||||
"GroupRequestEvent",
|
||||
"MetaEvent",
|
||||
"LifecycleMetaEvent",
|
||||
"HeartbeatMetaEvent",
|
||||
"get_event_model",
|
||||
]
|
||||
|
@ -8,7 +8,6 @@ from nonebot.exception import ApiNotAvailable as BaseApiNotAvailable
|
||||
|
||||
|
||||
class CQHTTPAdapterException(AdapterException):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("cqhttp")
|
||||
|
||||
@ -33,8 +32,11 @@ class ActionFailed(BaseActionFailed, CQHTTPAdapterException):
|
||||
self.info = kwargs
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ActionFailed " + ", ".join(
|
||||
f"{k}={v}" for k, v in self.info.items()) + ">"
|
||||
return (
|
||||
f"<ActionFailed "
|
||||
+ ", ".join(f"{k}={v}" for k, v in self.info.items())
|
||||
+ ">"
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
@ -5,10 +5,11 @@ from base64 import b64encode
|
||||
from typing import Any, Type, Tuple, Union, Mapping, Iterable, Optional, cast
|
||||
|
||||
from nonebot.typing import overrides
|
||||
from .utils import log, _b2s, escape, unescape
|
||||
from nonebot.adapters import Message as BaseMessage
|
||||
from nonebot.adapters import MessageSegment as BaseMessageSegment
|
||||
|
||||
from .utils import log, _b2s, escape, unescape
|
||||
|
||||
|
||||
class MessageSegment(BaseMessageSegment["Message"]):
|
||||
"""
|
||||
@ -27,23 +28,24 @@ class MessageSegment(BaseMessageSegment["Message"]):
|
||||
|
||||
# process special types
|
||||
if type_ == "text":
|
||||
return escape(
|
||||
data.get("text", ""), # type: ignore
|
||||
escape_comma=False)
|
||||
return escape(data.get("text", ""), escape_comma=False) # type: ignore
|
||||
|
||||
params = ",".join(
|
||||
[f"{k}={escape(str(v))}" for k, v in data.items() if v is not None])
|
||||
[f"{k}={escape(str(v))}" for k, v in data.items() if v is not None]
|
||||
)
|
||||
return f"[CQ:{type_}{',' if params else ''}{params}]"
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def __add__(self, other) -> "Message":
|
||||
return Message(self) + (MessageSegment.text(other) if isinstance(
|
||||
other, str) else other)
|
||||
return Message(self) + (
|
||||
MessageSegment.text(other) if isinstance(other, str) else other
|
||||
)
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def __radd__(self, other) -> "Message":
|
||||
return (MessageSegment.text(other)
|
||||
if isinstance(other, str) else Message(other)) + self
|
||||
return (
|
||||
MessageSegment.text(other) if isinstance(other, str) else Message(other)
|
||||
) + self
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def is_text(self) -> bool:
|
||||
@ -83,11 +85,13 @@ class MessageSegment(BaseMessageSegment["Message"]):
|
||||
return MessageSegment("forward", {"id": id_})
|
||||
|
||||
@staticmethod
|
||||
def image(file: Union[str, bytes, BytesIO, Path],
|
||||
type_: Optional[str] = None,
|
||||
cache: bool = True,
|
||||
proxy: bool = True,
|
||||
timeout: Optional[int] = None) -> "MessageSegment":
|
||||
def image(
|
||||
file: Union[str, bytes, BytesIO, Path],
|
||||
type_: Optional[str] = None,
|
||||
cache: bool = True,
|
||||
proxy: bool = True,
|
||||
timeout: Optional[int] = None,
|
||||
) -> "MessageSegment":
|
||||
if isinstance(file, BytesIO):
|
||||
file = file.getvalue()
|
||||
if isinstance(file, bytes):
|
||||
@ -95,74 +99,85 @@ class MessageSegment(BaseMessageSegment["Message"]):
|
||||
elif isinstance(file, Path):
|
||||
file = f"file:///{file.resolve()}"
|
||||
return MessageSegment(
|
||||
"image", {
|
||||
"image",
|
||||
{
|
||||
"file": file,
|
||||
"type": type_,
|
||||
"cache": _b2s(cache),
|
||||
"proxy": _b2s(proxy),
|
||||
"timeout": timeout
|
||||
})
|
||||
"timeout": timeout,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def json(data: str) -> "MessageSegment":
|
||||
return MessageSegment("json", {"data": data})
|
||||
|
||||
@staticmethod
|
||||
def location(latitude: float,
|
||||
longitude: float,
|
||||
title: Optional[str] = None,
|
||||
content: Optional[str] = None) -> "MessageSegment":
|
||||
def location(
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
title: Optional[str] = None,
|
||||
content: Optional[str] = None,
|
||||
) -> "MessageSegment":
|
||||
return MessageSegment(
|
||||
"location", {
|
||||
"location",
|
||||
{
|
||||
"lat": str(latitude),
|
||||
"lon": str(longitude),
|
||||
"title": title,
|
||||
"content": content
|
||||
})
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def music(type_: str, id_: int) -> "MessageSegment":
|
||||
return MessageSegment("music", {"type": type_, "id": id_})
|
||||
|
||||
@staticmethod
|
||||
def music_custom(url: str,
|
||||
audio: str,
|
||||
title: str,
|
||||
content: Optional[str] = None,
|
||||
img_url: Optional[str] = None) -> "MessageSegment":
|
||||
def music_custom(
|
||||
url: str,
|
||||
audio: str,
|
||||
title: str,
|
||||
content: Optional[str] = None,
|
||||
img_url: Optional[str] = None,
|
||||
) -> "MessageSegment":
|
||||
return MessageSegment(
|
||||
"music", {
|
||||
"music",
|
||||
{
|
||||
"type": "custom",
|
||||
"url": url,
|
||||
"audio": audio,
|
||||
"title": title,
|
||||
"content": content,
|
||||
"image": img_url
|
||||
})
|
||||
"image": img_url,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def node(id_: int) -> "MessageSegment":
|
||||
return MessageSegment("node", {"id": str(id_)})
|
||||
|
||||
@staticmethod
|
||||
def node_custom(user_id: int, nickname: str,
|
||||
content: Union[str, "Message"]) -> "MessageSegment":
|
||||
return MessageSegment("node", {
|
||||
"user_id": str(user_id),
|
||||
"nickname": nickname,
|
||||
"content": content
|
||||
})
|
||||
def node_custom(
|
||||
user_id: int, nickname: str, content: Union[str, "Message"]
|
||||
) -> "MessageSegment":
|
||||
return MessageSegment(
|
||||
"node", {"user_id": str(user_id), "nickname": nickname, "content": content}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def poke(type_: str, id_: str) -> "MessageSegment":
|
||||
return MessageSegment("poke", {"type": type_, "id": id_})
|
||||
|
||||
@staticmethod
|
||||
def record(file: Union[str, bytes, BytesIO, Path],
|
||||
magic: Optional[bool] = None,
|
||||
cache: Optional[bool] = None,
|
||||
proxy: Optional[bool] = None,
|
||||
timeout: Optional[int] = None) -> "MessageSegment":
|
||||
def record(
|
||||
file: Union[str, bytes, BytesIO, Path],
|
||||
magic: Optional[bool] = None,
|
||||
cache: Optional[bool] = None,
|
||||
proxy: Optional[bool] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> "MessageSegment":
|
||||
if isinstance(file, BytesIO):
|
||||
file = file.getvalue()
|
||||
if isinstance(file, bytes):
|
||||
@ -170,13 +185,15 @@ class MessageSegment(BaseMessageSegment["Message"]):
|
||||
elif isinstance(file, Path):
|
||||
file = f"file:///{file.resolve()}"
|
||||
return MessageSegment(
|
||||
"record", {
|
||||
"record",
|
||||
{
|
||||
"file": file,
|
||||
"magic": _b2s(magic),
|
||||
"cache": _b2s(cache),
|
||||
"proxy": _b2s(proxy),
|
||||
"timeout": timeout
|
||||
})
|
||||
"timeout": timeout,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def reply(id_: int) -> "MessageSegment":
|
||||
@ -191,26 +208,27 @@ class MessageSegment(BaseMessageSegment["Message"]):
|
||||
return MessageSegment("shake", {})
|
||||
|
||||
@staticmethod
|
||||
def share(url: str = "",
|
||||
title: str = "",
|
||||
content: Optional[str] = None,
|
||||
image: Optional[str] = None) -> "MessageSegment":
|
||||
return MessageSegment("share", {
|
||||
"url": url,
|
||||
"title": title,
|
||||
"content": content,
|
||||
"image": image
|
||||
})
|
||||
def share(
|
||||
url: str = "",
|
||||
title: str = "",
|
||||
content: Optional[str] = None,
|
||||
image: Optional[str] = None,
|
||||
) -> "MessageSegment":
|
||||
return MessageSegment(
|
||||
"share", {"url": url, "title": title, "content": content, "image": image}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def text(text: str) -> "MessageSegment":
|
||||
return MessageSegment("text", {"text": text})
|
||||
|
||||
@staticmethod
|
||||
def video(file: Union[str, bytes, BytesIO, Path],
|
||||
cache: Optional[bool] = None,
|
||||
proxy: Optional[bool] = None,
|
||||
timeout: Optional[int] = None) -> "MessageSegment":
|
||||
def video(
|
||||
file: Union[str, bytes, BytesIO, Path],
|
||||
cache: Optional[bool] = None,
|
||||
proxy: Optional[bool] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> "MessageSegment":
|
||||
if isinstance(file, BytesIO):
|
||||
file = file.getvalue()
|
||||
if isinstance(file, bytes):
|
||||
@ -218,12 +236,14 @@ class MessageSegment(BaseMessageSegment["Message"]):
|
||||
elif isinstance(file, Path):
|
||||
file = f"file:///{file.resolve()}"
|
||||
return MessageSegment(
|
||||
"video", {
|
||||
"video",
|
||||
{
|
||||
"file": file,
|
||||
"cache": _b2s(cache),
|
||||
"proxy": _b2s(proxy),
|
||||
"timeout": timeout
|
||||
})
|
||||
"timeout": timeout,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def xml(data: str) -> "MessageSegment":
|
||||
@ -241,22 +261,22 @@ class Message(BaseMessage[MessageSegment]):
|
||||
return MessageSegment
|
||||
|
||||
@overrides(BaseMessage)
|
||||
def __add__(self, other: Union[str, Mapping,
|
||||
Iterable[Mapping]]) -> "Message":
|
||||
def __add__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> "Message":
|
||||
return super(Message, self).__add__(
|
||||
MessageSegment.text(other) if isinstance(other, str) else other)
|
||||
MessageSegment.text(other) if isinstance(other, str) else other
|
||||
)
|
||||
|
||||
@overrides(BaseMessage)
|
||||
def __radd__(self, other: Union[str, Mapping,
|
||||
Iterable[Mapping]]) -> "Message":
|
||||
def __radd__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> "Message":
|
||||
return super(Message, self).__radd__(
|
||||
MessageSegment.text(other) if isinstance(other, str) else other)
|
||||
MessageSegment.text(other) if isinstance(other, str) else other
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@overrides(BaseMessage)
|
||||
def _construct(
|
||||
msg: Union[str, Mapping,
|
||||
Iterable[Mapping]]) -> Iterable[MessageSegment]:
|
||||
msg: Union[str, Mapping, Iterable[Mapping]]
|
||||
) -> Iterable[MessageSegment]:
|
||||
if isinstance(msg, Mapping):
|
||||
msg = cast(Mapping[str, Any], msg)
|
||||
yield MessageSegment(msg["type"], msg.get("data") or {})
|
||||
@ -270,14 +290,15 @@ class Message(BaseMessage[MessageSegment]):
|
||||
def _iter_message(msg: str) -> Iterable[Tuple[str, str]]:
|
||||
text_begin = 0
|
||||
for cqcode in re.finditer(
|
||||
r"\[CQ:(?P<type>[a-zA-Z0-9-_.]+)"
|
||||
r"(?P<params>"
|
||||
r"(?:,[a-zA-Z0-9-_.]+=[^,\]]+)*"
|
||||
r"),?\]", msg):
|
||||
yield "text", msg[text_begin:cqcode.pos + cqcode.start()]
|
||||
r"\[CQ:(?P<type>[a-zA-Z0-9-_.]+)"
|
||||
r"(?P<params>"
|
||||
r"(?:,[a-zA-Z0-9-_.]+=[^,\]]+)*"
|
||||
r"),?\]",
|
||||
msg,
|
||||
):
|
||||
yield "text", msg[text_begin : cqcode.pos + cqcode.start()]
|
||||
text_begin = cqcode.pos + cqcode.end()
|
||||
yield cqcode.group("type"), cqcode.group("params").lstrip(
|
||||
",")
|
||||
yield cqcode.group("type"), cqcode.group("params").lstrip(",")
|
||||
yield "text", msg[text_begin:]
|
||||
|
||||
for type_, data in _iter_message(msg):
|
||||
@ -287,10 +308,11 @@ class Message(BaseMessage[MessageSegment]):
|
||||
yield MessageSegment(type_, {"text": unescape(data)})
|
||||
else:
|
||||
data = {
|
||||
k: unescape(v) for k, v in map(
|
||||
k: unescape(v)
|
||||
for k, v in map(
|
||||
lambda x: x.split("=", maxsplit=1),
|
||||
filter(lambda x: x, (
|
||||
x.lstrip() for x in data.split(","))))
|
||||
filter(lambda x: x, (x.lstrip() for x in data.split(","))),
|
||||
)
|
||||
}
|
||||
yield MessageSegment(type_, data)
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
from nonebot.adapters import Event
|
||||
from nonebot.permission import Permission
|
||||
|
||||
from .event import GroupMessageEvent, PrivateMessageEvent
|
||||
|
||||
|
||||
@ -42,8 +43,7 @@ async def _group(event: Event) -> bool:
|
||||
|
||||
|
||||
async def _group_member(event: Event) -> bool:
|
||||
return isinstance(event,
|
||||
GroupMessageEvent) and event.sender.role == "member"
|
||||
return isinstance(event, GroupMessageEvent) and event.sender.role == "member"
|
||||
|
||||
|
||||
async def _group_admin(event: Event) -> bool:
|
||||
@ -76,6 +76,12 @@ GROUP_OWNER = Permission(_group_owner)
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"PRIVATE", "PRIVATE_FRIEND", "PRIVATE_GROUP", "PRIVATE_OTHER", "GROUP",
|
||||
"GROUP_MEMBER", "GROUP_ADMIN", "GROUP_OWNER"
|
||||
"PRIVATE",
|
||||
"PRIVATE_FRIEND",
|
||||
"PRIVATE_GROUP",
|
||||
"PRIVATE_OTHER",
|
||||
"GROUP",
|
||||
"GROUP_MEMBER",
|
||||
"GROUP_ADMIN",
|
||||
"GROUP_OWNER",
|
||||
]
|
||||
|
@ -16,9 +16,7 @@ def escape(s: str, *, escape_comma: bool = True) -> str:
|
||||
* ``s: str``: 需要转义的字符串
|
||||
* ``escape_comma: bool``: 是否转义逗号(``,``)。
|
||||
"""
|
||||
s = s.replace("&", "&") \
|
||||
.replace("[", "[") \
|
||||
.replace("]", "]")
|
||||
s = s.replace("&", "&").replace("[", "[").replace("]", "]")
|
||||
if escape_comma:
|
||||
s = s.replace(",", ",")
|
||||
return s
|
||||
@ -34,10 +32,12 @@ def unescape(s: str) -> str:
|
||||
|
||||
* ``s: str``: 需要转义的字符串
|
||||
"""
|
||||
return s.replace(",", ",") \
|
||||
.replace("[", "[") \
|
||||
.replace("]", "]") \
|
||||
return (
|
||||
s.replace(",", ",")
|
||||
.replace("[", "[")
|
||||
.replace("]", "]")
|
||||
.replace("&", "&")
|
||||
)
|
||||
|
||||
|
||||
def _b2s(b: Optional[bool]) -> Optional[str]:
|
||||
|
@ -34,6 +34,21 @@ nonebot2 = { path = "../../", develop = true }
|
||||
# url = "https://mirrors.aliyun.com/pypi/simple/"
|
||||
# default = true
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ["py37", "py38", "py39"]
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '''
|
||||
'''
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 80
|
||||
length_sort = true
|
||||
skip_gitignore = true
|
||||
force_sort_within_sections = true
|
||||
extra_standard_library = ["typing_extensions"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
@ -16,10 +16,18 @@ from nonebot.drivers import Driver, HTTPRequest, HTTPResponse, HTTPConnection
|
||||
from .config import Config as DingConfig
|
||||
from .utils import log, calc_hmac_base64
|
||||
from .message import Message, MessageSegment
|
||||
from .exception import (ActionFailed, NetworkError, SessionExpired,
|
||||
ApiNotAvailable)
|
||||
from .event import (MessageEvent, ConversationType, GroupMessageEvent,
|
||||
PrivateMessageEvent)
|
||||
from .exception import (
|
||||
ActionFailed,
|
||||
NetworkError,
|
||||
SessionExpired,
|
||||
ApiNotAvailable,
|
||||
)
|
||||
from .event import (
|
||||
MessageEvent,
|
||||
ConversationType,
|
||||
GroupMessageEvent,
|
||||
PrivateMessageEvent,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.config import Config
|
||||
@ -31,6 +39,7 @@ class Bot(BaseBot):
|
||||
"""
|
||||
钉钉 协议 Bot 适配。继承属性参考 `BaseBot <./#class-basebot>`_ 。
|
||||
"""
|
||||
|
||||
ding_config: DingConfig
|
||||
|
||||
@property
|
||||
@ -48,8 +57,8 @@ class Bot(BaseBot):
|
||||
@classmethod
|
||||
@overrides(BaseBot)
|
||||
async def check_permission(
|
||||
cls, driver: Driver,
|
||||
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]:
|
||||
cls, driver: Driver, request: HTTPConnection
|
||||
) -> Tuple[Optional[str], HTTPResponse]:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -61,7 +70,8 @@ class Bot(BaseBot):
|
||||
# 检查连接方式
|
||||
if not isinstance(request, HTTPRequest):
|
||||
return None, HTTPResponse(
|
||||
405, b"Unsupported connection type, available type: `http`")
|
||||
405, b"Unsupported connection type, available type: `http`"
|
||||
)
|
||||
|
||||
# 检查 timestamp
|
||||
if not timestamp:
|
||||
@ -74,13 +84,15 @@ class Bot(BaseBot):
|
||||
log("WARNING", "Missing Signature Header")
|
||||
return None, HTTPResponse(400, b"Missing `sign` Header")
|
||||
sign_base64 = calc_hmac_base64(str(timestamp), secret)
|
||||
if sign != sign_base64.decode('utf-8'):
|
||||
if sign != sign_base64.decode("utf-8"):
|
||||
log("WARNING", "Signature Header is invalid")
|
||||
return None, HTTPResponse(403, b"Signature is invalid")
|
||||
else:
|
||||
log("WARNING", "Ding signature check ignored!")
|
||||
return (json.loads(request.body.decode())["chatbotUserId"],
|
||||
HTTPResponse(204, b''))
|
||||
return (
|
||||
json.loads(request.body.decode())["chatbotUserId"],
|
||||
HTTPResponse(204, b""),
|
||||
)
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def handle_message(self, message: bytes):
|
||||
@ -111,10 +123,9 @@ class Bot(BaseBot):
|
||||
return
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def _call_api(self,
|
||||
api: str,
|
||||
event: Optional[MessageEvent] = None,
|
||||
**data) -> Any:
|
||||
async def _call_api(
|
||||
self, api: str, event: Optional[MessageEvent] = None, **data
|
||||
) -> Any:
|
||||
if not isinstance(self.request, HTTPRequest):
|
||||
log("ERROR", "Only support http connection.")
|
||||
return
|
||||
@ -138,7 +149,8 @@ class Bot(BaseBot):
|
||||
if event:
|
||||
# 确保 sessionWebhook 没有过期
|
||||
if int(datetime.now().timestamp()) > int(
|
||||
event.sessionWebhookExpiredTime / 1000):
|
||||
event.sessionWebhookExpiredTime / 1000
|
||||
):
|
||||
raise SessionExpired
|
||||
|
||||
webhook = event.sessionWebhook
|
||||
@ -150,32 +162,37 @@ class Bot(BaseBot):
|
||||
if not message:
|
||||
raise ValueError("Message not found")
|
||||
try:
|
||||
async with httpx.AsyncClient(headers=headers,
|
||||
follow_redirects=True) as client:
|
||||
response = await client.post(webhook,
|
||||
params=params,
|
||||
json=message._produce(),
|
||||
timeout=self.config.api_timeout)
|
||||
async with httpx.AsyncClient(
|
||||
headers=headers, follow_redirects=True
|
||||
) as client:
|
||||
response = await client.post(
|
||||
webhook,
|
||||
params=params,
|
||||
json=message._produce(),
|
||||
timeout=self.config.api_timeout,
|
||||
)
|
||||
|
||||
if 200 <= response.status_code < 300:
|
||||
result = response.json()
|
||||
if isinstance(result, dict):
|
||||
if result.get("errcode") != 0:
|
||||
raise ActionFailed(errcode=result.get("errcode"),
|
||||
errmsg=result.get("errmsg"))
|
||||
raise ActionFailed(
|
||||
errcode=result.get("errcode"), errmsg=result.get("errmsg")
|
||||
)
|
||||
return result
|
||||
raise NetworkError(f"HTTP request received unexpected "
|
||||
f"status code: {response.status_code}")
|
||||
raise NetworkError(
|
||||
f"HTTP request received unexpected "
|
||||
f"status code: {response.status_code}"
|
||||
)
|
||||
except httpx.InvalidURL:
|
||||
raise NetworkError("API root url invalid")
|
||||
except httpx.HTTPError:
|
||||
raise NetworkError("HTTP request failed")
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def call_api(self,
|
||||
api: str,
|
||||
event: Optional[MessageEvent] = None,
|
||||
**data) -> Any:
|
||||
async def call_api(
|
||||
self, api: str, event: Optional[MessageEvent] = None, **data
|
||||
) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -199,13 +216,15 @@ class Bot(BaseBot):
|
||||
return await super().call_api(api, event=event, **data)
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def send(self,
|
||||
event: MessageEvent,
|
||||
message: Union[str, "Message", "MessageSegment"],
|
||||
at_sender: bool = False,
|
||||
webhook: Optional[str] = None,
|
||||
secret: Optional[str] = None,
|
||||
**kwargs) -> Any:
|
||||
async def send(
|
||||
self,
|
||||
event: MessageEvent,
|
||||
message: Union[str, "Message", "MessageSegment"],
|
||||
at_sender: bool = False,
|
||||
webhook: Optional[str] = None,
|
||||
secret: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -241,9 +260,11 @@ class Bot(BaseBot):
|
||||
params.update(kwargs)
|
||||
|
||||
if at_sender and event.conversationType != ConversationType.private:
|
||||
params[
|
||||
"message"] = f"@{event.senderId} " + msg + MessageSegment.atDingtalkIds(
|
||||
event.senderId)
|
||||
params["message"] = (
|
||||
f"@{event.senderId} "
|
||||
+ msg
|
||||
+ MessageSegment.atDingtalkIds(event.senderId)
|
||||
)
|
||||
else:
|
||||
params["message"] = msg
|
||||
|
||||
|
@ -12,6 +12,7 @@ class Config(BaseModel):
|
||||
- ``access_token`` / ``ding_access_token``: 钉钉令牌
|
||||
- ``secret`` / ``ding_secret``: 钉钉 HTTP 上报数据签名口令
|
||||
"""
|
||||
|
||||
secret: Optional[str] = Field(default=None, alias="ding_secret")
|
||||
access_token: Optional[str] = Field(default=None, alias="ding_access_token")
|
||||
|
||||
|
@ -69,6 +69,7 @@ class ConversationType(str, Enum):
|
||||
|
||||
class MessageEvent(Event):
|
||||
"""消息事件"""
|
||||
|
||||
msgtype: str
|
||||
text: TextMessage
|
||||
msgId: str
|
||||
@ -88,11 +89,10 @@ class MessageEvent(Event):
|
||||
def gen_message(cls, values: dict):
|
||||
assert "msgtype" in values, "msgtype must be specified"
|
||||
# 其实目前钉钉机器人只能接收到 text 类型的消息
|
||||
assert values[
|
||||
"msgtype"] in values, f"{values['msgtype']} must be specified"
|
||||
content = values[values['msgtype']]['content']
|
||||
assert values["msgtype"] in values, f"{values['msgtype']} must be specified"
|
||||
content = values[values["msgtype"]]["content"]
|
||||
# 如果是被 @,第一个字符将会为空格,移除特殊情况
|
||||
if content[0] == ' ':
|
||||
if content[0] == " ":
|
||||
content = content[1:]
|
||||
values["message"] = content
|
||||
return values
|
||||
@ -128,6 +128,7 @@ class MessageEvent(Event):
|
||||
|
||||
class PrivateMessageEvent(MessageEvent):
|
||||
"""私聊消息事件"""
|
||||
|
||||
chatbotCorpId: str
|
||||
senderStaffId: Optional[str]
|
||||
conversationType: ConversationType = ConversationType.private
|
||||
@ -135,6 +136,7 @@ class PrivateMessageEvent(MessageEvent):
|
||||
|
||||
class GroupMessageEvent(MessageEvent):
|
||||
"""群消息事件"""
|
||||
|
||||
atUsers: List[AtUsersItem]
|
||||
conversationType: ConversationType = ConversationType.group
|
||||
conversationTitle: str
|
||||
|
@ -1,9 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from nonebot.exception import (AdapterException, ActionFailed as
|
||||
BaseActionFailed, ApiNotAvailable as
|
||||
BaseApiNotAvailable, NetworkError as
|
||||
BaseNetworkError)
|
||||
from nonebot.exception import AdapterException
|
||||
from nonebot.exception import ActionFailed as BaseActionFailed
|
||||
from nonebot.exception import NetworkError as BaseNetworkError
|
||||
from nonebot.exception import ApiNotAvailable as BaseApiNotAvailable
|
||||
|
||||
|
||||
class DingAdapterException(AdapterException):
|
||||
@ -29,15 +29,13 @@ class ActionFailed(BaseActionFailed, DingAdapterException):
|
||||
* ``errmsg: Optional[str]``: 错误信息
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
errcode: Optional[int] = None,
|
||||
errmsg: Optional[str] = None):
|
||||
def __init__(self, errcode: Optional[int] = None, errmsg: Optional[str] = None):
|
||||
super().__init__()
|
||||
self.errcode = errcode
|
||||
self.errmsg = errmsg
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ApiError errcode={self.errcode} errmsg=\"{self.errmsg}\">"
|
||||
return f'<ApiError errcode={self.errcode} errmsg="{self.errmsg}">'
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
@ -77,10 +77,9 @@ class MessageSegment(BaseMessageSegment["Message"]):
|
||||
def code(code_language: str, code: str) -> "Message":
|
||||
"""发送 code 消息段"""
|
||||
message = MessageSegment.text(code)
|
||||
message += MessageSegment.extension({
|
||||
"text_type": "code_snippet",
|
||||
"code_language": code_language
|
||||
})
|
||||
message += MessageSegment.extension(
|
||||
{"text_type": "code_snippet", "code_language": code_language}
|
||||
)
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
@ -95,16 +94,19 @@ class MessageSegment(BaseMessageSegment["Message"]):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def actionCardSingleBtn(title: str, text: str, singleTitle: str,
|
||||
singleURL) -> "MessageSegment":
|
||||
def actionCardSingleBtn(
|
||||
title: str, text: str, singleTitle: str, singleURL
|
||||
) -> "MessageSegment":
|
||||
"""发送 ``actionCardSingleBtn`` 类型消息"""
|
||||
return MessageSegment(
|
||||
"actionCard", {
|
||||
"actionCard",
|
||||
{
|
||||
"title": title,
|
||||
"text": text,
|
||||
"singleTitle": singleTitle,
|
||||
"singleURL": singleURL
|
||||
})
|
||||
"singleURL": singleURL,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def actionCardMultiBtns(
|
||||
@ -112,7 +114,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
|
||||
text: str,
|
||||
btns: list,
|
||||
hideAvatar: bool = False,
|
||||
btnOrientation: str = '1',
|
||||
btnOrientation: str = "1",
|
||||
) -> "MessageSegment":
|
||||
"""
|
||||
发送 ``actionCardMultiBtn`` 类型消息
|
||||
@ -123,13 +125,15 @@ class MessageSegment(BaseMessageSegment["Message"]):
|
||||
* ``btns``: ``[{ "title": title, "actionURL": actionURL }, ...]``
|
||||
"""
|
||||
return MessageSegment(
|
||||
"actionCard", {
|
||||
"actionCard",
|
||||
{
|
||||
"title": title,
|
||||
"text": text,
|
||||
"hideAvatar": "1" if hideAvatar else "0",
|
||||
"btnOrientation": btnOrientation,
|
||||
"btns": btns
|
||||
})
|
||||
"btns": btns,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def feedCard(links: list) -> "MessageSegment":
|
||||
@ -144,7 +148,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
|
||||
|
||||
@staticmethod
|
||||
def raw(data) -> "MessageSegment":
|
||||
return MessageSegment('raw', data)
|
||||
return MessageSegment("raw", data)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
# 让用户可以直接发送原始的消息格式
|
||||
@ -171,8 +175,8 @@ class Message(BaseMessage[MessageSegment]):
|
||||
@staticmethod
|
||||
@overrides(BaseMessage)
|
||||
def _construct(
|
||||
msg: Union[str, Mapping,
|
||||
Iterable[Mapping]]) -> Iterable[MessageSegment]:
|
||||
msg: Union[str, Mapping, Iterable[Mapping]]
|
||||
) -> Iterable[MessageSegment]:
|
||||
if isinstance(msg, Mapping):
|
||||
msg = cast(Mapping[str, Any], msg)
|
||||
yield MessageSegment(msg["type"], msg.get("data") or {})
|
||||
@ -187,10 +191,11 @@ class Message(BaseMessage[MessageSegment]):
|
||||
segment: MessageSegment
|
||||
for segment in self:
|
||||
# text 可以和 text 合并
|
||||
if segment.type == "text" and data.get("msgtype") == 'text':
|
||||
if segment.type == "text" and data.get("msgtype") == "text":
|
||||
data.setdefault("text", {})
|
||||
data["text"]["content"] = data["text"].setdefault(
|
||||
"content", "") + segment.data["content"]
|
||||
data["text"]["content"] = (
|
||||
data["text"].setdefault("content", "") + segment.data["content"]
|
||||
)
|
||||
else:
|
||||
data.update(segment.to_dict())
|
||||
return data
|
||||
|
@ -8,10 +8,10 @@ log = logger_wrapper("DING")
|
||||
|
||||
|
||||
def calc_hmac_base64(timestamp: str, secret: str):
|
||||
secret_enc = secret.encode('utf-8')
|
||||
string_to_sign = '{}\n{}'.format(timestamp, secret)
|
||||
string_to_sign_enc = string_to_sign.encode('utf-8')
|
||||
hmac_code = hmac.new(secret_enc,
|
||||
string_to_sign_enc,
|
||||
digestmod=hashlib.sha256).digest()
|
||||
secret_enc = secret.encode("utf-8")
|
||||
string_to_sign = "{}\n{}".format(timestamp, secret)
|
||||
string_to_sign_enc = string_to_sign.encode("utf-8")
|
||||
hmac_code = hmac.new(
|
||||
secret_enc, string_to_sign_enc, digestmod=hashlib.sha256
|
||||
).digest()
|
||||
return base64.b64encode(hmac_code)
|
||||
|
@ -34,6 +34,21 @@ nonebot2 = { path = "../../", develop = true }
|
||||
# url = "https://mirrors.aliyun.com/pypi/simple/"
|
||||
# default = true
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ["py37", "py38", "py39"]
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '''
|
||||
'''
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 80
|
||||
length_sort = true
|
||||
skip_gitignore = true
|
||||
force_sort_within_sections = true
|
||||
extra_standard_library = ["typing_extensions"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
@ -7,7 +7,7 @@ aiocache_logger.setLevel(logging.DEBUG)
|
||||
aiocache_logger.handlers.clear()
|
||||
aiocache_logger.addHandler(LoguruHandler())
|
||||
|
||||
from .bot import Bot as Bot
|
||||
from .event import *
|
||||
from .bot import Bot as Bot
|
||||
from .message import Message as Message
|
||||
from .message import MessageSegment as MessageSegment
|
||||
|
@ -1,24 +1,39 @@
|
||||
import re
|
||||
import json
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Tuple, Union, Iterable, Optional,
|
||||
AsyncIterable, cast)
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Tuple,
|
||||
Union,
|
||||
Iterable,
|
||||
Optional,
|
||||
AsyncIterable,
|
||||
cast,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from aiocache import Cache, cached
|
||||
from aiocache.serializers import PickleSerializer
|
||||
|
||||
from nonebot.log import logger
|
||||
from .utils import AESCipher, log
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.utils import escape_tag
|
||||
from nonebot.message import handle_event
|
||||
from .config import Config as FeishuConfig
|
||||
from nonebot.adapters import Bot as BaseBot
|
||||
from nonebot.drivers import Driver, HTTPRequest, HTTPResponse
|
||||
|
||||
from .utils import AESCipher, log
|
||||
from .config import Config as FeishuConfig
|
||||
from .message import Message, MessageSegment, MessageSerializer
|
||||
from .exception import ActionFailed, NetworkError, ApiNotAvailable
|
||||
from .event import (Event, MessageEvent, GroupMessageEvent, PrivateMessageEvent,
|
||||
get_event_model)
|
||||
from .event import (
|
||||
Event,
|
||||
MessageEvent,
|
||||
GroupMessageEvent,
|
||||
PrivateMessageEvent,
|
||||
get_event_model,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.config import Config
|
||||
@ -47,8 +62,10 @@ def _check_at_me(bot: "Bot", event: "Event"):
|
||||
event.to_me = True
|
||||
|
||||
for index, segment in enumerate(message):
|
||||
if segment.type == "at" and segment.data.get(
|
||||
"user_name") in bot.config.nickname:
|
||||
if (
|
||||
segment.type == "at"
|
||||
and segment.data.get("user_name") in bot.config.nickname
|
||||
):
|
||||
event.to_me = True
|
||||
del event.event.message.content[index]
|
||||
return
|
||||
@ -57,7 +74,8 @@ def _check_at_me(bot: "Bot", event: "Event"):
|
||||
if mention["name"] in bot.config.nickname:
|
||||
event.to_me = True
|
||||
segment.data["text"] = segment.data["text"].replace(
|
||||
f"@{mention['name']}", "")
|
||||
f"@{mention['name']}", ""
|
||||
)
|
||||
segment.data["text"] = segment.data["text"].lstrip()
|
||||
break
|
||||
else:
|
||||
@ -92,18 +110,18 @@ def _check_nickname(bot: "Bot", event: "Event"):
|
||||
if nicknames:
|
||||
# check if the user is calling me with my nickname
|
||||
nickname_regex = "|".join(nicknames)
|
||||
m = re.search(rf"^({nickname_regex})([\s,,]*|$)", first_text,
|
||||
re.IGNORECASE)
|
||||
m = re.search(rf"^({nickname_regex})([\s,,]*|$)", first_text, re.IGNORECASE)
|
||||
if m:
|
||||
nickname = m.group(1)
|
||||
log("DEBUG", f"User is calling me {nickname}")
|
||||
event.to_me = True
|
||||
first_msg_seg.data["text"] = first_text[m.end():]
|
||||
first_msg_seg.data["text"] = first_text[m.end() :]
|
||||
|
||||
|
||||
def _handle_api_result(
|
||||
result: Union[Optional[Dict[str, Any]], str, bytes, Iterable[bytes],
|
||||
AsyncIterable[bytes]]
|
||||
result: Union[
|
||||
Optional[Dict[str, Any]], str, bytes, Iterable[bytes], AsyncIterable[bytes]
|
||||
]
|
||||
) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
@ -155,13 +173,13 @@ class Bot(BaseBot):
|
||||
@classmethod
|
||||
@overrides(BaseBot)
|
||||
async def check_permission(
|
||||
cls, driver: Driver, request: HTTPRequest
|
||||
cls, driver: Driver, request: HTTPRequest
|
||||
) -> Tuple[Optional[str], Optional[HTTPResponse]]:
|
||||
if not isinstance(request, HTTPRequest):
|
||||
log("WARNING",
|
||||
"Unsupported connection type, available type: `http`")
|
||||
log("WARNING", "Unsupported connection type, available type: `http`")
|
||||
return None, HTTPResponse(
|
||||
405, b"Unsupported connection type, available type: `http`")
|
||||
405, b"Unsupported connection type, available type: `http`"
|
||||
)
|
||||
|
||||
encrypt_key = cls.feishu_config.encrypt_key
|
||||
if encrypt_key:
|
||||
@ -174,16 +192,13 @@ class Bot(BaseBot):
|
||||
challenge = data.get("challenge")
|
||||
if challenge:
|
||||
return data.get("token"), HTTPResponse(
|
||||
200,
|
||||
json.dumps({
|
||||
"challenge": challenge
|
||||
}).encode())
|
||||
200, json.dumps({"challenge": challenge}).encode()
|
||||
)
|
||||
|
||||
schema = data.get("schema")
|
||||
if not schema:
|
||||
return None, HTTPResponse(
|
||||
400,
|
||||
b"Missing `schema` in POST body, only accept event of version 2.0"
|
||||
400, b"Missing `schema` in POST body, only accept event of version 2.0"
|
||||
)
|
||||
|
||||
headers = data.get("header")
|
||||
@ -196,15 +211,13 @@ class Bot(BaseBot):
|
||||
|
||||
if not token:
|
||||
log("WARNING", "Missing `verification token` in POST body")
|
||||
return None, HTTPResponse(
|
||||
400, b"Missing `verification token` in POST body")
|
||||
return None, HTTPResponse(400, b"Missing `verification token` in POST body")
|
||||
else:
|
||||
if token != cls.feishu_config.verification_token:
|
||||
log("WARNING", "Verification token check failed")
|
||||
return None, HTTPResponse(403,
|
||||
b"Verification token check failed")
|
||||
return None, HTTPResponse(403, b"Verification token check failed")
|
||||
|
||||
return app_id, HTTPResponse(200, b'')
|
||||
return app_id, HTTPResponse(200, b"")
|
||||
|
||||
async def handle_message(self, message: bytes):
|
||||
"""
|
||||
@ -245,28 +258,32 @@ class Bot(BaseBot):
|
||||
def _construct_url(self, path: str) -> str:
|
||||
return self.api_root + path
|
||||
|
||||
@cached(ttl=60 * 60,
|
||||
cache=Cache.MEMORY,
|
||||
key="_feishu_tenant_access_token",
|
||||
serializer=PickleSerializer())
|
||||
@cached(
|
||||
ttl=60 * 60,
|
||||
cache=Cache.MEMORY,
|
||||
key="_feishu_tenant_access_token",
|
||||
serializer=PickleSerializer(),
|
||||
)
|
||||
async def _fetch_tenant_access_token(self) -> str:
|
||||
try:
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
response = await client.post(
|
||||
self._construct_url(
|
||||
"auth/v3/tenant_access_token/internal/"),
|
||||
self._construct_url("auth/v3/tenant_access_token/internal/"),
|
||||
json={
|
||||
"app_id": self.feishu_config.app_id,
|
||||
"app_secret": self.feishu_config.app_secret
|
||||
"app_secret": self.feishu_config.app_secret,
|
||||
},
|
||||
timeout=self.config.api_timeout)
|
||||
timeout=self.config.api_timeout,
|
||||
)
|
||||
|
||||
if 200 <= response.status_code < 300:
|
||||
result = response.json()
|
||||
return result["tenant_access_token"]
|
||||
else:
|
||||
raise NetworkError(f"HTTP request received unexpected "
|
||||
f"status code: {response.status_code}")
|
||||
raise NetworkError(
|
||||
f"HTTP request received unexpected "
|
||||
f"status code: {response.status_code}"
|
||||
)
|
||||
except httpx.InvalidURL:
|
||||
raise NetworkError("API root url invalid")
|
||||
except httpx.HTTPError:
|
||||
@ -280,30 +297,37 @@ class Bot(BaseBot):
|
||||
raise ApiNotAvailable
|
||||
|
||||
headers = {}
|
||||
self.feishu_config.tenant_access_token = await self._fetch_tenant_access_token(
|
||||
self.feishu_config.tenant_access_token = (
|
||||
await self._fetch_tenant_access_token()
|
||||
)
|
||||
headers["Authorization"] = (
|
||||
"Bearer " + self.feishu_config.tenant_access_token
|
||||
)
|
||||
headers[
|
||||
"Authorization"] = "Bearer " + self.feishu_config.tenant_access_token
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.config.api_timeout,
|
||||
follow_redirects=True) as client:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=self.config.api_timeout, follow_redirects=True
|
||||
) as client:
|
||||
response = await client.send(
|
||||
httpx.Request(data["method"],
|
||||
self.api_root + api,
|
||||
json=data.get("body", {}),
|
||||
params=data.get("query", {}),
|
||||
headers=headers))
|
||||
httpx.Request(
|
||||
data["method"],
|
||||
self.api_root + api,
|
||||
json=data.get("body", {}),
|
||||
params=data.get("query", {}),
|
||||
headers=headers,
|
||||
)
|
||||
)
|
||||
if 200 <= response.status_code < 300:
|
||||
if response.headers["content-type"].startswith(
|
||||
"application/json"):
|
||||
if response.headers["content-type"].startswith("application/json"):
|
||||
result = response.json()
|
||||
else:
|
||||
result = response.content
|
||||
return _handle_api_result(result)
|
||||
raise NetworkError(f"HTTP request received unexpected "
|
||||
f"status code: {response.status_code} "
|
||||
f"response body: {response.text}")
|
||||
raise NetworkError(
|
||||
f"HTTP request received unexpected "
|
||||
f"status code: {response.status_code} "
|
||||
f"response body: {response.text}"
|
||||
)
|
||||
except httpx.InvalidURL:
|
||||
raise NetworkError("API root url invalid")
|
||||
except httpx.HTTPError:
|
||||
@ -333,11 +357,13 @@ class Bot(BaseBot):
|
||||
return await super().call_api(api, **data)
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def send(self,
|
||||
event: Event,
|
||||
message: Union[str, Message, MessageSegment],
|
||||
at_sender: bool = False,
|
||||
**kwargs) -> Any:
|
||||
async def send(
|
||||
self,
|
||||
event: Event,
|
||||
message: Union[str, Message, MessageSegment],
|
||||
at_sender: bool = False,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
msg = message if isinstance(message, Message) else Message(message)
|
||||
|
||||
if isinstance(event, GroupMessageEvent):
|
||||
@ -346,7 +372,8 @@ class Bot(BaseBot):
|
||||
receive_id, receive_id_type = event.get_user_id(), "open_id"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot guess `receive_id` and `receive_id_type` to reply!")
|
||||
"Cannot guess `receive_id` and `receive_id_type` to reply!"
|
||||
)
|
||||
|
||||
at_sender = at_sender and bool(event.get_user_id())
|
||||
|
||||
@ -357,14 +384,12 @@ class Bot(BaseBot):
|
||||
|
||||
params = {
|
||||
"method": "POST",
|
||||
"query": {
|
||||
"receive_id_type": receive_id_type
|
||||
},
|
||||
"query": {"receive_id_type": receive_id_type},
|
||||
"body": {
|
||||
"receive_id": receive_id,
|
||||
"content": content,
|
||||
"msg_type": msg_type
|
||||
}
|
||||
"msg_type": msg_type,
|
||||
},
|
||||
}
|
||||
|
||||
return await self.call_api(f"im/v1/messages", **params)
|
||||
|
@ -17,13 +17,16 @@ class Config(BaseModel):
|
||||
- ``is_lark`` / ``feishu_is_lark``: 是否使用Lark(飞书海外版),默认为 false
|
||||
|
||||
"""
|
||||
|
||||
app_id: Optional[str] = Field(default=None, alias="feishu_app_id")
|
||||
app_secret: Optional[str] = Field(default=None, alias="feishu_app_secret")
|
||||
encrypt_key: Optional[str] = Field(default=None, alias="feishu_encrypt_key")
|
||||
verification_token: Optional[str] = Field(default=None,
|
||||
alias="feishu_verification_token")
|
||||
verification_token: Optional[str] = Field(
|
||||
default=None, alias="feishu_verification_token"
|
||||
)
|
||||
tenant_access_token: Optional[str] = Field(
|
||||
default=None, alias="feishu_tenant_access_token")
|
||||
default=None, alias="feishu_tenant_access_token"
|
||||
)
|
||||
is_lark: Optional[str] = Field(default=False, alias="feishu_is_lark")
|
||||
|
||||
class Config:
|
||||
|
@ -1,12 +1,12 @@
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, Type
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Type, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
from pygtrie import StringTrie
|
||||
from pydantic import Field, BaseModel, root_validator
|
||||
|
||||
from nonebot.adapters import Event as BaseEvent
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.adapters import Event as BaseEvent
|
||||
|
||||
from .message import Message, MessageDeserializer
|
||||
|
||||
|
@ -1,13 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
from nonebot.exception import ActionFailed as BaseActionFailed
|
||||
from nonebot.exception import AdapterException
|
||||
from nonebot.exception import ApiNotAvailable as BaseApiNotAvailable
|
||||
from nonebot.exception import ActionFailed as BaseActionFailed
|
||||
from nonebot.exception import NetworkError as BaseNetworkError
|
||||
from nonebot.exception import ApiNotAvailable as BaseApiNotAvailable
|
||||
|
||||
|
||||
class FeishuAdapterException(AdapterException):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("feishu")
|
||||
|
||||
@ -28,8 +27,11 @@ class ActionFailed(BaseActionFailed, FeishuAdapterException):
|
||||
self.info = kwargs
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ActionFailed " + ", ".join(
|
||||
f"{k}={v}" for k, v in self.info.items()) + ">"
|
||||
return (
|
||||
f"<ActionFailed "
|
||||
+ ", ".join(f"{k}={v}" for k, v in self.info.items())
|
||||
+ ">"
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
@ -1,8 +1,18 @@
|
||||
import json
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from typing import (Any, Dict, List, Type, Tuple, Union, Mapping, Iterable,
|
||||
Optional, cast)
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Type,
|
||||
Tuple,
|
||||
Union,
|
||||
Mapping,
|
||||
Iterable,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.adapters import Message as BaseMessage
|
||||
@ -34,7 +44,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
|
||||
"share_user": "[个人名片]",
|
||||
"system": "[系统消息]",
|
||||
"location": "[位置]",
|
||||
"video_chat": "[视频通话]"
|
||||
"video_chat": "[视频通话]",
|
||||
}
|
||||
|
||||
def __str__(self) -> str:
|
||||
@ -47,24 +57,26 @@ class MessageSegment(BaseMessageSegment["Message"]):
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def __add__(self, other) -> "Message":
|
||||
return Message(self) + (MessageSegment.text(other) if isinstance(
|
||||
other, str) else other)
|
||||
return Message(self) + (
|
||||
MessageSegment.text(other) if isinstance(other, str) else other
|
||||
)
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def __radd__(self, other) -> "Message":
|
||||
return (MessageSegment.text(other)
|
||||
if isinstance(other, str) else Message(other)) + self
|
||||
return (
|
||||
MessageSegment.text(other) if isinstance(other, str) else Message(other)
|
||||
) + self
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def is_text(self) -> bool:
|
||||
return self.type == "text"
|
||||
|
||||
#接收消息
|
||||
# 接收消息
|
||||
@staticmethod
|
||||
def at(user_id: str) -> "MessageSegment":
|
||||
return MessageSegment("at", {"user_id": user_id})
|
||||
|
||||
#发送消息
|
||||
# 发送消息
|
||||
@staticmethod
|
||||
def text(text: str) -> "MessageSegment":
|
||||
return MessageSegment("text", {"text": text})
|
||||
@ -79,10 +91,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
|
||||
|
||||
@staticmethod
|
||||
def interactive(title: str, elements: list) -> "MessageSegment":
|
||||
return MessageSegment("interactive", {
|
||||
"title": title,
|
||||
"elements": elements
|
||||
})
|
||||
return MessageSegment("interactive", {"title": title, "elements": elements})
|
||||
|
||||
@staticmethod
|
||||
def share_chat(chat_id: str) -> "MessageSegment":
|
||||
@ -94,28 +103,25 @@ class MessageSegment(BaseMessageSegment["Message"]):
|
||||
|
||||
@staticmethod
|
||||
def audio(file_key: str, duration: int) -> "MessageSegment":
|
||||
return MessageSegment("audio", {
|
||||
"file_key": file_key,
|
||||
"duration": duration
|
||||
})
|
||||
return MessageSegment("audio", {"file_key": file_key, "duration": duration})
|
||||
|
||||
@staticmethod
|
||||
def media(file_key: str, image_key: str, file_name: str,
|
||||
duration: int) -> "MessageSegment":
|
||||
def media(
|
||||
file_key: str, image_key: str, file_name: str, duration: int
|
||||
) -> "MessageSegment":
|
||||
return MessageSegment(
|
||||
"media", {
|
||||
"media",
|
||||
{
|
||||
"file_key": file_key,
|
||||
"image_key": image_key,
|
||||
"file_name": file_name,
|
||||
"duration": duration
|
||||
})
|
||||
"duration": duration,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def file(file_key: str, file_name: str) -> "MessageSegment":
|
||||
return MessageSegment("file", {
|
||||
"file_key": file_key,
|
||||
"file_name": file_name
|
||||
})
|
||||
return MessageSegment("file", {"file_key": file_key, "file_name": file_name})
|
||||
|
||||
@staticmethod
|
||||
def sticker(file_key) -> "MessageSegment":
|
||||
@ -133,22 +139,22 @@ class Message(BaseMessage[MessageSegment]):
|
||||
return MessageSegment
|
||||
|
||||
@overrides(BaseMessage)
|
||||
def __add__(self, other: Union[str, Mapping,
|
||||
Iterable[Mapping]]) -> "Message":
|
||||
def __add__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> "Message":
|
||||
return super(Message, self).__add__(
|
||||
MessageSegment.text(other) if isinstance(other, str) else other)
|
||||
MessageSegment.text(other) if isinstance(other, str) else other
|
||||
)
|
||||
|
||||
@overrides(BaseMessage)
|
||||
def __radd__(self, other: Union[str, Mapping,
|
||||
Iterable[Mapping]]) -> "Message":
|
||||
def __radd__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> "Message":
|
||||
return super(Message, self).__radd__(
|
||||
MessageSegment.text(other) if isinstance(other, str) else other)
|
||||
MessageSegment.text(other) if isinstance(other, str) else other
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@overrides(BaseMessage)
|
||||
def _construct(
|
||||
msg: Union[str, Mapping,
|
||||
Iterable[Mapping]]) -> Iterable[MessageSegment]:
|
||||
msg: Union[str, Mapping, Iterable[Mapping]]
|
||||
) -> Iterable[MessageSegment]:
|
||||
if isinstance(msg, Mapping):
|
||||
msg = cast(Mapping[str, Any], msg)
|
||||
yield MessageSegment(msg["type"], msg.get("data") or {})
|
||||
@ -169,7 +175,8 @@ class Message(BaseMessage[MessageSegment]):
|
||||
for i, seg in enumerate(self):
|
||||
if seg.type == "text" and i != 0 and msg[-1].type == "text":
|
||||
msg[-1] = MessageSegment(
|
||||
"text", {"text": msg[-1].data["text"] + seg.data["text"]})
|
||||
"text", {"text": msg[-1].data["text"] + seg.data["text"]}
|
||||
)
|
||||
else:
|
||||
msg.append(seg)
|
||||
return Message(msg)
|
||||
@ -184,6 +191,7 @@ class MessageSerializer:
|
||||
"""
|
||||
飞书 协议 Message 序列化器。
|
||||
"""
|
||||
|
||||
message: Message
|
||||
|
||||
def serialize(self) -> Tuple[str, str]:
|
||||
@ -198,10 +206,12 @@ class MessageSerializer:
|
||||
else:
|
||||
if last_segment_type == "image":
|
||||
msg["content"].append([])
|
||||
msg["content"][-1].append({
|
||||
"tag": segment.type if segment.type != "image" else "img",
|
||||
**segment.data
|
||||
})
|
||||
msg["content"][-1].append(
|
||||
{
|
||||
"tag": segment.type if segment.type != "image" else "img",
|
||||
**segment.data,
|
||||
}
|
||||
)
|
||||
last_segment_type = segment.type
|
||||
return "post", json.dumps({"zh_cn": {**msg}})
|
||||
|
||||
@ -214,6 +224,7 @@ class MessageDeserializer:
|
||||
"""
|
||||
飞书 协议 Message 反序列化器。
|
||||
"""
|
||||
|
||||
type: str
|
||||
data: Dict[str, Any]
|
||||
mentions: Optional[List[dict]]
|
||||
@ -227,14 +238,13 @@ class MessageDeserializer:
|
||||
if self.type == "post":
|
||||
msg = Message()
|
||||
if self.data["title"] != "":
|
||||
msg += MessageSegment("text", {'text': self.data["title"]})
|
||||
msg += MessageSegment("text", {"text": self.data["title"]})
|
||||
|
||||
for seg in itertools.chain(*self.data["content"]):
|
||||
tag = seg.pop("tag")
|
||||
if tag == "at":
|
||||
seg["user_name"] = dict_mention[seg["user_id"]]["name"]
|
||||
seg["user_id"] = dict_mention[
|
||||
seg["user_id"]]["id"]["open_id"]
|
||||
seg["user_id"] = dict_mention[seg["user_id"]]["id"]["open_id"]
|
||||
|
||||
msg += MessageSegment(tag if tag != "img" else "image", seg)
|
||||
|
||||
@ -242,7 +252,8 @@ class MessageDeserializer:
|
||||
elif self.type == "text":
|
||||
for key, mention in dict_mention.items():
|
||||
self.data["text"] = self.data["text"].replace(
|
||||
key, f"@{mention['name']}")
|
||||
key, f"@{mention['name']}"
|
||||
)
|
||||
self.data["mentions"] = dict_mention
|
||||
|
||||
return Message(MessageSegment(self.type, self.data))
|
||||
|
@ -9,27 +9,26 @@ log = logger_wrapper("FEISHU")
|
||||
|
||||
|
||||
class AESCipher(object):
|
||||
|
||||
def __init__(self, key):
|
||||
self.block_size = AES.block_size
|
||||
self.key = hashlib.sha256(AESCipher.str_to_bytes(key)).digest()
|
||||
|
||||
@staticmethod
|
||||
def str_to_bytes(data):
|
||||
u_type = type(b"".decode('utf8'))
|
||||
u_type = type(b"".decode("utf8"))
|
||||
if isinstance(data, u_type):
|
||||
return data.encode('utf8')
|
||||
return data.encode("utf8")
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _unpad(s):
|
||||
return s[:-ord(s[len(s) - 1:])]
|
||||
return s[: -ord(s[len(s) - 1 :])]
|
||||
|
||||
def decrypt(self, enc):
|
||||
iv = enc[:AES.block_size]
|
||||
iv = enc[: AES.block_size]
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
return self._unpad(cipher.decrypt(enc[AES.block_size:]))
|
||||
return self._unpad(cipher.decrypt(enc[AES.block_size :]))
|
||||
|
||||
def decrypt_string(self, enc):
|
||||
enc = base64.b64decode(enc)
|
||||
return self.decrypt(enc).decode('utf8')
|
||||
return self.decrypt(enc).decode("utf8")
|
||||
|
@ -36,6 +36,21 @@ nonebot2 = { path = "../../", develop = true }
|
||||
# url = "https://mirrors.aliyun.com/pypi/simple/"
|
||||
# default = true
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ["py37", "py38", "py39"]
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '''
|
||||
'''
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 80
|
||||
length_sort = true
|
||||
skip_gitignore = true
|
||||
force_sort_within_sections = true
|
||||
extra_standard_library = ["typing_extensions"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
@ -12,8 +12,15 @@ from nonebot.config import Config
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.adapters import Bot as BaseBot
|
||||
from nonebot.exception import ApiNotAvailable
|
||||
from nonebot.drivers import (Driver, WebSocket, HTTPResponse, ForwardDriver,
|
||||
ReverseDriver, HTTPConnection, WebSocketSetup)
|
||||
from nonebot.drivers import (
|
||||
Driver,
|
||||
WebSocket,
|
||||
HTTPResponse,
|
||||
ForwardDriver,
|
||||
ReverseDriver,
|
||||
HTTPConnection,
|
||||
WebSocketSetup,
|
||||
)
|
||||
|
||||
from .config import Config as MiraiConfig
|
||||
from .message import MessageChain, MessageSegment
|
||||
@ -23,16 +30,14 @@ from .utils import Log, process_event, argument_validation, catch_network_error
|
||||
|
||||
class SessionManager:
|
||||
"""Bot会话管理器, 提供API主动调用接口"""
|
||||
|
||||
sessions: Dict[int, Tuple[str, httpx.AsyncClient]] = {}
|
||||
|
||||
def __init__(self, session_key: str, client: httpx.AsyncClient):
|
||||
self.session_key, self.client = session_key, client
|
||||
|
||||
@catch_network_error
|
||||
async def post(self,
|
||||
path: str,
|
||||
*,
|
||||
params: Optional[Dict[str, Any]] = None) -> Any:
|
||||
async def post(self, path: str, *, params: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -51,7 +56,7 @@ class SessionManager:
|
||||
path,
|
||||
json={
|
||||
**(params or {}),
|
||||
'sessionKey': self.session_key,
|
||||
"sessionKey": self.session_key,
|
||||
},
|
||||
timeout=3,
|
||||
)
|
||||
@ -59,10 +64,9 @@ class SessionManager:
|
||||
return response.json()
|
||||
|
||||
@catch_network_error
|
||||
async def request(self,
|
||||
path: str,
|
||||
*,
|
||||
params: Optional[Dict[str, Any]] = None) -> Any:
|
||||
async def request(
|
||||
self, path: str, *, params: Optional[Dict[str, Any]] = None
|
||||
) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -77,7 +81,7 @@ class SessionManager:
|
||||
path,
|
||||
params={
|
||||
**(params or {}),
|
||||
'sessionKey': self.session_key,
|
||||
"sessionKey": self.session_key,
|
||||
},
|
||||
timeout=3,
|
||||
)
|
||||
@ -98,7 +102,7 @@ class SessionManager:
|
||||
"""
|
||||
files = {k: v for k, v in params.items() if isinstance(v, BytesIO)}
|
||||
form = {k: v for k, v in params.items() if k not in files}
|
||||
form['sessionKey'] = self.session_key
|
||||
form["sessionKey"] = self.session_key
|
||||
response = await self.client.post(
|
||||
path,
|
||||
data=form,
|
||||
@ -109,25 +113,25 @@ class SessionManager:
|
||||
return response.json()
|
||||
|
||||
@classmethod
|
||||
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
|
||||
auth_key: str) -> "SessionManager":
|
||||
async def new(
|
||||
cls, self_id: int, *, host: IPv4Address, port: int, auth_key: str
|
||||
) -> "SessionManager":
|
||||
session = cls.get(self_id)
|
||||
if session is not None:
|
||||
return session
|
||||
|
||||
client = httpx.AsyncClient(base_url=f'http://{host}:{port}',
|
||||
follow_redirects=True)
|
||||
response = await client.post('/auth', json={'authKey': auth_key})
|
||||
client = httpx.AsyncClient(
|
||||
base_url=f"http://{host}:{port}", follow_redirects=True
|
||||
)
|
||||
response = await client.post("/auth", json={"authKey": auth_key})
|
||||
response.raise_for_status()
|
||||
auth = response.json()
|
||||
assert auth['code'] == 0
|
||||
session_key = auth['session']
|
||||
response = await client.post('/verify',
|
||||
json={
|
||||
'sessionKey': session_key,
|
||||
'qq': self_id
|
||||
})
|
||||
assert response.json()['code'] == 0
|
||||
assert auth["code"] == 0
|
||||
session_key = auth["session"]
|
||||
response = await client.post(
|
||||
"/verify", json={"sessionKey": session_key, "qq": self_id}
|
||||
)
|
||||
assert response.json()["code"] == 0
|
||||
cls.sessions[self_id] = session_key, client
|
||||
|
||||
return cls(session_key, client)
|
||||
@ -152,7 +156,7 @@ class Bot(BaseBot):
|
||||
|
||||
"""
|
||||
|
||||
_type = 'mirai'
|
||||
_type = "mirai"
|
||||
|
||||
@property
|
||||
@overrides(BaseBot)
|
||||
@ -166,37 +170,42 @@ class Bot(BaseBot):
|
||||
if api is None:
|
||||
if isinstance(self.request, WebSocket):
|
||||
asyncio.create_task(self.request.close(1000))
|
||||
assert api is not None, 'SessionManager has not been initialized'
|
||||
assert api is not None, "SessionManager has not been initialized"
|
||||
return api
|
||||
|
||||
@classmethod
|
||||
@overrides(BaseBot)
|
||||
async def check_permission(
|
||||
cls, driver: Driver,
|
||||
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]:
|
||||
cls, driver: Driver, request: HTTPConnection
|
||||
) -> Tuple[Optional[str], HTTPResponse]:
|
||||
if isinstance(request, WebSocket):
|
||||
return None, HTTPResponse(
|
||||
501, b'Websocket connection is not implemented')
|
||||
self_id: Optional[str] = request.headers.get('bot')
|
||||
return None, HTTPResponse(501, b"Websocket connection is not implemented")
|
||||
self_id: Optional[str] = request.headers.get("bot")
|
||||
if self_id is None:
|
||||
return None, HTTPResponse(400, b'Header `Bot` is required.')
|
||||
return None, HTTPResponse(400, b"Header `Bot` is required.")
|
||||
self_id = str(self_id).strip()
|
||||
await SessionManager.new(
|
||||
int(self_id),
|
||||
host=cls.mirai_config.host, # type: ignore
|
||||
port=cls.mirai_config.port, #type: ignore
|
||||
auth_key=cls.mirai_config.auth_key) # type: ignore
|
||||
return self_id, HTTPResponse(204, b'')
|
||||
port=cls.mirai_config.port, # type: ignore
|
||||
auth_key=cls.mirai_config.auth_key, # type: ignore
|
||||
)
|
||||
return self_id, HTTPResponse(204, b"")
|
||||
|
||||
@classmethod
|
||||
@overrides(BaseBot)
|
||||
def register(cls,
|
||||
driver: Driver,
|
||||
config: "Config",
|
||||
qq: Optional[Union[int, List[int]]] = None):
|
||||
def register(
|
||||
cls,
|
||||
driver: Driver,
|
||||
config: "Config",
|
||||
qq: Optional[Union[int, List[int]]] = None,
|
||||
):
|
||||
cls.mirai_config = MiraiConfig(**config.dict())
|
||||
if (cls.mirai_config.auth_key and cls.mirai_config.host and
|
||||
cls.mirai_config.port) is None:
|
||||
if (
|
||||
cls.mirai_config.auth_key
|
||||
and cls.mirai_config.host
|
||||
and cls.mirai_config.port
|
||||
) is None:
|
||||
raise ApiNotAvailable(cls._type)
|
||||
|
||||
super().register(driver, config)
|
||||
@ -209,17 +218,25 @@ class Bot(BaseBot):
|
||||
self_ids = [qq] if isinstance(qq, int) else qq
|
||||
|
||||
async def url_factory(qq: int):
|
||||
assert cls.mirai_config.host and cls.mirai_config.port and cls.mirai_config.auth_key
|
||||
assert (
|
||||
cls.mirai_config.host
|
||||
and cls.mirai_config.port
|
||||
and cls.mirai_config.auth_key
|
||||
)
|
||||
session = await SessionManager.new(
|
||||
qq,
|
||||
host=cls.mirai_config.host,
|
||||
port=cls.mirai_config.port,
|
||||
auth_key=cls.mirai_config.auth_key)
|
||||
auth_key=cls.mirai_config.auth_key,
|
||||
)
|
||||
return WebSocketSetup(
|
||||
adapter=cls._type,
|
||||
self_id=str(qq),
|
||||
url=(f'ws://{cls.mirai_config.host}:{cls.mirai_config.port}'
|
||||
f'/all?sessionKey={session.session_key}'))
|
||||
url=(
|
||||
f"ws://{cls.mirai_config.host}:{cls.mirai_config.port}"
|
||||
f"/all?sessionKey={session.session_key}"
|
||||
),
|
||||
)
|
||||
|
||||
for self_id in self_ids:
|
||||
driver.setup_websocket(partial(url_factory, qq=self_id))
|
||||
@ -234,13 +251,15 @@ class Bot(BaseBot):
|
||||
try:
|
||||
await process_event(
|
||||
bot=self,
|
||||
event=Event.new({
|
||||
**json.loads(message),
|
||||
'self_id': self.self_id,
|
||||
}),
|
||||
event=Event.new(
|
||||
{
|
||||
**json.loads(message),
|
||||
"self_id": self.self_id,
|
||||
}
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
Log.error(f'Failed to handle message: {message}', e)
|
||||
Log.error(f"Failed to handle message: {message}", e)
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def _call_api(self, api: str, **data) -> NoReturn:
|
||||
@ -266,10 +285,12 @@ class Bot(BaseBot):
|
||||
|
||||
@overrides(BaseBot)
|
||||
@argument_validation
|
||||
async def send(self,
|
||||
event: Event,
|
||||
message: Union[MessageChain, MessageSegment, str],
|
||||
at_sender: bool = False):
|
||||
async def send(
|
||||
self,
|
||||
event: Event,
|
||||
message: Union[MessageChain, MessageSegment, str],
|
||||
at_sender: bool = False,
|
||||
):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -284,23 +305,24 @@ class Bot(BaseBot):
|
||||
if not isinstance(message, MessageChain):
|
||||
message = MessageChain(message)
|
||||
if isinstance(event, FriendMessage):
|
||||
return await self.send_friend_message(target=event.sender.id,
|
||||
message_chain=message)
|
||||
return await self.send_friend_message(
|
||||
target=event.sender.id, message_chain=message
|
||||
)
|
||||
elif isinstance(event, GroupMessage):
|
||||
if at_sender:
|
||||
message = MessageSegment.at(event.sender.id) + message
|
||||
return await self.send_group_message(group=event.sender.group.id,
|
||||
message_chain=message)
|
||||
return await self.send_group_message(
|
||||
group=event.sender.group.id, message_chain=message
|
||||
)
|
||||
elif isinstance(event, TempMessage):
|
||||
return await self.send_temp_message(qq=event.sender.id,
|
||||
group=event.sender.group.id,
|
||||
message_chain=message)
|
||||
return await self.send_temp_message(
|
||||
qq=event.sender.id, group=event.sender.group.id, message_chain=message
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unsupported event type {event!r}.')
|
||||
raise ValueError(f"Unsupported event type {event!r}.")
|
||||
|
||||
@argument_validation
|
||||
async def send_friend_message(self, target: int,
|
||||
message_chain: MessageChain):
|
||||
async def send_friend_message(self, target: int, message_chain: MessageChain):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -311,15 +333,13 @@ class Bot(BaseBot):
|
||||
* ``target: int``: 发送消息目标好友的 QQ 号
|
||||
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
|
||||
"""
|
||||
return await self.api.post('sendFriendMessage',
|
||||
params={
|
||||
'target': target,
|
||||
'messageChain': message_chain.export()
|
||||
})
|
||||
return await self.api.post(
|
||||
"sendFriendMessage",
|
||||
params={"target": target, "messageChain": message_chain.export()},
|
||||
)
|
||||
|
||||
@argument_validation
|
||||
async def send_temp_message(self, qq: int, group: int,
|
||||
message_chain: MessageChain):
|
||||
async def send_temp_message(self, qq: int, group: int, message_chain: MessageChain):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -331,18 +351,15 @@ class Bot(BaseBot):
|
||||
* ``group: int``: 临时会话群号
|
||||
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
|
||||
"""
|
||||
return await self.api.post('sendTempMessage',
|
||||
params={
|
||||
'qq': qq,
|
||||
'group': group,
|
||||
'messageChain': message_chain.export()
|
||||
})
|
||||
return await self.api.post(
|
||||
"sendTempMessage",
|
||||
params={"qq": qq, "group": group, "messageChain": message_chain.export()},
|
||||
)
|
||||
|
||||
@argument_validation
|
||||
async def send_group_message(self,
|
||||
group: int,
|
||||
message_chain: MessageChain,
|
||||
quote: Optional[int] = None):
|
||||
async def send_group_message(
|
||||
self, group: int, message_chain: MessageChain, quote: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -354,12 +371,14 @@ class Bot(BaseBot):
|
||||
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
|
||||
* ``quote: Optional[int]``: 引用一条消息的 message_id 进行回复
|
||||
"""
|
||||
return await self.api.post('sendGroupMessage',
|
||||
params={
|
||||
'group': group,
|
||||
'messageChain': message_chain.export(),
|
||||
'quote': quote
|
||||
})
|
||||
return await self.api.post(
|
||||
"sendGroupMessage",
|
||||
params={
|
||||
"group": group,
|
||||
"messageChain": message_chain.export(),
|
||||
"quote": quote,
|
||||
},
|
||||
)
|
||||
|
||||
@argument_validation
|
||||
async def recall(self, target: int):
|
||||
@ -372,11 +391,12 @@ class Bot(BaseBot):
|
||||
|
||||
* ``target: int``: 需要撤回的消息的message_id
|
||||
"""
|
||||
return await self.api.post('recall', params={'target': target})
|
||||
return await self.api.post("recall", params={"target": target})
|
||||
|
||||
@argument_validation
|
||||
async def send_image_message(self, target: int, qq: int, group: int,
|
||||
urls: List[str]) -> List[str]:
|
||||
async def send_image_message(
|
||||
self, target: int, qq: int, group: int, urls: List[str]
|
||||
) -> List[str]:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -396,13 +416,10 @@ class Bot(BaseBot):
|
||||
|
||||
- ``List[str]``: 一个包含图片imageId的数组
|
||||
"""
|
||||
return await self.api.post('sendImageMessage',
|
||||
params={
|
||||
'target': target,
|
||||
'qq': qq,
|
||||
'group': group,
|
||||
'urls': urls
|
||||
})
|
||||
return await self.api.post(
|
||||
"sendImageMessage",
|
||||
params={"target": target, "qq": qq, "group": group, "urls": urls},
|
||||
)
|
||||
|
||||
@argument_validation
|
||||
async def upload_image(self, type: str, img: BytesIO):
|
||||
@ -416,11 +433,7 @@ class Bot(BaseBot):
|
||||
* ``type: str``: "friend" 或 "group" 或 "temp"
|
||||
* ``img: BytesIO``: 图片的BytesIO对象
|
||||
"""
|
||||
return await self.api.upload('uploadImage',
|
||||
params={
|
||||
'type': type,
|
||||
'img': img
|
||||
})
|
||||
return await self.api.upload("uploadImage", params={"type": type, "img": img})
|
||||
|
||||
@argument_validation
|
||||
async def upload_voice(self, type: str, voice: BytesIO):
|
||||
@ -434,11 +447,9 @@ class Bot(BaseBot):
|
||||
* ``type: str``: 当前仅支持 "group"
|
||||
* ``voice: BytesIO``: 语音的BytesIO对象
|
||||
"""
|
||||
return await self.api.upload('uploadVoice',
|
||||
params={
|
||||
'type': type,
|
||||
'voice': voice
|
||||
})
|
||||
return await self.api.upload(
|
||||
"uploadVoice", params={"type": type, "voice": voice}
|
||||
)
|
||||
|
||||
@argument_validation
|
||||
async def fetch_message(self, count: int = 10):
|
||||
@ -452,7 +463,7 @@ class Bot(BaseBot):
|
||||
|
||||
* ``count: int``: 获取消息和事件的数量
|
||||
"""
|
||||
return await self.api.request('fetchMessage', params={'count': count})
|
||||
return await self.api.request("fetchMessage", params={"count": count})
|
||||
|
||||
@argument_validation
|
||||
async def fetch_latest_message(self, count: int = 10):
|
||||
@ -466,8 +477,7 @@ class Bot(BaseBot):
|
||||
|
||||
* ``count: int``: 获取消息和事件的数量
|
||||
"""
|
||||
return await self.api.request('fetchLatestMessage',
|
||||
params={'count': count})
|
||||
return await self.api.request("fetchLatestMessage", params={"count": count})
|
||||
|
||||
@argument_validation
|
||||
async def peek_message(self, count: int = 10):
|
||||
@ -481,7 +491,7 @@ class Bot(BaseBot):
|
||||
|
||||
* ``count: int``: 获取消息和事件的数量
|
||||
"""
|
||||
return await self.api.request('peekMessage', params={'count': count})
|
||||
return await self.api.request("peekMessage", params={"count": count})
|
||||
|
||||
@argument_validation
|
||||
async def peek_latest_message(self, count: int = 10):
|
||||
@ -495,8 +505,7 @@ class Bot(BaseBot):
|
||||
|
||||
* ``count: int``: 获取消息和事件的数量
|
||||
"""
|
||||
return await self.api.request('peekLatestMessage',
|
||||
params={'count': count})
|
||||
return await self.api.request("peekLatestMessage", params={"count": count})
|
||||
|
||||
@argument_validation
|
||||
async def messsage_from_id(self, id: int):
|
||||
@ -510,7 +519,7 @@ class Bot(BaseBot):
|
||||
|
||||
* ``id: int``: 获取消息的message_id
|
||||
"""
|
||||
return await self.api.request('messageFromId', params={'id': id})
|
||||
return await self.api.request("messageFromId", params={"id": id})
|
||||
|
||||
@argument_validation
|
||||
async def count_message(self):
|
||||
@ -519,7 +528,7 @@ class Bot(BaseBot):
|
||||
|
||||
使用此方法获取bot接收并缓存的消息总数,注意不包含被删除的
|
||||
"""
|
||||
return await self.api.request('countMessage')
|
||||
return await self.api.request("countMessage")
|
||||
|
||||
@argument_validation
|
||||
async def friend_list(self) -> List[Dict[str, Any]]:
|
||||
@ -532,7 +541,7 @@ class Bot(BaseBot):
|
||||
|
||||
- ``List[Dict[str, Any]]``: 返回的好友列表数据
|
||||
"""
|
||||
return await self.api.request('friendList')
|
||||
return await self.api.request("friendList")
|
||||
|
||||
@argument_validation
|
||||
async def group_list(self) -> List[Dict[str, Any]]:
|
||||
@ -545,7 +554,7 @@ class Bot(BaseBot):
|
||||
|
||||
- ``List[Dict[str, Any]]``: 返回的群列表数据
|
||||
"""
|
||||
return await self.api.request('groupList')
|
||||
return await self.api.request("groupList")
|
||||
|
||||
@argument_validation
|
||||
async def member_list(self, target: int) -> List[Dict[str, Any]]:
|
||||
@ -562,7 +571,7 @@ class Bot(BaseBot):
|
||||
|
||||
- ``List[Dict[str, Any]]``: 返回的群成员列表数据
|
||||
"""
|
||||
return await self.api.request('memberList', params={'target': target})
|
||||
return await self.api.request("memberList", params={"target": target})
|
||||
|
||||
@argument_validation
|
||||
async def mute(self, target: int, member_id: int, time: int):
|
||||
@ -577,12 +586,9 @@ class Bot(BaseBot):
|
||||
* ``member_id: int``: 指定群员QQ号
|
||||
* ``time: int``: 禁言时长,单位为秒,最多30天
|
||||
"""
|
||||
return await self.api.post('mute',
|
||||
params={
|
||||
'target': target,
|
||||
'memberId': member_id,
|
||||
'time': time
|
||||
})
|
||||
return await self.api.post(
|
||||
"mute", params={"target": target, "memberId": member_id, "time": time}
|
||||
)
|
||||
|
||||
@argument_validation
|
||||
async def unmute(self, target: int, member_id: int):
|
||||
@ -596,11 +602,9 @@ class Bot(BaseBot):
|
||||
* ``target: int``: 指定群的群号
|
||||
* ``member_id: int``: 指定群员QQ号
|
||||
"""
|
||||
return await self.api.post('unmute',
|
||||
params={
|
||||
'target': target,
|
||||
'memberId': member_id
|
||||
})
|
||||
return await self.api.post(
|
||||
"unmute", params={"target": target, "memberId": member_id}
|
||||
)
|
||||
|
||||
@argument_validation
|
||||
async def kick(self, target: int, member_id: int, msg: str):
|
||||
@ -615,12 +619,9 @@ class Bot(BaseBot):
|
||||
* ``member_id: int``: 指定群员QQ号
|
||||
* ``msg: str``: 信息
|
||||
"""
|
||||
return await self.api.post('kick',
|
||||
params={
|
||||
'target': target,
|
||||
'memberId': member_id,
|
||||
'msg': msg
|
||||
})
|
||||
return await self.api.post(
|
||||
"kick", params={"target": target, "memberId": member_id, "msg": msg}
|
||||
)
|
||||
|
||||
@argument_validation
|
||||
async def quit(self, target: int):
|
||||
@ -633,7 +634,7 @@ class Bot(BaseBot):
|
||||
|
||||
* ``target: int``: 退出的群号
|
||||
"""
|
||||
return await self.api.post('quit', params={'target': target})
|
||||
return await self.api.post("quit", params={"target": target})
|
||||
|
||||
@argument_validation
|
||||
async def mute_all(self, target: int):
|
||||
@ -646,7 +647,7 @@ class Bot(BaseBot):
|
||||
|
||||
* ``target: int``: 指定群的群号
|
||||
"""
|
||||
return await self.api.post('muteAll', params={'target': target})
|
||||
return await self.api.post("muteAll", params={"target": target})
|
||||
|
||||
@argument_validation
|
||||
async def unmute_all(self, target: int):
|
||||
@ -659,7 +660,7 @@ class Bot(BaseBot):
|
||||
|
||||
* ``target: int``: 指定群的群号
|
||||
"""
|
||||
return await self.api.post('unmuteAll', params={'target': target})
|
||||
return await self.api.post("unmuteAll", params={"target": target})
|
||||
|
||||
@argument_validation
|
||||
async def group_config(self, target: int):
|
||||
@ -685,7 +686,7 @@ class Bot(BaseBot):
|
||||
"anonymousChat": true
|
||||
}
|
||||
"""
|
||||
return await self.api.request('groupConfig', params={'target': target})
|
||||
return await self.api.request("groupConfig", params={"target": target})
|
||||
|
||||
@argument_validation
|
||||
async def modify_group_config(self, target: int, config: Dict[str, Any]):
|
||||
@ -699,11 +700,9 @@ class Bot(BaseBot):
|
||||
* ``target: int``: 指定群的群号
|
||||
* ``config: Dict[str, Any]``: 群设置, 格式见 ``group_config`` 的返回值
|
||||
"""
|
||||
return await self.api.post('groupConfig',
|
||||
params={
|
||||
'target': target,
|
||||
'config': config
|
||||
})
|
||||
return await self.api.post(
|
||||
"groupConfig", params={"target": target, "config": config}
|
||||
)
|
||||
|
||||
@argument_validation
|
||||
async def member_info(self, target: int, member_id: int):
|
||||
@ -726,15 +725,14 @@ class Bot(BaseBot):
|
||||
"specialTitle": "群头衔"
|
||||
}
|
||||
"""
|
||||
return await self.api.request('memberInfo',
|
||||
params={
|
||||
'target': target,
|
||||
'memberId': member_id
|
||||
})
|
||||
return await self.api.request(
|
||||
"memberInfo", params={"target": target, "memberId": member_id}
|
||||
)
|
||||
|
||||
@argument_validation
|
||||
async def modify_member_info(self, target: int, member_id: int,
|
||||
info: Dict[str, Any]):
|
||||
async def modify_member_info(
|
||||
self, target: int, member_id: int, info: Dict[str, Any]
|
||||
):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -746,9 +744,6 @@ class Bot(BaseBot):
|
||||
* ``member_id: int``: 群员QQ号
|
||||
* ``info: Dict[str, Any]``: 群员资料, 格式见 ``member_info`` 的返回值
|
||||
"""
|
||||
return await self.api.post('memberInfo',
|
||||
params={
|
||||
'target': target,
|
||||
'memberId': member_id,
|
||||
'info': info
|
||||
})
|
||||
return await self.api.post(
|
||||
"memberInfo", params={"target": target, "memberId": member_id, "info": info}
|
||||
)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
from pydantic import Extra, Field, BaseModel
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
@ -14,9 +14,10 @@ class Config(BaseModel):
|
||||
- ``mirai_host``: mirai-api-http 的地址
|
||||
- ``mirai_port``: mirai-api-http 的端口
|
||||
"""
|
||||
auth_key: Optional[str] = Field(None, alias='mirai_auth_key')
|
||||
host: Optional[IPv4Address] = Field(None, alias='mirai_host')
|
||||
port: Optional[int] = Field(None, alias='mirai_port')
|
||||
|
||||
auth_key: Optional[str] = Field(None, alias="mirai_auth_key")
|
||||
host: Optional[IPv4Address] = Field(None, alias="mirai_host")
|
||||
port: Optional[int] = Field(None, alias="mirai_port")
|
||||
|
||||
class Config:
|
||||
extra = Extra.ignore
|
||||
|
@ -5,25 +5,56 @@ r"""
|
||||
部分字段可能与文档在符号上不一致
|
||||
\:\:\:
|
||||
"""
|
||||
from .base import (Event, GroupChatInfo, GroupInfo, PrivateChatInfo,
|
||||
UserPermission)
|
||||
from .message import *
|
||||
from .notice import *
|
||||
from .message import *
|
||||
from .request import *
|
||||
from .base import (
|
||||
Event,
|
||||
GroupInfo,
|
||||
GroupChatInfo,
|
||||
UserPermission,
|
||||
PrivateChatInfo,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'Event', 'GroupChatInfo', 'GroupInfo', 'PrivateChatInfo', 'UserPermission',
|
||||
'MessageSource', 'MessageEvent', 'GroupMessage', 'FriendMessage',
|
||||
'TempMessage', 'NoticeEvent', 'MuteEvent', 'BotMuteEvent', 'BotUnmuteEvent',
|
||||
'MemberMuteEvent', 'MemberUnmuteEvent', 'BotJoinGroupEvent',
|
||||
'BotLeaveEventActive', 'BotLeaveEventKick', 'MemberJoinEvent',
|
||||
'MemberLeaveEventKick', 'MemberLeaveEventQuit', 'FriendRecallEvent',
|
||||
'GroupRecallEvent', 'GroupStateChangeEvent', 'GroupNameChangeEvent',
|
||||
'GroupEntranceAnnouncementChangeEvent', 'GroupMuteAllEvent',
|
||||
'GroupAllowAnonymousChatEvent', 'GroupAllowConfessTalkEvent',
|
||||
'GroupAllowMemberInviteEvent', 'MemberStateChangeEvent',
|
||||
'MemberCardChangeEvent', 'MemberSpecialTitleChangeEvent',
|
||||
'BotGroupPermissionChangeEvent', 'MemberPermissionChangeEvent',
|
||||
'RequestEvent', 'NewFriendRequestEvent', 'MemberJoinRequestEvent',
|
||||
'BotInvitedJoinGroupRequestEvent'
|
||||
"Event",
|
||||
"GroupChatInfo",
|
||||
"GroupInfo",
|
||||
"PrivateChatInfo",
|
||||
"UserPermission",
|
||||
"MessageSource",
|
||||
"MessageEvent",
|
||||
"GroupMessage",
|
||||
"FriendMessage",
|
||||
"TempMessage",
|
||||
"NoticeEvent",
|
||||
"MuteEvent",
|
||||
"BotMuteEvent",
|
||||
"BotUnmuteEvent",
|
||||
"MemberMuteEvent",
|
||||
"MemberUnmuteEvent",
|
||||
"BotJoinGroupEvent",
|
||||
"BotLeaveEventActive",
|
||||
"BotLeaveEventKick",
|
||||
"MemberJoinEvent",
|
||||
"MemberLeaveEventKick",
|
||||
"MemberLeaveEventQuit",
|
||||
"FriendRecallEvent",
|
||||
"GroupRecallEvent",
|
||||
"GroupStateChangeEvent",
|
||||
"GroupNameChangeEvent",
|
||||
"GroupEntranceAnnouncementChangeEvent",
|
||||
"GroupMuteAllEvent",
|
||||
"GroupAllowAnonymousChatEvent",
|
||||
"GroupAllowConfessTalkEvent",
|
||||
"GroupAllowMemberInviteEvent",
|
||||
"MemberStateChangeEvent",
|
||||
"MemberCardChangeEvent",
|
||||
"MemberSpecialTitleChangeEvent",
|
||||
"BotGroupPermissionChangeEvent",
|
||||
"MemberPermissionChangeEvent",
|
||||
"RequestEvent",
|
||||
"NewFriendRequestEvent",
|
||||
"MemberJoinRequestEvent",
|
||||
"BotInvitedJoinGroupRequestEvent",
|
||||
]
|
||||
|
@ -22,9 +22,10 @@ class UserPermission(str, Enum):
|
||||
* ``ADMINISTRATOR``: 群管理
|
||||
* ``MEMBER``: 普通群成员
|
||||
"""
|
||||
OWNER = 'OWNER'
|
||||
ADMINISTRATOR = 'ADMINISTRATOR'
|
||||
MEMBER = 'MEMBER'
|
||||
|
||||
OWNER = "OWNER"
|
||||
ADMINISTRATOR = "ADMINISTRATOR"
|
||||
MEMBER = "MEMBER"
|
||||
|
||||
|
||||
class NudgeSubjectKind(str, Enum):
|
||||
@ -36,8 +37,9 @@ class NudgeSubjectKind(str, Enum):
|
||||
* ``Group``: 群
|
||||
* ``Friend``: 好友
|
||||
"""
|
||||
Group = 'Group'
|
||||
Friend = 'Friend'
|
||||
|
||||
Group = "Group"
|
||||
Friend = "Friend"
|
||||
|
||||
|
||||
class GroupInfo(BaseModel):
|
||||
@ -48,7 +50,7 @@ class GroupInfo(BaseModel):
|
||||
|
||||
class GroupChatInfo(BaseModel):
|
||||
id: int
|
||||
name: str = Field(alias='memberName')
|
||||
name: str = Field(alias="memberName")
|
||||
permission: UserPermission
|
||||
group: GroupInfo
|
||||
|
||||
@ -71,6 +73,7 @@ class Event(BaseEvent):
|
||||
.. _mirai-api-http 事件类型:
|
||||
https://github.com/project-mirai/mirai-api-http/blob/master/docs/EventType.md
|
||||
"""
|
||||
|
||||
self_id: int
|
||||
type: str
|
||||
|
||||
@ -79,11 +82,12 @@ class Event(BaseEvent):
|
||||
"""
|
||||
此事件类的工厂函数, 能够通过事件数据选择合适的子类进行序列化
|
||||
"""
|
||||
type = data['type']
|
||||
type = data["type"]
|
||||
|
||||
def all_subclasses(cls: Type[Event]):
|
||||
return set(cls.__subclasses__()).union(
|
||||
[s for c in cls.__subclasses__() for s in all_subclasses(c)])
|
||||
[s for c in cls.__subclasses__() for s in all_subclasses(c)]
|
||||
)
|
||||
|
||||
event_class: Optional[Type[Event]] = None
|
||||
for subclass in all_subclasses(cls):
|
||||
@ -99,23 +103,25 @@ class Event(BaseEvent):
|
||||
return event_class.parse_obj(data)
|
||||
except ValidationError as e:
|
||||
logger.info(
|
||||
f'Failed to parse {data} to class {event_class.__name__}: '
|
||||
f'{e.errors()!r}. Fallback to parent class.')
|
||||
f"Failed to parse {data} to class {event_class.__name__}: "
|
||||
f"{e.errors()!r}. Fallback to parent class."
|
||||
)
|
||||
event_class = event_class.__base__ # type: ignore
|
||||
|
||||
raise ValueError(f'Failed to serialize {data}.')
|
||||
raise ValueError(f"Failed to serialize {data}.")
|
||||
|
||||
@overrides(BaseEvent)
|
||||
def get_type(self) -> Literal["message", "notice", "request", "meta_event"]:
|
||||
from . import meta, notice, message, request
|
||||
|
||||
if isinstance(self, message.MessageEvent):
|
||||
return 'message'
|
||||
return "message"
|
||||
elif isinstance(self, notice.NoticeEvent):
|
||||
return 'notice'
|
||||
return "notice"
|
||||
elif isinstance(self, request.RequestEvent):
|
||||
return 'request'
|
||||
return "request"
|
||||
else:
|
||||
return 'meta_event'
|
||||
return "meta_event"
|
||||
|
||||
@overrides(BaseEvent)
|
||||
def get_event_name(self) -> str:
|
||||
|
@ -1,7 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from nonebot.typing import overrides
|
||||
|
||||
@ -16,7 +16,8 @@ class MessageSource(BaseModel):
|
||||
|
||||
class MessageEvent(Event):
|
||||
"""消息事件基类"""
|
||||
message_chain: MessageChain = Field(alias='messageChain')
|
||||
|
||||
message_chain: MessageChain = Field(alias="messageChain")
|
||||
source: Optional[MessageSource] = None
|
||||
sender: Any
|
||||
|
||||
@ -39,12 +40,13 @@ class MessageEvent(Event):
|
||||
|
||||
class GroupMessage(MessageEvent):
|
||||
"""群消息事件"""
|
||||
|
||||
sender: GroupChatInfo
|
||||
to_me: bool = False
|
||||
|
||||
@overrides(MessageEvent)
|
||||
def get_session_id(self) -> str:
|
||||
return f'group_{self.sender.group.id}_' + self.get_user_id()
|
||||
return f"group_{self.sender.group.id}_" + self.get_user_id()
|
||||
|
||||
@overrides(MessageEvent)
|
||||
def get_user_id(self) -> str:
|
||||
@ -57,6 +59,7 @@ class GroupMessage(MessageEvent):
|
||||
|
||||
class FriendMessage(MessageEvent):
|
||||
"""好友消息事件"""
|
||||
|
||||
sender: PrivateChatInfo
|
||||
|
||||
@overrides(MessageEvent)
|
||||
@ -65,7 +68,7 @@ class FriendMessage(MessageEvent):
|
||||
|
||||
@overrides(MessageEvent)
|
||||
def get_session_id(self) -> str:
|
||||
return 'friend_' + self.get_user_id()
|
||||
return "friend_" + self.get_user_id()
|
||||
|
||||
@overrides(MessageEvent)
|
||||
def is_tome(self) -> bool:
|
||||
@ -74,11 +77,12 @@ class FriendMessage(MessageEvent):
|
||||
|
||||
class TempMessage(MessageEvent):
|
||||
"""临时会话消息事件"""
|
||||
|
||||
sender: GroupChatInfo
|
||||
|
||||
@overrides(MessageEvent)
|
||||
def get_session_id(self) -> str:
|
||||
return f'temp_{self.sender.group.id}_' + self.get_user_id()
|
||||
return f"temp_{self.sender.group.id}_" + self.get_user_id()
|
||||
|
||||
@overrides(MessageEvent)
|
||||
def is_tome(self) -> bool:
|
||||
|
@ -3,29 +3,35 @@ from .base import Event
|
||||
|
||||
class MetaEvent(Event):
|
||||
"""元事件基类"""
|
||||
|
||||
qq: int
|
||||
|
||||
|
||||
class BotOnlineEvent(MetaEvent):
|
||||
"""Bot登录成功"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BotOfflineEventActive(MetaEvent):
|
||||
"""Bot主动离线"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BotOfflineEventForce(MetaEvent):
|
||||
"""Bot被挤下线"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BotOfflineEventDropped(MetaEvent):
|
||||
"""Bot被服务器断开或因网络问题而掉线"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BotReloginEvent(MetaEvent):
|
||||
"""Bot主动重新登录"""
|
||||
pass
|
||||
|
||||
pass
|
||||
|
@ -2,88 +2,103 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import Event, GroupChatInfo, GroupInfo, NudgeSubject, UserPermission
|
||||
from .base import Event, GroupInfo, NudgeSubject, GroupChatInfo, UserPermission
|
||||
|
||||
|
||||
class NoticeEvent(Event):
|
||||
"""通知事件基类"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MuteEvent(NoticeEvent):
|
||||
"""禁言类事件基类"""
|
||||
|
||||
operator: GroupChatInfo
|
||||
|
||||
|
||||
class BotMuteEvent(MuteEvent):
|
||||
"""Bot被禁言"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BotUnmuteEvent(MuteEvent):
|
||||
"""Bot被取消禁言"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MemberMuteEvent(MuteEvent):
|
||||
"""群成员被禁言事件(该成员不是Bot)"""
|
||||
duration_seconds: int = Field(alias='durationSeconds')
|
||||
|
||||
duration_seconds: int = Field(alias="durationSeconds")
|
||||
member: GroupChatInfo
|
||||
operator: Optional[GroupChatInfo] = None
|
||||
|
||||
|
||||
class MemberUnmuteEvent(MuteEvent):
|
||||
"""群成员被取消禁言事件(该成员不是Bot)"""
|
||||
|
||||
member: GroupChatInfo
|
||||
operator: Optional[GroupChatInfo] = None
|
||||
|
||||
|
||||
class BotJoinGroupEvent(NoticeEvent):
|
||||
"""Bot加入了一个新群"""
|
||||
|
||||
group: GroupInfo
|
||||
|
||||
|
||||
class BotLeaveEventActive(BotJoinGroupEvent):
|
||||
"""Bot主动退出一个群"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BotLeaveEventKick(BotJoinGroupEvent):
|
||||
"""Bot被踢出一个群"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MemberJoinEvent(NoticeEvent):
|
||||
"""新人入群的事件"""
|
||||
|
||||
member: GroupChatInfo
|
||||
|
||||
|
||||
class MemberLeaveEventKick(MemberJoinEvent):
|
||||
"""成员被踢出群(该成员不是Bot)"""
|
||||
|
||||
operator: Optional[GroupChatInfo] = None
|
||||
|
||||
|
||||
class MemberLeaveEventQuit(MemberJoinEvent):
|
||||
"""成员主动离群(该成员不是Bot)"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FriendRecallEvent(NoticeEvent):
|
||||
"""好友消息撤回"""
|
||||
author_id: int = Field(alias='authorId')
|
||||
message_id: int = Field(alias='messageId')
|
||||
|
||||
author_id: int = Field(alias="authorId")
|
||||
message_id: int = Field(alias="messageId")
|
||||
time: int
|
||||
operator: int
|
||||
|
||||
|
||||
class GroupRecallEvent(FriendRecallEvent):
|
||||
"""群消息撤回"""
|
||||
|
||||
group: GroupInfo
|
||||
operator: Optional[GroupChatInfo] = None
|
||||
|
||||
|
||||
class GroupStateChangeEvent(NoticeEvent):
|
||||
"""群变化事件基类"""
|
||||
|
||||
origin: Any
|
||||
current: Any
|
||||
group: GroupInfo
|
||||
@ -92,73 +107,85 @@ class GroupStateChangeEvent(NoticeEvent):
|
||||
|
||||
class GroupNameChangeEvent(GroupStateChangeEvent):
|
||||
"""某个群名改变"""
|
||||
|
||||
origin: str
|
||||
current: str
|
||||
|
||||
|
||||
class GroupEntranceAnnouncementChangeEvent(GroupStateChangeEvent):
|
||||
"""某群入群公告改变"""
|
||||
|
||||
origin: str
|
||||
current: str
|
||||
|
||||
|
||||
class GroupMuteAllEvent(GroupStateChangeEvent):
|
||||
"""全员禁言"""
|
||||
|
||||
origin: bool
|
||||
current: bool
|
||||
|
||||
|
||||
class GroupAllowAnonymousChatEvent(GroupStateChangeEvent):
|
||||
"""匿名聊天"""
|
||||
|
||||
origin: bool
|
||||
current: bool
|
||||
|
||||
|
||||
class GroupAllowConfessTalkEvent(GroupStateChangeEvent):
|
||||
"""坦白说"""
|
||||
|
||||
origin: bool
|
||||
current: bool
|
||||
|
||||
|
||||
class GroupAllowMemberInviteEvent(GroupStateChangeEvent):
|
||||
"""允许群员邀请好友加群"""
|
||||
|
||||
origin: bool
|
||||
current: bool
|
||||
|
||||
|
||||
class MemberStateChangeEvent(NoticeEvent):
|
||||
"""群成员变化事件基类"""
|
||||
|
||||
member: GroupChatInfo
|
||||
operator: Optional[GroupChatInfo] = None
|
||||
|
||||
|
||||
class MemberCardChangeEvent(MemberStateChangeEvent):
|
||||
"""群名片改动"""
|
||||
|
||||
origin: str
|
||||
current: str
|
||||
|
||||
|
||||
class MemberSpecialTitleChangeEvent(MemberStateChangeEvent):
|
||||
"""群头衔改动(只有群主有操作限权)"""
|
||||
|
||||
origin: str
|
||||
current: str
|
||||
|
||||
|
||||
class BotGroupPermissionChangeEvent(MemberStateChangeEvent):
|
||||
"""Bot在群里的权限被改变"""
|
||||
|
||||
origin: UserPermission
|
||||
current: UserPermission
|
||||
|
||||
|
||||
class MemberPermissionChangeEvent(MemberStateChangeEvent):
|
||||
"""成员权限改变的事件(该成员不是Bot)"""
|
||||
|
||||
origin: UserPermission
|
||||
current: UserPermission
|
||||
|
||||
|
||||
class NudgeEvent(NoticeEvent):
|
||||
"""戳一戳触发事件"""
|
||||
from_id: int = Field(alias='fromId')
|
||||
|
||||
from_id: int = Field(alias="fromId")
|
||||
target: int
|
||||
subject: NudgeSubject
|
||||
action: str
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing_extensions import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
from .base import Event
|
||||
|
||||
@ -11,15 +11,17 @@ if TYPE_CHECKING:
|
||||
|
||||
class RequestEvent(Event):
|
||||
"""请求事件基类"""
|
||||
event_id: int = Field(alias='eventId')
|
||||
|
||||
event_id: int = Field(alias="eventId")
|
||||
message: str
|
||||
nick: str
|
||||
|
||||
|
||||
class NewFriendRequestEvent(RequestEvent):
|
||||
"""添加好友申请"""
|
||||
from_id: int = Field(alias='fromId')
|
||||
group_id: int = Field(0, alias='groupId')
|
||||
|
||||
from_id: int = Field(alias="fromId")
|
||||
group_id: int = Field(0, alias="groupId")
|
||||
|
||||
async def approve(self, bot: "Bot"):
|
||||
"""
|
||||
@ -31,19 +33,18 @@ class NewFriendRequestEvent(RequestEvent):
|
||||
|
||||
* ``bot: Bot``: 当前的 ``Bot`` 对象
|
||||
"""
|
||||
return await bot.api.post('/resp/newFriendRequestEvent',
|
||||
params={
|
||||
'eventId': self.event_id,
|
||||
'groupId': self.group_id,
|
||||
'fromId': self.from_id,
|
||||
'operate': 0,
|
||||
'message': ''
|
||||
})
|
||||
return await bot.api.post(
|
||||
"/resp/newFriendRequestEvent",
|
||||
params={
|
||||
"eventId": self.event_id,
|
||||
"groupId": self.group_id,
|
||||
"fromId": self.from_id,
|
||||
"operate": 0,
|
||||
"message": "",
|
||||
},
|
||||
)
|
||||
|
||||
async def reject(self,
|
||||
bot: "Bot",
|
||||
operate: Literal[1, 2] = 1,
|
||||
message: str = ''):
|
||||
async def reject(self, bot: "Bot", operate: Literal[1, 2] = 1, message: str = ""):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -60,21 +61,24 @@ class NewFriendRequestEvent(RequestEvent):
|
||||
* ``message: str``: 回复的信息
|
||||
"""
|
||||
assert operate > 0
|
||||
return await bot.api.post('/resp/newFriendRequestEvent',
|
||||
params={
|
||||
'eventId': self.event_id,
|
||||
'groupId': self.group_id,
|
||||
'fromId': self.from_id,
|
||||
'operate': operate,
|
||||
'message': message
|
||||
})
|
||||
return await bot.api.post(
|
||||
"/resp/newFriendRequestEvent",
|
||||
params={
|
||||
"eventId": self.event_id,
|
||||
"groupId": self.group_id,
|
||||
"fromId": self.from_id,
|
||||
"operate": operate,
|
||||
"message": message,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class MemberJoinRequestEvent(RequestEvent):
|
||||
"""用户入群申请(Bot需要有管理员权限)"""
|
||||
from_id: int = Field(alias='fromId')
|
||||
group_id: int = Field(alias='groupId')
|
||||
group_name: str = Field(alias='groupName')
|
||||
|
||||
from_id: int = Field(alias="fromId")
|
||||
group_id: int = Field(alias="groupId")
|
||||
group_name: str = Field(alias="groupName")
|
||||
|
||||
async def approve(self, bot: "Bot"):
|
||||
"""
|
||||
@ -86,19 +90,20 @@ class MemberJoinRequestEvent(RequestEvent):
|
||||
|
||||
* ``bot: Bot``: 当前的 ``Bot`` 对象
|
||||
"""
|
||||
return await bot.api.post('/resp/memberJoinRequestEvent',
|
||||
params={
|
||||
'eventId': self.event_id,
|
||||
'groupId': self.group_id,
|
||||
'fromId': self.from_id,
|
||||
'operate': 0,
|
||||
'message': ''
|
||||
})
|
||||
return await bot.api.post(
|
||||
"/resp/memberJoinRequestEvent",
|
||||
params={
|
||||
"eventId": self.event_id,
|
||||
"groupId": self.group_id,
|
||||
"fromId": self.from_id,
|
||||
"operate": 0,
|
||||
"message": "",
|
||||
},
|
||||
)
|
||||
|
||||
async def reject(self,
|
||||
bot: "Bot",
|
||||
operate: Literal[1, 2, 3, 4] = 1,
|
||||
message: str = ''):
|
||||
async def reject(
|
||||
self, bot: "Bot", operate: Literal[1, 2, 3, 4] = 1, message: str = ""
|
||||
):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -117,21 +122,24 @@ class MemberJoinRequestEvent(RequestEvent):
|
||||
* ``message: str``: 回复的信息
|
||||
"""
|
||||
assert operate > 0
|
||||
return await bot.api.post('/resp/memberJoinRequestEvent',
|
||||
params={
|
||||
'eventId': self.event_id,
|
||||
'groupId': self.group_id,
|
||||
'fromId': self.from_id,
|
||||
'operate': operate,
|
||||
'message': message
|
||||
})
|
||||
return await bot.api.post(
|
||||
"/resp/memberJoinRequestEvent",
|
||||
params={
|
||||
"eventId": self.event_id,
|
||||
"groupId": self.group_id,
|
||||
"fromId": self.from_id,
|
||||
"operate": operate,
|
||||
"message": message,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class BotInvitedJoinGroupRequestEvent(RequestEvent):
|
||||
"""Bot被邀请入群申请"""
|
||||
from_id: int = Field(alias='fromId')
|
||||
group_id: int = Field(alias='groupId')
|
||||
group_name: str = Field(alias='groupName')
|
||||
|
||||
from_id: int = Field(alias="fromId")
|
||||
group_id: int = Field(alias="groupId")
|
||||
group_name: str = Field(alias="groupName")
|
||||
|
||||
async def approve(self, bot: "Bot"):
|
||||
"""
|
||||
@ -143,14 +151,16 @@ class BotInvitedJoinGroupRequestEvent(RequestEvent):
|
||||
|
||||
* ``bot: Bot``: 当前的 ``Bot`` 对象
|
||||
"""
|
||||
return await bot.api.post('/resp/botInvitedJoinGroupRequestEvent',
|
||||
params={
|
||||
'eventId': self.event_id,
|
||||
'groupId': self.group_id,
|
||||
'fromId': self.from_id,
|
||||
'operate': 0,
|
||||
'message': ''
|
||||
})
|
||||
return await bot.api.post(
|
||||
"/resp/botInvitedJoinGroupRequestEvent",
|
||||
params={
|
||||
"eventId": self.event_id,
|
||||
"groupId": self.group_id,
|
||||
"fromId": self.from_id,
|
||||
"operate": 0,
|
||||
"message": "",
|
||||
},
|
||||
)
|
||||
|
||||
async def reject(self, bot: "Bot", message: str = ""):
|
||||
"""
|
||||
@ -163,11 +173,13 @@ class BotInvitedJoinGroupRequestEvent(RequestEvent):
|
||||
* ``bot: Bot``: 当前的 ``Bot`` 对象
|
||||
* ``message: str``: 邀请消息
|
||||
"""
|
||||
return await bot.api.post('/resp/botInvitedJoinGroupRequestEvent',
|
||||
params={
|
||||
'eventId': self.event_id,
|
||||
'groupId': self.group_id,
|
||||
'fromId': self.from_id,
|
||||
'operate': 1,
|
||||
'message': message
|
||||
})
|
||||
return await bot.api.post(
|
||||
"/resp/botInvitedJoinGroupRequestEvent",
|
||||
params={
|
||||
"eventId": self.event_id,
|
||||
"groupId": self.group_id,
|
||||
"fromId": self.from_id,
|
||||
"operate": 1,
|
||||
"message": message,
|
||||
},
|
||||
)
|
||||
|
@ -1,28 +1,29 @@
|
||||
from enum import Enum
|
||||
from typing import Any, List, Dict, Type, Iterable, Optional, Union
|
||||
from typing import Any, Dict, List, Type, Union, Iterable, Optional
|
||||
|
||||
from pydantic import validate_arguments
|
||||
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.adapters import Message as BaseMessage
|
||||
from nonebot.adapters import MessageSegment as BaseMessageSegment
|
||||
from nonebot.typing import overrides
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
"""消息类型枚举类"""
|
||||
SOURCE = 'Source'
|
||||
QUOTE = 'Quote'
|
||||
AT = 'At'
|
||||
AT_ALL = 'AtAll'
|
||||
FACE = 'Face'
|
||||
PLAIN = 'Plain'
|
||||
IMAGE = 'Image'
|
||||
FLASH_IMAGE = 'FlashImage'
|
||||
VOICE = 'Voice'
|
||||
XML = 'Xml'
|
||||
JSON = 'Json'
|
||||
APP = 'App'
|
||||
POKE = 'Poke'
|
||||
|
||||
SOURCE = "Source"
|
||||
QUOTE = "Quote"
|
||||
AT = "At"
|
||||
AT_ALL = "AtAll"
|
||||
FACE = "Face"
|
||||
PLAIN = "Plain"
|
||||
IMAGE = "Image"
|
||||
FLASH_IMAGE = "FlashImage"
|
||||
VOICE = "Voice"
|
||||
XML = "Xml"
|
||||
JSON = "Json"
|
||||
APP = "App"
|
||||
POKE = "Poke"
|
||||
|
||||
|
||||
class MessageSegment(BaseMessageSegment["MessageChain"]):
|
||||
@ -43,21 +44,24 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
|
||||
@validate_arguments
|
||||
@overrides(BaseMessageSegment)
|
||||
def __init__(self, type: MessageType, **data: Any):
|
||||
super().__init__(type=type,
|
||||
data={k: v for k, v in data.items() if v is not None})
|
||||
super().__init__(
|
||||
type=type, data={k: v for k, v in data.items() if v is not None}
|
||||
)
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def __str__(self) -> str:
|
||||
return self.data['text'] if self.is_text() else repr(self)
|
||||
return self.data["text"] if self.is_text() else repr(self)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '[mirai:%s]' % ','.join([
|
||||
self.type.value,
|
||||
*map(
|
||||
lambda s: '%s=%r' % s,
|
||||
self.data.items(),
|
||||
),
|
||||
])
|
||||
return "[mirai:%s]" % ",".join(
|
||||
[
|
||||
self.type.value,
|
||||
*map(
|
||||
lambda s: "%s=%r" % s,
|
||||
self.data.items(),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def is_text(self) -> bool:
|
||||
@ -65,15 +69,21 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
|
||||
|
||||
def as_dict(self) -> Dict[str, Any]:
|
||||
"""导出可以被正常json序列化的结构体"""
|
||||
return {'type': self.type.value, **self.data}
|
||||
return {"type": self.type.value, **self.data}
|
||||
|
||||
@classmethod
|
||||
def source(cls, id: int, time: int):
|
||||
return cls(type=MessageType.SOURCE, id=id, time=time)
|
||||
|
||||
@classmethod
|
||||
def quote(cls, id: int, group_id: int, sender_id: int, target_id: int,
|
||||
origin: "MessageChain"):
|
||||
def quote(
|
||||
cls,
|
||||
id: int,
|
||||
group_id: int,
|
||||
sender_id: int,
|
||||
target_id: int,
|
||||
origin: "MessageChain",
|
||||
):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -87,12 +97,14 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
|
||||
* ``target_id: int``: 被引用回复的原消息的接收者者的QQ号(或群号)
|
||||
* ``origin: MessageChain``: 被引用回复的原消息的消息链对象
|
||||
"""
|
||||
return cls(type=MessageType.QUOTE,
|
||||
id=id,
|
||||
groupId=group_id,
|
||||
senderId=sender_id,
|
||||
targetId=target_id,
|
||||
origin=origin.export())
|
||||
return cls(
|
||||
type=MessageType.QUOTE,
|
||||
id=id,
|
||||
groupId=group_id,
|
||||
senderId=sender_id,
|
||||
targetId=target_id,
|
||||
origin=origin.export(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def at(cls, target: int):
|
||||
@ -144,10 +156,12 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
|
||||
return cls(type=MessageType.PLAIN, text=text)
|
||||
|
||||
@classmethod
|
||||
def image(cls,
|
||||
image_id: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
path: Optional[str] = None):
|
||||
def image(
|
||||
cls,
|
||||
image_id: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -162,10 +176,12 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
|
||||
return cls(type=MessageType.IMAGE, imageId=image_id, url=url, path=path)
|
||||
|
||||
@classmethod
|
||||
def flash_image(cls,
|
||||
image_id: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
path: Optional[str] = None):
|
||||
def flash_image(
|
||||
cls,
|
||||
image_id: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -175,16 +191,15 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
|
||||
|
||||
同 ``image``
|
||||
"""
|
||||
return cls(type=MessageType.FLASH_IMAGE,
|
||||
imageId=image_id,
|
||||
url=url,
|
||||
path=path)
|
||||
return cls(type=MessageType.FLASH_IMAGE, imageId=image_id, url=url, path=path)
|
||||
|
||||
@classmethod
|
||||
def voice(cls,
|
||||
voice_id: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
path: Optional[str] = None):
|
||||
def voice(
|
||||
cls,
|
||||
voice_id: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -196,10 +211,7 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
|
||||
* ``url: Optional[str]``: 语音的URL,发送时可作网络语音的链接
|
||||
* ``path: Optional[str]``: 语音的路径,发送本地语音
|
||||
"""
|
||||
return cls(type=MessageType.FLASH_IMAGE,
|
||||
imageId=voice_id,
|
||||
url=url,
|
||||
path=path)
|
||||
return cls(type=MessageType.FLASH_IMAGE, imageId=voice_id, url=url, path=path)
|
||||
|
||||
@classmethod
|
||||
def xml(cls, xml: str):
|
||||
@ -282,16 +294,14 @@ class MessageChain(BaseMessage[MessageSegment]):
|
||||
return [MessageSegment.plain(text=message)]
|
||||
return [
|
||||
*map(
|
||||
lambda x: x
|
||||
if isinstance(x, MessageSegment) else MessageSegment(**x),
|
||||
message)
|
||||
lambda x: x if isinstance(x, MessageSegment) else MessageSegment(**x),
|
||||
message,
|
||||
)
|
||||
]
|
||||
|
||||
def export(self) -> List[Dict[str, Any]]:
|
||||
"""导出为可以被正常json序列化的数组"""
|
||||
return [
|
||||
*map(lambda segment: segment.as_dict(), self.copy()) # type: ignore
|
||||
]
|
||||
return [*map(lambda segment: segment.as_dict(), self.copy())] # type: ignore
|
||||
|
||||
def extract_first(self, *type: MessageType) -> Optional[MessageSegment]:
|
||||
"""
|
||||
@ -311,4 +321,4 @@ class MessageChain(BaseMessage[MessageSegment]):
|
||||
return None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__} {[*self.copy()]}>'
|
||||
return f"<{self.__class__.__name__} {[*self.copy()]}>"
|
||||
|
@ -1,17 +1,17 @@
|
||||
import re
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, Callable, Optional, Coroutine
|
||||
|
||||
import httpx
|
||||
from pydantic import Extra, ValidationError, validate_arguments
|
||||
|
||||
import nonebot.exception as exception
|
||||
from nonebot.log import logger
|
||||
import nonebot.exception as exception
|
||||
from nonebot.message import handle_event
|
||||
from nonebot.utils import escape_tag, logger_wrapper
|
||||
|
||||
from .event import Event, GroupMessage, MessageEvent, MessageSource
|
||||
from .message import MessageType, MessageSegment
|
||||
from .event import Event, GroupMessage, MessageEvent, MessageSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .bot import Bot
|
||||
@ -21,28 +21,27 @@ _AnyCallable = TypeVar("_AnyCallable", bound=Callable)
|
||||
|
||||
|
||||
class Log:
|
||||
|
||||
@staticmethod
|
||||
def log(level: str, message: str, exception: Optional[Exception] = None):
|
||||
logger = logger_wrapper('MIRAI')
|
||||
message = '<e>' + escape_tag(message) + '</e>'
|
||||
logger = logger_wrapper("MIRAI")
|
||||
message = "<e>" + escape_tag(message) + "</e>"
|
||||
logger(level=level.upper(), message=message, exception=exception)
|
||||
|
||||
@classmethod
|
||||
def info(cls, message: Any):
|
||||
cls.log('INFO', str(message))
|
||||
cls.log("INFO", str(message))
|
||||
|
||||
@classmethod
|
||||
def debug(cls, message: Any):
|
||||
cls.log('DEBUG', str(message))
|
||||
cls.log("DEBUG", str(message))
|
||||
|
||||
@classmethod
|
||||
def warn(cls, message: Any):
|
||||
cls.log('WARNING', str(message))
|
||||
cls.log("WARNING", str(message))
|
||||
|
||||
@classmethod
|
||||
def error(cls, message: Any, exception: Optional[Exception] = None):
|
||||
cls.log('ERROR', str(message), exception=exception)
|
||||
cls.log("ERROR", str(message), exception=exception)
|
||||
|
||||
|
||||
class ActionFailed(exception.ActionFailed):
|
||||
@ -53,12 +52,13 @@ class ActionFailed(exception.ActionFailed):
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__('mirai')
|
||||
super().__init__("mirai")
|
||||
self.data = kwargs.copy()
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(%s)' % ', '.join(
|
||||
map(lambda m: '%s=%r' % m, self.data.items()))
|
||||
return self.__class__.__name__ + "(%s)" % ", ".join(
|
||||
map(lambda m: "%s=%r" % m, self.data.items())
|
||||
)
|
||||
|
||||
|
||||
class InvalidArgument(exception.AdapterException):
|
||||
@ -69,7 +69,7 @@ class InvalidArgument(exception.AdapterException):
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__('mirai')
|
||||
super().__init__("mirai")
|
||||
|
||||
|
||||
def catch_network_error(function: _AsyncCallable) -> _AsyncCallable:
|
||||
@ -90,11 +90,12 @@ def catch_network_error(function: _AsyncCallable) -> _AsyncCallable:
|
||||
try:
|
||||
data = await function(*args, **kwargs)
|
||||
except httpx.HTTPError:
|
||||
raise exception.NetworkError('mirai')
|
||||
logger.opt(colors=True).debug('<b>Mirai API returned data:</b> '
|
||||
f'<y>{escape_tag(str(data))}</y>')
|
||||
raise exception.NetworkError("mirai")
|
||||
logger.opt(colors=True).debug(
|
||||
"<b>Mirai API returned data:</b> " f"<y>{escape_tag(str(data))}</y>"
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
if data.get('code', 0) != 0:
|
||||
if data.get("code", 0) != 0:
|
||||
raise ActionFailed(**data)
|
||||
return data
|
||||
|
||||
@ -109,10 +110,9 @@ def argument_validation(function: _AnyCallable) -> _AnyCallable:
|
||||
|
||||
会在参数出错时释放 ``InvalidArgument`` 异常
|
||||
"""
|
||||
function = validate_arguments(config={
|
||||
'arbitrary_types_allowed': True,
|
||||
'extra': Extra.forbid
|
||||
})(function)
|
||||
function = validate_arguments(
|
||||
config={"arbitrary_types_allowed": True, "extra": Extra.forbid}
|
||||
)(function)
|
||||
|
||||
@wraps(function)
|
||||
def wrapper(*args, **kwargs):
|
||||
@ -134,12 +134,12 @@ def process_source(bot: "Bot", event: MessageEvent) -> MessageEvent:
|
||||
def process_at(bot: "Bot", event: GroupMessage) -> GroupMessage:
|
||||
at = event.message_chain.extract_first(MessageType.AT)
|
||||
if at is not None:
|
||||
if at.data['target'] == event.self_id:
|
||||
if at.data["target"] == event.self_id:
|
||||
event.to_me = True
|
||||
else:
|
||||
event.message_chain.insert(0, at)
|
||||
if not event.message_chain:
|
||||
event.message_chain.append(MessageSegment.plain(''))
|
||||
event.message_chain.append(MessageSegment.plain(""))
|
||||
return event
|
||||
|
||||
|
||||
@ -147,13 +147,13 @@ def process_nick(bot: "Bot", event: GroupMessage) -> GroupMessage:
|
||||
plain = event.message_chain.extract_first(MessageType.PLAIN)
|
||||
if plain is not None:
|
||||
text = str(plain)
|
||||
nick_regex = '|'.join(filter(lambda x: x, bot.config.nickname))
|
||||
nick_regex = "|".join(filter(lambda x: x, bot.config.nickname))
|
||||
matched = re.search(rf"^({nick_regex})([\s,,]*|$)", text, re.IGNORECASE)
|
||||
if matched is not None:
|
||||
event.to_me = True
|
||||
nickname = matched.group(1)
|
||||
Log.info(f'User is calling me {nickname}')
|
||||
plain.data['text'] = text[matched.end():]
|
||||
Log.info(f"User is calling me {nickname}")
|
||||
plain.data["text"] = text[matched.end() :]
|
||||
event.message_chain.insert(0, plain)
|
||||
return event
|
||||
|
||||
@ -161,7 +161,7 @@ def process_nick(bot: "Bot", event: GroupMessage) -> GroupMessage:
|
||||
def process_reply(bot: "Bot", event: GroupMessage) -> GroupMessage:
|
||||
reply = event.message_chain.extract_first(MessageType.QUOTE)
|
||||
if reply is not None:
|
||||
if reply.data['senderId'] == event.self_id:
|
||||
if reply.data["senderId"] == event.self_id:
|
||||
event.to_me = True
|
||||
else:
|
||||
event.message_chain.insert(0, reply)
|
||||
|
@ -34,6 +34,21 @@ nonebot2 = { path = "../../", develop = true }
|
||||
# url = "https://mirrors.aliyun.com/pypi/simple/"
|
||||
# default = true
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ["py37", "py38", "py39"]
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '''
|
||||
'''
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 80
|
||||
length_sort = true
|
||||
skip_gitignore = true
|
||||
force_sort_within_sections = true
|
||||
extra_standard_library = ["typing_extensions"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
@ -7,8 +7,7 @@ from nonebot.log import logger
|
||||
def init():
|
||||
driver = nonebot.get_driver()
|
||||
try:
|
||||
_module = importlib.import_module(
|
||||
f"nonebot_plugin_docs.drivers.{driver.type}")
|
||||
_module = importlib.import_module(f"nonebot_plugin_docs.drivers.{driver.type}")
|
||||
except ImportError:
|
||||
logger.warning(f"Driver {driver.type} not supported")
|
||||
return
|
||||
@ -18,8 +17,9 @@ def init():
|
||||
port = driver.config.port
|
||||
if host in ["0.0.0.0", "127.0.0.1"]:
|
||||
host = "localhost"
|
||||
logger.opt(colors=True).info(f"Nonebot docs will be running at: "
|
||||
f"<b><u>http://{host}:{port}/docs/</u></b>")
|
||||
logger.opt(colors=True).info(
|
||||
f"Nonebot docs will be running at: " f"<b><u>http://{host}:{port}/docs/</u></b>"
|
||||
)
|
||||
|
||||
|
||||
init()
|
||||
|
@ -9,6 +9,4 @@ def register_route(driver: Driver):
|
||||
|
||||
static_path = str((Path(__file__).parent / ".." / "dist").resolve())
|
||||
|
||||
app.mount("/docs",
|
||||
StaticFiles(directory=static_path, html=True),
|
||||
name="docs")
|
||||
app.mount("/docs", StaticFiles(directory=static_path, html=True), name="docs")
|
||||
|
@ -18,6 +18,21 @@ nonebot2 = "^2.0.0-alpha.1"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ["py37", "py38", "py39"]
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '''
|
||||
'''
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 80
|
||||
length_sort = true
|
||||
skip_gitignore = true
|
||||
force_sort_within_sections = true
|
||||
extra_standard_library = ["typing_extensions"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
Reference in New Issue
Block a user