🏗️ change nonebot project structure

This commit is contained in:
yanyongyu
2021-02-11 17:29:16 +08:00
parent 3cc738c205
commit 31b8a5ff77
47 changed files with 2138 additions and 33 deletions

View File

@ -0,0 +1,16 @@
"""
钉钉群机器人 协议适配
============================
协议详情请看: `钉钉文档`_
.. _钉钉文档:
https://ding-doc.dingtalk.com/document#/org-dev-guide/elzz1p
"""
from .utils import log
from .bot import Bot
from .message import Message, MessageSegment
from .event import Event, MessageEvent, PrivateMessageEvent, GroupMessageEvent
from .exception import (DingAdapterException, ApiNotAvailable, NetworkError,
ActionFailed, SessionExpired)

View File

@ -0,0 +1,225 @@
import hmac
import base64
from datetime import datetime
from typing import Any, Union, Optional, TYPE_CHECKING
import httpx
from nonebot.log import logger
from nonebot.typing import overrides
from nonebot.message import handle_event
from nonebot.adapters import Bot as BaseBot
from nonebot.exception import RequestDenied
from .utils import log
from .config import Config as DingConfig
from .message import Message, MessageSegment
from .exception import NetworkError, ApiNotAvailable, ActionFailed, SessionExpired
from .event import MessageEvent, PrivateMessageEvent, GroupMessageEvent, ConversationType
if TYPE_CHECKING:
from nonebot.config import Config
from nonebot.drivers import Driver
SEND_BY_SESSION_WEBHOOK = "send_by_sessionWebhook"
class Bot(BaseBot):
"""
钉钉 协议 Bot 适配。继承属性参考 `BaseBot <./#class-basebot>`_ 。
"""
ding_config: DingConfig
def __init__(self, connection_type: str, self_id: str, **kwargs):
super().__init__(connection_type, self_id, **kwargs)
@property
def type(self) -> str:
"""
- 返回: ``"ding"``
"""
return "ding"
@classmethod
def register(cls, driver: "Driver", config: "Config"):
super().register(driver, config)
cls.ding_config = DingConfig(**config.dict())
@classmethod
@overrides(BaseBot)
async def check_permission(cls, driver: "Driver", connection_type: str,
headers: dict, body: Optional[dict]) -> str:
"""
:说明:
钉钉协议鉴权。参考 `鉴权 <https://ding-doc.dingtalk.com/doc#/serverapi2/elzz1p>`_
"""
timestamp = headers.get("timestamp")
sign = headers.get("sign")
# 检查连接方式
if connection_type not in ["http"]:
raise RequestDenied(
405, "Unsupported connection type, available type: `http`")
# 检查 timestamp
if not timestamp:
raise RequestDenied(400, "Missing `timestamp` Header")
# 检查 sign
secret = cls.ding_config.secret
if secret:
if not sign:
log("WARNING", "Missing Signature Header")
raise RequestDenied(400, "Missing `sign` Header")
string_to_sign = f"{timestamp}\n{secret}"
sig = hmac.new(secret.encode("utf-8"),
string_to_sign.encode("utf-8"), "sha256").digest()
if sign != base64.b64encode(sig).decode("utf-8"):
log("WARNING", "Signature Header is invalid")
raise RequestDenied(403, "Signature is invalid")
else:
log("WARNING", "Ding signature check ignored!")
return body["chatbotUserId"]
@overrides(BaseBot)
async def handle_message(self, message: dict):
if not message:
return
# 判断消息类型,生成不同的 Event
try:
conversation_type = message["conversationType"]
if conversation_type == ConversationType.private:
event = PrivateMessageEvent.parse_obj(message)
elif conversation_type == ConversationType.group:
event = GroupMessageEvent.parse_obj(message)
else:
raise ValueError("Unsupported conversation type")
except Exception as e:
log("ERROR", "Event Parser Error", e)
return
try:
await handle_event(self, event)
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Failed to handle event. Raw: {message}</bg #f8bbd0></r>"
)
return
@overrides(BaseBot)
async def call_api(self,
api: str,
event: Optional[MessageEvent] = None,
**data) -> Any:
"""
:说明:
调用 钉钉 协议 API
:参数:
* ``api: str``: API 名称
* ``**data: Any``: API 参数
:返回:
- ``Any``: API 调用返回数据
:异常:
- ``NetworkError``: 网络错误
- ``ActionFailed``: API 调用失败
"""
if self.connection_type != "http":
log("ERROR", "Only support http connection.")
return
if "self_id" in data:
self_id = data.pop("self_id")
if self_id:
bot = self.driver.bots[str(self_id)]
return await bot.call_api(api, **data)
log("DEBUG", f"Calling API <y>{api}</y>")
if api == SEND_BY_SESSION_WEBHOOK:
if event:
# 确保 sessionWebhook 没有过期
if int(datetime.now().timestamp()) > int(
event.sessionWebhookExpiredTime / 1000):
raise SessionExpired
target = event.sessionWebhook
else:
raise ApiNotAvailable
headers = {}
message: Message = data.get("message", None)
if not message:
raise ValueError("Message not found")
try:
async with httpx.AsyncClient(headers=headers) as client:
response = await client.post(
target,
params={"access_token": self.ding_config.access_token},
json=message._produce(),
timeout=self.config.api_timeout)
if 200 <= response.status_code < 300:
result = response.json()
if isinstance(result, dict):
if result.get("errcode") != 0:
raise ActionFailed(errcode=result.get("errcode"),
errmsg=result.get("errmsg"))
return result
raise NetworkError(f"HTTP request received unexpected "
f"status code: {response.status_code}")
except httpx.InvalidURL:
raise NetworkError("API root url invalid")
except httpx.HTTPError:
raise NetworkError("HTTP request failed")
@overrides(BaseBot)
async def send(self,
event: MessageEvent,
message: Union[str, "Message", "MessageSegment"],
at_sender: bool = False,
**kwargs) -> Any:
"""
:说明:
根据 ``event`` 向触发事件的主体发送消息。
:参数:
* ``event: Event``: Event 对象
* ``message: Union[str, Message, MessageSegment]``: 要发送的消息
* ``at_sender: bool``: 是否 @ 事件主体
* ``**kwargs``: 覆盖默认参数
:返回:
- ``Any``: API 调用返回数据
:异常:
- ``ValueError``: 缺少 ``user_id``, ``group_id``
- ``NetworkError``: 网络错误
- ``ActionFailed``: API 调用失败
"""
msg = message if isinstance(message, Message) else Message(message)
at_sender = at_sender and bool(event.senderId)
params = {}
params["event"] = event
params.update(kwargs)
if at_sender and event.conversationType != ConversationType.private:
params[
"message"] = f"@{event.senderId} " + msg + MessageSegment.atDingtalkIds(
event.senderId)
else:
params["message"] = msg
return await self.call_api(SEND_BY_SESSION_WEBHOOK, **params)

