mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-16 19:11:00 +00:00
✨ Feature: 细化 driver 职责类型 (#2296)
This commit is contained in:
@ -8,8 +8,10 @@ from nonebug import NONEBOT_INIT_KWARGS
|
||||
from werkzeug.serving import BaseWSGIServer, make_server
|
||||
|
||||
import nonebot
|
||||
from nonebot.drivers import URL
|
||||
from nonebot.config import Env
|
||||
from fake_server import request_handler
|
||||
from nonebot.drivers import URL, Driver
|
||||
from nonebot import _resolve_combine_expr
|
||||
|
||||
os.environ["CONFIG_FROM_ENV"] = '{"test": "test"}'
|
||||
os.environ["CONFIG_OVERRIDE"] = "new"
|
||||
@ -22,6 +24,17 @@ def pytest_configure(config: pytest.Config) -> None:
|
||||
config.stash[NONEBOT_INIT_KWARGS] = {"config_from_init": "init"}
|
||||
|
||||
|
||||
@pytest.fixture(name="driver")
|
||||
def load_driver(request: pytest.FixtureRequest) -> Driver:
|
||||
driver_name = getattr(request, "param", None)
|
||||
global_driver = nonebot.get_driver()
|
||||
if driver_name is None:
|
||||
return global_driver
|
||||
|
||||
DriverClass = _resolve_combine_expr(driver_name)
|
||||
return DriverClass(Env(environment=global_driver.env), global_driver.config)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def load_plugin(nonebug_init: None) -> Set["Plugin"]:
|
||||
# preload global plugins
|
||||
|
211
tests/test_adapters/test_adapter.py
Normal file
211
tests/test_adapters/test_adapter.py
Normal file
@ -0,0 +1,211 @@
|
||||
from typing import Optional
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import pytest
|
||||
from nonebug import App
|
||||
|
||||
from utils import FakeAdapter
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.drivers import (
|
||||
URL,
|
||||
Driver,
|
||||
Request,
|
||||
Response,
|
||||
WebSocket,
|
||||
HTTPServerSetup,
|
||||
WebSocketServerSetup,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapter_connect(app: App, driver: Driver):
|
||||
last_connect_bot: Optional[Bot] = None
|
||||
last_disconnect_bot: Optional[Bot] = None
|
||||
|
||||
def _fake_bot_connect(bot: Bot):
|
||||
nonlocal last_connect_bot
|
||||
last_connect_bot = bot
|
||||
|
||||
def _fake_bot_disconnect(bot: Bot):
|
||||
nonlocal last_disconnect_bot
|
||||
last_disconnect_bot = bot
|
||||
|
||||
with pytest.MonkeyPatch.context() as m:
|
||||
m.setattr(driver, "_bot_connect", _fake_bot_connect)
|
||||
m.setattr(driver, "_bot_disconnect", _fake_bot_disconnect)
|
||||
|
||||
adapter = FakeAdapter(driver)
|
||||
|
||||
async with app.test_api() as ctx:
|
||||
bot = ctx.create_bot(adapter=adapter)
|
||||
assert last_connect_bot is bot
|
||||
assert adapter.bots[bot.self_id] is bot
|
||||
|
||||
assert last_disconnect_bot is bot
|
||||
assert bot.self_id not in adapter.bots
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"driver",
|
||||
[
|
||||
pytest.param("nonebot.drivers.fastapi:Driver", id="fastapi"),
|
||||
pytest.param("nonebot.drivers.quart:Driver", id="quart"),
|
||||
pytest.param(
|
||||
"nonebot.drivers.httpx:Driver",
|
||||
id="httpx",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="not a server", raises=TypeError, strict=True
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
"nonebot.drivers.websockets:Driver",
|
||||
id="websockets",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="not a server", raises=TypeError, strict=True
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
"nonebot.drivers.aiohttp:Driver",
|
||||
id="aiohttp",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="not a server", raises=TypeError, strict=True
|
||||
),
|
||||
),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_adapter_server(driver: Driver):
|
||||
last_http_setup: Optional[HTTPServerSetup] = None
|
||||
last_ws_setup: Optional[WebSocketServerSetup] = None
|
||||
|
||||
def _fake_setup_http_server(setup: HTTPServerSetup):
|
||||
nonlocal last_http_setup
|
||||
last_http_setup = setup
|
||||
|
||||
def _fake_setup_websocket_server(setup: WebSocketServerSetup):
|
||||
nonlocal last_ws_setup
|
||||
last_ws_setup = setup
|
||||
|
||||
with pytest.MonkeyPatch.context() as m:
|
||||
m.setattr(driver, "setup_http_server", _fake_setup_http_server, raising=False)
|
||||
m.setattr(
|
||||
driver,
|
||||
"setup_websocket_server",
|
||||
_fake_setup_websocket_server,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
async def handle_http(request: Request):
|
||||
return Response(200, content="test")
|
||||
|
||||
async def handle_ws(ws: WebSocket):
|
||||
...
|
||||
|
||||
adapter = FakeAdapter(driver)
|
||||
|
||||
setup = HTTPServerSetup(URL("/test"), "GET", "test", handle_http)
|
||||
adapter.setup_http_server(setup)
|
||||
assert last_http_setup is setup
|
||||
|
||||
setup = WebSocketServerSetup(URL("/test"), "test", handle_ws)
|
||||
adapter.setup_websocket_server(setup)
|
||||
assert last_ws_setup is setup
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"driver",
|
||||
[
|
||||
pytest.param(
|
||||
"nonebot.drivers.fastapi:Driver",
|
||||
id="fastapi",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="not a http client", raises=TypeError, strict=True
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
"nonebot.drivers.quart:Driver",
|
||||
id="quart",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="not a http client", raises=TypeError, strict=True
|
||||
),
|
||||
),
|
||||
pytest.param("nonebot.drivers.httpx:Driver", id="httpx"),
|
||||
pytest.param(
|
||||
"nonebot.drivers.websockets:Driver",
|
||||
id="websockets",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="not a http client", raises=TypeError, strict=True
|
||||
),
|
||||
),
|
||||
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_adapter_http_client(driver: Driver):
|
||||
last_request: Optional[Request] = None
|
||||
|
||||
async def _fake_request(request: Request):
|
||||
nonlocal last_request
|
||||
last_request = request
|
||||
|
||||
with pytest.MonkeyPatch.context() as m:
|
||||
m.setattr(driver, "request", _fake_request, raising=False)
|
||||
|
||||
adapter = FakeAdapter(driver)
|
||||
|
||||
request = Request("GET", URL("/test"))
|
||||
await adapter.request(request)
|
||||
assert last_request is request
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"driver",
|
||||
[
|
||||
pytest.param(
|
||||
"nonebot.drivers.fastapi:Driver",
|
||||
id="fastapi",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="not a websocket client", raises=TypeError, strict=True
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
"nonebot.drivers.quart:Driver",
|
||||
id="quart",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="not a websocket client", raises=TypeError, strict=True
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
"nonebot.drivers.httpx:Driver",
|
||||
id="httpx",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="not a websocket client", raises=TypeError, strict=True
|
||||
),
|
||||
),
|
||||
pytest.param("nonebot.drivers.websockets:Driver", id="websockets"),
|
||||
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_adapter_websocket_client(driver: Driver):
|
||||
_fake_ws = object()
|
||||
_last_request: Optional[Request] = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def _fake_websocket(setup: Request):
|
||||
nonlocal _last_request
|
||||
_last_request = setup
|
||||
yield _fake_ws
|
||||
|
||||
with pytest.MonkeyPatch.context() as m:
|
||||
m.setattr(driver, "websocket", _fake_websocket, raising=False)
|
||||
|
||||
adapter = FakeAdapter(driver)
|
||||
|
||||
request = Request("GET", URL("/test"))
|
||||
async with adapter.websocket(request) as ws:
|
||||
assert _last_request is request
|
||||
assert ws is _fake_ws
|
@ -1,15 +1,12 @@
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Any, Set, Optional, cast
|
||||
from typing import Any, Set, Optional
|
||||
|
||||
import pytest
|
||||
from nonebug import App
|
||||
|
||||
import nonebot
|
||||
from nonebot.config import Env
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.params import Depends
|
||||
from nonebot import _resolve_combine_expr
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.exception import WebSocketClosed
|
||||
from nonebot.drivers._lifespan import Lifespan
|
||||
@ -18,25 +15,15 @@ from nonebot.drivers import (
|
||||
Driver,
|
||||
Request,
|
||||
Response,
|
||||
ASGIMixin,
|
||||
WebSocket,
|
||||
ForwardDriver,
|
||||
ReverseDriver,
|
||||
HTTPClientMixin,
|
||||
HTTPServerSetup,
|
||||
WebSocketClientMixin,
|
||||
WebSocketServerSetup,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="driver")
|
||||
def load_driver(request: pytest.FixtureRequest) -> Driver:
|
||||
driver_name = getattr(request, "param", None)
|
||||
global_driver = nonebot.get_driver()
|
||||
if driver_name is None:
|
||||
return global_driver
|
||||
|
||||
DriverClass = _resolve_combine_expr(driver_name)
|
||||
return DriverClass(Env(environment=global_driver.env), global_driver.config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan():
|
||||
lifespan = Lifespan()
|
||||
@ -80,7 +67,7 @@ async def test_lifespan():
|
||||
indirect=True,
|
||||
)
|
||||
async def test_http_server(app: App, driver: Driver):
|
||||
driver = cast(ReverseDriver, driver)
|
||||
assert isinstance(driver, ASGIMixin)
|
||||
|
||||
async def _handle_http(request: Request) -> Response:
|
||||
assert request.content in (b"test", "test")
|
||||
@ -108,7 +95,7 @@ async def test_http_server(app: App, driver: Driver):
|
||||
indirect=True,
|
||||
)
|
||||
async def test_websocket_server(app: App, driver: Driver):
|
||||
driver = cast(ReverseDriver, driver)
|
||||
assert isinstance(driver, ASGIMixin)
|
||||
|
||||
async def _handle_ws(ws: WebSocket) -> None:
|
||||
await ws.accept()
|
||||
@ -164,7 +151,7 @@ async def test_websocket_server(app: App, driver: Driver):
|
||||
indirect=True,
|
||||
)
|
||||
async def test_cross_context(app: App, driver: Driver):
|
||||
driver = cast(ReverseDriver, driver)
|
||||
assert isinstance(driver, ASGIMixin)
|
||||
|
||||
ws: Optional[WebSocket] = None
|
||||
ws_ready = asyncio.Event()
|
||||
@ -221,7 +208,7 @@ async def test_cross_context(app: App, driver: Driver):
|
||||
indirect=True,
|
||||
)
|
||||
async def test_http_client(driver: Driver, server_url: URL):
|
||||
driver = cast(ForwardDriver, driver)
|
||||
assert isinstance(driver, HTTPClientMixin)
|
||||
|
||||
# simple post with query, headers, cookies and content
|
||||
request = Request(
|
||||
@ -303,6 +290,19 @@ async def test_http_client(driver: Driver, server_url: URL):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"driver",
|
||||
[
|
||||
pytest.param("nonebot.drivers.websockets:Driver", id="websockets"),
|
||||
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_websocket_client(driver: Driver):
|
||||
assert isinstance(driver, WebSocketClientMixin)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("driver", "driver_type"),
|
||||
|
@ -2,7 +2,7 @@ import pytest
|
||||
from nonebug import App
|
||||
|
||||
import nonebot
|
||||
from nonebot.drivers import Driver, ReverseDriver
|
||||
from nonebot.drivers import Driver, ASGIMixin, ReverseDriver
|
||||
from nonebot import (
|
||||
get_app,
|
||||
get_bot,
|
||||
@ -47,6 +47,7 @@ async def test_get_driver(app: App, monkeypatch: pytest.MonkeyPatch):
|
||||
async def test_get_asgi(app: App, monkeypatch: pytest.MonkeyPatch):
|
||||
driver = get_driver()
|
||||
assert isinstance(driver, ReverseDriver)
|
||||
assert isinstance(driver, ASGIMixin)
|
||||
assert get_asgi() == driver.asgi
|
||||
|
||||
|
||||
@ -54,6 +55,7 @@ async def test_get_asgi(app: App, monkeypatch: pytest.MonkeyPatch):
|
||||
async def test_get_app(app: App, monkeypatch: pytest.MonkeyPatch):
|
||||
driver = get_driver()
|
||||
assert isinstance(driver, ReverseDriver)
|
||||
assert isinstance(driver, ASGIMixin)
|
||||
assert get_app() == driver.server_app
|
||||
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
from typing_extensions import override
|
||||
from typing import Type, Union, Mapping, Iterable, Optional
|
||||
|
||||
from pydantic import Extra, create_model
|
||||
|
||||
from nonebot.adapters import Event, Message, MessageSegment
|
||||
from nonebot.adapters import Bot, Event, Adapter, Message, MessageSegment
|
||||
|
||||
|
||||
def escape_text(s: str, *, escape_comma: bool = True) -> str:
|
||||
@ -12,11 +13,24 @@ def escape_text(s: str, *, escape_comma: bool = True) -> str:
|
||||
return s
|
||||
|
||||
|
||||
class FakeAdapter(Adapter):
|
||||
@classmethod
|
||||
@override
|
||||
def get_name(cls) -> str:
|
||||
return "fake"
|
||||
|
||||
@override
|
||||
async def _call_api(self, bot: Bot, api: str, **data):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FakeMessageSegment(MessageSegment["FakeMessage"]):
|
||||
@classmethod
|
||||
@override
|
||||
def get_message_class(cls):
|
||||
return FakeMessage
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return self.data["text"] if self.type == "text" else f"[fake:{self.type}]"
|
||||
|
||||
@ -32,16 +46,19 @@ class FakeMessageSegment(MessageSegment["FakeMessage"]):
|
||||
def nested(content: "FakeMessage"):
|
||||
return FakeMessageSegment("node", {"content": content})
|
||||
|
||||
@override
|
||||
def is_text(self) -> bool:
|
||||
return self.type == "text"
|
||||
|
||||
|
||||
class FakeMessage(Message[FakeMessageSegment]):
|
||||
@classmethod
|
||||
@override
|
||||
def get_segment_class(cls):
|
||||
return FakeMessageSegment
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def _construct(msg: Union[str, Iterable[Mapping]]):
|
||||
if isinstance(msg, str):
|
||||
yield FakeMessageSegment.text(msg)
|
||||
@ -50,6 +67,7 @@ class FakeMessage(Message[FakeMessageSegment]):
|
||||
yield FakeMessageSegment(**seg)
|
||||
return
|
||||
|
||||
@override
|
||||
def __add__(
|
||||
self, other: Union[str, FakeMessageSegment, Iterable[FakeMessageSegment]]
|
||||
):
|
||||
@ -71,30 +89,37 @@ def make_fake_event(
|
||||
Base = _base or Event
|
||||
|
||||
class FakeEvent(Base, extra=Extra.forbid):
|
||||
@override
|
||||
def get_type(self) -> str:
|
||||
return _type
|
||||
|
||||
@override
|
||||
def get_event_name(self) -> str:
|
||||
return _name
|
||||
|
||||
@override
|
||||
def get_event_description(self) -> str:
|
||||
return _description
|
||||
|
||||
@override
|
||||
def get_user_id(self) -> str:
|
||||
if _user_id is not None:
|
||||
return _user_id
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def get_session_id(self) -> str:
|
||||
if _session_id is not None:
|
||||
return _session_id
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def get_message(self) -> "Message":
|
||||
if _message is not None:
|
||||
return _message
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def is_tome(self) -> bool:
|
||||
return _to_me
|
||||
|
||||
|
Reference in New Issue
Block a user