Feature: 细化 driver 职责类型 (#2296)

This commit is contained in:
Ju4tCode
2023-08-26 11:03:24 +08:00
committed by GitHub
parent 807a86371d
commit 2e635370bb
20 changed files with 632 additions and 284 deletions

View File

@ -8,8 +8,10 @@ from nonebug import NONEBOT_INIT_KWARGS
from werkzeug.serving import BaseWSGIServer, make_server
import nonebot
from nonebot.drivers import URL
from nonebot.config import Env
from fake_server import request_handler
from nonebot.drivers import URL, Driver
from nonebot import _resolve_combine_expr
os.environ["CONFIG_FROM_ENV"] = '{"test": "test"}'
os.environ["CONFIG_OVERRIDE"] = "new"
@ -22,6 +24,17 @@ def pytest_configure(config: pytest.Config) -> None:
config.stash[NONEBOT_INIT_KWARGS] = {"config_from_init": "init"}
@pytest.fixture(name="driver")
def load_driver(request: pytest.FixtureRequest) -> Driver:
driver_name = getattr(request, "param", None)
global_driver = nonebot.get_driver()
if driver_name is None:
return global_driver
DriverClass = _resolve_combine_expr(driver_name)
return DriverClass(Env(environment=global_driver.env), global_driver.config)
@pytest.fixture(scope="session", autouse=True)
def load_plugin(nonebug_init: None) -> Set["Plugin"]:
# preload global plugins

View File

@ -0,0 +1,211 @@
from typing import Optional
from contextlib import asynccontextmanager
import pytest
from nonebug import App
from utils import FakeAdapter
from nonebot.adapters import Bot
from nonebot.drivers import (
URL,
Driver,
Request,
Response,
WebSocket,
HTTPServerSetup,
WebSocketServerSetup,
)
@pytest.mark.asyncio
async def test_adapter_connect(app: App, driver: Driver):
last_connect_bot: Optional[Bot] = None
last_disconnect_bot: Optional[Bot] = None
def _fake_bot_connect(bot: Bot):
nonlocal last_connect_bot
last_connect_bot = bot
def _fake_bot_disconnect(bot: Bot):
nonlocal last_disconnect_bot
last_disconnect_bot = bot
with pytest.MonkeyPatch.context() as m:
m.setattr(driver, "_bot_connect", _fake_bot_connect)
m.setattr(driver, "_bot_disconnect", _fake_bot_disconnect)
adapter = FakeAdapter(driver)
async with app.test_api() as ctx:
bot = ctx.create_bot(adapter=adapter)
assert last_connect_bot is bot
assert adapter.bots[bot.self_id] is bot
assert last_disconnect_bot is bot
assert bot.self_id not in adapter.bots
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.fastapi:Driver", id="fastapi"),
pytest.param("nonebot.drivers.quart:Driver", id="quart"),
pytest.param(
"nonebot.drivers.httpx:Driver",
id="httpx",
marks=pytest.mark.xfail(
reason="not a server", raises=TypeError, strict=True
),
),
pytest.param(
"nonebot.drivers.websockets:Driver",
id="websockets",
marks=pytest.mark.xfail(
reason="not a server", raises=TypeError, strict=True
),
),
pytest.param(
"nonebot.drivers.aiohttp:Driver",
id="aiohttp",
marks=pytest.mark.xfail(
reason="not a server", raises=TypeError, strict=True
),
),
],
indirect=True,
)
async def test_adapter_server(driver: Driver):
last_http_setup: Optional[HTTPServerSetup] = None
last_ws_setup: Optional[WebSocketServerSetup] = None
def _fake_setup_http_server(setup: HTTPServerSetup):
nonlocal last_http_setup
last_http_setup = setup
def _fake_setup_websocket_server(setup: WebSocketServerSetup):
nonlocal last_ws_setup
last_ws_setup = setup
with pytest.MonkeyPatch.context() as m:
m.setattr(driver, "setup_http_server", _fake_setup_http_server, raising=False)
m.setattr(
driver,
"setup_websocket_server",
_fake_setup_websocket_server,
raising=False,
)
async def handle_http(request: Request):
return Response(200, content="test")
async def handle_ws(ws: WebSocket):
...
adapter = FakeAdapter(driver)
setup = HTTPServerSetup(URL("/test"), "GET", "test", handle_http)
adapter.setup_http_server(setup)
assert last_http_setup is setup
setup = WebSocketServerSetup(URL("/test"), "test", handle_ws)
adapter.setup_websocket_server(setup)
assert last_ws_setup is setup
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",
[
pytest.param(
"nonebot.drivers.fastapi:Driver",
id="fastapi",
marks=pytest.mark.xfail(
reason="not a http client", raises=TypeError, strict=True
),
),
pytest.param(
"nonebot.drivers.quart:Driver",
id="quart",
marks=pytest.mark.xfail(
reason="not a http client", raises=TypeError, strict=True
),
),
pytest.param("nonebot.drivers.httpx:Driver", id="httpx"),
pytest.param(
"nonebot.drivers.websockets:Driver",
id="websockets",
marks=pytest.mark.xfail(
reason="not a http client", raises=TypeError, strict=True
),
),
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
],
indirect=True,
)
async def test_adapter_http_client(driver: Driver):
last_request: Optional[Request] = None
async def _fake_request(request: Request):
nonlocal last_request
last_request = request
with pytest.MonkeyPatch.context() as m:
m.setattr(driver, "request", _fake_request, raising=False)
adapter = FakeAdapter(driver)
request = Request("GET", URL("/test"))
await adapter.request(request)
assert last_request is request
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",
[
pytest.param(
"nonebot.drivers.fastapi:Driver",
id="fastapi",
marks=pytest.mark.xfail(
reason="not a websocket client", raises=TypeError, strict=True
),
),
pytest.param(
"nonebot.drivers.quart:Driver",
id="quart",
marks=pytest.mark.xfail(
reason="not a websocket client", raises=TypeError, strict=True
),
),
pytest.param(
"nonebot.drivers.httpx:Driver",
id="httpx",
marks=pytest.mark.xfail(
reason="not a websocket client", raises=TypeError, strict=True
),
),
pytest.param("nonebot.drivers.websockets:Driver", id="websockets"),
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
],
indirect=True,
)
async def test_adapter_websocket_client(driver: Driver):
_fake_ws = object()
_last_request: Optional[Request] = None
@asynccontextmanager
async def _fake_websocket(setup: Request):
nonlocal _last_request
_last_request = setup
yield _fake_ws
with pytest.MonkeyPatch.context() as m:
m.setattr(driver, "websocket", _fake_websocket, raising=False)
adapter = FakeAdapter(driver)
request = Request("GET", URL("/test"))
async with adapter.websocket(request) as ws:
assert _last_request is request
assert ws is _fake_ws

