🎨 format code using black and isort

This commit is contained in:
yanyongyu
2021-11-22 23:21:26 +08:00
parent 602185a34e
commit a98d98cd12
86 changed files with 2893 additions and 2095 deletions

View File

@ -16,10 +16,18 @@ from nonebot.drivers import Driver, HTTPRequest, HTTPResponse, HTTPConnection
from .config import Config as DingConfig
from .utils import log, calc_hmac_base64
from .message import Message, MessageSegment
from .exception import (ActionFailed, NetworkError, SessionExpired,
ApiNotAvailable)
from .event import (MessageEvent, ConversationType, GroupMessageEvent,
PrivateMessageEvent)
from .exception import (
ActionFailed,
NetworkError,
SessionExpired,
ApiNotAvailable,
)
from .event import (
MessageEvent,
ConversationType,
GroupMessageEvent,
PrivateMessageEvent,
)
if TYPE_CHECKING:
from nonebot.config import Config
@ -31,6 +39,7 @@ class Bot(BaseBot):
"""
钉钉 协议 Bot 适配。继承属性参考 `BaseBot <./#class-basebot>`_ 。
"""
ding_config: DingConfig
@property
@ -48,8 +57,8 @@ class Bot(BaseBot):
@classmethod
@overrides(BaseBot)
async def check_permission(
cls, driver: Driver,
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]:
cls, driver: Driver, request: HTTPConnection
) -> Tuple[Optional[str], HTTPResponse]:
"""
:说明:
@ -61,7 +70,8 @@ class Bot(BaseBot):
# 检查连接方式
if not isinstance(request, HTTPRequest):
return None, HTTPResponse(
405, b"Unsupported connection type, available type: `http`")
405, b"Unsupported connection type, available type: `http`"
)
# 检查 timestamp
if not timestamp:
@ -74,13 +84,15 @@ class Bot(BaseBot):
log("WARNING", "Missing Signature Header")
return None, HTTPResponse(400, b"Missing `sign` Header")
sign_base64 = calc_hmac_base64(str(timestamp), secret)
if sign != sign_base64.decode('utf-8'):
if sign != sign_base64.decode("utf-8"):
log("WARNING", "Signature Header is invalid")
return None, HTTPResponse(403, b"Signature is invalid")
else:
log("WARNING", "Ding signature check ignored!")
return (json.loads(request.body.decode())["chatbotUserId"],
HTTPResponse(204, b''))
return (
json.loads(request.body.decode())["chatbotUserId"],
HTTPResponse(204, b""),
)
@overrides(BaseBot)
async def handle_message(self, message: bytes):
@ -111,10 +123,9 @@ class Bot(BaseBot):
return
@overrides(BaseBot)
async def _call_api(self,
api: str,
event: Optional[MessageEvent] = None,
**data) -> Any:
async def _call_api(
self, api: str, event: Optional[MessageEvent] = None, **data
) -> Any:
if not isinstance(self.request, HTTPRequest):
log("ERROR", "Only support http connection.")
return
@ -138,7 +149,8 @@ class Bot(BaseBot):
if event:
# 确保 sessionWebhook 没有过期
if int(datetime.now().timestamp()) > int(
event.sessionWebhookExpiredTime / 1000):
event.sessionWebhookExpiredTime / 1000
):
raise SessionExpired
webhook = event.sessionWebhook
@ -150,32 +162,37 @@ class Bot(BaseBot):
if not message:
raise ValueError("Message not found")
try:
async with httpx.AsyncClient(headers=headers,
follow_redirects=True) as client:
response = await client.post(webhook,
params=params,
json=message._produce(),
timeout=self.config.api_timeout)
async with httpx.AsyncClient(
headers=headers, follow_redirects=True
) as client:
response = await client.post(
webhook,
params=params,
json=message._produce(),
timeout=self.config.api_timeout,
)
if 200 <= response.status_code < 300:
result = response.json()
if isinstance(result, dict):
if result.get("errcode") != 0:
raise ActionFailed(errcode=result.get("errcode"),
errmsg=result.get("errmsg"))
raise ActionFailed(
errcode=result.get("errcode"), errmsg=result.get("errmsg")
)
return result
raise NetworkError(f"HTTP request received unexpected "
f"status code: {response.status_code}")
raise NetworkError(
f"HTTP request received unexpected "
f"status code: {response.status_code}"
)
except httpx.InvalidURL:
raise NetworkError("API root url invalid")
except httpx.HTTPError:
raise NetworkError("HTTP request failed")
@overrides(BaseBot)
async def call_api(self,
api: str,
event: Optional[MessageEvent] = None,
**data) -> Any:
async def call_api(
self, api: str, event: Optional[MessageEvent] = None, **data
) -> Any:
"""
:说明:
@ -199,13 +216,15 @@ class Bot(BaseBot):
return await super().call_api(api, event=event, **data)
@overrides(BaseBot)
async def send(self,
event: MessageEvent,
message: Union[str, "Message", "MessageSegment"],
at_sender: bool = False,
webhook: Optional[str] = None,
secret: Optional[str] = None,
**kwargs) -> Any:
async def send(
self,
event: MessageEvent,
message: Union[str, "Message", "MessageSegment"],
at_sender: bool = False,
webhook: Optional[str] = None,
secret: Optional[str] = None,
**kwargs,
) -> Any:
"""
:说明:
@ -241,9 +260,11 @@ class Bot(BaseBot):
params.update(kwargs)
if at_sender and event.conversationType != ConversationType.private:
params[
"message"] = f"@{event.senderId} " + msg + MessageSegment.atDingtalkIds(
event.senderId)
params["message"] = (
f"@{event.senderId} "
+ msg
+ MessageSegment.atDingtalkIds(event.senderId)
)
else:
params["message"] = msg

View File

@ -12,6 +12,7 @@ class Config(BaseModel):
- ``access_token`` / ``ding_access_token``: 钉钉令牌
- ``secret`` / ``ding_secret``: 钉钉 HTTP 上报数据签名口令
"""
secret: Optional[str] = Field(default=None, alias="ding_secret")
access_token: Optional[str] = Field(default=None, alias="ding_access_token")

