⚗️ new driver combine expr support

This commit is contained in:
yanyongyu
2021-12-23 17:20:26 +08:00
parent b9f1890d80
commit 8fb394e4c3
11 changed files with 83 additions and 68 deletions

View File

@ -252,7 +252,13 @@ class ReverseDriver(Driver):
def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]:
class CombinedDriver(driver, *mixins, ForwardDriver): # type: ignore
# check first
assert issubclass(driver, Driver), "`driver` must be subclass of Driver"
assert all(
map(lambda m: issubclass(m, ForwardMixin), mixins)
), "`mixins` must be subclass of ForwardMixin"
class CombinedDriver(*mixins, driver, ForwardDriver): # type: ignore
@property
def type(self) -> str:
return (

View File

@ -4,9 +4,9 @@ import threading
from typing import Set, Callable, Awaitable
from nonebot.log import logger
from nonebot.drivers import Driver
from nonebot.typing import overrides
from nonebot.config import Env, Config
from nonebot.drivers import ForwardDriver
STARTUP_FUNC = Callable[[], Awaitable[None]]
SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
@ -16,11 +16,7 @@ HANDLED_SIGNALS = (
)
class BlockDriver(ForwardDriver):
"""
AIOHTTP 驱动框架
"""
class BlockDriver(Driver):
def __init__(self, env: Env, config: Config):
super().__init__(env, config)
self.startup_funcs: Set[STARTUP_FUNC] = set()
@ -29,18 +25,18 @@ class BlockDriver(ForwardDriver):
self.force_exit: bool = False
@property
@overrides(ForwardDriver)
@overrides(Driver)
def type(self) -> str:
"""驱动名称: ``block_driver``"""
return "block_driver"
@property
@overrides(ForwardDriver)
@overrides(Driver)
def logger(self):
"""block driver 使用的 logger"""
return logger
@overrides(ForwardDriver)
@overrides(Driver)
def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC:
"""
:说明:
@ -54,7 +50,7 @@ class BlockDriver(ForwardDriver):
self.startup_funcs.add(func)
return func
@overrides(ForwardDriver)
@overrides(Driver)
def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC:
"""
:说明:
@ -68,7 +64,7 @@ class BlockDriver(ForwardDriver):
self.shutdown_funcs.add(func)
return func
@overrides(ForwardDriver)
@overrides(Driver)
def run(self, *args, **kwargs):
"""启动 block driver"""
super().run(*args, **kwargs)

View File

@ -19,7 +19,7 @@ except ImportError:
) from None
class AiohttpMixin(ForwardMixin):
class Mixin(ForwardMixin):
@property
@overrides(ForwardMixin)
def type(self) -> str:
@ -114,4 +114,4 @@ class WebSocket(BaseWebSocket):
await self.websocket.send_bytes(data)
Driver = combine_driver(BlockDriver, AiohttpMixin)
Driver = combine_driver(BlockDriver, Mixin)

View File

@ -22,22 +22,10 @@ from starlette.websockets import WebSocket, WebSocketState
from nonebot.config import Env
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.drivers.httpx import HttpxMixin
from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers.websockets import WebSocketsMixin
from nonebot.drivers import (
ReverseDriver,
HTTPServerSetup,
WebSocketServerSetup,
combine_driver,
)
try:
from nonebot.drivers.aiohttp import AiohttpMixin
except ImportError:
AiohttpMixin = None
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
class Config(BaseSettings):
@ -317,8 +305,3 @@ class FastAPIWebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send({"type": "websocket.send", "bytes": data})
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)
if AiohttpMixin:
AiohttpDriver = combine_driver(Driver, AiohttpMixin)

View File

@ -1,5 +1,3 @@
import httpx
from nonebot.typing import overrides
from nonebot.drivers._block_driver import BlockDriver
from nonebot.drivers import (
@ -11,8 +9,13 @@ from nonebot.drivers import (
combine_driver,
)
try:
import httpx
except ImportError:
raise ImportError("Please install httpx by using `pip install nonebot2[httpx]`")
class HttpxMixin(ForwardMixin):
class Mixin(ForwardMixin):
@property
@overrides(ForwardMixin)
def type(self) -> str:
@ -39,7 +42,7 @@ class HttpxMixin(ForwardMixin):
@overrides(ForwardMixin)
async def websocket(self, setup: Request) -> WebSocket:
return await super(HttpxMixin, self).websocket(setup)
return await super(Mixin, self).websocket(setup)
Driver = combine_driver(BlockDriver, HttpxMixin)
Driver = combine_driver(BlockDriver, Mixin)

View File

@ -17,17 +17,10 @@ from nonebot.config import Env
from nonebot.log import logger
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.drivers.httpx import HttpxMixin
from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers.websockets import WebSocketsMixin
from nonebot.drivers import (
ReverseDriver,
HTTPServerSetup,
WebSocketServerSetup,
combine_driver,
)
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
try:
from quart import request as _request
@ -295,6 +288,3 @@ class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes):
await self.websocket.send(data)
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)

View File

@ -1,7 +1,5 @@
import logging
from websockets.legacy.client import Connect, WebSocketClientProtocol
from nonebot.typing import overrides
from nonebot.log import LoguruHandler
from nonebot.drivers import Request, Response
@ -9,11 +7,18 @@ from nonebot.drivers._block_driver import BlockDriver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import ForwardMixin, combine_driver
try:
from websockets.legacy.client import Connect, WebSocketClientProtocol
except ImportError:
raise ImportError(
"Please install websockets by using `pip install nonebot2[websockets]`"
)
logger = logging.Logger("websockets.client", "INFO")
logger.addHandler(LoguruHandler())
class WebSocketsMixin(ForwardMixin):
class Mixin(ForwardMixin):
@property
@overrides(ForwardMixin)
def type(self) -> str:
@ -21,7 +26,7 @@ class WebSocketsMixin(ForwardMixin):
@overrides(ForwardMixin)
async def request(self, setup: Request) -> Response:
return await super(WebSocketsMixin, self).request(setup)
return await super(Mixin, self).request(setup)
@overrides(ForwardMixin)
async def websocket(self, setup: Request) -> "WebSocket":
@ -75,4 +80,4 @@ class WebSocket(BaseWebSocket):
await self.websocket.send(data)
Driver = combine_driver(BlockDriver, WebSocketsMixin)
Driver = combine_driver(BlockDriver, Mixin)