🔀 Merge pull request #127

Update ding adapter
This commit is contained in:
Ju4tCode
2020-12-30 20:29:18 +08:00
committed by GitHub
13 changed files with 635 additions and 380 deletions

View File

@ -6,6 +6,7 @@
"""
import abc
from copy import copy
from typing_extensions import Literal
from functools import reduce, partial
from dataclasses import dataclass, field
@ -164,7 +165,7 @@ class MessageSegment(abc.ABC):
@abc.abstractmethod
def __add__(self: T_MessageSegment, other: Union[str, T_MessageSegment,
T_Message]) -> "T_Message":
T_Message]) -> T_Message:
"""你需要在这里实现不同消息段的合并:
比如:
if isinstance(other, str):
@ -198,6 +199,9 @@ class MessageSegment(abc.ABC):
def get(self, key, default=None):
return getattr(self, key, default)
def copy(self: T_MessageSegment) -> T_MessageSegment:
return copy(self)
@abc.abstractmethod
def is_text(self) -> bool:
raise NotImplementedError
@ -207,22 +211,22 @@ class Message(list, abc.ABC):
"""消息数组"""
def __init__(self,
message: Union[str, dict, list, T_MessageSegment,
T_Message] = None,
message: Union[str, list, dict, T_MessageSegment, T_Message,
Any] = None,
*args,
**kwargs):
"""
:参数:
* ``message: Union[str, dict, list, MessageSegment, Message]``: 消息内容
* ``message: Union[str, list, dict, MessageSegment, Message, Any]``: 消息内容
"""
super().__init__(*args, **kwargs)
if isinstance(message, (str, dict, list)):
self.extend(self._construct(message))
elif isinstance(message, Message):
if isinstance(message, Message):
self.extend(message)
elif isinstance(message, MessageSegment):
self.append(message)
else:
self.extend(self._construct(message))
def __str__(self):
return ''.join((str(seg) for seg in self))
@ -238,8 +242,7 @@ class Message(list, abc.ABC):
@staticmethod
@abc.abstractmethod
def _construct(
msg: Union[str, dict, list,
BaseModel]) -> Iterable[T_MessageSegment]:
msg: Union[str, list, dict, Any]) -> Iterable[T_MessageSegment]:
raise NotImplementedError
def __add__(self: T_Message, other: Union[str, T_MessageSegment,

View File

@ -6,7 +6,6 @@ import asyncio
from typing import Any, Dict, Union, Optional, TYPE_CHECKING
import httpx
from nonebot.log import logger
from nonebot.config import Config
from nonebot.typing import overrides

View File

@ -5,13 +5,13 @@
协议详情请看: `钉钉文档`_
.. _钉钉文档:
https://ding-doc.dingtalk.com/doc#/serverapi2/krgddi
https://ding-doc.dingtalk.com/document#/org-dev-guide/elzz1p/
"""
from .utils import log
from .bot import Bot
from .event import Event
from .message import Message, MessageSegment
from .event import Event, MessageEvent, PrivateMessageEvent, GroupMessageEvent
from .exception import (DingAdapterException, ApiNotAvailable, NetworkError,
ActionFailed, SessionExpired)

View File

@ -6,18 +6,20 @@ from typing import Any, Union, Optional, TYPE_CHECKING
import httpx
from nonebot.log import logger
from nonebot.config import Config
from nonebot.typing import overrides
from nonebot.message import handle_event
from nonebot.adapters import Bot as BaseBot
from nonebot.exception import RequestDenied
from .utils import log
from .event import Event
from .model import MessageModel
from .message import Message, MessageSegment
from .exception import NetworkError, ApiNotAvailable, ActionFailed, SessionExpired
from .event import Event, MessageEvent, PrivateMessageEvent, GroupMessageEvent, ConversationType
if TYPE_CHECKING:
from nonebot.drivers import BaseDriver as Driver
from nonebot.drivers import Driver
SEND_BY_SESSION_WEBHOOK = "send_by_sessionWebhook"
class Bot(BaseBot):
@ -38,6 +40,7 @@ class Bot(BaseBot):
return "ding"
@classmethod
@overrides(BaseBot)
async def check_permission(cls, driver: "Driver", connection_type: str,
headers: dict, body: Optional[dict]) -> str:
"""
@ -50,7 +53,8 @@ class Bot(BaseBot):
# 检查连接方式
if connection_type not in ["http"]:
raise RequestDenied(405, "Unsupported connection type")
raise RequestDenied(
405, "Unsupported connection type, available type: `http`")
# 检查 timestamp
if not timestamp:
@ -72,13 +76,25 @@ class Bot(BaseBot):
log("WARNING", "Ding signature check ignored!")
return body["chatbotUserId"]
async def handle_message(self, body: dict):
message = MessageModel.parse_obj(body)
@overrides(BaseBot)
async def handle_message(self, message: dict):
if not message:
return
# 判断消息类型,生成不同的 Event
try:
conversation_type = message["conversationType"]
if conversation_type == ConversationType.private:
event = PrivateMessageEvent.parse_obj(message)
elif conversation_type == ConversationType.group:
event = GroupMessageEvent.parse_obj(message)
else:
raise ValueError("Unsupported conversation type")
except Exception as e:
log("ERROR", "Event Parser Error", e)
return
try:
event = Event(message)
await handle_event(self, event)
except Exception as e:
logger.opt(colors=True, exception=e).error(
@ -86,9 +102,10 @@ class Bot(BaseBot):
)
return
@overrides(BaseBot)
async def call_api(self,
api: str,
event: Optional[Event] = None,
event: Optional[MessageEvent] = None,
**data) -> Any:
"""
:说明:
@ -120,28 +137,27 @@ class Bot(BaseBot):
log("DEBUG", f"Calling API <y>{api}</y>")
if api == "send_message":
if api == SEND_BY_SESSION_WEBHOOK:
if event:
# 确保 sessionWebhook 没有过期
if int(datetime.now().timestamp()) > int(
event.raw_event.sessionWebhookExpiredTime / 1000):
event.sessionWebhookExpiredTime / 1000):
raise SessionExpired
target = event.raw_event.sessionWebhook
target = event.sessionWebhook
else:
target = None
if not target:
raise ApiNotAvailable
headers = {}
segment: MessageSegment = data["message"][0]
message: Message = data.get("message", None)
if not message:
raise ValueError("Message not found")
try:
async with httpx.AsyncClient(headers=headers) as client:
response = await client.post(
target,
params={"access_token": self.config.access_token},
json=segment.data,
json=message._produce(),
timeout=self.config.api_timeout)
if 200 <= response.status_code < 300:
@ -158,8 +174,9 @@ class Bot(BaseBot):
except httpx.HTTPError:
raise NetworkError("HTTP request failed")
@overrides(BaseBot)
async def send(self,
event: Event,
event: MessageEvent,
message: Union[str, "Message", "MessageSegment"],
at_sender: bool = False,
**kwargs) -> Any:
@ -187,14 +204,14 @@ class Bot(BaseBot):
"""
msg = message if isinstance(message, Message) else Message(message)
at_sender = at_sender and bool(event.user_id)
at_sender = at_sender and bool(event.senderId)
params = {}
params["event"] = event
params.update(kwargs)
if at_sender and event.detail_type != "private":
params["message"] = f"@{event.user_id} " + msg
if at_sender and event.conversationType != ConversationType.private:
params["message"] = f"@{event.senderNick} " + msg
else:
params["message"] = msg
return await self.call_api("send_message", **params)
return await self.call_api(SEND_BY_SESSION_WEBHOOK, **params)

View File

@ -1,197 +1,142 @@
from typing import Union, Optional
from enum import Enum
from typing import List, Optional
from typing_extensions import Literal
from pydantic import BaseModel, root_validator
from nonebot.typing import overrides
from nonebot.adapters import Event as BaseEvent
from .message import Message
from .model import MessageModel, ConversationType, TextMessage
class Event(BaseEvent):
"""
钉钉 协议 Event 适配。继承属性参考 `BaseEvent <./#class-baseevent>`_ 。
钉钉 协议 Event 适配。各事件字段参考 `钉钉文档`_
.. _钉钉文档:
https://ding-doc.dingtalk.com/document#/org-dev-guide/elzz1p
"""
def __init__(self, message: MessageModel):
super().__init__(message)
chatbotUserId: str
@overrides(BaseEvent)
def get_type(self) -> Literal["message", "notice", "request", "meta_event"]:
raise ValueError("Event has no type!")
@overrides(BaseEvent)
def get_event_name(self) -> str:
raise ValueError("Event has no name!")
@overrides(BaseEvent)
def get_event_description(self) -> str:
raise ValueError("Event has no description!")
@overrides(BaseEvent)
def get_message(self) -> "Message":
raise ValueError("Event has no message!")
@overrides(BaseEvent)
def get_plaintext(self) -> str:
raise ValueError("Event has no plaintext!")
@overrides(BaseEvent)
def get_user_id(self) -> str:
raise ValueError("Event has no user_id!")
@overrides(BaseEvent)
def get_session_id(self) -> str:
raise ValueError("Event has no session_id!")
@overrides(BaseEvent)
def is_tome(self) -> bool:
return True
class TextMessage(BaseModel):
content: str
class AtUsersItem(BaseModel):
dingtalkId: str
staffId: Optional[str]
class ConversationType(str, Enum):
private = "1"
group = "2"
class MessageEvent(Event):
msgtype: str
text: TextMessage
msgId: str
createAt: int # ms
conversationType: ConversationType
conversationId: str
senderId: str
senderNick: str
senderCorpId: str
sessionWebhook: str
sessionWebhookExpiredTime: int
isAdmin: bool
message: Message
@root_validator(pre=True)
def gen_message(cls, values: dict):
assert "msgtype" in values, "msgtype must be specified"
# 其实目前钉钉机器人只能接收到 text 类型的消息
self._message = Message(getattr(message, message.msgtype or "text"))
assert values[
"msgtype"] in values, f"{values['msgtype']} must be specified"
content = values[values['msgtype']]['content']
# 如果是被 @,第一个字符将会为空格,移除特殊情况
if content[0] == ' ':
content = content[1:]
values["message"] = content
return values
@property
def raw_event(self) -> MessageModel:
"""原始上报消息"""
return self._raw_event
@property
def id(self) -> Optional[str]:
"""
- 类型: ``Optional[str]``
- 说明: 消息 ID
"""
return self.raw_event.msgId
@property
def name(self) -> str:
"""
- 类型: ``str``
- 说明: 事件名称,由 `type`.`detail_type` 组合而成
"""
return self.type + "." + self.detail_type
@property
def self_id(self) -> str:
"""
- 类型: ``str``
- 说明: 机器人自身 ID
"""
return str(self.raw_event.chatbotUserId)
@property
def time(self) -> int:
"""
- 类型: ``int``
- 说明: 消息的时间戳,单位 s
"""
# 单位 ms -> s
return int(self.raw_event.createAt / 1000)
@property
def type(self) -> str:
"""
- 类型: ``str``
- 说明: 事件类型
"""
@overrides(Event)
def get_type(self) -> Literal["message", "notice", "request", "meta_event"]:
return "message"
@type.setter
def type(self, value) -> None:
pass
@overrides(Event)
def get_event_name(self) -> str:
return f"{self.get_type()}.{self.conversationType.name}"
@property
def detail_type(self) -> str:
"""
- 类型: ``str``
- 说明: 事件详细类型
"""
return self.raw_event.conversationType.name
@overrides(Event)
def get_event_description(self) -> str:
return f'Message[{self.msgtype}] {self.msgId} from {self.senderId} "{self.text.content}"'
@detail_type.setter
def detail_type(self, value) -> None:
if value == "private":
self.raw_event.conversationType = ConversationType.private
if value == "group":
self.raw_event.conversationType = ConversationType.group
@overrides(Event)
def get_message(self) -> Message:
return self.message
@property
def sub_type(self) -> None:
"""
- 类型: ``None``
- 说明: 钉钉适配器无事件子类型
"""
return None
@overrides(Event)
def get_plaintext(self) -> str:
return self.text.content
@sub_type.setter
def sub_type(self, value) -> None:
pass
@overrides(Event)
def get_user_id(self) -> str:
return self.senderId
@property
def user_id(self) -> Optional[str]:
"""
- 类型: ``Optional[str]``
- 说明: 发送者 ID
"""
return self.raw_event.senderId
@overrides(Event)
def get_session_id(self) -> str:
return self.senderId
@user_id.setter
def user_id(self, value) -> None:
self.raw_event.senderId = value
@property
def group_id(self) -> Optional[str]:
"""
- 类型: ``Optional[str]``
- 说明: 事件主体群 ID
"""
return self.raw_event.conversationId
class PrivateMessageEvent(MessageEvent):
chatbotCorpId: str
senderStaffId: Optional[str]
conversationType: ConversationType = ConversationType.private
@group_id.setter
def group_id(self, value) -> None:
self.raw_event.conversationId = value
@property
def to_me(self) -> Optional[bool]:
"""
- 类型: ``Optional[bool]``
- 说明: 消息是否与机器人相关
"""
return self.detail_type == "private" or self.raw_event.isInAtList
class GroupMessageEvent(MessageEvent):
atUsers: List[AtUsersItem]
conversationType: ConversationType = ConversationType.group
conversationTitle: str
isInAtList: bool
@property
def message(self) -> Optional["Message"]:
"""
- 类型: ``Optional[Message]``
- 说明: 消息内容
"""
return self._message
@message.setter
def message(self, value) -> None:
self._message = value
@property
def reply(self) -> None:
"""
- 类型: ``None``
- 说明: 回复消息详情
"""
raise ValueError("暂不支持 reply")
@property
def raw_message(self) -> Optional[Union[TextMessage]]:
"""
- 类型: ``Optional[str]``
- 说明: 原始消息
"""
return getattr(self.raw_event, self.raw_event.msgtype)
@raw_message.setter
def raw_message(self, value) -> None:
setattr(self.raw_event, self.raw_event.msgtype, value)
@property
def plain_text(self) -> Optional[str]:
"""
- 类型: ``Optional[str]``
- 说明: 纯文本消息内容
"""
return self.message and self.message.extract_plain_text().strip()
@property
def sender(self) -> Optional[dict]:
"""
- 类型: ``Optional[dict]``
- 说明: 消息发送者信息
"""
result = {
# 加密的发送者ID。
"senderId": self.raw_event.senderId,
# 发送者昵称。
"senderNick": self.raw_event.senderNick,
# 企业内部群有的发送者当前群的企业 corpId。
"senderCorpId": self.raw_event.senderCorpId,
# 企业内部群有的发送者在企业内的 userId。
"senderStaffId": self.raw_event.senderStaffId,
"role": "admin" if self.raw_event.isAdmin else "member"
}
return result
@sender.setter
def sender(self, value) -> None:
def set_wrapper(name):
if value.get(name):
setattr(self.raw_event, name, value.get(name))
set_wrapper("senderId")
set_wrapper("senderNick")
set_wrapper("senderCorpId")
set_wrapper("senderStaffId")
@overrides(MessageEvent)
def is_tome(self) -> bool:
return self.isInAtList

View File

@ -37,7 +37,10 @@ class ActionFailed(BaseActionFailed, DingAdapterException):
self.errmsg = errmsg
def __repr__(self):
return f"<ApiError errcode={self.errcode} errmsg={self.errmsg}>"
return f"<ApiError errcode={self.errcode} errmsg=\"{self.errmsg}\">"
def __str__(self):
return self.__repr__()
class ApiNotAvailable(BaseApiNotAvailable, DingAdapterException):
@ -66,7 +69,7 @@ class NetworkError(BaseNetworkError, DingAdapterException):
return self.__repr__()
class SessionExpired(BaseApiNotAvailable, DingAdapterException):
class SessionExpired(ApiNotAvailable, DingAdapterException):
"""
:说明:
@ -75,3 +78,6 @@ class SessionExpired(BaseApiNotAvailable, DingAdapterException):
def __repr__(self) -> str:
return f"<Session Webhook is Expired>"
def __str__(self):
return self.__repr__()

View File

@ -1,82 +1,87 @@
from copy import copy
from typing import Any, Dict, Union, Iterable
from nonebot.typing import overrides
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
from .utils import log
from .model import TextMessage
class MessageSegment(BaseMessageSegment):
"""
钉钉 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。
"""
def __init__(self, type_: str, msg: Dict[str, Any]) -> None:
data = {
"msgtype": type_,
}
if msg:
data.update(msg)
log("DEBUG", f"data {data}")
@overrides(BaseMessageSegment)
def __init__(self, type_: str, data: Dict[str, Any]) -> None:
super().__init__(type=type_, data=data)
@classmethod
def from_segment(cls, segment: "MessageSegment"):
return MessageSegment(segment.type, segment.data)
@overrides(BaseMessageSegment)
def __str__(self):
log("DEBUG", f"__str__: self.type {self.type} data {self.data}")
if self.type == "text":
return str(self.data["text"]["content"].strip())
return str(self.data["content"])
elif self.type == "markdown":
return str(self.data["text"])
return ""
@overrides(BaseMessageSegment)
def __add__(self, other) -> "Message":
if isinstance(other, str):
if self.type == 'text':
self.data['text']['content'] += other
return MessageSegment.from_segment(self)
return Message(self) + other
def atMobile(self, mobileNumber):
self.data.setdefault("at", {})
self.data["at"].setdefault("atMobiles", [])
self.data["at"]["atMobiles"].append(mobileNumber)
@overrides(BaseMessageSegment)
def __radd__(self, other) -> "Message":
return Message(other) + self
def atAll(self, value):
self.data.setdefault("at", {})
self.data["at"]["isAtAll"] = value
@overrides(BaseMessageSegment)
def is_text(self) -> bool:
return self.type == "text"
@staticmethod
def text(text_: str) -> "MessageSegment":
return MessageSegment("text", {"text": {"content": text_.strip()}})
def atAll() -> "MessageSegment":
return MessageSegment("at", {"isAtAll": True})
@staticmethod
def atMobiles(*mobileNumber: str) -> "MessageSegment":
return MessageSegment("at", {"atMobiles": list(mobileNumber)})
@staticmethod
def text(text: str) -> "MessageSegment":
return MessageSegment("text", {"content": text})
@staticmethod
def image(picURL: str) -> "MessageSegment":
return MessageSegment("image", {"picURL": picURL})
@staticmethod
def extension(dict_: dict) -> "MessageSegment":
""""标记 text 文本的 extension 属性,需要与 text 消息段相加。
"""
return MessageSegment("extension", dict_)
@staticmethod
def markdown(title: str, text: str) -> "MessageSegment":
return MessageSegment("markdown", {
"markdown": {
return MessageSegment(
"markdown",
{
"title": title,
"text": text,
},
})
)
@staticmethod
def actionCardSingleBtn(title: str, text: str, btnTitle: str,
btnUrl) -> "MessageSegment":
def actionCardSingleBtn(title: str, text: str, singleTitle: str,
singleURL) -> "MessageSegment":
return MessageSegment(
"actionCard", {
"actionCard": {
"title": title,
"text": text,
"singleTitle": btnTitle,
"singleURL": btnUrl
}
"title": title,
"text": text,
"singleTitle": singleTitle,
"singleURL": singleURL
})
@staticmethod
def actionCardSingleMultiBtns(
def actionCardMultiBtns(
title: str,
text: str,
btns: list = [],
btns: list,
hideAvatar: bool = False,
btnOrientation: str = '1',
) -> "MessageSegment":
@ -89,28 +94,36 @@ class MessageSegment(BaseMessageSegment):
"""
return MessageSegment(
"actionCard", {
"actionCard": {
"title": title,
"text": text,
"hideAvatar": "1" if hideAvatar else "0",
"btnOrientation": btnOrientation,
"btns": btns
}
"title": title,
"text": text,
"hideAvatar": "1" if hideAvatar else "0",
"btnOrientation": btnOrientation,
"btns": btns
})
@staticmethod
def feedCard(links: list = [],) -> "MessageSegment":
def feedCard(links: list) -> "MessageSegment":
"""
:参数:
* ``links``: [{ "title": xxx, "messageURL": xxx, "picURL": xxx }, ...]
"""
return MessageSegment("feedCard", {"feedCard": {"links": links}})
return MessageSegment("feedCard", {"links": links})
@staticmethod
def empty() -> "MessageSegment":
"""不想回复消息到群里"""
return MessageSegment("empty")
def raw(data) -> "MessageSegment":
return MessageSegment('raw', data)
def to_dict(self) -> dict:
# 让用户可以直接发送原始的消息格式
if self.type == "raw":
return copy(self.data)
# 不属于消息内容,只是作为消息段的辅助
if self.type in ["at", "extension"]:
return {self.type: copy(self.data)}
return {"msgtype": self.type, self.type: copy(self.data)}
class Message(BaseMessage):
@ -119,17 +132,24 @@ class Message(BaseMessage):
"""
@staticmethod
def _construct(
msg: Union[str, dict, list,
TextMessage]) -> Iterable[MessageSegment]:
@overrides(BaseMessage)
def _construct(msg: Union[str, dict, list]) -> Iterable[MessageSegment]:
if isinstance(msg, dict):
yield MessageSegment(msg["type"], msg.get("data") or {})
return
elif isinstance(msg, list):
for seg in msg:
yield MessageSegment(seg["type"], seg.get("data") or {})
return
elif isinstance(msg, TextMessage):
yield MessageSegment("text", {"text": msg.dict()})
elif isinstance(msg, str):
yield MessageSegment.text(msg)
def _produce(self) -> dict:
data = {}
for segment in self:
# text 可以和 text 合并
if segment.type == "text" and data.get("msgtype") == 'text':
data.setdefault("text", {})
data["text"]["content"] = data["text"].setdefault(
"content", "") + segment.data["content"]
else:
data.update(segment.to_dict())
return data

View File

@ -1,48 +0,0 @@
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
class Headers(BaseModel):
sign: str
token: str
# ms
timestamp: int
class TextMessage(BaseModel):
content: str
class AtUsersItem(BaseModel):
dingtalkId: str
staffId: Optional[str]
class ConversationType(str, Enum):
private = '1'
group = '2'
class MessageModel(BaseModel):
msgtype: str = None
text: Optional[TextMessage] = None
msgId: str
# ms
createAt: int = None
conversationType: ConversationType = None
conversationId: str = None
conversationTitle: str = None
senderId: str = None
senderNick: str = None
senderCorpId: str = None
senderStaffId: str = None
chatbotUserId: str = None
chatbotCorpId: str = None
atUsers: List[AtUsersItem] = None
sessionWebhook: str = None
# ms
sessionWebhookExpiredTime: int = None
isAdmin: bool = None
isInAtList: bool = None

View File

@ -113,7 +113,7 @@ class Matcher(metaclass=MatcherMeta):
self.state = self._default_state.copy()
def __repr__(self) -> str:
return (f"<Matcher from {self.module or 'unknow'}, type={self.type}, "
return (f"<Matcher from {self.module or 'unknown'}, type={self.type}, "
f"priority={self.priority}, temp={self.temp}>")
def __str__(self) -> str:
@ -460,13 +460,23 @@ class Matcher(metaclass=MatcherMeta):
if not hasattr(handler, "__params__"):
self.process_handler(handler)
params = getattr(handler, "__params__")
BotType = ((params["bot"] is not inspect.Parameter.empty) and
inspect.isclass(params["bot"]) and params["bot"])
if BotType and not isinstance(bot, BotType):
logger.debug(
f"Matcher {self} bot type {type(bot)} not match annotation {BotType}, ignored"
)
return
EventType = ((params["event"] is not inspect.Parameter.empty) and
inspect.isclass(params["event"]) and params["event"])
if (BotType and not isinstance(bot, BotType)) or (
EventType and not isinstance(event, EventType)):
if EventType and not isinstance(event, EventType):
logger.debug(
f"Matcher {self} event type {type(event)} not match annotation {EventType}, ignored"
)
return
args = {"bot": bot, "event": event, "state": state, "matcher": self}
await handler(
**{k: v for k, v in args.items() if params[k] is not None})