View File

@ -69,6 +69,7 @@ class ConversationType(str, Enum):
class MessageEvent(Event):
"""消息事件"""
msgtype: str
text: TextMessage
msgId: str
@ -88,11 +89,10 @@ class MessageEvent(Event):
def gen_message(cls, values: dict):
assert "msgtype" in values, "msgtype must be specified"
# 其实目前钉钉机器人只能接收到 text 类型的消息
assert values[
"msgtype"] in values, f"{values['msgtype']} must be specified"
content = values[values['msgtype']]['content']
assert values["msgtype"] in values, f"{values['msgtype']} must be specified"
content = values[values["msgtype"]]["content"]
# 如果是被 @,第一个字符将会为空格,移除特殊情况
if content[0] == ' ':
if content[0] == " ":
content = content[1:]
values["message"] = content
return values
@ -128,6 +128,7 @@ class MessageEvent(Event):
class PrivateMessageEvent(MessageEvent):
"""私聊消息事件"""
chatbotCorpId: str
senderStaffId: Optional[str]
conversationType: ConversationType = ConversationType.private
@ -135,6 +136,7 @@ class PrivateMessageEvent(MessageEvent):
class GroupMessageEvent(MessageEvent):
"""群消息事件"""
atUsers: List[AtUsersItem]
conversationType: ConversationType = ConversationType.group
conversationTitle: str

View File

