🎨 format code using black and isort

This commit is contained in:
yanyongyu
2021-11-22 23:21:26 +08:00
parent 602185a34e
commit a98d98cd12
86 changed files with 2893 additions and 2095 deletions

View File

@ -7,7 +7,7 @@ aiocache_logger.setLevel(logging.DEBUG)
aiocache_logger.handlers.clear()
aiocache_logger.addHandler(LoguruHandler())
from .bot import Bot as Bot
from .event import *
from .bot import Bot as Bot
from .message import Message as Message
from .message import MessageSegment as MessageSegment

View File

@ -1,24 +1,39 @@
import re
import json
from typing import (TYPE_CHECKING, Any, Dict, Tuple, Union, Iterable, Optional,
AsyncIterable, cast)
from typing import (
TYPE_CHECKING,
Any,
Dict,
Tuple,
Union,
Iterable,
Optional,
AsyncIterable,
cast,
)
import httpx
from aiocache import Cache, cached
from aiocache.serializers import PickleSerializer
from nonebot.log import logger
from .utils import AESCipher, log
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.message import handle_event
from .config import Config as FeishuConfig
from nonebot.adapters import Bot as BaseBot
from nonebot.drivers import Driver, HTTPRequest, HTTPResponse
from .utils import AESCipher, log
from .config import Config as FeishuConfig
from .message import Message, MessageSegment, MessageSerializer
from .exception import ActionFailed, NetworkError, ApiNotAvailable
from .event import (Event, MessageEvent, GroupMessageEvent, PrivateMessageEvent,
get_event_model)
from .event import (
Event,
MessageEvent,
GroupMessageEvent,
PrivateMessageEvent,
get_event_model,
)
if TYPE_CHECKING:
from nonebot.config import Config
@ -47,8 +62,10 @@ def _check_at_me(bot: "Bot", event: "Event"):
event.to_me = True
for index, segment in enumerate(message):
if segment.type == "at" and segment.data.get(
"user_name") in bot.config.nickname:
if (
segment.type == "at"
and segment.data.get("user_name") in bot.config.nickname
):
event.to_me = True
del event.event.message.content[index]
return
@ -57,7 +74,8 @@ def _check_at_me(bot: "Bot", event: "Event"):
if mention["name"] in bot.config.nickname:
event.to_me = True
segment.data["text"] = segment.data["text"].replace(
f"@{mention['name']}", "")
f"@{mention['name']}", ""
)
segment.data["text"] = segment.data["text"].lstrip()
break
else:
@ -92,18 +110,18 @@ def _check_nickname(bot: "Bot", event: "Event"):
if nicknames:
# check if the user is calling me with my nickname
nickname_regex = "|".join(nicknames)
m = re.search(rf"^({nickname_regex})([\s,]*|$)", first_text,
re.IGNORECASE)
m = re.search(rf"^({nickname_regex})([\s,]*|$)", first_text, re.IGNORECASE)
if m:
nickname = m.group(1)
log("DEBUG", f"User is calling me {nickname}")
event.to_me = True
first_msg_seg.data["text"] = first_text[m.end():]
first_msg_seg.data["text"] = first_text[m.end() :]
def _handle_api_result(
result: Union[Optional[Dict[str, Any]], str, bytes, Iterable[bytes],
AsyncIterable[bytes]]
result: Union[
Optional[Dict[str, Any]], str, bytes, Iterable[bytes], AsyncIterable[bytes]
]
) -> Any:
"""
:说明:
@ -155,13 +173,13 @@ class Bot(BaseBot):
@classmethod
@overrides(BaseBot)
async def check_permission(
cls, driver: Driver, request: HTTPRequest
cls, driver: Driver, request: HTTPRequest
) -> Tuple[Optional[str], Optional[HTTPResponse]]:
if not isinstance(request, HTTPRequest):
log("WARNING",
"Unsupported connection type, available type: `http`")
log("WARNING", "Unsupported connection type, available type: `http`")
return None, HTTPResponse(
405, b"Unsupported connection type, available type: `http`")
405, b"Unsupported connection type, available type: `http`"
)
encrypt_key = cls.feishu_config.encrypt_key
if encrypt_key:
@ -174,16 +192,13 @@ class Bot(BaseBot):
challenge = data.get("challenge")
if challenge:
return data.get("token"), HTTPResponse(
200,
json.dumps({
"challenge": challenge
}).encode())
200, json.dumps({"challenge": challenge}).encode()
)
schema = data.get("schema")
if not schema:
return None, HTTPResponse(
400,
b"Missing `schema` in POST body, only accept event of version 2.0"
400, b"Missing `schema` in POST body, only accept event of version 2.0"
)
headers = data.get("header")
@ -196,15 +211,13 @@ class Bot(BaseBot):
if not token:
log("WARNING", "Missing `verification token` in POST body")
return None, HTTPResponse(
400, b"Missing `verification token` in POST body")
return None, HTTPResponse(400, b"Missing `verification token` in POST body")
else:
if token != cls.feishu_config.verification_token:
log("WARNING", "Verification token check failed")
return None, HTTPResponse(403,
b"Verification token check failed")
return None, HTTPResponse(403, b"Verification token check failed")
return app_id, HTTPResponse(200, b'')
return app_id, HTTPResponse(200, b"")
async def handle_message(self, message: bytes):
"""
@ -245,28 +258,32 @@ class Bot(BaseBot):
def _construct_url(self, path: str) -> str:
return self.api_root + path
@cached(ttl=60 * 60,
cache=Cache.MEMORY,
key="_feishu_tenant_access_token",
serializer=PickleSerializer())
@cached(
ttl=60 * 60,
cache=Cache.MEMORY,
key="_feishu_tenant_access_token",
serializer=PickleSerializer(),
)
async def _fetch_tenant_access_token(self) -> str:
try:
async with httpx.AsyncClient(follow_redirects=True) as client:
response = await client.post(
self._construct_url(
"auth/v3/tenant_access_token/internal/"),
self._construct_url("auth/v3/tenant_access_token/internal/"),
json={
"app_id": self.feishu_config.app_id,
"app_secret": self.feishu_config.app_secret
"app_secret": self.feishu_config.app_secret,
},
timeout=self.config.api_timeout)
timeout=self.config.api_timeout,
)
if 200 <= response.status_code < 300:
result = response.json()
return result["tenant_access_token"]
else:
raise NetworkError(f"HTTP request received unexpected "
f"status code: {response.status_code}")
raise NetworkError(
f"HTTP request received unexpected "
f"status code: {response.status_code}"
)
except httpx.InvalidURL:
raise NetworkError("API root url invalid")
except httpx.HTTPError:
@ -280,30 +297,37 @@ class Bot(BaseBot):
raise ApiNotAvailable
headers = {}
self.feishu_config.tenant_access_token = await self._fetch_tenant_access_token(
self.feishu_config.tenant_access_token = (
await self._fetch_tenant_access_token()
)
headers["Authorization"] = (
"Bearer " + self.feishu_config.tenant_access_token
)
headers[
"Authorization"] = "Bearer " + self.feishu_config.tenant_access_token
try:
async with httpx.AsyncClient(timeout=self.config.api_timeout,
follow_redirects=True) as client:
async with httpx.AsyncClient(
timeout=self.config.api_timeout, follow_redirects=True
) as client:
response = await client.send(
httpx.Request(data["method"],
self.api_root + api,
json=data.get("body", {}),
params=data.get("query", {}),
headers=headers))
httpx.Request(
data["method"],
self.api_root + api,
json=data.get("body", {}),
params=data.get("query", {}),
headers=headers,
)
)
if 200 <= response.status_code < 300:
if response.headers["content-type"].startswith(
"application/json"):
if response.headers["content-type"].startswith("application/json"):
result = response.json()
else:
result = response.content
return _handle_api_result(result)
raise NetworkError(f"HTTP request received unexpected "
f"status code: {response.status_code} "
f"response body: {response.text}")
raise NetworkError(
f"HTTP request received unexpected "
f"status code: {response.status_code} "
f"response body: {response.text}"
)
except httpx.InvalidURL:
raise NetworkError("API root url invalid")
except httpx.HTTPError:
@ -333,11 +357,13 @@ class Bot(BaseBot):
return await super().call_api(api, **data)
@overrides(BaseBot)
async def send(self,
event: Event,
message: Union[str, Message, MessageSegment],
at_sender: bool = False,
**kwargs) -> Any:
async def send(
self,
event: Event,
message: Union[str, Message, MessageSegment],
at_sender: bool = False,
**kwargs,
) -> Any:
msg = message if isinstance(message, Message) else Message(message)
if isinstance(event, GroupMessageEvent):
@ -346,7 +372,8 @@ class Bot(BaseBot):
receive_id, receive_id_type = event.get_user_id(), "open_id"
else:
raise ValueError(
"Cannot guess `receive_id` and `receive_id_type` to reply!")
"Cannot guess `receive_id` and `receive_id_type` to reply!"
)
at_sender = at_sender and bool(event.get_user_id())
@ -357,14 +384,12 @@ class Bot(BaseBot):
params = {
"method": "POST",
"query": {
"receive_id_type": receive_id_type
},
"query": {"receive_id_type": receive_id_type},
"body": {
"receive_id": receive_id,
"content": content,
"msg_type": msg_type
}
"msg_type": msg_type,
},
}
return await self.call_api(f"im/v1/messages", **params)

