mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-17 03:20:54 +00:00
add types
This commit is contained in:
@ -4,17 +4,17 @@
|
||||
import abc
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
from nonebot.config import Config
|
||||
from nonebot.adapters import BaseBot
|
||||
from nonebot.typing import Dict, Optional
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.typing import Bot, Dict, Optional
|
||||
|
||||
|
||||
class BaseDriver(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self, config: Config):
|
||||
def __init__(self, env: Env, config: Config):
|
||||
self.env = env.environment
|
||||
self.config = config
|
||||
self._clients: Dict[int, BaseBot] = {}
|
||||
self._clients: Dict[int, Bot] = {}
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
@ -32,7 +32,7 @@ class BaseDriver(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def bots(self) -> Dict[int, BaseBot]:
|
||||
def bots(self) -> Dict[int, Bot]:
|
||||
return self._clients
|
||||
|
||||
@abc.abstractmethod
|
||||
@ -59,7 +59,6 @@ class BaseWebSocket(object):
|
||||
self._websocket = websocket
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def websocket(self):
|
||||
return self._websocket
|
||||
|
||||
|
@ -1,26 +1,28 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import logging
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
from fastapi import Body, status, Header, FastAPI, WebSocket as FastAPIWebSocket
|
||||
from fastapi import Body, Header, Response, WebSocket as FastAPIWebSocket
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.config import Config
|
||||
from nonebot.adapters import BaseBot
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.utils import DataclassEncoder
|
||||
from nonebot.typing import Optional, overrides
|
||||
from nonebot.adapters.cqhttp import Bot as CQBot
|
||||
from nonebot.typing import Dict, Optional, overrides
|
||||
from nonebot.drivers import BaseDriver, BaseWebSocket
|
||||
|
||||
|
||||
class Driver(BaseDriver):
|
||||
|
||||
def __init__(self, config: Config):
|
||||
super().__init__(config)
|
||||
def __init__(self, env: Env, config: Config):
|
||||
super().__init__(env, config)
|
||||
|
||||
self._server_app = FastAPI(
|
||||
debug=config.debug,
|
||||
@ -94,21 +96,28 @@ class Driver(BaseDriver):
|
||||
@overrides(BaseDriver)
|
||||
async def _handle_http(self,
|
||||
adapter: str,
|
||||
response: Response,
|
||||
data: dict = Body(...),
|
||||
x_self_id: int = Header(None),
|
||||
access_token: str = OAuth2PasswordBearer(
|
||||
"/", auto_error=False)):
|
||||
# TODO: Check authorization
|
||||
logger.debug(f"Received message: {data}")
|
||||
|
||||
# Create Bot Object
|
||||
if adapter == "cqhttp":
|
||||
bot = CQBot("http", self.config)
|
||||
await bot.handle_message(data)
|
||||
bot = CQBot("http", self.config, x_self_id)
|
||||
else:
|
||||
response.status_code = status.HTTP_404_NOT_FOUND
|
||||
return {"status": 404, "message": "adapter not found"}
|
||||
|
||||
await bot.handle_message(data)
|
||||
return {"status": 200, "message": "success"}
|
||||
|
||||
@overrides(BaseDriver)
|
||||
async def _handle_ws_reverse(self,
|
||||
adapter: str,
|
||||
websocket: FastAPIWebSocket,
|
||||
self_id: int = Header(None),
|
||||
x_self_id: int = Header(None),
|
||||
access_token: str = OAuth2PasswordBearer(
|
||||
"/", auto_error=False)):
|
||||
websocket = WebSocket(websocket)
|
||||
@ -117,13 +126,16 @@ class Driver(BaseDriver):
|
||||
|
||||
# Create Bot Object
|
||||
if adapter == "coolq":
|
||||
bot = CQBot("websocket", self.config, self_id, websocket=websocket)
|
||||
bot = CQBot("websocket",
|
||||
self.config,
|
||||
x_self_id,
|
||||
websocket=websocket)
|
||||
else:
|
||||
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
self._clients[self_id] = bot
|
||||
self._clients[x_self_id] = bot
|
||||
|
||||
while not websocket.closed:
|
||||
data = await websocket.receive()
|
||||
@ -133,7 +145,7 @@ class Driver(BaseDriver):
|
||||
|
||||
await bot.handle_message(data)
|
||||
|
||||
del self._clients[self_id]
|
||||
del self._clients[x_self_id]
|
||||
|
||||
|
||||
class WebSocket(BaseWebSocket):
|
||||
@ -172,4 +184,5 @@ class WebSocket(BaseWebSocket):
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def send(self, data: dict) -> None:
|
||||
await self.websocket.send_json(data)
|
||||
text = json.dumps(data, cls=DataclassEncoder)
|
||||
await self.websocket.send({"type": "websocket.send", "text": text})
|
||||
|
Reference in New Issue
Block a user