🎨 change permission check from driver into adapter #46

This commit is contained in:
yanyongyu
2020-11-11 15:14:29 +08:00
parent 1f1f9cd7e6
commit b2a2234d5c
5 changed files with 130 additions and 76 deletions

View File

@ -15,12 +15,13 @@ import logging
import uvicorn
from fastapi.responses import Response
from fastapi import Body, status, Header, FastAPI, Depends, HTTPException
from fastapi import Body, status, Header, Request, FastAPI, Depends, HTTPException
from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket
from nonebot.log import logger
from nonebot.config import Env, Config
from nonebot.utils import DataclassEncoder
from nonebot.exception import RequestDenied
from nonebot.drivers import BaseDriver, BaseWebSocket
from nonebot.typing import Optional, Callable, overrides
@ -127,97 +128,58 @@ class Driver(BaseDriver):
@overrides(BaseDriver)
async def _handle_http(self,
adapter: str,
data: dict = Body(...),
x_self_id: Optional[str] = Header(None),
x_signature: Optional[str] = Header(None),
auth: Optional[str] = Depends(get_auth_bearer)):
# 检查self_id
if not x_self_id:
logger.warning("Missing X-Self-ID Header")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing X-Self-ID Header")
# 检查签名
secret = self.config.secret
if secret:
if not x_signature:
logger.warning("Missing Signature Header")
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing Signature")
sig = hmac.new(secret.encode("utf-8"),
json.dumps(data).encode(), "sha1").hexdigest()
if x_signature != "sha1=" + sig:
logger.warning("Signature Header is invalid")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
detail="Signature is invalid")
access_token = self.config.access_token
if access_token and access_token != auth:
logger.warning("Authorization Header is invalid"
if auth else "Missing Authorization Header")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
detail="Authorization Header is invalid"
if auth else "Missing Authorization Header")
request: Request,
data: dict = Body(...)):
if not isinstance(data, dict):
logger.warning("Data received is invalid")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
if adapter not in self._adapters:
logger.warning("Unknown adapter")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
detail="adapter not found")
# 创建 Bot 对象
BotClass = self._adapters[adapter]
headers = dict(request.headers)
try:
x_self_id = await BotClass.check_permission(self, "http", headers,
data)
except RequestDenied as e:
raise HTTPException(status_code=e.status_code,
detail=e.reason) from None
if x_self_id in self._clients:
logger.warning("There's already a reverse websocket api connection,"
"so the event may be handled twice.")
# 创建 Bot 对象
if adapter in self._adapters:
BotClass = self._adapters[adapter]
bot = BotClass(self, "http", self.config, x_self_id)
else:
logger.warning("Unknown adapter")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
detail="adapter not found")
bot = BotClass(self, "http", self.config, x_self_id)
asyncio.create_task(bot.handle_message(data))
return Response("", 204)
@overrides(BaseDriver)
async def _handle_ws_reverse(
self,
adapter: str,
websocket: FastAPIWebSocket,
x_self_id: str = Header(None),
auth: Optional[str] = Depends(get_auth_bearer)):
async def _handle_ws_reverse(self, adapter: str,
websocket: FastAPIWebSocket):
ws = WebSocket(websocket)
access_token = self.config.access_token
if access_token and access_token != auth:
logger.warning("Authorization Header is invalid"
if auth else "Missing Authorization Header")
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
if not x_self_id:
logger.warning(f"Missing X-Self-ID Header")
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
if x_self_id in self._clients:
logger.warning(f"Connection Conflict: self_id {x_self_id}")
if adapter not in self._adapters:
logger.warning("Unknown adapter")
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
# Create Bot Object
if adapter in self._adapters:
BotClass = self._adapters[adapter]
bot = BotClass(self,
"websocket",
self.config,
x_self_id,
websocket=ws)
else:
logger.warning("Unknown adapter")
BotClass = self._adapters[adapter]
headers = dict(websocket.headers)
try:
x_self_id = await BotClass.check_permission(self, "websocket",
headers, None)
except RequestDenied:
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
bot = BotClass(self, "websocket", self.config, x_self_id, websocket=ws)
await ws.accept()
self._clients[x_self_id] = bot
logger.opt(colors=True).info(