diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index 623d23e9..25928d87 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -14,10 +14,10 @@ from typing import Any, Dict, Union, TypeVar, Mapping, Optional, Callable, Itera from pydantic import BaseModel -from nonebot.config import Config from nonebot.utils import DataclassEncoder if TYPE_CHECKING: + from nonebot.config import Config from nonebot.drivers import Driver, WebSocket @@ -26,29 +26,26 @@ class Bot(abc.ABC): Bot 基类。用于处理上报消息,并提供 API 调用接口。 """ + driver: "Driver" + """Driver 对象""" + config: "Config" + """Config 配置对象""" + @abc.abstractmethod def __init__(self, - driver: "Driver", connection_type: str, - config: Config, self_id: str, *, websocket: Optional["WebSocket"] = None): """ :参数: - * ``driver: Driver``: Driver 对象 * ``connection_type: str``: http 或者 websocket - * ``config: Config``: Config 对象 * ``self_id: str``: 机器人 ID * ``websocket: Optional[WebSocket]``: Websocket 连接对象 """ - self.driver = driver - """Driver 对象""" self.connection_type = connection_type """连接类型""" - self.config = config - """Config 配置对象""" self.self_id = self_id """机器人 ID""" self.websocket = websocket @@ -63,6 +60,16 @@ class Bot(abc.ABC): """Adapter 类型""" raise NotImplementedError + @classmethod + def register(cls, driver: "Driver", config: "Config"): + """ + :说明: + + `register` 方法会在 `driver.register_adapter` 时被调用,用于初始化相关配置 + """ + cls.driver = driver + cls.config = config + @classmethod @abc.abstractmethod async def check_permission(cls, driver: "Driver", connection_type: str, diff --git a/nonebot/adapters/cqhttp/bot.py b/nonebot/adapters/cqhttp/bot.py index efc9a2d8..62fb4aad 100644 --- a/nonebot/adapters/cqhttp/bot.py +++ b/nonebot/adapters/cqhttp/bot.py @@ -7,7 +7,6 @@ from typing import Any, Dict, Union, Optional, TYPE_CHECKING import httpx from nonebot.log import logger -from nonebot.config import Config from nonebot.typing import overrides from nonebot.message import handle_event from nonebot.adapters import Bot as BaseBot @@ -20,6 +19,7 @@ from .event import Reply, Event, MessageEvent, get_event_model from .exception import NetworkError, ApiNotAvailable, ActionFailed if TYPE_CHECKING: + from nonebot.config import Config from nonebot.drivers import Driver, WebSocket @@ -218,22 +218,15 @@ class Bot(BaseBot): """ CQHTTP 协议 Bot 适配。继承属性参考 `BaseBot <./#class-basebot>`_ 。 """ + cqhttp_config: CQHTTPConfig def __init__(self, - driver: "Driver", connection_type: str, - config: Config, self_id: str, *, websocket: Optional["WebSocket"] = None): - self.cqhttp_config = CQHTTPConfig(**config.dict()) - - super().__init__(driver, - connection_type, - config, - self_id, - websocket=websocket) + super().__init__(connection_type, self_id, websocket=websocket) @property @overrides(BaseBot) @@ -243,6 +236,11 @@ class Bot(BaseBot): """ return "cqhttp" + @classmethod + def register(cls, driver: "Driver", config: "Config"): + super().register(driver, config) + cls.cqhttp_config = CQHTTPConfig(**config.dict()) + @classmethod @overrides(BaseBot) async def check_permission(cls, driver: "Driver", connection_type: str, @@ -268,7 +266,7 @@ class Bot(BaseBot): raise RequestDenied(400, "Missing X-Self-ID Header") # 检查签名 - secret = cqhttp_config.cqhttp_secret + secret = cqhttp_config.secret if secret and connection_type == "http": if not x_signature: log("WARNING", "Missing Signature Header") @@ -279,7 +277,7 @@ class Bot(BaseBot): log("WARNING", "Signature Header is invalid") raise RequestDenied(403, "Signature is invalid") - access_token = cqhttp_config.cqhttp_access_token + access_token = cqhttp_config.access_token if access_token and access_token != token: log( "WARNING", "Authorization Header is invalid" @@ -378,9 +376,9 @@ class Bot(BaseBot): api_root += "/" headers = {} - if self.cqhttp_config.cqhttp_access_token is not None: + if self.cqhttp_config.access_token is not None: headers[ - "Authorization"] = "Bearer " + self.cqhttp_config.cqhttp_access_token + "Authorization"] = "Bearer " + self.cqhttp_config.access_token try: async with httpx.AsyncClient(headers=headers) as client: diff --git a/nonebot/adapters/cqhttp/config.py b/nonebot/adapters/cqhttp/config.py index ff30f172..c537170a 100644 --- a/nonebot/adapters/cqhttp/config.py +++ b/nonebot/adapters/cqhttp/config.py @@ -1,12 +1,13 @@ from typing import Optional -from pydantic import Field, BaseSettings +from pydantic import Field, BaseModel -class Config(BaseSettings): - cqhttp_access_token: Optional[str] = Field(default=None, - alias="access_token") - cqhttp_secret: Optional[str] = Field(default=None, alias="secret") +# priority: alias > origin +class Config(BaseModel): + access_token: Optional[str] = Field(default=None, + alias="cqhttp_access_token") + secret: Optional[str] = Field(default=None, alias="cqhttp_secret") class Config: extra = "ignore" diff --git a/nonebot/adapters/ding/bot.py b/nonebot/adapters/ding/bot.py index e46febc5..0b87f44c 100644 --- a/nonebot/adapters/ding/bot.py +++ b/nonebot/adapters/ding/bot.py @@ -5,18 +5,19 @@ from typing import Any, Union, Optional, TYPE_CHECKING import httpx from nonebot.log import logger -from nonebot.config import Config from nonebot.typing import overrides from nonebot.message import handle_event from nonebot.adapters import Bot as BaseBot from nonebot.exception import RequestDenied from .utils import log +from .config import Config as DingConfig from .message import Message, MessageSegment from .exception import NetworkError, ApiNotAvailable, ActionFailed, SessionExpired -from .event import Event, MessageEvent, PrivateMessageEvent, GroupMessageEvent, ConversationType +from .event import MessageEvent, PrivateMessageEvent, GroupMessageEvent, ConversationType if TYPE_CHECKING: + from nonebot.config import Config from nonebot.drivers import Driver SEND_BY_SESSION_WEBHOOK = "send_by_sessionWebhook" @@ -26,11 +27,11 @@ class Bot(BaseBot): """ 钉钉 协议 Bot 适配。继承属性参考 `BaseBot <./#class-basebot>`_ 。 """ + ding_config: DingConfig - def __init__(self, driver: "Driver", connection_type: str, config: Config, - self_id: str, **kwargs): + def __init__(self, connection_type: str, self_id: str, **kwargs): - super().__init__(driver, connection_type, config, self_id, **kwargs) + super().__init__(connection_type, self_id, **kwargs) @property def type(self) -> str: @@ -39,6 +40,11 @@ class Bot(BaseBot): """ return "ding" + @classmethod + def register(cls, driver: "Driver", config: "Config"): + super().register(driver, config) + cls.ding_config = DingConfig(**config.dict()) + @classmethod @overrides(BaseBot) async def check_permission(cls, driver: "Driver", connection_type: str, @@ -61,7 +67,7 @@ class Bot(BaseBot): raise RequestDenied(400, "Missing `timestamp` Header") # 检查 sign - secret = driver.config.secret + secret = cls.ding_config.secret if secret: if not sign: log("WARNING", "Missing Signature Header") @@ -156,7 +162,7 @@ class Bot(BaseBot): async with httpx.AsyncClient(headers=headers) as client: response = await client.post( target, - params={"access_token": self.config.access_token}, + params={"access_token": self.ding_config.access_token}, json=message._produce(), timeout=self.config.api_timeout) diff --git a/nonebot/adapters/ding/config.py b/nonebot/adapters/ding/config.py new file mode 100644 index 00000000..5b334a94 --- /dev/null +++ b/nonebot/adapters/ding/config.py @@ -0,0 +1,11 @@ +from typing import Optional + +from pydantic import Field, BaseModel + + +class Config(BaseModel): + secret: Optional[str] = Field(default=None, alias="ding_secret") + access_token: Optional[str] = Field(default=None, alias="ding_access_token") + + class Config: + extra = "ignore" diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index 7e95ee91..986d59a3 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -62,8 +62,7 @@ class Driver(abc.ABC): :说明: 已连接的 Bot """ - @classmethod - def register_adapter(cls, name: str, adapter: Type["Bot"]): + def register_adapter(self, name: str, adapter: Type["Bot"]): """ :说明: @@ -74,7 +73,8 @@ class Driver(abc.ABC): * ``name: str``: 适配器名称,用于在连接时进行识别 * ``adapter: Type[Bot]``: 适配器 Class """ - cls._adapters[name] = adapter + self._adapters[name] = adapter + adapter.register(self, self.config) logger.opt( colors=True).debug(f'Succeeded to load adapter "{name}"')