mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-09-07 04:26:45 +00:00
⚗️ new driver combine expr support
This commit is contained in:
@ -252,7 +252,13 @@ class ReverseDriver(Driver):
|
||||
|
||||
|
||||
def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]:
|
||||
class CombinedDriver(driver, *mixins, ForwardDriver): # type: ignore
|
||||
# check first
|
||||
assert issubclass(driver, Driver), "`driver` must be subclass of Driver"
|
||||
assert all(
|
||||
map(lambda m: issubclass(m, ForwardMixin), mixins)
|
||||
), "`mixins` must be subclass of ForwardMixin"
|
||||
|
||||
class CombinedDriver(*mixins, driver, ForwardDriver): # type: ignore
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return (
|
||||
|
@ -4,9 +4,9 @@ import threading
|
||||
from typing import Set, Callable, Awaitable
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.drivers import Driver
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.drivers import ForwardDriver
|
||||
|
||||
STARTUP_FUNC = Callable[[], Awaitable[None]]
|
||||
SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
|
||||
@ -16,11 +16,7 @@ HANDLED_SIGNALS = (
|
||||
)
|
||||
|
||||
|
||||
class BlockDriver(ForwardDriver):
|
||||
"""
|
||||
AIOHTTP 驱动框架
|
||||
"""
|
||||
|
||||
class BlockDriver(Driver):
|
||||
def __init__(self, env: Env, config: Config):
|
||||
super().__init__(env, config)
|
||||
self.startup_funcs: Set[STARTUP_FUNC] = set()
|
||||
@ -29,18 +25,18 @@ class BlockDriver(ForwardDriver):
|
||||
self.force_exit: bool = False
|
||||
|
||||
@property
|
||||
@overrides(ForwardDriver)
|
||||
@overrides(Driver)
|
||||
def type(self) -> str:
|
||||
"""驱动名称: ``block_driver``"""
|
||||
return "block_driver"
|
||||
|
||||
@property
|
||||
@overrides(ForwardDriver)
|
||||
@overrides(Driver)
|
||||
def logger(self):
|
||||
"""block driver 使用的 logger"""
|
||||
return logger
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
@overrides(Driver)
|
||||
def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC:
|
||||
"""
|
||||
:说明:
|
||||
@ -54,7 +50,7 @@ class BlockDriver(ForwardDriver):
|
||||
self.startup_funcs.add(func)
|
||||
return func
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
@overrides(Driver)
|
||||
def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC:
|
||||
"""
|
||||
:说明:
|
||||
@ -68,7 +64,7 @@ class BlockDriver(ForwardDriver):
|
||||
self.shutdown_funcs.add(func)
|
||||
return func
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
@overrides(Driver)
|
||||
def run(self, *args, **kwargs):
|
||||
"""启动 block driver"""
|
||||
super().run(*args, **kwargs)
|
||||
|
@ -19,7 +19,7 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
class AiohttpMixin(ForwardMixin):
|
||||
class Mixin(ForwardMixin):
|
||||
@property
|
||||
@overrides(ForwardMixin)
|
||||
def type(self) -> str:
|
||||
@ -114,4 +114,4 @@ class WebSocket(BaseWebSocket):
|
||||
await self.websocket.send_bytes(data)
|
||||
|
||||
|
||||
Driver = combine_driver(BlockDriver, AiohttpMixin)
|
||||
Driver = combine_driver(BlockDriver, Mixin)
|
||||
|
@ -22,22 +22,10 @@ from starlette.websockets import WebSocket, WebSocketState
|
||||
from nonebot.config import Env
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.utils import escape_tag
|
||||
from nonebot.drivers.httpx import HttpxMixin
|
||||
from nonebot.config import Config as NoneBotConfig
|
||||
from nonebot.drivers import Request as BaseRequest
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers.websockets import WebSocketsMixin
|
||||
from nonebot.drivers import (
|
||||
ReverseDriver,
|
||||
HTTPServerSetup,
|
||||
WebSocketServerSetup,
|
||||
combine_driver,
|
||||
)
|
||||
|
||||
try:
|
||||
from nonebot.drivers.aiohttp import AiohttpMixin
|
||||
except ImportError:
|
||||
AiohttpMixin = None
|
||||
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
@ -317,8 +305,3 @@ class FastAPIWebSocket(BaseWebSocket):
|
||||
@overrides(BaseWebSocket)
|
||||
async def send_bytes(self, data: bytes) -> None:
|
||||
await self.websocket.send({"type": "websocket.send", "bytes": data})
|
||||
|
||||
|
||||
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)
|
||||
if AiohttpMixin:
|
||||
AiohttpDriver = combine_driver(Driver, AiohttpMixin)
|
||||
|
@ -1,5 +1,3 @@
|
||||
import httpx
|
||||
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.drivers._block_driver import BlockDriver
|
||||
from nonebot.drivers import (
|
||||
@ -11,8 +9,13 @@ from nonebot.drivers import (
|
||||
combine_driver,
|
||||
)
|
||||
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
raise ImportError("Please install httpx by using `pip install nonebot2[httpx]`")
|
||||
|
||||
class HttpxMixin(ForwardMixin):
|
||||
|
||||
class Mixin(ForwardMixin):
|
||||
@property
|
||||
@overrides(ForwardMixin)
|
||||
def type(self) -> str:
|
||||
@ -39,7 +42,7 @@ class HttpxMixin(ForwardMixin):
|
||||
|
||||
@overrides(ForwardMixin)
|
||||
async def websocket(self, setup: Request) -> WebSocket:
|
||||
return await super(HttpxMixin, self).websocket(setup)
|
||||
return await super(Mixin, self).websocket(setup)
|
||||
|
||||
|
||||
Driver = combine_driver(BlockDriver, HttpxMixin)
|
||||
Driver = combine_driver(BlockDriver, Mixin)
|
||||
|
@ -17,17 +17,10 @@ from nonebot.config import Env
|
||||
from nonebot.log import logger
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.utils import escape_tag
|
||||
from nonebot.drivers.httpx import HttpxMixin
|
||||
from nonebot.config import Config as NoneBotConfig
|
||||
from nonebot.drivers import Request as BaseRequest
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers.websockets import WebSocketsMixin
|
||||
from nonebot.drivers import (
|
||||
ReverseDriver,
|
||||
HTTPServerSetup,
|
||||
WebSocketServerSetup,
|
||||
combine_driver,
|
||||
)
|
||||
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
|
||||
|
||||
try:
|
||||
from quart import request as _request
|
||||
@ -295,6 +288,3 @@ class WebSocket(BaseWebSocket):
|
||||
@overrides(BaseWebSocket)
|
||||
async def send_bytes(self, data: bytes):
|
||||
await self.websocket.send(data)
|
||||
|
||||
|
||||
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)
|
||||
|
@ -1,7 +1,5 @@
|
||||
import logging
|
||||
|
||||
from websockets.legacy.client import Connect, WebSocketClientProtocol
|
||||
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.log import LoguruHandler
|
||||
from nonebot.drivers import Request, Response
|
||||
@ -9,11 +7,18 @@ from nonebot.drivers._block_driver import BlockDriver
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers import ForwardMixin, combine_driver
|
||||
|
||||
try:
|
||||
from websockets.legacy.client import Connect, WebSocketClientProtocol
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install websockets by using `pip install nonebot2[websockets]`"
|
||||
)
|
||||
|
||||
logger = logging.Logger("websockets.client", "INFO")
|
||||
logger.addHandler(LoguruHandler())
|
||||
|
||||
|
||||
class WebSocketsMixin(ForwardMixin):
|
||||
class Mixin(ForwardMixin):
|
||||
@property
|
||||
@overrides(ForwardMixin)
|
||||
def type(self) -> str:
|
||||
@ -21,7 +26,7 @@ class WebSocketsMixin(ForwardMixin):
|
||||
|
||||
@overrides(ForwardMixin)
|
||||
async def request(self, setup: Request) -> Response:
|
||||
return await super(WebSocketsMixin, self).request(setup)
|
||||
return await super(Mixin, self).request(setup)
|
||||
|
||||
@overrides(ForwardMixin)
|
||||
async def websocket(self, setup: Request) -> "WebSocket":
|
||||
@ -75,4 +80,4 @@ class WebSocket(BaseWebSocket):
|
||||
await self.websocket.send(data)
|
||||
|
||||
|
||||
Driver = combine_driver(BlockDriver, WebSocketsMixin)
|
||||
Driver = combine_driver(BlockDriver, Mixin)
|
||||
|
Reference in New Issue
Block a user