View File

@ -0,0 +1,19 @@
from typing import Optional
from pydantic import Field, BaseModel
class Config(BaseModel):
"""
钉钉配置类
:配置项:
- ``access_token`` / ``ding_access_token``: 钉钉令牌
- ``secret`` / ``ding_secret``: 钉钉 HTTP 上报数据签名口令
"""
secret: Optional[str] = Field(default=None, alias="ding_secret")
access_token: Optional[str] = Field(default=None, alias="ding_access_token")
class Config:
extra = "ignore"

View File

@ -0,0 +1,145 @@
from enum import Enum
from typing import List, Optional
from typing_extensions import Literal
from pydantic import BaseModel, root_validator
from nonebot.typing import overrides
from nonebot.adapters import Event as BaseEvent
from .message import Message
class Event(BaseEvent):
"""
钉钉协议事件。各事件字段参考 `钉钉文档`_
.. _钉钉文档:
https://ding-doc.dingtalk.com/document#/org-dev-guide/elzz1p
"""
chatbotUserId: str
@overrides(BaseEvent)
def get_type(self) -> Literal["message", "notice", "request", "meta_event"]:
raise ValueError("Event has no type!")
@overrides(BaseEvent)
def get_event_name(self) -> str:
raise ValueError("Event has no name!")
@overrides(BaseEvent)
def get_event_description(self) -> str:
raise ValueError("Event has no description!")
@overrides(BaseEvent)
def get_message(self) -> "Message":
raise ValueError("Event has no message!")
@overrides(BaseEvent)
def get_plaintext(self) -> str:
raise ValueError("Event has no plaintext!")
@overrides(BaseEvent)
def get_user_id(self) -> str:
raise ValueError("Event has no user_id!")
@overrides(BaseEvent)
def get_session_id(self) -> str:
raise ValueError("Event has no session_id!")
@overrides(BaseEvent)
def is_tome(self) -> bool:
return True
class TextMessage(BaseModel):
content: str
class AtUsersItem(BaseModel):
dingtalkId: str
staffId: Optional[str]
class ConversationType(str, Enum):
private = "1"
group = "2"
class MessageEvent(Event):
"""消息事件"""
msgtype: str
text: TextMessage
msgId: str
createAt: int # ms
conversationType: ConversationType
conversationId: str
senderId: str
senderNick: str
senderCorpId: Optional[str]
sessionWebhook: str
sessionWebhookExpiredTime: int
isAdmin: bool
message: Message
@root_validator(pre=True)
def gen_message(cls, values: dict):
assert "msgtype" in values, "msgtype must be specified"
# 其实目前钉钉机器人只能接收到 text 类型的消息
assert values[
"msgtype"] in values, f"{values['msgtype']} must be specified"
content = values[values['msgtype']]['content']
# 如果是被 @,第一个字符将会为空格,移除特殊情况
if content[0] == ' ':
content = content[1:]
values["message"] = content
return values
@overrides(Event)
def get_type(self) -> Literal["message", "notice", "request", "meta_event"]:
return "message"
@overrides(Event)
def get_event_name(self) -> str:
return f"{self.get_type()}.{self.conversationType.name}"
@overrides(Event)
def get_event_description(self) -> str:
return f'Message[{self.msgtype}] {self.msgId} from {self.senderId} "{self.text.content}"'
@overrides(Event)
def get_message(self) -> Message:
return self.message
@overrides(Event)
def get_plaintext(self) -> str:
return self.text.content
@overrides(Event)
def get_user_id(self) -> str:
return self.senderId
@overrides(Event)
def get_session_id(self) -> str:
return self.senderId
class PrivateMessageEvent(MessageEvent):
"""私聊消息事件"""
chatbotCorpId: str
senderStaffId: Optional[str]
conversationType: ConversationType = ConversationType.private
class GroupMessageEvent(MessageEvent):
"""群消息事件"""
atUsers: List[AtUsersItem]
conversationType: ConversationType = ConversationType.group
conversationTitle: str
isInAtList: bool
@overrides(MessageEvent)
def is_tome(self) -> bool:
return self.isInAtList

