mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-28 16:51:26 +00:00
✨ support custom response
This commit is contained in:
@ -3,14 +3,15 @@ import sys
|
||||
import hmac
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Any, Dict, Union, Optional, TYPE_CHECKING
|
||||
from typing import Any, Dict, Tuple, Union, Optional, TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
from nonebot.log import logger
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.message import handle_event
|
||||
from nonebot.utils import DataclassEncoder
|
||||
from nonebot.adapters import Bot as BaseBot
|
||||
from nonebot.exception import RequestDenied
|
||||
from nonebot.drivers import Driver, HTTPConnection, HTTPRequest, HTTPResponse, WebSocket
|
||||
|
||||
from .utils import log, escape
|
||||
from .config import Config as CQHTTPConfig
|
||||
@ -20,7 +21,6 @@ from .exception import NetworkError, ApiNotAvailable, ActionFailed
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.config import Config
|
||||
from nonebot.drivers import Driver, WebSocket
|
||||
|
||||
|
||||
def get_auth_bearer(access_token: Optional[str] = None) -> Optional[str]:
|
||||
@ -28,7 +28,7 @@ def get_auth_bearer(access_token: Optional[str] = None) -> Optional[str]:
|
||||
return None
|
||||
scheme, _, param = access_token.partition(" ")
|
||||
if scheme.lower() not in ["bearer", "token"]:
|
||||
raise RequestDenied(401, "Not authenticated")
|
||||
return None
|
||||
return param
|
||||
|
||||
|
||||
@ -225,14 +225,6 @@ class Bot(BaseBot):
|
||||
"""
|
||||
cqhttp_config: CQHTTPConfig
|
||||
|
||||
def __init__(self,
|
||||
connection_type: str,
|
||||
self_id: str,
|
||||
*,
|
||||
websocket: Optional["WebSocket"] = None):
|
||||
|
||||
super().__init__(connection_type, self_id, websocket=websocket)
|
||||
|
||||
@property
|
||||
@overrides(BaseBot)
|
||||
def type(self) -> str:
|
||||
@ -242,84 +234,84 @@ class Bot(BaseBot):
|
||||
return "cqhttp"
|
||||
|
||||
@classmethod
|
||||
def register(cls, driver: "Driver", config: "Config"):
|
||||
def register(cls, driver: Driver, config: "Config"):
|
||||
super().register(driver, config)
|
||||
cls.cqhttp_config = CQHTTPConfig(**config.dict())
|
||||
|
||||
@classmethod
|
||||
@overrides(BaseBot)
|
||||
async def check_permission(cls, driver: "Driver", connection_type: str,
|
||||
headers: dict, body: Optional[bytes]) -> str:
|
||||
async def check_permission(
|
||||
cls, driver: Driver,
|
||||
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
CQHTTP (OneBot) 协议鉴权。参考 `鉴权 <https://github.com/howmanybots/onebot/blob/master/v11/specs/communication/authorization.md>`_
|
||||
"""
|
||||
x_self_id = headers.get("x-self-id")
|
||||
x_signature = headers.get("x-signature")
|
||||
token = get_auth_bearer(headers.get("authorization"))
|
||||
x_self_id = request.headers.get("x-self-id")
|
||||
x_signature = request.headers.get("x-signature")
|
||||
token = get_auth_bearer(request.headers.get("authorization"))
|
||||
cqhttp_config = CQHTTPConfig(**driver.config.dict())
|
||||
|
||||
# 检查连接方式
|
||||
if connection_type not in ["http", "websocket"]:
|
||||
log("WARNING", "Unsupported connection type")
|
||||
raise RequestDenied(405, "Unsupported connection type")
|
||||
|
||||
# 检查self_id
|
||||
if not x_self_id:
|
||||
log("WARNING", "Missing X-Self-ID Header")
|
||||
raise RequestDenied(400, "Missing X-Self-ID Header")
|
||||
return None, HTTPResponse(400, b"Missing X-Self-ID Header")
|
||||
|
||||
# 检查签名
|
||||
secret = cqhttp_config.secret
|
||||
if secret and connection_type == "http":
|
||||
if secret and isinstance(request, HTTPRequest):
|
||||
if not x_signature:
|
||||
log("WARNING", "Missing Signature Header")
|
||||
raise RequestDenied(401, "Missing Signature")
|
||||
sig = hmac.new(secret.encode("utf-8"), body, "sha1").hexdigest()
|
||||
return None, HTTPResponse(401, b"Missing Signature")
|
||||
sig = hmac.new(secret.encode("utf-8"), request.body,
|
||||
"sha1").hexdigest()
|
||||
if x_signature != "sha1=" + sig:
|
||||
log("WARNING", "Signature Header is invalid")
|
||||
raise RequestDenied(403, "Signature is invalid")
|
||||
return None, HTTPResponse(403, b"Signature is invalid")
|
||||
|
||||
access_token = cqhttp_config.access_token
|
||||
if access_token and access_token != token and connection_type == "websocket":
|
||||
if access_token and access_token != token and isinstance(
|
||||
request, WebSocket):
|
||||
log(
|
||||
"WARNING", "Authorization Header is invalid"
|
||||
if token else "Missing Authorization Header")
|
||||
raise RequestDenied(
|
||||
403, "Authorization Header is invalid"
|
||||
if token else "Missing Authorization Header")
|
||||
return str(x_self_id)
|
||||
return None, HTTPResponse(
|
||||
403, b"Authorization Header is invalid"
|
||||
if token else b"Missing Authorization Header")
|
||||
return str(x_self_id), HTTPResponse(204, b'')
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def handle_message(self, message: dict):
|
||||
async def handle_message(self, message: bytes):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
调用 `_check_reply <#async-check-reply-bot-event>`_, `_check_at_me <#check-at-me-bot-event>`_, `_check_nickname <#check-nickname-bot-event>`_ 处理事件并转换为 `Event <#class-event>`_
|
||||
"""
|
||||
if not message:
|
||||
data = json.loads(message)
|
||||
|
||||
if not data:
|
||||
return
|
||||
|
||||
if "post_type" not in message:
|
||||
ResultStore.add_result(message)
|
||||
if "post_type" not in data:
|
||||
ResultStore.add_result(data)
|
||||
return
|
||||
|
||||
try:
|
||||
post_type = message['post_type']
|
||||
detail_type = message.get(f"{post_type}_type")
|
||||
post_type = data['post_type']
|
||||
detail_type = data.get(f"{post_type}_type")
|
||||
detail_type = f".{detail_type}" if detail_type else ""
|
||||
sub_type = message.get("sub_type")
|
||||
sub_type = data.get("sub_type")
|
||||
sub_type = f".{sub_type}" if sub_type else ""
|
||||
models = get_event_model(post_type + detail_type + sub_type)
|
||||
for model in models:
|
||||
try:
|
||||
event = model.parse_obj(message)
|
||||
event = model.parse_obj(data)
|
||||
break
|
||||
except Exception as e:
|
||||
log("DEBUG", "Event Parser Error", e)
|
||||
else:
|
||||
event = Event.parse_obj(message)
|
||||
event = Event.parse_obj(data)
|
||||
|
||||
# Check whether user is calling me
|
||||
await _check_reply(self, event)
|
||||
@ -329,25 +321,28 @@ class Bot(BaseBot):
|
||||
await handle_event(self, event)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
f"<r><bg #f8bbd0>Failed to handle event. Raw: {message}</bg #f8bbd0></r>"
|
||||
f"<r><bg #f8bbd0>Failed to handle event. Raw: {data}</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def _call_api(self, api: str, **data) -> Any:
|
||||
log("DEBUG", f"Calling API <y>{api}</y>")
|
||||
if self.connection_type == "websocket":
|
||||
if isinstance(self.request, WebSocket):
|
||||
seq = ResultStore.get_seq()
|
||||
await self.websocket.send({
|
||||
"action": api,
|
||||
"params": data,
|
||||
"echo": {
|
||||
"seq": seq
|
||||
}
|
||||
})
|
||||
json_data = json.dumps(
|
||||
{
|
||||
"action": api,
|
||||
"params": data,
|
||||
"echo": {
|
||||
"seq": seq
|
||||
}
|
||||
},
|
||||
cls=DataclassEncoder)
|
||||
await self.request.send(json_data)
|
||||
return _handle_api_result(await ResultStore.fetch(
|
||||
seq, self.config.api_timeout))
|
||||
|
||||
elif self.connection_type == "http":
|
||||
elif isinstance(self.request, HTTPRequest):
|
||||
api_root = self.config.api_root.get(self.self_id)
|
||||
if not api_root:
|
||||
raise ApiNotAvailable
|
||||
@ -431,7 +426,7 @@ class Bot(BaseBot):
|
||||
message, str) else message
|
||||
msg = message if isinstance(message, Message) else Message(message)
|
||||
|
||||
at_sender = at_sender and getattr(event, "user_id", None)
|
||||
at_sender = at_sender and bool(getattr(event, "user_id", None))
|
||||
|
||||
params = {}
|
||||
if getattr(event, "user_id", None):
|
||||
@ -449,8 +444,7 @@ class Bot(BaseBot):
|
||||
raise ValueError("Cannot guess message type to reply!")
|
||||
|
||||
if at_sender and params["message_type"] != "private":
|
||||
params["message"] = MessageSegment.at(params["user_id"]) + \
|
||||
MessageSegment.text(" ") + msg
|
||||
params["message"] = MessageSegment.at(params["user_id"]) + " " + msg
|
||||
else:
|
||||
params["message"] = msg
|
||||
return await self.send_msg(**params)
|
||||
|
Reference in New Issue
Block a user