🎨 change permission check from driver into adapter #46

This commit is contained in:
yanyongyu
2020-11-11 15:14:29 +08:00
parent 1f1f9cd7e6
commit b2a2234d5c
5 changed files with 130 additions and 76 deletions

View File

@ -12,6 +12,8 @@ CQHTTP (OneBot) v11 协议适配
import re
import sys
import hmac
import json
import asyncio
import httpx
@ -19,10 +21,10 @@ import httpx
from nonebot.log import logger
from nonebot.config import Config
from nonebot.message import handle_event
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable
from nonebot.typing import overrides, Driver, WebSocket, NoReturn
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
from nonebot.exception import NetworkError, ActionFailed, RequestDenied, ApiNotAvailable
def log(level: str, message: str):
@ -39,6 +41,16 @@ def log(level: str, message: str):
return logger.opt(colors=True).log(level, "<m>CQHTTP</m> | " + message)
def get_auth_bearer(
access_token: Optional[str] = None) -> Union[Optional[str], NoReturn]:
if not access_token:
return None
scheme, _, param = access_token.partition(" ")
if scheme.lower() not in ["bearer", "token"]:
raise RequestDenied(401, "Not authenticated")
return param
def escape(s: str, *, escape_comma: bool = True) -> str:
"""
:说明:
@ -264,8 +276,6 @@ class Bot(BaseBot):
self_id: str,
*,
websocket: Optional[WebSocket] = None):
if connection_type not in ["http", "websocket"]:
raise ValueError("Unsupported connection type")
super().__init__(driver,
connection_type,
@ -281,6 +291,47 @@ class Bot(BaseBot):
"""
return "cqhttp"
@classmethod
@overrides(BaseBot)
async def check_permission(cls, driver: Driver, connection_type: str,
headers: dict,
body: Optional[dict]) -> Union[str, NoReturn]:
x_self_id = headers.get("x-self-id")
x_signature = headers.get("x-signature")
access_token = get_auth_bearer(headers.get("authorization"))
# 检查连接方式
if connection_type not in ["http", "websocket"]:
log("WARNING", "Unsupported connection type")
raise RequestDenied(405, "Unsupported connection type")
# 检查self_id
if not x_self_id:
log("WARNING", "Missing X-Self-ID Header")
raise RequestDenied(400, "Missing X-Self-ID Header")
# 检查签名
secret = driver.config.secret
if secret and connection_type == "http":
if not x_signature:
log("WARNING", "Missing Signature Header")
raise RequestDenied(401, "Missing Signature")
sig = hmac.new(secret.encode("utf-8"),
json.dumps(body).encode(), "sha1").hexdigest()
if x_signature != "sha1=" + sig:
log("WARNING", "Signature Header is invalid")
raise RequestDenied(403, "Signature is invalid")
access_token = driver.config.access_token
if access_token and access_token != access_token:
log(
"WARNING", "Authorization Header is invalid"
if access_token else "Missing Authorization Header")
raise RequestDenied(
403, "Authorization Header is invalid"
if access_token else "Missing Authorization Header")
return str(x_self_id)
@overrides(BaseBot)
async def handle_message(self, message: dict):
"""

View File

@ -9,6 +9,11 @@ def log(level: str, message: str):
...
def get_auth_bearer(
access_token: Optional[str] = ...) -> Union[Optional[str], NoReturn]:
...
def escape(s: str, *, escape_comma: bool = ...) -> str:
...
@ -69,6 +74,12 @@ class Bot(BaseBot):
def type(self) -> str:
...
@classmethod
async def check_permission(cls, driver: Driver, connection_type: str,
headers: dict,
body: Optional[dict]) -> Union[str, NoReturn]:
...
async def handle_message(self, message: dict):
...