mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-17 19:40:44 +00:00
🎨 change permission check from driver into adapter #46
This commit is contained in:
@ -15,12 +15,13 @@ import logging
|
||||
|
||||
import uvicorn
|
||||
from fastapi.responses import Response
|
||||
from fastapi import Body, status, Header, FastAPI, Depends, HTTPException
|
||||
from fastapi import Body, status, Header, Request, FastAPI, Depends, HTTPException
|
||||
from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.utils import DataclassEncoder
|
||||
from nonebot.exception import RequestDenied
|
||||
from nonebot.drivers import BaseDriver, BaseWebSocket
|
||||
from nonebot.typing import Optional, Callable, overrides
|
||||
|
||||
@ -127,97 +128,58 @@ class Driver(BaseDriver):
|
||||
@overrides(BaseDriver)
|
||||
async def _handle_http(self,
|
||||
adapter: str,
|
||||
data: dict = Body(...),
|
||||
x_self_id: Optional[str] = Header(None),
|
||||
x_signature: Optional[str] = Header(None),
|
||||
auth: Optional[str] = Depends(get_auth_bearer)):
|
||||
# 检查self_id
|
||||
if not x_self_id:
|
||||
logger.warning("Missing X-Self-ID Header")
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Missing X-Self-ID Header")
|
||||
|
||||
# 检查签名
|
||||
secret = self.config.secret
|
||||
if secret:
|
||||
if not x_signature:
|
||||
logger.warning("Missing Signature Header")
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing Signature")
|
||||
sig = hmac.new(secret.encode("utf-8"),
|
||||
json.dumps(data).encode(), "sha1").hexdigest()
|
||||
if x_signature != "sha1=" + sig:
|
||||
logger.warning("Signature Header is invalid")
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Signature is invalid")
|
||||
|
||||
access_token = self.config.access_token
|
||||
if access_token and access_token != auth:
|
||||
logger.warning("Authorization Header is invalid"
|
||||
if auth else "Missing Authorization Header")
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Authorization Header is invalid"
|
||||
if auth else "Missing Authorization Header")
|
||||
|
||||
request: Request,
|
||||
data: dict = Body(...)):
|
||||
if not isinstance(data, dict):
|
||||
logger.warning("Data received is invalid")
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if adapter not in self._adapters:
|
||||
logger.warning("Unknown adapter")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="adapter not found")
|
||||
|
||||
# 创建 Bot 对象
|
||||
BotClass = self._adapters[adapter]
|
||||
headers = dict(request.headers)
|
||||
try:
|
||||
x_self_id = await BotClass.check_permission(self, "http", headers,
|
||||
data)
|
||||
except RequestDenied as e:
|
||||
raise HTTPException(status_code=e.status_code,
|
||||
detail=e.reason) from None
|
||||
|
||||
if x_self_id in self._clients:
|
||||
logger.warning("There's already a reverse websocket api connection,"
|
||||
"so the event may be handled twice.")
|
||||
|
||||
# 创建 Bot 对象
|
||||
if adapter in self._adapters:
|
||||
BotClass = self._adapters[adapter]
|
||||
bot = BotClass(self, "http", self.config, x_self_id)
|
||||
else:
|
||||
logger.warning("Unknown adapter")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="adapter not found")
|
||||
bot = BotClass(self, "http", self.config, x_self_id)
|
||||
|
||||
asyncio.create_task(bot.handle_message(data))
|
||||
return Response("", 204)
|
||||
|
||||
@overrides(BaseDriver)
|
||||
async def _handle_ws_reverse(
|
||||
self,
|
||||
adapter: str,
|
||||
websocket: FastAPIWebSocket,
|
||||
x_self_id: str = Header(None),
|
||||
auth: Optional[str] = Depends(get_auth_bearer)):
|
||||
async def _handle_ws_reverse(self, adapter: str,
|
||||
websocket: FastAPIWebSocket):
|
||||
ws = WebSocket(websocket)
|
||||
|
||||
access_token = self.config.access_token
|
||||
if access_token and access_token != auth:
|
||||
logger.warning("Authorization Header is invalid"
|
||||
if auth else "Missing Authorization Header")
|
||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
if not x_self_id:
|
||||
logger.warning(f"Missing X-Self-ID Header")
|
||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
if x_self_id in self._clients:
|
||||
logger.warning(f"Connection Conflict: self_id {x_self_id}")
|
||||
if adapter not in self._adapters:
|
||||
logger.warning("Unknown adapter")
|
||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
# Create Bot Object
|
||||
if adapter in self._adapters:
|
||||
BotClass = self._adapters[adapter]
|
||||
bot = BotClass(self,
|
||||
"websocket",
|
||||
self.config,
|
||||
x_self_id,
|
||||
websocket=ws)
|
||||
else:
|
||||
logger.warning("Unknown adapter")
|
||||
BotClass = self._adapters[adapter]
|
||||
headers = dict(websocket.headers)
|
||||
try:
|
||||
x_self_id = await BotClass.check_permission(self, "websocket",
|
||||
headers, None)
|
||||
except RequestDenied:
|
||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
bot = BotClass(self, "websocket", self.config, x_self_id, websocket=ws)
|
||||
|
||||
await ws.accept()
|
||||
self._clients[x_self_id] = bot
|
||||
logger.opt(colors=True).info(
|
||||
|
Reference in New Issue
Block a user