mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-30 22:46:40 +00:00 
			
		
		
		
	🐛 fix bugs in quart driver
This commit is contained in:
		| @@ -1,19 +1,21 @@ | ||||
| import asyncio | ||||
| from json.decoder import JSONDecodeError | ||||
| from logging import getLogger, warn | ||||
| from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar | ||||
| from typing import (TYPE_CHECKING, Any, Callable, Coroutine, Dict, Optional, | ||||
|                     Type, TypeVar) | ||||
|  | ||||
| import uvicorn | ||||
|  | ||||
| from nonebot.config import Config as NoneBotConfig | ||||
| from nonebot.config import Env | ||||
| from nonebot.drivers import Driver as BaseDriver | ||||
| from nonebot.drivers import WebSocket as BaseWebSocket | ||||
| from nonebot.exception import RequestDenied | ||||
| from nonebot.log import LoguruHandler, logger | ||||
| from nonebot.log import logger | ||||
| from nonebot.typing import overrides | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from nonebot.adapters import Bot | ||||
| try: | ||||
|     from hypercorn.asyncio import serve | ||||
|     from hypercorn.config import Config as HypercornConfig | ||||
|     from quart import Quart, Request, Response | ||||
|     from quart import Websocket as QuartWebSocket | ||||
|     from quart import exceptions | ||||
| @@ -32,11 +34,21 @@ class Driver(BaseDriver): | ||||
|         super().__init__(env, config) | ||||
|  | ||||
|         self._server_app = Quart(self.__class__.__qualname__) | ||||
|         self._server_app.logger.handlers.clear() | ||||
|         self._server_app.logger.addHandler(LoguruHandler()) | ||||
|         self._server_app.route('/<adapter>/http', | ||||
|                                methods=['POST'])(self._handle_http) | ||||
|         self._server_app.websocket('/<adapter>/ws')(self._handle_ws_reverse) | ||||
|  | ||||
|     @overrides(BaseDriver) | ||||
|     def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs): | ||||
|         if name in self._adapters: | ||||
|             return | ||||
|  | ||||
|         super().register_adapter(name, adapter, **kwargs) | ||||
|  | ||||
|         @self.server_app.route(f'/{name}/http', endpoint=name + '_http') | ||||
|         async def _http_handler(): | ||||
|             await self._handle_http(name) | ||||
|  | ||||
|         @self.server_app.websocket(f'/{name}/ws', endpoint=name + '_ws') | ||||
|         async def _ws_handler(): | ||||
|             await self._handle_ws_reverse(name) | ||||
|  | ||||
|     @property | ||||
|     @overrides(BaseDriver) | ||||
| @@ -55,7 +67,7 @@ class Driver(BaseDriver): | ||||
|  | ||||
|     @property | ||||
|     @overrides(BaseDriver) | ||||
|     def loggers(self): | ||||
|     def logger(self): | ||||
|         return self._server_app.logger | ||||
|  | ||||
|     @overrides(BaseDriver) | ||||
| @@ -66,43 +78,60 @@ class Driver(BaseDriver): | ||||
|     def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable: | ||||
|         return self.server_app.after_serving(func)  # type: ignore | ||||
|  | ||||
|     @overrides(BaseDriver) | ||||
|     @overrides(BaseDriver) | ||||
|     def run(self, | ||||
|             host: Optional[str] = None, | ||||
|             port: Optional[int] = None, | ||||
|             *, | ||||
|             app: Optional[str] = None, | ||||
|             **kwargs): | ||||
|         super().run(host, port, **kwargs) | ||||
|         config = HypercornConfig() | ||||
|         for k, v in kwargs.items(): | ||||
|             if not hasattr(config, k): | ||||
|                 warn(f'Config {k!r} is not available for quart driver.') | ||||
|                 continue | ||||
|             setattr(config, k, v) | ||||
|         config.bind.append( | ||||
|             f'{host or self.config.host}:{port or self.config.port}') | ||||
|  | ||||
|         serve_task = asyncio.run_coroutine_threadsafe( | ||||
|             coro=serve(self.server_app, config), | ||||
|             loop=asyncio.get_running_loop(), | ||||
|         ) | ||||
|         try: | ||||
|             serve_task.result() | ||||
|         finally: | ||||
|             serve_task.cancel() | ||||
|         """使用 ``uvicorn`` 启动 Quart""" | ||||
|         super().run(host, port, app, **kwargs) | ||||
|         LOGGING_CONFIG = { | ||||
|             "version": 1, | ||||
|             "disable_existing_loggers": False, | ||||
|             "handlers": { | ||||
|                 "default": { | ||||
|                     "class": "nonebot.log.LoguruHandler", | ||||
|                 }, | ||||
|             }, | ||||
|             "loggers": { | ||||
|                 "uvicorn.error": { | ||||
|                     "handlers": ["default"], | ||||
|                     "level": "INFO" | ||||
|                 }, | ||||
|                 "uvicorn.access": { | ||||
|                     "handlers": ["default"], | ||||
|                     "level": "INFO", | ||||
|                 }, | ||||
|             }, | ||||
|         } | ||||
|         uvicorn.run(app or self.server_app, | ||||
|                     host=host or str(self.config.host), | ||||
|                     port=port or self.config.port, | ||||
|                     reload=bool(app) and self.config.debug, | ||||
|                     debug=self.config.debug, | ||||
|                     log_config=LOGGING_CONFIG, | ||||
|                     **kwargs) | ||||
|  | ||||
|     @overrides(BaseDriver) | ||||
|     async def _handle_http(self, adapter: str): | ||||
|         request: Request = _request | ||||
|  | ||||
|         try: | ||||
|             data: Dict[str, Any] = await request.get_json() | ||||
|         except Exception as e: | ||||
|             raise exceptions.BadRequest() | ||||
|  | ||||
|         if adapter not in self._adapters: | ||||
|             logger.warning(f'Unknown adapter {adapter}. ' | ||||
|                            'Please register the adapter before use.') | ||||
|             raise exceptions.NotFound() | ||||
|  | ||||
|         BotClass = self._adapters[adapter] | ||||
|         headers = dict(request.headers) | ||||
|         headers = {k: v for k, v in request.headers.items(lower=True)} | ||||
|  | ||||
|         try: | ||||
|             self_id = await BotClass.check_permission(self, 'http', headers, | ||||
|                                                       data) | ||||
| @@ -120,7 +149,6 @@ class Driver(BaseDriver): | ||||
|     @overrides(BaseDriver) | ||||
|     async def _handle_ws_reverse(self, adapter: str): | ||||
|         websocket: QuartWebSocket = _websocket | ||||
|  | ||||
|         if adapter not in self._adapters: | ||||
|             logger.warning( | ||||
|                 f'Unknown adapter {adapter}. Please register the adapter before use.' | ||||
| @@ -128,10 +156,12 @@ class Driver(BaseDriver): | ||||
|             raise exceptions.NotFound() | ||||
|  | ||||
|         BotClass = self._adapters[adapter] | ||||
|         headers = dict(websocket.headers) | ||||
|         headers = {k: v for k, v in websocket.headers.items(lower=True)} | ||||
|         try: | ||||
|             self_id = await BotClass.check_permission(self, 'ws', headers, None) | ||||
|             self_id = await BotClass.check_permission(self, 'websocket', | ||||
|                                                       headers, None) | ||||
|         except RequestDenied as e: | ||||
|             print(e.reason) | ||||
|             raise exceptions.HTTPException(status_code=e.status_code, | ||||
|                                            description=e.reason, | ||||
|                                            name='Request Denied') | ||||
|   | ||||
		Reference in New Issue
	
	Block a user