View File

@ -1,15 +1,12 @@
import json
import asyncio
from typing import Any, Set, Optional, cast
from typing import Any, Set, Optional
import pytest
from nonebug import App
import nonebot
from nonebot.config import Env
from nonebot.adapters import Bot
from nonebot.params import Depends
from nonebot import _resolve_combine_expr
from nonebot.dependencies import Dependent
from nonebot.exception import WebSocketClosed
from nonebot.drivers._lifespan import Lifespan
@ -18,25 +15,15 @@ from nonebot.drivers import (
Driver,
Request,
Response,
ASGIMixin,
WebSocket,
ForwardDriver,
ReverseDriver,
HTTPClientMixin,
HTTPServerSetup,
WebSocketClientMixin,
WebSocketServerSetup,
)
@pytest.fixture(name="driver")
def load_driver(request: pytest.FixtureRequest) -> Driver:
driver_name = getattr(request, "param", None)
global_driver = nonebot.get_driver()
if driver_name is None:
return global_driver
DriverClass = _resolve_combine_expr(driver_name)
return DriverClass(Env(environment=global_driver.env), global_driver.config)
@pytest.mark.asyncio
async def test_lifespan():
lifespan = Lifespan()
@ -80,7 +67,7 @@ async def test_lifespan():
indirect=True,
)
async def test_http_server(app: App, driver: Driver):
driver = cast(ReverseDriver, driver)
assert isinstance(driver, ASGIMixin)
async def _handle_http(request: Request) -> Response:
assert request.content in (b"test", "test")
@ -108,7 +95,7 @@ async def test_http_server(app: App, driver: Driver):
indirect=True,
)
async def test_websocket_server(app: App, driver: Driver):
driver = cast(ReverseDriver, driver)
assert isinstance(driver, ASGIMixin)
async def _handle_ws(ws: WebSocket) -> None:
await ws.accept()
@ -164,7 +151,7 @@ async def test_websocket_server(app: App, driver: Driver):
indirect=True,
)
async def test_cross_context(app: App, driver: Driver):
driver = cast(ReverseDriver, driver)
assert isinstance(driver, ASGIMixin)
ws: Optional[WebSocket] = None
ws_ready = asyncio.Event()
@ -221,7 +208,7 @@ async def test_cross_context(app: App, driver: Driver):
indirect=True,
)
async def test_http_client(driver: Driver, server_url: URL):
driver = cast(ForwardDriver, driver)
assert isinstance(driver, HTTPClientMixin)
# simple post with query, headers, cookies and content
request = Request(
@ -303,6 +290,19 @@ async def test_http_client(driver: Driver, server_url: URL):
await asyncio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.websockets:Driver", id="websockets"),
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
],
indirect=True,
)
async def test_websocket_client(driver: Driver):
assert isinstance(driver, WebSocketClientMixin)
@pytest.mark.asyncio
@pytest.mark.parametrize(
("driver", "driver_type"),

View File

@ -2,7 +2,7 @@ import pytest
from nonebug import App
import nonebot
from nonebot.drivers import Driver, ReverseDriver
from nonebot.drivers import Driver, ASGIMixin, ReverseDriver
from nonebot import (
get_app,
get_bot,
@ -47,6 +47,7 @@ async def test_get_driver(app: App, monkeypatch: pytest.MonkeyPatch):
async def test_get_asgi(app: App, monkeypatch: pytest.MonkeyPatch):
driver = get_driver()
assert isinstance(driver, ReverseDriver)
assert isinstance(driver, ASGIMixin)
assert get_asgi() == driver.asgi
@ -54,6 +55,7 @@ async def test_get_asgi(app: App, monkeypatch: pytest.MonkeyPatch):
async def test_get_app(app: App, monkeypatch: pytest.MonkeyPatch):
driver = get_driver()
assert isinstance(driver, ReverseDriver)
assert isinstance(driver, ASGIMixin)
assert get_app() == driver.server_app

View File

@ -1,8 +1,9 @@
from typing_extensions import override
from typing import Type, Union, Mapping, Iterable, Optional
from pydantic import Extra, create_model
from nonebot.adapters import Event, Message, MessageSegment
from nonebot.adapters import Bot, Event, Adapter, Message, MessageSegment
def escape_text(s: str, *, escape_comma: bool = True) -> str:
@ -12,11 +13,24 @@ def escape_text(s: str, *, escape_comma: bool = True) -> str:
return s
class FakeAdapter(Adapter):
@classmethod
@override
def get_name(cls) -> str:
return "fake"
@override
async def _call_api(self, bot: Bot, api: str, **data):
raise NotImplementedError
class FakeMessageSegment(MessageSegment["FakeMessage"]):
@classmethod
@override
def get_message_class(cls):
return FakeMessage
@override
def __str__(self) -> str:
return self.data["text"] if self.type == "text" else f"[fake:{self.type}]"
@ -32,16 +46,19 @@ class FakeMessageSegment(MessageSegment["FakeMessage"]):
def nested(content: "FakeMessage"):
return FakeMessageSegment("node", {"content": content})
@override
def is_text(self) -> bool:
return self.type == "text"
class FakeMessage(Message[FakeMessageSegment]):
@classmethod
@override
def get_segment_class(cls):
return FakeMessageSegment
@staticmethod
@override
def _construct(msg: Union[str, Iterable[Mapping]]):
if isinstance(msg, str):
yield FakeMessageSegment.text(msg)
@ -50,6 +67,7 @@ class FakeMessage(Message[FakeMessageSegment]):
yield FakeMessageSegment(**seg)
return
@override
def __add__(
self, other: Union[str, FakeMessageSegment, Iterable[FakeMessageSegment]]
):
@ -71,30 +89,37 @@ def make_fake_event(
Base = _base or Event
class FakeEvent(Base, extra=Extra.forbid):
@override
def get_type(self) -> str:
return _type
@override
def get_event_name(self) -> str:
return _name
@override
def get_event_description(self) -> str:
return _description
@override
def get_user_id(self) -> str:
if _user_id is not None:
return _user_id
raise NotImplementedError
@override
def get_session_id(self) -> str:
if _session_id is not None:
return _session_id
raise NotImplementedError
@override
def get_message(self) -> "Message":
if _message is not None:
return _message
raise NotImplementedError
@override
def is_tome(self) -> bool:
return _to_me