mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-30 22:46:40 +00:00 
			
		
		
		
	🎨 change permission check from driver into adapter #46
This commit is contained in:
		| @@ -11,7 +11,7 @@ from dataclasses import dataclass, field | ||||
|  | ||||
| from nonebot.config import Config | ||||
| from nonebot.typing import Driver, Message, WebSocket | ||||
| from nonebot.typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable | ||||
| from nonebot.typing import Any, Dict, Union, Optional, NoReturn, Callable, Iterable, Awaitable | ||||
|  | ||||
|  | ||||
| class BaseBot(abc.ABC): | ||||
| @@ -55,6 +55,13 @@ class BaseBot(abc.ABC): | ||||
|         """Adapter 类型""" | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @classmethod | ||||
|     @abc.abstractmethod | ||||
|     async def check_permission(cls, driver: Driver, connection_type: str, | ||||
|                                headers: dict, | ||||
|                                body: Optional[dict]) -> Union[str, NoReturn]: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     async def handle_message(self, message: dict): | ||||
|         """ | ||||
|   | ||||
| @@ -12,6 +12,8 @@ CQHTTP (OneBot) v11 协议适配 | ||||
|  | ||||
| import re | ||||
| import sys | ||||
| import hmac | ||||
| import json | ||||
| import asyncio | ||||
|  | ||||
| import httpx | ||||
| @@ -19,10 +21,10 @@ import httpx | ||||
| from nonebot.log import logger | ||||
| from nonebot.config import Config | ||||
| from nonebot.message import handle_event | ||||
| from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional | ||||
| from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable | ||||
| from nonebot.typing import overrides, Driver, WebSocket, NoReturn | ||||
| from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional | ||||
| from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment | ||||
| from nonebot.exception import NetworkError, ActionFailed, RequestDenied, ApiNotAvailable | ||||
|  | ||||
|  | ||||
| def log(level: str, message: str): | ||||
| @@ -39,6 +41,16 @@ def log(level: str, message: str): | ||||
|     return logger.opt(colors=True).log(level, "<m>CQHTTP</m> | " + message) | ||||
|  | ||||
|  | ||||
| def get_auth_bearer( | ||||
|         access_token: Optional[str] = None) -> Union[Optional[str], NoReturn]: | ||||
|     if not access_token: | ||||
|         return None | ||||
|     scheme, _, param = access_token.partition(" ") | ||||
|     if scheme.lower() not in ["bearer", "token"]: | ||||
|         raise RequestDenied(401, "Not authenticated") | ||||
|     return param | ||||
|  | ||||
|  | ||||
| def escape(s: str, *, escape_comma: bool = True) -> str: | ||||
|     """ | ||||
|     :说明: | ||||
| @@ -264,8 +276,6 @@ class Bot(BaseBot): | ||||
|                  self_id: str, | ||||
|                  *, | ||||
|                  websocket: Optional[WebSocket] = None): | ||||
|         if connection_type not in ["http", "websocket"]: | ||||
|             raise ValueError("Unsupported connection type") | ||||
|  | ||||
|         super().__init__(driver, | ||||
|                          connection_type, | ||||
| @@ -281,6 +291,47 @@ class Bot(BaseBot): | ||||
|         """ | ||||
|         return "cqhttp" | ||||
|  | ||||
|     @classmethod | ||||
|     @overrides(BaseBot) | ||||
|     async def check_permission(cls, driver: Driver, connection_type: str, | ||||
|                                headers: dict, | ||||
|                                body: Optional[dict]) -> Union[str, NoReturn]: | ||||
|         x_self_id = headers.get("x-self-id") | ||||
|         x_signature = headers.get("x-signature") | ||||
|         access_token = get_auth_bearer(headers.get("authorization")) | ||||
|  | ||||
|         # 检查连接方式 | ||||
|         if connection_type not in ["http", "websocket"]: | ||||
|             log("WARNING", "Unsupported connection type") | ||||
|             raise RequestDenied(405, "Unsupported connection type") | ||||
|  | ||||
|         # 检查self_id | ||||
|         if not x_self_id: | ||||
|             log("WARNING", "Missing X-Self-ID Header") | ||||
|             raise RequestDenied(400, "Missing X-Self-ID Header") | ||||
|  | ||||
|         # 检查签名 | ||||
|         secret = driver.config.secret | ||||
|         if secret and connection_type == "http": | ||||
|             if not x_signature: | ||||
|                 log("WARNING", "Missing Signature Header") | ||||
|                 raise RequestDenied(401, "Missing Signature") | ||||
|             sig = hmac.new(secret.encode("utf-8"), | ||||
|                            json.dumps(body).encode(), "sha1").hexdigest() | ||||
|             if x_signature != "sha1=" + sig: | ||||
|                 log("WARNING", "Signature Header is invalid") | ||||
|                 raise RequestDenied(403, "Signature is invalid") | ||||
|  | ||||
|         access_token = driver.config.access_token | ||||
|         if access_token and access_token != access_token: | ||||
|             log( | ||||
|                 "WARNING", "Authorization Header is invalid" | ||||
|                 if access_token else "Missing Authorization Header") | ||||
|             raise RequestDenied( | ||||
|                 403, "Authorization Header is invalid" | ||||
|                 if access_token else "Missing Authorization Header") | ||||
|         return str(x_self_id) | ||||
|  | ||||
|     @overrides(BaseBot) | ||||
|     async def handle_message(self, message: dict): | ||||
|         """ | ||||
|   | ||||
| @@ -9,6 +9,11 @@ def log(level: str, message: str): | ||||
|     ... | ||||
|  | ||||
|  | ||||
| def get_auth_bearer( | ||||
|         access_token: Optional[str] = ...) -> Union[Optional[str], NoReturn]: | ||||
|     ... | ||||
|  | ||||
|  | ||||
| def escape(s: str, *, escape_comma: bool = ...) -> str: | ||||
|     ... | ||||
|  | ||||
| @@ -69,6 +74,12 @@ class Bot(BaseBot): | ||||
|     def type(self) -> str: | ||||
|         ... | ||||
|  | ||||
|     @classmethod | ||||
|     async def check_permission(cls, driver: Driver, connection_type: str, | ||||
|                                headers: dict, | ||||
|                                body: Optional[dict]) -> Union[str, NoReturn]: | ||||
|         ... | ||||
|  | ||||
|     async def handle_message(self, message: dict): | ||||
|         ... | ||||
|  | ||||
|   | ||||
| @@ -15,12 +15,13 @@ import logging | ||||
|  | ||||
| import uvicorn | ||||
| from fastapi.responses import Response | ||||
| from fastapi import Body, status, Header, FastAPI, Depends, HTTPException | ||||
| from fastapi import Body, status, Header, Request, FastAPI, Depends, HTTPException | ||||
| from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket | ||||
|  | ||||
| from nonebot.log import logger | ||||
| from nonebot.config import Env, Config | ||||
| from nonebot.utils import DataclassEncoder | ||||
| from nonebot.exception import RequestDenied | ||||
| from nonebot.drivers import BaseDriver, BaseWebSocket | ||||
| from nonebot.typing import Optional, Callable, overrides | ||||
|  | ||||
| @@ -127,97 +128,58 @@ class Driver(BaseDriver): | ||||
|     @overrides(BaseDriver) | ||||
|     async def _handle_http(self, | ||||
|                            adapter: str, | ||||
|                            data: dict = Body(...), | ||||
|                            x_self_id: Optional[str] = Header(None), | ||||
|                            x_signature: Optional[str] = Header(None), | ||||
|                            auth: Optional[str] = Depends(get_auth_bearer)): | ||||
|         # 检查self_id | ||||
|         if not x_self_id: | ||||
|             logger.warning("Missing X-Self-ID Header") | ||||
|             raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, | ||||
|                                 detail="Missing X-Self-ID Header") | ||||
|  | ||||
|         # 检查签名 | ||||
|         secret = self.config.secret | ||||
|         if secret: | ||||
|             if not x_signature: | ||||
|                 logger.warning("Missing Signature Header") | ||||
|                 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|                                     detail="Missing Signature") | ||||
|             sig = hmac.new(secret.encode("utf-8"), | ||||
|                            json.dumps(data).encode(), "sha1").hexdigest() | ||||
|             if x_signature != "sha1=" + sig: | ||||
|                 logger.warning("Signature Header is invalid") | ||||
|                 raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, | ||||
|                                     detail="Signature is invalid") | ||||
|  | ||||
|         access_token = self.config.access_token | ||||
|         if access_token and access_token != auth: | ||||
|             logger.warning("Authorization Header is invalid" | ||||
|                            if auth else "Missing Authorization Header") | ||||
|             raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, | ||||
|                                 detail="Authorization Header is invalid" | ||||
|                                 if auth else "Missing Authorization Header") | ||||
|  | ||||
|                            request: Request, | ||||
|                            data: dict = Body(...)): | ||||
|         if not isinstance(data, dict): | ||||
|             logger.warning("Data received is invalid") | ||||
|             raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) | ||||
|  | ||||
|         if adapter not in self._adapters: | ||||
|             logger.warning("Unknown adapter") | ||||
|             raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, | ||||
|                                 detail="adapter not found") | ||||
|  | ||||
|         # 创建 Bot 对象 | ||||
|         BotClass = self._adapters[adapter] | ||||
|         headers = dict(request.headers) | ||||
|         try: | ||||
|             x_self_id = await BotClass.check_permission(self, "http", headers, | ||||
|                                                         data) | ||||
|         except RequestDenied as e: | ||||
|             raise HTTPException(status_code=e.status_code, | ||||
|                                 detail=e.reason) from None | ||||
|  | ||||
|         if x_self_id in self._clients: | ||||
|             logger.warning("There's already a reverse websocket api connection," | ||||
|                            "so the event may be handled twice.") | ||||
|  | ||||
|         # 创建 Bot 对象 | ||||
|         if adapter in self._adapters: | ||||
|             BotClass = self._adapters[adapter] | ||||
|             bot = BotClass(self, "http", self.config, x_self_id) | ||||
|         else: | ||||
|             logger.warning("Unknown adapter") | ||||
|             raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, | ||||
|                                 detail="adapter not found") | ||||
|         bot = BotClass(self, "http", self.config, x_self_id) | ||||
|  | ||||
|         asyncio.create_task(bot.handle_message(data)) | ||||
|         return Response("", 204) | ||||
|  | ||||
|     @overrides(BaseDriver) | ||||
|     async def _handle_ws_reverse( | ||||
|         self, | ||||
|         adapter: str, | ||||
|         websocket: FastAPIWebSocket, | ||||
|         x_self_id: str = Header(None), | ||||
|         auth: Optional[str] = Depends(get_auth_bearer)): | ||||
|     async def _handle_ws_reverse(self, adapter: str, | ||||
|                                  websocket: FastAPIWebSocket): | ||||
|         ws = WebSocket(websocket) | ||||
|  | ||||
|         access_token = self.config.access_token | ||||
|         if access_token and access_token != auth: | ||||
|             logger.warning("Authorization Header is invalid" | ||||
|                            if auth else "Missing Authorization Header") | ||||
|             await ws.close(code=status.WS_1008_POLICY_VIOLATION) | ||||
|             return | ||||
|  | ||||
|         if not x_self_id: | ||||
|             logger.warning(f"Missing X-Self-ID Header") | ||||
|             await ws.close(code=status.WS_1008_POLICY_VIOLATION) | ||||
|             return | ||||
|  | ||||
|         if x_self_id in self._clients: | ||||
|             logger.warning(f"Connection Conflict: self_id {x_self_id}") | ||||
|         if adapter not in self._adapters: | ||||
|             logger.warning("Unknown adapter") | ||||
|             await ws.close(code=status.WS_1008_POLICY_VIOLATION) | ||||
|             return | ||||
|  | ||||
|         # Create Bot Object | ||||
|         if adapter in self._adapters: | ||||
|             BotClass = self._adapters[adapter] | ||||
|             bot = BotClass(self, | ||||
|                            "websocket", | ||||
|                            self.config, | ||||
|                            x_self_id, | ||||
|                            websocket=ws) | ||||
|         else: | ||||
|             logger.warning("Unknown adapter") | ||||
|         BotClass = self._adapters[adapter] | ||||
|         headers = dict(websocket.headers) | ||||
|         try: | ||||
|             x_self_id = await BotClass.check_permission(self, "websocket", | ||||
|                                                         headers, None) | ||||
|         except RequestDenied: | ||||
|             await ws.close(code=status.WS_1008_POLICY_VIOLATION) | ||||
|             return | ||||
|  | ||||
|         bot = BotClass(self, "websocket", self.config, x_self_id, websocket=ws) | ||||
|  | ||||
|         await ws.accept() | ||||
|         self._clients[x_self_id] = bot | ||||
|         logger.opt(colors=True).info( | ||||
|   | ||||
| @@ -105,6 +105,29 @@ class StopPropagation(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class RequestDenied(Exception): | ||||
|     """ | ||||
|     :说明: | ||||
|  | ||||
|       Bot 连接请求不合法。 | ||||
|  | ||||
|     :参数: | ||||
|  | ||||
|       * ``status_code: int``: HTTP 状态码 | ||||
|       * ``reason: str``: 拒绝原因 | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, status_code: int, reason: str): | ||||
|         self.status_code = status_code | ||||
|         self.reason = reason | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return f"<RequestDenied, status_code={self.status_code}, reason={self.reason}>" | ||||
|  | ||||
|     def __str__(self): | ||||
|         return self.__repr__() | ||||
|  | ||||
|  | ||||
| class ApiNotAvailable(Exception): | ||||
|     """ | ||||
|     :说明: | ||||
| @@ -131,7 +154,7 @@ class ActionFailed(Exception): | ||||
|  | ||||
|     :参数: | ||||
|  | ||||
|       * ``retcode``: 错误代码 | ||||
|       * ``retcode: Optional[int]``: 错误代码 | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, retcode: Optional[int]): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user