Feature: 细化 driver 职责类型 (#2296)

This commit is contained in:
Ju4tCode
2023-08-26 11:03:24 +08:00
committed by GitHub
parent 807a86371d
commit 2e635370bb
20 changed files with 632 additions and 284 deletions

View File

@ -1,8 +1,9 @@
from typing_extensions import override
from typing import Type, Union, Mapping, Iterable, Optional
from pydantic import Extra, create_model
from nonebot.adapters import Event, Message, MessageSegment
from nonebot.adapters import Bot, Event, Adapter, Message, MessageSegment
def escape_text(s: str, *, escape_comma: bool = True) -> str:
@ -12,11 +13,24 @@ def escape_text(s: str, *, escape_comma: bool = True) -> str:
return s
class FakeAdapter(Adapter):
@classmethod
@override
def get_name(cls) -> str:
return "fake"
@override
async def _call_api(self, bot: Bot, api: str, **data):
raise NotImplementedError
class FakeMessageSegment(MessageSegment["FakeMessage"]):
@classmethod
@override
def get_message_class(cls):
return FakeMessage
@override
def __str__(self) -> str:
return self.data["text"] if self.type == "text" else f"[fake:{self.type}]"
@ -32,16 +46,19 @@ class FakeMessageSegment(MessageSegment["FakeMessage"]):
def nested(content: "FakeMessage"):
return FakeMessageSegment("node", {"content": content})
@override
def is_text(self) -> bool:
return self.type == "text"
class FakeMessage(Message[FakeMessageSegment]):
@classmethod
@override
def get_segment_class(cls):
return FakeMessageSegment
@staticmethod
@override
def _construct(msg: Union[str, Iterable[Mapping]]):
if isinstance(msg, str):
yield FakeMessageSegment.text(msg)
@ -50,6 +67,7 @@ class FakeMessage(Message[FakeMessageSegment]):
yield FakeMessageSegment(**seg)
return
@override
def __add__(
self, other: Union[str, FakeMessageSegment, Iterable[FakeMessageSegment]]
):
@ -71,30 +89,37 @@ def make_fake_event(
Base = _base or Event
class FakeEvent(Base, extra=Extra.forbid):
@override
def get_type(self) -> str:
return _type
@override
def get_event_name(self) -> str:
return _name
@override
def get_event_description(self) -> str:
return _description
@override
def get_user_id(self) -> str:
if _user_id is not None:
return _user_id
raise NotImplementedError
@override
def get_session_id(self) -> str:
if _session_id is not None:
return _session_id
raise NotImplementedError
@override
def get_message(self) -> "Message":
if _message is not None:
return _message
raise NotImplementedError
@override
def is_tome(self) -> bool:
return _to_me