View File

@ -0,0 +1,83 @@
from typing import Optional
from nonebot.exception import (AdapterException, ActionFailed as
BaseActionFailed, ApiNotAvailable as
BaseApiNotAvailable, NetworkError as
BaseNetworkError)
class DingAdapterException(AdapterException):
"""
:说明:
钉钉 Adapter 错误基类
"""
def __init__(self) -> None:
super().__init__("ding")
class ActionFailed(BaseActionFailed, DingAdapterException):
"""
:说明:
API 请求返回错误信息。
:参数:
* ``errcode: Optional[int]``: 错误码
* ``errmsg: Optional[str]``: 错误信息
"""
def __init__(self,
errcode: Optional[int] = None,
errmsg: Optional[str] = None):
super().__init__()
self.errcode = errcode
self.errmsg = errmsg
def __repr__(self):
return f"<ApiError errcode={self.errcode} errmsg=\"{self.errmsg}\">"
def __str__(self):
return self.__repr__()
class ApiNotAvailable(BaseApiNotAvailable, DingAdapterException):
pass
class NetworkError(BaseNetworkError, DingAdapterException):
"""
:说明:
网络错误。
:参数:
* ``retcode: Optional[int]``: 错误码
"""
def __init__(self, msg: Optional[str] = None):
super().__init__()
self.msg = msg
def __repr__(self):
return f"<NetWorkError message={self.msg}>"
def __str__(self):
return self.__repr__()
class SessionExpired(ApiNotAvailable, DingAdapterException):
"""
:说明:
发消息的 session 已经过期。
"""
def __repr__(self) -> str:
return f"<Session Webhook is Expired>"
def __str__(self):
return self.__repr__()

View File