@ -1,9 +1,9 @@
from typing import Optional
from nonebot.exception import (AdapterException, ActionFailed as
BaseActionFailed, ApiNotAvailable as
BaseApiNotAvailable, NetworkError as
BaseNetworkError)
from nonebot.exception import AdapterException
from nonebot.exception import ActionFailed as BaseActionFailed
from nonebot.exception import NetworkError as BaseNetworkError
from nonebot.exception import ApiNotAvailable as BaseApiNotAvailable
class DingAdapterException(AdapterException):
@ -29,15 +29,13 @@ class ActionFailed(BaseActionFailed, DingAdapterException):
* ``errmsg: Optional[str]``: 错误信息
"""
def __init__(self,
errcode: Optional[int] = None,
errmsg: Optional[str] = None):
def __init__(self, errcode: Optional[int] = None, errmsg: Optional[str] = None):
super().__init__()
self.errcode = errcode
self.errmsg = errmsg
def __repr__(self):
return f"<ApiError errcode={self.errcode} errmsg=\"{self.errmsg}\">"
return f'<ApiError errcode={self.errcode} errmsg="{self.errmsg}">'
def __str__(self):
return self.__repr__()

View File

@ -77,10 +77,9 @@ class MessageSegment(BaseMessageSegment["Message"]):
def code(code_language: str, code: str) -> "Message":
"""发送 code 消息段"""
message = MessageSegment.text(code)
message += MessageSegment.extension({
"text_type": "code_snippet",
"code_language": code_language
})
message += MessageSegment.extension(
{"text_type": "code_snippet", "code_language": code_language}
)
return message
@staticmethod
@ -95,16 +94,19 @@ class MessageSegment(BaseMessageSegment["Message"]):
)
@staticmethod
def actionCardSingleBtn(title: str, text: str, singleTitle: str,
singleURL) -> "MessageSegment":
def actionCardSingleBtn(
title: str, text: str, singleTitle: str, singleURL
) -> "MessageSegment":
"""发送 ``actionCardSingleBtn`` 类型消息"""
return MessageSegment(
"actionCard", {
"actionCard",
{
"title": title,
"text": text,
"singleTitle": singleTitle,
"singleURL": singleURL
})
"singleURL": singleURL,
},
)
@staticmethod
def actionCardMultiBtns(
@ -112,7 +114,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
text: str,
btns: list,
hideAvatar: bool = False,
btnOrientation: str = '1',
btnOrientation: str = "1",
) -> "MessageSegment":
"""
发送 ``actionCardMultiBtn`` 类型消息
@ -123,13 +125,15 @@ class MessageSegment(BaseMessageSegment["Message"]):
* ``btns``: ``[{ "title": title, "actionURL": actionURL }, ...]``
"""
return MessageSegment(
"actionCard", {
"actionCard",
{
"title": title,
"text": text,
"hideAvatar": "1" if hideAvatar else "0",
"btnOrientation": btnOrientation,
"btns": btns
})
"btns": btns,
},
)
@staticmethod
def feedCard(links: list) -> "MessageSegment":
@ -144,7 +148,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
@staticmethod
def raw(data) -> "MessageSegment":
return MessageSegment('raw', data)
return MessageSegment("raw", data)
def to_dict(self) -> Dict[str, Any]:
# 让用户可以直接发送原始的消息格式
@ -171,8 +175,8 @@ class Message(BaseMessage[MessageSegment]):
@staticmethod
@overrides(BaseMessage)
def _construct(
msg: Union[str, Mapping,
Iterable[Mapping]]) -> Iterable[MessageSegment]:
msg: Union[str, Mapping, Iterable[Mapping]]
) -> Iterable[MessageSegment]:
if isinstance(msg, Mapping):
msg = cast(Mapping[str, Any], msg)
yield MessageSegment(msg["type"], msg.get("data") or {})
@ -187,10 +191,11 @@ class Message(BaseMessage[MessageSegment]):
segment: MessageSegment
for segment in self:
# text 可以和 text 合并
if segment.type == "text" and data.get("msgtype") == 'text':
if segment.type == "text" and data.get("msgtype") == "text":
data.setdefault("text", {})
data["text"]["content"] = data["text"].setdefault(
"content", "") + segment.data["content"]
data["text"]["content"] = (
data["text"].setdefault("content", "") + segment.data["content"]
)
else:
data.update(segment.to_dict())
return data

View File

@ -8,10 +8,10 @@ log = logger_wrapper("DING")
def calc_hmac_base64(timestamp: str, secret: str):
secret_enc = secret.encode('utf-8')
string_to_sign = '{}\n{}'.format(timestamp, secret)
string_to_sign_enc = string_to_sign.encode('utf-8')
hmac_code = hmac.new(secret_enc,
string_to_sign_enc,
digestmod=hashlib.sha256).digest()
secret_enc = secret.encode("utf-8")
string_to_sign = "{}\n{}".format(timestamp, secret)
string_to_sign_enc = string_to_sign.encode("utf-8")
hmac_code = hmac.new(
secret_enc, string_to_sign_enc, digestmod=hashlib.sha256
).digest()
return base64.b64encode(hmac_code)