mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-31 15:06:42 +00:00 
			
		
		
		
	🐛 fix bugs in quart driver
This commit is contained in:
		| @@ -1,19 +1,21 @@ | |||||||
| import asyncio | import asyncio | ||||||
| from json.decoder import JSONDecodeError | from json.decoder import JSONDecodeError | ||||||
| from logging import getLogger, warn | from typing import (TYPE_CHECKING, Any, Callable, Coroutine, Dict, Optional, | ||||||
| from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar |                     Type, TypeVar) | ||||||
|  |  | ||||||
|  | import uvicorn | ||||||
|  |  | ||||||
| from nonebot.config import Config as NoneBotConfig | from nonebot.config import Config as NoneBotConfig | ||||||
| from nonebot.config import Env | from nonebot.config import Env | ||||||
| from nonebot.drivers import Driver as BaseDriver | from nonebot.drivers import Driver as BaseDriver | ||||||
| from nonebot.drivers import WebSocket as BaseWebSocket | from nonebot.drivers import WebSocket as BaseWebSocket | ||||||
| from nonebot.exception import RequestDenied | from nonebot.exception import RequestDenied | ||||||
| from nonebot.log import LoguruHandler, logger | from nonebot.log import logger | ||||||
| from nonebot.typing import overrides | from nonebot.typing import overrides | ||||||
|  |  | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from nonebot.adapters import Bot | ||||||
| try: | try: | ||||||
|     from hypercorn.asyncio import serve |  | ||||||
|     from hypercorn.config import Config as HypercornConfig |  | ||||||
|     from quart import Quart, Request, Response |     from quart import Quart, Request, Response | ||||||
|     from quart import Websocket as QuartWebSocket |     from quart import Websocket as QuartWebSocket | ||||||
|     from quart import exceptions |     from quart import exceptions | ||||||
| @@ -32,11 +34,21 @@ class Driver(BaseDriver): | |||||||
|         super().__init__(env, config) |         super().__init__(env, config) | ||||||
|  |  | ||||||
|         self._server_app = Quart(self.__class__.__qualname__) |         self._server_app = Quart(self.__class__.__qualname__) | ||||||
|         self._server_app.logger.handlers.clear() |  | ||||||
|         self._server_app.logger.addHandler(LoguruHandler()) |     @overrides(BaseDriver) | ||||||
|         self._server_app.route('/<adapter>/http', |     def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs): | ||||||
|                                methods=['POST'])(self._handle_http) |         if name in self._adapters: | ||||||
|         self._server_app.websocket('/<adapter>/ws')(self._handle_ws_reverse) |             return | ||||||
|  |  | ||||||
|  |         super().register_adapter(name, adapter, **kwargs) | ||||||
|  |  | ||||||
|  |         @self.server_app.route(f'/{name}/http', endpoint=name + '_http') | ||||||
|  |         async def _http_handler(): | ||||||
|  |             await self._handle_http(name) | ||||||
|  |  | ||||||
|  |         @self.server_app.websocket(f'/{name}/ws', endpoint=name + '_ws') | ||||||
|  |         async def _ws_handler(): | ||||||
|  |             await self._handle_ws_reverse(name) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     @overrides(BaseDriver) |     @overrides(BaseDriver) | ||||||
| @@ -55,7 +67,7 @@ class Driver(BaseDriver): | |||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     @overrides(BaseDriver) |     @overrides(BaseDriver) | ||||||
|     def loggers(self): |     def logger(self): | ||||||
|         return self._server_app.logger |         return self._server_app.logger | ||||||
|  |  | ||||||
|     @overrides(BaseDriver) |     @overrides(BaseDriver) | ||||||
| @@ -66,43 +78,60 @@ class Driver(BaseDriver): | |||||||
|     def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable: |     def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable: | ||||||
|         return self.server_app.after_serving(func)  # type: ignore |         return self.server_app.after_serving(func)  # type: ignore | ||||||
|  |  | ||||||
|  |     @overrides(BaseDriver) | ||||||
|     @overrides(BaseDriver) |     @overrides(BaseDriver) | ||||||
|     def run(self, |     def run(self, | ||||||
|             host: Optional[str] = None, |             host: Optional[str] = None, | ||||||
|             port: Optional[int] = None, |             port: Optional[int] = None, | ||||||
|  |             *, | ||||||
|  |             app: Optional[str] = None, | ||||||
|             **kwargs): |             **kwargs): | ||||||
|         super().run(host, port, **kwargs) |         """使用 ``uvicorn`` 启动 Quart""" | ||||||
|         config = HypercornConfig() |         super().run(host, port, app, **kwargs) | ||||||
|         for k, v in kwargs.items(): |         LOGGING_CONFIG = { | ||||||
|             if not hasattr(config, k): |             "version": 1, | ||||||
|                 warn(f'Config {k!r} is not available for quart driver.') |             "disable_existing_loggers": False, | ||||||
|                 continue |             "handlers": { | ||||||
|             setattr(config, k, v) |                 "default": { | ||||||
|         config.bind.append( |                     "class": "nonebot.log.LoguruHandler", | ||||||
|             f'{host or self.config.host}:{port or self.config.port}') |                 }, | ||||||
|  |             }, | ||||||
|         serve_task = asyncio.run_coroutine_threadsafe( |             "loggers": { | ||||||
|             coro=serve(self.server_app, config), |                 "uvicorn.error": { | ||||||
|             loop=asyncio.get_running_loop(), |                     "handlers": ["default"], | ||||||
|         ) |                     "level": "INFO" | ||||||
|         try: |                 }, | ||||||
|             serve_task.result() |                 "uvicorn.access": { | ||||||
|         finally: |                     "handlers": ["default"], | ||||||
|             serve_task.cancel() |                     "level": "INFO", | ||||||
|  |                 }, | ||||||
|  |             }, | ||||||
|  |         } | ||||||
|  |         uvicorn.run(app or self.server_app, | ||||||
|  |                     host=host or str(self.config.host), | ||||||
|  |                     port=port or self.config.port, | ||||||
|  |                     reload=bool(app) and self.config.debug, | ||||||
|  |                     debug=self.config.debug, | ||||||
|  |                     log_config=LOGGING_CONFIG, | ||||||
|  |                     **kwargs) | ||||||
|  |  | ||||||
|     @overrides(BaseDriver) |     @overrides(BaseDriver) | ||||||
|     async def _handle_http(self, adapter: str): |     async def _handle_http(self, adapter: str): | ||||||
|         request: Request = _request |         request: Request = _request | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             data: Dict[str, Any] = await request.get_json() |             data: Dict[str, Any] = await request.get_json() | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             raise exceptions.BadRequest() |             raise exceptions.BadRequest() | ||||||
|  |  | ||||||
|         if adapter not in self._adapters: |         if adapter not in self._adapters: | ||||||
|             logger.warning(f'Unknown adapter {adapter}. ' |             logger.warning(f'Unknown adapter {adapter}. ' | ||||||
|                            'Please register the adapter before use.') |                            'Please register the adapter before use.') | ||||||
|             raise exceptions.NotFound() |             raise exceptions.NotFound() | ||||||
|  |  | ||||||
|         BotClass = self._adapters[adapter] |         BotClass = self._adapters[adapter] | ||||||
|         headers = dict(request.headers) |         headers = {k: v for k, v in request.headers.items(lower=True)} | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             self_id = await BotClass.check_permission(self, 'http', headers, |             self_id = await BotClass.check_permission(self, 'http', headers, | ||||||
|                                                       data) |                                                       data) | ||||||
| @@ -120,7 +149,6 @@ class Driver(BaseDriver): | |||||||
|     @overrides(BaseDriver) |     @overrides(BaseDriver) | ||||||
|     async def _handle_ws_reverse(self, adapter: str): |     async def _handle_ws_reverse(self, adapter: str): | ||||||
|         websocket: QuartWebSocket = _websocket |         websocket: QuartWebSocket = _websocket | ||||||
|  |  | ||||||
|         if adapter not in self._adapters: |         if adapter not in self._adapters: | ||||||
|             logger.warning( |             logger.warning( | ||||||
|                 f'Unknown adapter {adapter}. Please register the adapter before use.' |                 f'Unknown adapter {adapter}. Please register the adapter before use.' | ||||||
| @@ -128,10 +156,12 @@ class Driver(BaseDriver): | |||||||
|             raise exceptions.NotFound() |             raise exceptions.NotFound() | ||||||
|  |  | ||||||
|         BotClass = self._adapters[adapter] |         BotClass = self._adapters[adapter] | ||||||
|         headers = dict(websocket.headers) |         headers = {k: v for k, v in websocket.headers.items(lower=True)} | ||||||
|         try: |         try: | ||||||
|             self_id = await BotClass.check_permission(self, 'ws', headers, None) |             self_id = await BotClass.check_permission(self, 'websocket', | ||||||
|  |                                                       headers, None) | ||||||
|         except RequestDenied as e: |         except RequestDenied as e: | ||||||
|  |             print(e.reason) | ||||||
|             raise exceptions.HTTPException(status_code=e.status_code, |             raise exceptions.HTTPException(status_code=e.status_code, | ||||||
|                                            description=e.reason, |                                            description=e.reason, | ||||||
|                                            name='Request Denied') |                                            name='Request Denied') | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| DRIVER=nonebot.drivers.fastapi | DRIVER=nonebot.drivers.quart | ||||||
| HOST=0.0.0.0 | HOST=0.0.0.0 | ||||||
| PORT=2333 | PORT=2333 | ||||||
| DEBUG=true | DEBUG=true | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user