cqhttp support forward websocket

This commit is contained in:
yanyongyu
2021-07-19 23:46:29 +08:00
parent 32787fdc1e
commit 04b3fda40c
5 changed files with 46 additions and 22 deletions

View File

@ -3,6 +3,7 @@ import sys
import hmac
import json
import asyncio
from urllib.parse import urlsplit
from typing import Any, Dict, Tuple, Union, Optional, TYPE_CHECKING
import httpx
@ -11,7 +12,8 @@ from nonebot.typing import overrides
from nonebot.message import handle_event
from nonebot.adapters import Bot as BaseBot
from nonebot.utils import escape_tag, DataclassEncoder
from nonebot.drivers import Driver, HTTPConnection, HTTPRequest, HTTPResponse, WebSocket
from nonebot.drivers import Driver, ForwardDriver, ReverseDriver
from nonebot.drivers import HTTPConnection, HTTPRequest, HTTPResponse, WebSocket
from .utils import log, escape
from .config import Config as CQHTTPConfig
@ -237,6 +239,29 @@ class Bot(BaseBot):
def register(cls, driver: Driver, config: "Config"):
super().register(driver, config)
cls.cqhttp_config = CQHTTPConfig(**config.dict())
if not isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls:
logger.warning(
f"Current driver {cls.config.driver} don't support forward connections"
)
elif isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls:
for self_id, url in cls.cqhttp_config.ws_urls.items():
try:
url_info = urlsplit(url)
headers = {
"authorization":
f"Bearer {cls.cqhttp_config.access_token}",
"host":
url_info.netloc if not url_info.port else
f"{url_info.netloc}:{url_info.port}",
}
driver.setup(
"cqhttp", self_id,
WebSocket("1.1", url_info.scheme, url_info.path,
url_info.query.encode("latin-1"), headers))
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Bad url {url} for bot {self_id} "
"in cqhttp forward websocket</bg></r>")
@classmethod
@overrides(BaseBot)

View File

@ -1,6 +1,6 @@
from typing import Optional
from typing import Dict, Optional
from pydantic import Field, BaseModel
from pydantic import Field, BaseModel, AnyUrl
# priority: alias > origin
@ -12,10 +12,13 @@ class Config(BaseModel):
- ``access_token`` / ``cqhttp_access_token``: CQHTTP 协议授权令牌
- ``secret`` / ``cqhttp_secret``: CQHTTP HTTP 上报数据签名口令
- ``ws_urls`` / ``cqhttp_ws_urls``: CQHTTP 正向 Websocket 连接 Bot ID、目标 URL 字典
"""
access_token: Optional[str] = Field(default=None,
alias="cqhttp_access_token")
secret: Optional[str] = Field(default=None, alias="cqhttp_secret")
ws_urls: Dict[str, AnyUrl] = Field(default_factory=set,
alias="cqhttp_ws_urls")
class Config:
extra = "ignore"