mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-30 22:46:40 +00:00 
			
		
		
		
	🎨 change permission check from driver into adapter #46
This commit is contained in:
		| @@ -11,7 +11,7 @@ from dataclasses import dataclass, field | |||||||
|  |  | ||||||
| from nonebot.config import Config | from nonebot.config import Config | ||||||
| from nonebot.typing import Driver, Message, WebSocket | from nonebot.typing import Driver, Message, WebSocket | ||||||
| from nonebot.typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable | from nonebot.typing import Any, Dict, Union, Optional, NoReturn, Callable, Iterable, Awaitable | ||||||
|  |  | ||||||
|  |  | ||||||
| class BaseBot(abc.ABC): | class BaseBot(abc.ABC): | ||||||
| @@ -55,6 +55,13 @@ class BaseBot(abc.ABC): | |||||||
|         """Adapter 类型""" |         """Adapter 类型""" | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     @abc.abstractmethod | ||||||
|  |     async def check_permission(cls, driver: Driver, connection_type: str, | ||||||
|  |                                headers: dict, | ||||||
|  |                                body: Optional[dict]) -> Union[str, NoReturn]: | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     async def handle_message(self, message: dict): |     async def handle_message(self, message: dict): | ||||||
|         """ |         """ | ||||||
|   | |||||||
| @@ -12,6 +12,8 @@ CQHTTP (OneBot) v11 协议适配 | |||||||
|  |  | ||||||
| import re | import re | ||||||
| import sys | import sys | ||||||
|  | import hmac | ||||||
|  | import json | ||||||
| import asyncio | import asyncio | ||||||
|  |  | ||||||
| import httpx | import httpx | ||||||
| @@ -19,10 +21,10 @@ import httpx | |||||||
| from nonebot.log import logger | from nonebot.log import logger | ||||||
| from nonebot.config import Config | from nonebot.config import Config | ||||||
| from nonebot.message import handle_event | from nonebot.message import handle_event | ||||||
| from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional |  | ||||||
| from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable |  | ||||||
| from nonebot.typing import overrides, Driver, WebSocket, NoReturn | from nonebot.typing import overrides, Driver, WebSocket, NoReturn | ||||||
|  | from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional | ||||||
| from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment | from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment | ||||||
|  | from nonebot.exception import NetworkError, ActionFailed, RequestDenied, ApiNotAvailable | ||||||
|  |  | ||||||
|  |  | ||||||
| def log(level: str, message: str): | def log(level: str, message: str): | ||||||
| @@ -39,6 +41,16 @@ def log(level: str, message: str): | |||||||
|     return logger.opt(colors=True).log(level, "<m>CQHTTP</m> | " + message) |     return logger.opt(colors=True).log(level, "<m>CQHTTP</m> | " + message) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_auth_bearer( | ||||||
|  |         access_token: Optional[str] = None) -> Union[Optional[str], NoReturn]: | ||||||
|  |     if not access_token: | ||||||
|  |         return None | ||||||
|  |     scheme, _, param = access_token.partition(" ") | ||||||
|  |     if scheme.lower() not in ["bearer", "token"]: | ||||||
|  |         raise RequestDenied(401, "Not authenticated") | ||||||
|  |     return param | ||||||
|  |  | ||||||
|  |  | ||||||
| def escape(s: str, *, escape_comma: bool = True) -> str: | def escape(s: str, *, escape_comma: bool = True) -> str: | ||||||
|     """ |     """ | ||||||
|     :说明: |     :说明: | ||||||
| @@ -264,8 +276,6 @@ class Bot(BaseBot): | |||||||
|                  self_id: str, |                  self_id: str, | ||||||
|                  *, |                  *, | ||||||
|                  websocket: Optional[WebSocket] = None): |                  websocket: Optional[WebSocket] = None): | ||||||
|         if connection_type not in ["http", "websocket"]: |  | ||||||
|             raise ValueError("Unsupported connection type") |  | ||||||
|  |  | ||||||
|         super().__init__(driver, |         super().__init__(driver, | ||||||
|                          connection_type, |                          connection_type, | ||||||
| @@ -281,6 +291,47 @@ class Bot(BaseBot): | |||||||
|         """ |         """ | ||||||
|         return "cqhttp" |         return "cqhttp" | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     @overrides(BaseBot) | ||||||
|  |     async def check_permission(cls, driver: Driver, connection_type: str, | ||||||
|  |                                headers: dict, | ||||||
|  |                                body: Optional[dict]) -> Union[str, NoReturn]: | ||||||
|  |         x_self_id = headers.get("x-self-id") | ||||||
|  |         x_signature = headers.get("x-signature") | ||||||
|  |         access_token = get_auth_bearer(headers.get("authorization")) | ||||||
|  |  | ||||||
|  |         # 检查连接方式 | ||||||
|  |         if connection_type not in ["http", "websocket"]: | ||||||
|  |             log("WARNING", "Unsupported connection type") | ||||||
|  |             raise RequestDenied(405, "Unsupported connection type") | ||||||
|  |  | ||||||
|  |         # 检查self_id | ||||||
|  |         if not x_self_id: | ||||||
|  |             log("WARNING", "Missing X-Self-ID Header") | ||||||
|  |             raise RequestDenied(400, "Missing X-Self-ID Header") | ||||||
|  |  | ||||||
|  |         # 检查签名 | ||||||
|  |         secret = driver.config.secret | ||||||
|  |         if secret and connection_type == "http": | ||||||
|  |             if not x_signature: | ||||||
|  |                 log("WARNING", "Missing Signature Header") | ||||||
|  |                 raise RequestDenied(401, "Missing Signature") | ||||||
|  |             sig = hmac.new(secret.encode("utf-8"), | ||||||
|  |                            json.dumps(body).encode(), "sha1").hexdigest() | ||||||
|  |             if x_signature != "sha1=" + sig: | ||||||
|  |                 log("WARNING", "Signature Header is invalid") | ||||||
|  |                 raise RequestDenied(403, "Signature is invalid") | ||||||
|  |  | ||||||
|  |         access_token = driver.config.access_token | ||||||
|  |         if access_token and access_token != access_token: | ||||||
|  |             log( | ||||||
|  |                 "WARNING", "Authorization Header is invalid" | ||||||
|  |                 if access_token else "Missing Authorization Header") | ||||||
|  |             raise RequestDenied( | ||||||
|  |                 403, "Authorization Header is invalid" | ||||||
|  |                 if access_token else "Missing Authorization Header") | ||||||
|  |         return str(x_self_id) | ||||||
|  |  | ||||||
|     @overrides(BaseBot) |     @overrides(BaseBot) | ||||||
|     async def handle_message(self, message: dict): |     async def handle_message(self, message: dict): | ||||||
|         """ |         """ | ||||||
|   | |||||||
| @@ -9,6 +9,11 @@ def log(level: str, message: str): | |||||||
|     ... |     ... | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_auth_bearer( | ||||||
|  |         access_token: Optional[str] = ...) -> Union[Optional[str], NoReturn]: | ||||||
|  |     ... | ||||||
|  |  | ||||||
|  |  | ||||||
| def escape(s: str, *, escape_comma: bool = ...) -> str: | def escape(s: str, *, escape_comma: bool = ...) -> str: | ||||||
|     ... |     ... | ||||||
|  |  | ||||||
| @@ -69,6 +74,12 @@ class Bot(BaseBot): | |||||||
|     def type(self) -> str: |     def type(self) -> str: | ||||||
|         ... |         ... | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     async def check_permission(cls, driver: Driver, connection_type: str, | ||||||
|  |                                headers: dict, | ||||||
|  |                                body: Optional[dict]) -> Union[str, NoReturn]: | ||||||
|  |         ... | ||||||
|  |  | ||||||
|     async def handle_message(self, message: dict): |     async def handle_message(self, message: dict): | ||||||
|         ... |         ... | ||||||
|  |  | ||||||
|   | |||||||
| @@ -15,12 +15,13 @@ import logging | |||||||
|  |  | ||||||
| import uvicorn | import uvicorn | ||||||
| from fastapi.responses import Response | from fastapi.responses import Response | ||||||
| from fastapi import Body, status, Header, FastAPI, Depends, HTTPException | from fastapi import Body, status, Header, Request, FastAPI, Depends, HTTPException | ||||||
| from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket | from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket | ||||||
|  |  | ||||||
| from nonebot.log import logger | from nonebot.log import logger | ||||||
| from nonebot.config import Env, Config | from nonebot.config import Env, Config | ||||||
| from nonebot.utils import DataclassEncoder | from nonebot.utils import DataclassEncoder | ||||||
|  | from nonebot.exception import RequestDenied | ||||||
| from nonebot.drivers import BaseDriver, BaseWebSocket | from nonebot.drivers import BaseDriver, BaseWebSocket | ||||||
| from nonebot.typing import Optional, Callable, overrides | from nonebot.typing import Optional, Callable, overrides | ||||||
|  |  | ||||||
| @@ -127,97 +128,58 @@ class Driver(BaseDriver): | |||||||
|     @overrides(BaseDriver) |     @overrides(BaseDriver) | ||||||
|     async def _handle_http(self, |     async def _handle_http(self, | ||||||
|                            adapter: str, |                            adapter: str, | ||||||
|                            data: dict = Body(...), |                            request: Request, | ||||||
|                            x_self_id: Optional[str] = Header(None), |                            data: dict = Body(...)): | ||||||
|                            x_signature: Optional[str] = Header(None), |  | ||||||
|                            auth: Optional[str] = Depends(get_auth_bearer)): |  | ||||||
|         # 检查self_id |  | ||||||
|         if not x_self_id: |  | ||||||
|             logger.warning("Missing X-Self-ID Header") |  | ||||||
|             raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, |  | ||||||
|                                 detail="Missing X-Self-ID Header") |  | ||||||
|  |  | ||||||
|         # 检查签名 |  | ||||||
|         secret = self.config.secret |  | ||||||
|         if secret: |  | ||||||
|             if not x_signature: |  | ||||||
|                 logger.warning("Missing Signature Header") |  | ||||||
|                 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, |  | ||||||
|                                     detail="Missing Signature") |  | ||||||
|             sig = hmac.new(secret.encode("utf-8"), |  | ||||||
|                            json.dumps(data).encode(), "sha1").hexdigest() |  | ||||||
|             if x_signature != "sha1=" + sig: |  | ||||||
|                 logger.warning("Signature Header is invalid") |  | ||||||
|                 raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, |  | ||||||
|                                     detail="Signature is invalid") |  | ||||||
|  |  | ||||||
|         access_token = self.config.access_token |  | ||||||
|         if access_token and access_token != auth: |  | ||||||
|             logger.warning("Authorization Header is invalid" |  | ||||||
|                            if auth else "Missing Authorization Header") |  | ||||||
|             raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, |  | ||||||
|                                 detail="Authorization Header is invalid" |  | ||||||
|                                 if auth else "Missing Authorization Header") |  | ||||||
|  |  | ||||||
|         if not isinstance(data, dict): |         if not isinstance(data, dict): | ||||||
|             logger.warning("Data received is invalid") |             logger.warning("Data received is invalid") | ||||||
|             raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) |             raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) | ||||||
|  |  | ||||||
|  |         if adapter not in self._adapters: | ||||||
|  |             logger.warning("Unknown adapter") | ||||||
|  |             raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, | ||||||
|  |                                 detail="adapter not found") | ||||||
|  |  | ||||||
|  |         # 创建 Bot 对象 | ||||||
|  |         BotClass = self._adapters[adapter] | ||||||
|  |         headers = dict(request.headers) | ||||||
|  |         try: | ||||||
|  |             x_self_id = await BotClass.check_permission(self, "http", headers, | ||||||
|  |                                                         data) | ||||||
|  |         except RequestDenied as e: | ||||||
|  |             raise HTTPException(status_code=e.status_code, | ||||||
|  |                                 detail=e.reason) from None | ||||||
|  |  | ||||||
|         if x_self_id in self._clients: |         if x_self_id in self._clients: | ||||||
|             logger.warning("There's already a reverse websocket api connection," |             logger.warning("There's already a reverse websocket api connection," | ||||||
|                            "so the event may be handled twice.") |                            "so the event may be handled twice.") | ||||||
|  |  | ||||||
|         # 创建 Bot 对象 |         bot = BotClass(self, "http", self.config, x_self_id) | ||||||
|         if adapter in self._adapters: |  | ||||||
|             BotClass = self._adapters[adapter] |  | ||||||
|             bot = BotClass(self, "http", self.config, x_self_id) |  | ||||||
|         else: |  | ||||||
|             logger.warning("Unknown adapter") |  | ||||||
|             raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, |  | ||||||
|                                 detail="adapter not found") |  | ||||||
|  |  | ||||||
|         asyncio.create_task(bot.handle_message(data)) |         asyncio.create_task(bot.handle_message(data)) | ||||||
|         return Response("", 204) |         return Response("", 204) | ||||||
|  |  | ||||||
|     @overrides(BaseDriver) |     @overrides(BaseDriver) | ||||||
|     async def _handle_ws_reverse( |     async def _handle_ws_reverse(self, adapter: str, | ||||||
|         self, |                                  websocket: FastAPIWebSocket): | ||||||
|         adapter: str, |  | ||||||
|         websocket: FastAPIWebSocket, |  | ||||||
|         x_self_id: str = Header(None), |  | ||||||
|         auth: Optional[str] = Depends(get_auth_bearer)): |  | ||||||
|         ws = WebSocket(websocket) |         ws = WebSocket(websocket) | ||||||
|  |  | ||||||
|         access_token = self.config.access_token |         if adapter not in self._adapters: | ||||||
|         if access_token and access_token != auth: |             logger.warning("Unknown adapter") | ||||||
|             logger.warning("Authorization Header is invalid" |  | ||||||
|                            if auth else "Missing Authorization Header") |  | ||||||
|             await ws.close(code=status.WS_1008_POLICY_VIOLATION) |  | ||||||
|             return |  | ||||||
|  |  | ||||||
|         if not x_self_id: |  | ||||||
|             logger.warning(f"Missing X-Self-ID Header") |  | ||||||
|             await ws.close(code=status.WS_1008_POLICY_VIOLATION) |  | ||||||
|             return |  | ||||||
|  |  | ||||||
|         if x_self_id in self._clients: |  | ||||||
|             logger.warning(f"Connection Conflict: self_id {x_self_id}") |  | ||||||
|             await ws.close(code=status.WS_1008_POLICY_VIOLATION) |             await ws.close(code=status.WS_1008_POLICY_VIOLATION) | ||||||
|             return |             return | ||||||
|  |  | ||||||
|         # Create Bot Object |         # Create Bot Object | ||||||
|         if adapter in self._adapters: |         BotClass = self._adapters[adapter] | ||||||
|             BotClass = self._adapters[adapter] |         headers = dict(websocket.headers) | ||||||
|             bot = BotClass(self, |         try: | ||||||
|                            "websocket", |             x_self_id = await BotClass.check_permission(self, "websocket", | ||||||
|                            self.config, |                                                         headers, None) | ||||||
|                            x_self_id, |         except RequestDenied: | ||||||
|                            websocket=ws) |  | ||||||
|         else: |  | ||||||
|             logger.warning("Unknown adapter") |  | ||||||
|             await ws.close(code=status.WS_1008_POLICY_VIOLATION) |             await ws.close(code=status.WS_1008_POLICY_VIOLATION) | ||||||
|             return |             return | ||||||
|  |  | ||||||
|  |         bot = BotClass(self, "websocket", self.config, x_self_id, websocket=ws) | ||||||
|  |  | ||||||
|         await ws.accept() |         await ws.accept() | ||||||
|         self._clients[x_self_id] = bot |         self._clients[x_self_id] = bot | ||||||
|         logger.opt(colors=True).info( |         logger.opt(colors=True).info( | ||||||
|   | |||||||
| @@ -105,6 +105,29 @@ class StopPropagation(Exception): | |||||||
|     pass |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RequestDenied(Exception): | ||||||
|  |     """ | ||||||
|  |     :说明: | ||||||
|  |  | ||||||
|  |       Bot 连接请求不合法。 | ||||||
|  |  | ||||||
|  |     :参数: | ||||||
|  |  | ||||||
|  |       * ``status_code: int``: HTTP 状态码 | ||||||
|  |       * ``reason: str``: 拒绝原因 | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, status_code: int, reason: str): | ||||||
|  |         self.status_code = status_code | ||||||
|  |         self.reason = reason | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return f"<RequestDenied, status_code={self.status_code}, reason={self.reason}>" | ||||||
|  |  | ||||||
|  |     def __str__(self): | ||||||
|  |         return self.__repr__() | ||||||
|  |  | ||||||
|  |  | ||||||
| class ApiNotAvailable(Exception): | class ApiNotAvailable(Exception): | ||||||
|     """ |     """ | ||||||
|     :说明: |     :说明: | ||||||
| @@ -131,7 +154,7 @@ class ActionFailed(Exception): | |||||||
|  |  | ||||||
|     :参数: |     :参数: | ||||||
|  |  | ||||||
|       * ``retcode``: 错误代码 |       * ``retcode: Optional[int]``: 错误代码 | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, retcode: Optional[int]): |     def __init__(self, retcode: Optional[int]): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user