Feature: 细化 driver 职责类型 (#2296)

This commit is contained in:
Ju4tCode
2023-08-26 11:03:24 +08:00
committed by GitHub
parent 807a86371d
commit 2e635370bb
20 changed files with 632 additions and 284 deletions

View File

@ -19,14 +19,14 @@ import logging
from functools import wraps
from contextlib import asynccontextmanager
from typing_extensions import ParamSpec, override
from typing import Type, Union, TypeVar, Callable, Awaitable, AsyncGenerator
from typing import TYPE_CHECKING, Union, TypeVar, Callable, Awaitable, AsyncGenerator
from nonebot.drivers import Request
from nonebot.log import LoguruHandler
from nonebot.drivers import Request, Response
from nonebot.exception import WebSocketClosed
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import ForwardMixin, ForwardDriver, combine_driver
from nonebot.drivers import WebSocketClientMixin, combine_driver
try:
from websockets.exceptions import ConnectionClosed
@ -58,7 +58,7 @@ def catch_closed(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
return decorator
class Mixin(ForwardMixin):
class Mixin(WebSocketClientMixin):
"""Websockets Mixin"""
@property
@ -66,10 +66,6 @@ class Mixin(ForwardMixin):
def type(self) -> str:
return "websockets"
@override
async def request(self, setup: Request) -> Response:
return await super().request(setup)
@override
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
@ -133,5 +129,11 @@ class WebSocket(BaseWebSocket):
await self.websocket.send(data)
Driver: Type[ForwardDriver] = combine_driver(NoneDriver, Mixin) # type: ignore
"""Websockets Driver"""
if TYPE_CHECKING:
class Driver(Mixin, NoneDriver):
...
else:
Driver = combine_driver(NoneDriver, Mixin)
"""Websockets Driver"""