mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-30 22:46:40 +00:00 
			
		
		
		
	🎨 change permission check from driver into adapter #46
This commit is contained in:
		| @@ -12,6 +12,8 @@ CQHTTP (OneBot) v11 协议适配 | ||||
|  | ||||
| import re | ||||
| import sys | ||||
| import hmac | ||||
| import json | ||||
| import asyncio | ||||
|  | ||||
| import httpx | ||||
| @@ -19,10 +21,10 @@ import httpx | ||||
| from nonebot.log import logger | ||||
| from nonebot.config import Config | ||||
| from nonebot.message import handle_event | ||||
| from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional | ||||
| from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable | ||||
| from nonebot.typing import overrides, Driver, WebSocket, NoReturn | ||||
| from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional | ||||
| from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment | ||||
| from nonebot.exception import NetworkError, ActionFailed, RequestDenied, ApiNotAvailable | ||||
|  | ||||
|  | ||||
| def log(level: str, message: str): | ||||
| @@ -39,6 +41,16 @@ def log(level: str, message: str): | ||||
|     return logger.opt(colors=True).log(level, "<m>CQHTTP</m> | " + message) | ||||
|  | ||||
|  | ||||
| def get_auth_bearer( | ||||
|         access_token: Optional[str] = None) -> Union[Optional[str], NoReturn]: | ||||
|     if not access_token: | ||||
|         return None | ||||
|     scheme, _, param = access_token.partition(" ") | ||||
|     if scheme.lower() not in ["bearer", "token"]: | ||||
|         raise RequestDenied(401, "Not authenticated") | ||||
|     return param | ||||
|  | ||||
|  | ||||
| def escape(s: str, *, escape_comma: bool = True) -> str: | ||||
|     """ | ||||
|     :说明: | ||||
| @@ -264,8 +276,6 @@ class Bot(BaseBot): | ||||
|                  self_id: str, | ||||
|                  *, | ||||
|                  websocket: Optional[WebSocket] = None): | ||||
|         if connection_type not in ["http", "websocket"]: | ||||
|             raise ValueError("Unsupported connection type") | ||||
|  | ||||
|         super().__init__(driver, | ||||
|                          connection_type, | ||||
| @@ -281,6 +291,47 @@ class Bot(BaseBot): | ||||
|         """ | ||||
|         return "cqhttp" | ||||
|  | ||||
|     @classmethod | ||||
|     @overrides(BaseBot) | ||||
|     async def check_permission(cls, driver: Driver, connection_type: str, | ||||
|                                headers: dict, | ||||
|                                body: Optional[dict]) -> Union[str, NoReturn]: | ||||
|         x_self_id = headers.get("x-self-id") | ||||
|         x_signature = headers.get("x-signature") | ||||
|         access_token = get_auth_bearer(headers.get("authorization")) | ||||
|  | ||||
|         # 检查连接方式 | ||||
|         if connection_type not in ["http", "websocket"]: | ||||
|             log("WARNING", "Unsupported connection type") | ||||
|             raise RequestDenied(405, "Unsupported connection type") | ||||
|  | ||||
|         # 检查self_id | ||||
|         if not x_self_id: | ||||
|             log("WARNING", "Missing X-Self-ID Header") | ||||
|             raise RequestDenied(400, "Missing X-Self-ID Header") | ||||
|  | ||||
|         # 检查签名 | ||||
|         secret = driver.config.secret | ||||
|         if secret and connection_type == "http": | ||||
|             if not x_signature: | ||||
|                 log("WARNING", "Missing Signature Header") | ||||
|                 raise RequestDenied(401, "Missing Signature") | ||||
|             sig = hmac.new(secret.encode("utf-8"), | ||||
|                            json.dumps(body).encode(), "sha1").hexdigest() | ||||
|             if x_signature != "sha1=" + sig: | ||||
|                 log("WARNING", "Signature Header is invalid") | ||||
|                 raise RequestDenied(403, "Signature is invalid") | ||||
|  | ||||
|         access_token = driver.config.access_token | ||||
|         if access_token and access_token != access_token: | ||||
|             log( | ||||
|                 "WARNING", "Authorization Header is invalid" | ||||
|                 if access_token else "Missing Authorization Header") | ||||
|             raise RequestDenied( | ||||
|                 403, "Authorization Header is invalid" | ||||
|                 if access_token else "Missing Authorization Header") | ||||
|         return str(x_self_id) | ||||
|  | ||||
|     @overrides(BaseBot) | ||||
|     async def handle_message(self, message: dict): | ||||
|         """ | ||||
|   | ||||
| @@ -9,6 +9,11 @@ def log(level: str, message: str): | ||||
|     ... | ||||
|  | ||||
|  | ||||
| def get_auth_bearer( | ||||
|         access_token: Optional[str] = ...) -> Union[Optional[str], NoReturn]: | ||||
|     ... | ||||
|  | ||||
|  | ||||
| def escape(s: str, *, escape_comma: bool = ...) -> str: | ||||
|     ... | ||||
|  | ||||
| @@ -69,6 +74,12 @@ class Bot(BaseBot): | ||||
|     def type(self) -> str: | ||||
|         ... | ||||
|  | ||||
|     @classmethod | ||||
|     async def check_permission(cls, driver: Driver, connection_type: str, | ||||
|                                headers: dict, | ||||
|                                body: Optional[dict]) -> Union[str, NoReturn]: | ||||
|         ... | ||||
|  | ||||
|     async def handle_message(self, message: dict): | ||||
|         ... | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user