View File

@ -17,13 +17,16 @@ class Config(BaseModel):
- ``is_lark`` / ``feishu_is_lark``: 是否使用Lark飞书海外版默认为 false
"""
app_id: Optional[str] = Field(default=None, alias="feishu_app_id")
app_secret: Optional[str] = Field(default=None, alias="feishu_app_secret")
encrypt_key: Optional[str] = Field(default=None, alias="feishu_encrypt_key")
verification_token: Optional[str] = Field(default=None,
alias="feishu_verification_token")
verification_token: Optional[str] = Field(
default=None, alias="feishu_verification_token"
)
tenant_access_token: Optional[str] = Field(
default=None, alias="feishu_tenant_access_token")
default=None, alias="feishu_tenant_access_token"
)
is_lark: Optional[str] = Field(default=False, alias="feishu_is_lark")
class Config:

View File

@ -1,12 +1,12 @@
import inspect
import json
from typing import Any, Dict, List, Literal, Optional, Type
import inspect
from typing import Any, Dict, List, Type, Literal, Optional
from pydantic import BaseModel, Field, root_validator
from pygtrie import StringTrie
from pydantic import Field, BaseModel, root_validator
from nonebot.adapters import Event as BaseEvent
from nonebot.typing import overrides
from nonebot.adapters import Event as BaseEvent
from .message import Message, MessageDeserializer

View File

@ -1,13 +1,12 @@
from typing import Optional
from nonebot.exception import ActionFailed as BaseActionFailed
from nonebot.exception import AdapterException
from nonebot.exception import ApiNotAvailable as BaseApiNotAvailable
from nonebot.exception import ActionFailed as BaseActionFailed
from nonebot.exception import NetworkError as BaseNetworkError
from nonebot.exception import ApiNotAvailable as BaseApiNotAvailable
class FeishuAdapterException(AdapterException):
def __init__(self):
super().__init__("feishu")
@ -28,8 +27,11 @@ class ActionFailed(BaseActionFailed, FeishuAdapterException):
self.info = kwargs
def __repr__(self):
return f"<ActionFailed " + ", ".join(
f"{k}={v}" for k, v in self.info.items()) + ">"
return (
f"<ActionFailed "
+ ", ".join(f"{k}={v}" for k, v in self.info.items())
+ ">"
)
def __str__(self):
return self.__repr__()

View File

@ -1,8 +1,18 @@
import json
import itertools
from dataclasses import dataclass
from typing import (Any, Dict, List, Type, Tuple, Union, Mapping, Iterable,
Optional, cast)
from typing import (
Any,
Dict,
List,
Type,
Tuple,
Union,
Mapping,
Iterable,
Optional,
cast,
)
from nonebot.typing import overrides
from nonebot.adapters import Message as BaseMessage
@ -34,7 +44,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
"share_user": "[个人名片]",
"system": "[系统消息]",
"location": "[位置]",
"video_chat": "[视频通话]"
"video_chat": "[视频通话]",
}
def __str__(self) -> str:
@ -47,24 +57,26 @@ class MessageSegment(BaseMessageSegment["Message"]):
@overrides(BaseMessageSegment)
def __add__(self, other) -> "Message":
return Message(self) + (MessageSegment.text(other) if isinstance(
other, str) else other)
return Message(self) + (
MessageSegment.text(other) if isinstance(other, str) else other
)
@overrides(BaseMessageSegment)
def __radd__(self, other) -> "Message":
return (MessageSegment.text(other)
if isinstance(other, str) else Message(other)) + self
return (
MessageSegment.text(other) if isinstance(other, str) else Message(other)
) + self
@overrides(BaseMessageSegment)
def is_text(self) -> bool:
return self.type == "text"
#接收消息
# 接收消息
@staticmethod
def at(user_id: str) -> "MessageSegment":
return MessageSegment("at", {"user_id": user_id})
#发送消息
# 发送消息
@staticmethod
def text(text: str) -> "MessageSegment":
return MessageSegment("text", {"text": text})
@ -79,10 +91,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
@staticmethod
def interactive(title: str, elements: list) -> "MessageSegment":
return MessageSegment("interactive", {
"title": title,
"elements": elements
})
return MessageSegment("interactive", {"title": title, "elements": elements})
@staticmethod
def share_chat(chat_id: str) -> "MessageSegment":
@ -94,28 +103,25 @@ class MessageSegment(BaseMessageSegment["Message"]):
@staticmethod
def audio(file_key: str, duration: int) -> "MessageSegment":
return MessageSegment("audio", {
"file_key": file_key,
"duration": duration
})
return MessageSegment("audio", {"file_key": file_key, "duration": duration})
@staticmethod
def media(file_key: str, image_key: str, file_name: str,
duration: int) -> "MessageSegment":
def media(
file_key: str, image_key: str, file_name: str, duration: int
) -> "MessageSegment":
return MessageSegment(
"media", {
"media",
{
"file_key": file_key,
"image_key": image_key,
"file_name": file_name,
"duration": duration
})
"duration": duration,
},
)
@staticmethod
def file(file_key: str, file_name: str) -> "MessageSegment":
return MessageSegment("file", {
"file_key": file_key,
"file_name": file_name
})
return MessageSegment("file", {"file_key": file_key, "file_name": file_name})
@staticmethod
def sticker(file_key) -> "MessageSegment":
@ -133,22 +139,22 @@ class Message(BaseMessage[MessageSegment]):
return MessageSegment
@overrides(BaseMessage)
def __add__(self, other: Union[str, Mapping,
Iterable[Mapping]]) -> "Message":
def __add__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> "Message":
return super(Message, self).__add__(
MessageSegment.text(other) if isinstance(other, str) else other)
MessageSegment.text(other) if isinstance(other, str) else other
)
@overrides(BaseMessage)
def __radd__(self, other: Union[str, Mapping,
Iterable[Mapping]]) -> "Message":
def __radd__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> "Message":
return super(Message, self).__radd__(
MessageSegment.text(other) if isinstance(other, str) else other)
MessageSegment.text(other) if isinstance(other, str) else other
)
@staticmethod
@overrides(BaseMessage)
def _construct(
msg: Union[str, Mapping,
Iterable[Mapping]]) -> Iterable[MessageSegment]:
msg: Union[str, Mapping, Iterable[Mapping]]
) -> Iterable[MessageSegment]:
if isinstance(msg, Mapping):
msg = cast(Mapping[str, Any], msg)
yield MessageSegment(msg["type"], msg.get("data") or {})
@ -169,7 +175,8 @@ class Message(BaseMessage[MessageSegment]):
for i, seg in enumerate(self):
if seg.type == "text" and i != 0 and msg[-1].type == "text":
msg[-1] = MessageSegment(
"text", {"text": msg[-1].data["text"] + seg.data["text"]})
"text", {"text": msg[-1].data["text"] + seg.data["text"]}
)
else:
msg.append(seg)
return Message(msg)
@ -184,6 +191,7 @@ class MessageSerializer:
"""
飞书 协议 Message 序列化器。
"""
message: Message
def serialize(self) -> Tuple[str, str]:
@ -198,10 +206,12 @@ class MessageSerializer:
else:
if last_segment_type == "image":
msg["content"].append([])
msg["content"][-1].append({
"tag": segment.type if segment.type != "image" else "img",
**segment.data
})
msg["content"][-1].append(
{
"tag": segment.type if segment.type != "image" else "img",
**segment.data,
}
)
last_segment_type = segment.type
return "post", json.dumps({"zh_cn": {**msg}})
@ -214,6 +224,7 @@ class MessageDeserializer:
"""
飞书 协议 Message 反序列化器。
"""
type: str
data: Dict[str, Any]
mentions: Optional[List[dict]]
@ -227,14 +238,13 @@ class MessageDeserializer:
if self.type == "post":
msg = Message()
if self.data["title"] != "":
msg += MessageSegment("text", {'text': self.data["title"]})
msg += MessageSegment("text", {"text": self.data["title"]})
for seg in itertools.chain(*self.data["content"]):
tag = seg.pop("tag")
if tag == "at":
seg["user_name"] = dict_mention[seg["user_id"]]["name"]
seg["user_id"] = dict_mention[
seg["user_id"]]["id"]["open_id"]
seg["user_id"] = dict_mention[seg["user_id"]]["id"]["open_id"]
msg += MessageSegment(tag if tag != "img" else "image", seg)
@ -242,7 +252,8 @@ class MessageDeserializer:
elif self.type == "text":
for key, mention in dict_mention.items():
self.data["text"] = self.data["text"].replace(
key, f"@{mention['name']}")
key, f"@{mention['name']}"
)
self.data["mentions"] = dict_mention
return Message(MessageSegment(self.type, self.data))

View File

@ -9,27 +9,26 @@ log = logger_wrapper("FEISHU")
class AESCipher(object):
def __init__(self, key):
self.block_size = AES.block_size
self.key = hashlib.sha256(AESCipher.str_to_bytes(key)).digest()
@staticmethod
def str_to_bytes(data):
u_type = type(b"".decode('utf8'))
u_type = type(b"".decode("utf8"))
if isinstance(data, u_type):
return data.encode('utf8')
return data.encode("utf8")
return data
@staticmethod
def _unpad(s):
return s[:-ord(s[len(s) - 1:])]
return s[: -ord(s[len(s) - 1 :])]
def decrypt(self, enc):
iv = enc[:AES.block_size]
iv = enc[: AES.block_size]
cipher = AES.new(self.key, AES.MODE_CBC, iv)
return self._unpad(cipher.decrypt(enc[AES.block_size:]))
return self._unpad(cipher.decrypt(enc[AES.block_size :]))
def decrypt_string(self, enc):
enc = base64.b64decode(enc)
return self.decrypt(enc).decode('utf8')
return self.decrypt(enc).decode("utf8")

View File

@ -36,6 +36,21 @@ nonebot2 = { path = "../../", develop = true }
# url = "https://mirrors.aliyun.com/pypi/simple/"
# default = true
[tool.black]
line-length = 88
target-version = ["py37", "py38", "py39"]
include = '\.pyi?$'
extend-exclude = '''
'''
[tool.isort]
profile = "black"
line_length = 80
length_sort = true
skip_gitignore = true
force_sort_within_sections = true
extra_standard_library = ["typing_extensions"]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"