@ -0,0 +1,187 @@
from copy import copy
from typing import Any, Dict, Union, Mapping, Iterable
from nonebot.typing import overrides
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
class MessageSegment(BaseMessageSegment):
"""
钉钉 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。
"""
@overrides(BaseMessageSegment)
def __init__(self, type_: str, data: Dict[str, Any]) -> None:
super().__init__(type=type_, data=data)
@overrides(BaseMessageSegment)
def __str__(self):
if self.type == "text":
return str(self.data["content"])
elif self.type == "markdown":
return str(self.data["text"])
return ""
@overrides(BaseMessageSegment)
def __add__(self, other) -> "Message":
return Message(self) + other
@overrides(BaseMessageSegment)
def __radd__(self, other) -> "Message":
return Message(other) + self
@overrides(BaseMessageSegment)
def is_text(self) -> bool:
return self.type == "text"
@staticmethod
def atAll() -> "MessageSegment":
"""@全体"""
return MessageSegment("at", {"isAtAll": True})
@staticmethod
def atMobiles(*mobileNumber: str) -> "MessageSegment":
"""@指定手机号人员"""
return MessageSegment("at", {"atMobiles": list(mobileNumber)})
@staticmethod
def atDingtalkIds(*dingtalkIds: str) -> "MessageSegment":
"""@指定 id@ 默认会在消息段末尾。
所以你可以在消息中使用 @{senderId} 占位,发送出去之后 @ 就会出现在占位的位置:
```python
message = MessageSegment.text(f"@{event.senderId},你好")
message += MessageSegment.atDingtalkIds(event.senderId)
```
"""
return MessageSegment("at", {"atDingtalkIds": list(dingtalkIds)})
@staticmethod
def text(text: str) -> "MessageSegment":
"""发送 ``text`` 类型消息"""
return MessageSegment("text", {"content": text})
@staticmethod
def image(picURL: str) -> "MessageSegment":
"""发送 ``image`` 类型消息"""
return MessageSegment("image", {"picURL": picURL})
@staticmethod
def extension(dict_: dict) -> "MessageSegment":
""""标记 text 文本的 extension 属性,需要与 text 消息段相加。"""
return MessageSegment("extension", dict_)
@staticmethod
def code(code_language: str, code: str) -> "Message":
""""发送 code 消息段"""
message = MessageSegment.text(code)
message += MessageSegment.extension({
"text_type": "code_snippet",
"code_language": code_language
})
return message
@staticmethod
def markdown(title: str, text: str) -> "MessageSegment":
"""发送 ``markdown`` 类型消息"""
return MessageSegment(
"markdown",
{
"title": title,
"text": text,
},
)
@staticmethod
def actionCardSingleBtn(title: str, text: str, singleTitle: str,
singleURL) -> "MessageSegment":
"""发送 ``actionCardSingleBtn`` 类型消息"""
return MessageSegment(
"actionCard", {
"title": title,
"text": text,
"singleTitle": singleTitle,
"singleURL": singleURL
})
@staticmethod
def actionCardMultiBtns(
title: str,
text: str,
btns: list,
hideAvatar: bool = False,
btnOrientation: str = '1',
) -> "MessageSegment":
"""
发送 ``actionCardMultiBtn`` 类型消息
:参数:
* ``btnOrientation``: 0按钮竖直排列 1按钮横向排列
* ``btns``: [{ "title": title, "actionURL": actionURL }, ...]
"""
return MessageSegment(
"actionCard", {
"title": title,
"text": text,
"hideAvatar": "1" if hideAvatar else "0",
"btnOrientation": btnOrientation,
"btns": btns
})
@staticmethod
def feedCard(links: list) -> "MessageSegment":
"""
发送 ``feedCard`` 类型消息
:参数:
* ``links``: [{ "title": xxx, "messageURL": xxx, "picURL": xxx }, ...]
"""
return MessageSegment("feedCard", {"links": links})
@staticmethod
def raw(data) -> "MessageSegment":
return MessageSegment('raw', data)
def to_dict(self) -> dict:
# 让用户可以直接发送原始的消息格式
if self.type == "raw":
return copy(self.data)
# 不属于消息内容,只是作为消息段的辅助
if self.type in ["at", "extension"]:
return {self.type: copy(self.data)}
return {"msgtype": self.type, self.type: copy(self.data)}
class Message(BaseMessage):
"""
钉钉 协议 Message 适配。
"""
@staticmethod
@overrides(BaseMessage)
def _construct(
msg: Union[str, Mapping,
Iterable[Mapping]]) -> Iterable[MessageSegment]:
if isinstance(msg, Mapping):
yield MessageSegment(msg["type"], msg.get("data") or {})
elif isinstance(msg, str):
yield MessageSegment.text(msg)
elif isinstance(msg, Iterable):
for seg in msg:
yield MessageSegment(seg["type"], seg.get("data") or {})
def _produce(self) -> dict:
data = {}
segment: MessageSegment
for segment in self:
# text 可以和 text 合并
if segment.type == "text" and data.get("msgtype") == 'text':
data.setdefault("text", {})
data["text"]["content"] = data["text"].setdefault(
"content", "") + segment.data["content"]
else:
data.update(segment.to_dict())
return data

View File

@ -0,0 +1,3 @@
from nonebot.utils import logger_wrapper
log = logger_wrapper("DING")